Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

迭代器设计 #255

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions internal/iterator/iterator.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
// Copyright 2021 ecodeclub
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package iterator

// 一个迭代器的接口,所有的容器类型都可以实现自己的迭代器
// 只需要继承当前接口即可
type Iterator[T any] interface {
dxyinme marked this conversation as resolved.
Show resolved Hide resolved
// 迭代器移动到下一个节点,如果成功的话就返回true
// 如果没有下一个节点,则迭代器所指向的位置会变为非法,一般为nil,并且返回false
Next() bool

// 获取迭代器当前所指向的节点的信息
Get() T

// 获取error
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

获取 error。当你在执行 Get 或者 Delete 之后,如果有必要,可以通过该方法来检测是否运作正常。

Err() error

// 判断是否有后继节点
HasNext() bool

// 判断当前节点是否合法
flycash marked this conversation as resolved.
Show resolved Hide resolved
Valid() bool

// 删除当前节点
Delete()
}
153 changes: 135 additions & 18 deletions internal/tree/red_black_tree.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,11 @@ package tree

import (
"errors"
"fmt"

"github.com/ecodeclub/ekit"
"github.com/ecodeclub/ekit/internal/iterator"
"github.com/ecodeclub/ekit/tuple/pair"
)

type color bool
Expand All @@ -31,6 +34,9 @@ var (
ErrRBTreeSameRBNode = errors.New("ekit: RBTree不能添加重复节点Key")
ErrRBTreeNotRBNode = errors.New("ekit: RBTree不存在节点Key")
// errRBTreeCantRepaceNil = errors.New("ekit: RBTree不能将节点替换为nil")
ErrRBTreeIteratorNoNext = errors.New("ekit: RBTree Iterator没有后继节点")
ErrRBTreeIteratorNodeNil = errors.New("ekit: RBTree Iterator指向的节点为nil")
ErrRBTreeIteratorInvalid = errors.New("ekit: RBTree Iterator不合法")
)

type RBTree[K any, V any] struct {
Expand All @@ -53,6 +59,26 @@ type rbNode[K any, V any] struct {
left, right, parent *rbNode[K, V]
}

func (node *rbNode[K, V]) getNext() *rbNode[K, V] {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不要定义在这里。你想,如果我们是广度优先遍历,那么 next 的概念是完全不一样的。你应该把这个挪过去 Iterator 的实现里面。

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我没理解错的话,这是一个中序遍历的吧?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

确实是一个中序遍历,但是我有一个疑问就是对于一个红黑树来说,我们需要广度优先遍历吗,或者说我们需要除了中序遍历 (Begin 从小到大) 和反向中序遍历 (RBegin 从大到小)以外的遍历方式么

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

你不需要纠结要不要,只是说我们需要留出来这个扩展点。所以你现在并不能直接说我这里不可能用别的遍历方式。而且你从设计理念上来说,next 就是归属你 iterator 的概念,一个红黑树是没有这个概念的。

if node == nil {
return nil
} else if node.right != nil {
p := node.right
for p.left != nil {
p = p.left
}
return p
} else {
p := node.parent
ch := node
for p != nil && ch == p.right {
ch = p
p = p.parent
}
return p
}
}

func (node *rbNode[K, V]) setNode(v V) {
if node == nil {
return
Expand All @@ -79,6 +105,19 @@ func newRBNode[K any, V any](key K, value V) *rbNode[K, V] {
}
}

// 获取起始点的迭代器
func (rb *RBTree[K, V]) beginNode() *rbNode[K, V] {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同 getNext,beginNode 也是取决于你如何遍历的。比如说要是我准备实现一个猥琐的从后往前的遍历的 iterator,那么你这个 beginNode 就不对了。因此应该挪过去 Iterator 的实现上

curr := rb.root
for curr.left != nil {
curr = curr.left
}
return curr
}

func (rb *RBTree[K, V]) Begin() iterator.Iterator[pair.Pair[K, V]] {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同样将 Begin 改个名字,指明你的遍历顺序,在注释里面也解释一下。

return newRBTreeIterator(rb, rb.beginNode())
}

// Add 增加节点
func (rb *RBTree[K, V]) Add(key K, value V) error {
return rb.addNode(newRBNode(key, value))
Expand All @@ -95,6 +134,14 @@ func (rb *RBTree[K, V]) Delete(key K) (V, bool) {
return v, false
}

// 查找结点 (但是返回iterator)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个方法不需要。因为正常在使用 Iterator 的时候,查找是这么用的:

for itr.Next() {
    val := itr.Get()
    if val == xxx {
     // 这就是判定
    }
}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这样的话 iterator无法支持查找欸,比如我想使用iterator实现一个功能就是查找到比 key 大 和 比 key 小的 10 个值,本来可以做到O(logN)的现在只能做到O(N)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Iterator 本身的语义就是迭代。Find 类的方法可以在 RBTree 上实现,但是没有必要在 Iterator 这里。

func (rb *RBTree[K, V]) FindIt(key K) (iterator.Iterator[pair.Pair[K, V]], error) {
if node := rb.findNode(key); node != nil {
return newRBTreeIterator(rb, node), nil
}
return nil, ErrRBTreeNotRBNode
}

// Find 查找节点
func (rb *RBTree[K, V]) Find(key K) (V, error) {
var v V
Expand All @@ -103,6 +150,8 @@ func (rb *RBTree[K, V]) Find(key K) (V, error) {
}
return v, ErrRBTreeNotRBNode
}

// 给对应的Key 设置 Value
func (rb *RBTree[K, V]) Set(key K, value V) error {
if node := rb.findNode(key); node != nil {
node.setNode(value)
Expand Down Expand Up @@ -243,24 +292,7 @@ func (rb *RBTree[K, V]) deleteNode(tgt *rbNode[K, V]) {
// case1: node节点存在右子节点,则右子树的最小节点是node的后继节点
// case2: node节点不存在右子节点,则其第一个为左节点的祖先的父节点为node的后继节点
func (rb *RBTree[K, V]) findSuccessor(node *rbNode[K, V]) *rbNode[K, V] {
if node == nil {
return nil
} else if node.right != nil {
p := node.right
for p.left != nil {
p = p.left
}
return p
} else {
p := node.parent
ch := node
for p != nil && ch == p.right {
ch = p
p = p.parent
}
return p
}

return node.getNext()
}

func (rb *RBTree[K, V]) findNode(key K) *rbNode[K, V] {
Expand Down Expand Up @@ -544,3 +576,88 @@ func (node *rbNode[K, V]) getBrother() *rbNode[K, V] {
}
return node.getParent().getLeft()
}

func (node *rbNode[K, V]) isValidNode() bool {
if node == nil {
return false
}
vis := 0
if node.getLeft() != nil {
if node.getLeft().getParent() != node {
return false
}
vis++
}
if node.getRight() != nil {
if node.getRight().getParent() != node {
return false
}
vis++
}
if node.getParent() != nil {
if node.getParent().getLeft() != node && node.getParent().getRight() != node {
return false
}
vis++
}
return vis > 0
}

type rbTreeIterator[K any, V any] struct {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

改个名字,名字里面包含它是如何遍历的,我的理解是你这个是中序遍历的?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

后续万一要实现广度优先、前序、后序,名字就能够区分出来。

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

把这个 iterator 挪出去作为一个单独的文件,测试也要相应的挪出去。

rbTree *RBTree[K, V]
currNode *rbNode[K, V]
nxtNode *rbNode[K, V]
err error
}

func (iter *rbTreeIterator[K, V]) Next() bool {
fmt.Println(iter.currNode.key, iter.currNode.value)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这种 DEBUG 信息要记得删了

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的

iter.err = nil
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个 err 暂时不要置为 nil,如果要是已经不为 nil 了,你就永远返回 false。

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

或者说,我们认为,如果 err 不为 nil,那么就说明这个 iterator 不可用了。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK 可以的

iter.currNode = iter.nxtNode
if iter.currNode != nil {
iter.nxtNode = iter.currNode.getNext()
return true
}
iter.err = ErrRBTreeIteratorNoNext
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NoNext 算是正常结束,你这里不用标记一个错误。

return false
}

func (iter *rbTreeIterator[K, V]) HasNext() bool {
return iter.nxtNode != nil
}

func (iter *rbTreeIterator[K, V]) Get() (kvPair pair.Pair[K, V]) {
iter.err = nil
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同样不能设置为 nil

if iter.currNode == nil {
iter.err = ErrRBTreeIteratorInvalid
return
}
kvPair = pair.NewPair(iter.currNode.key, iter.currNode.value)
return
}

func (iter *rbTreeIterator[K, V]) Err() error {
return iter.err
}

func (iter *rbTreeIterator[K, V]) Valid() bool {
return iter.currNode != nil
}

func (iter *rbTreeIterator[K, V]) Delete() {
iter.err = nil
if !iter.currNode.isValidNode() {
iter.err = ErrRBTreeIteratorInvalid
return
}
iter.rbTree.deleteNode(iter.currNode)
}

func newRBTreeIterator[K any, V any](rbTree *RBTree[K, V], rbNode *rbNode[K, V]) iterator.Iterator[pair.Pair[K, V]] {
iter := &rbTreeIterator[K, V]{
rbTree: rbTree,
currNode: rbNode,
}
iter.nxtNode = iter.currNode.getNext()
return iter
}
147 changes: 147 additions & 0 deletions internal/tree/red_black_tree_iterator_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
// Copyright 2021 ecodeclub
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package tree

import (
"math/rand"
"testing"

"github.com/ecodeclub/ekit"
"github.com/stretchr/testify/assert"
)

func TestIteratorToVisitFullRBTree(t *testing.T) {
t.Parallel()
n := 10000
arr := generateArray(n)
rbTree := NewRBTree[int, int](ekit.ComparatorRealNumber[int])
for _, v := range arr {
assert.Nil(t, rbTree.Add(v, v))
}

arrVisit := make([]int, n)
id := 0
for iter := rbTree.Begin(); iter.Valid(); iter.Next() {
pa, err := iter.Get(), iter.Err()
assert.Nil(t, err)
arrVisit[id] = pa.Key
assert.Equal(t, id, pa.Key)
id++
}
assert.Equal(t, n, id)
}

func TestIteratorFind(t *testing.T) {
t.Run("查找存在的节点", func(t *testing.T) {
t.Parallel()
rbt := NewRBTree[int, int](ekit.ComparatorRealNumber[int])
assert.Nil(t, rbt.Add(1, 101))
assert.Nil(t, rbt.Add(-100, 102))
assert.Nil(t, rbt.Add(100, 103))
it, err := rbt.FindIt(-100)
assert.Nil(t, err)
assert.Equal(t, 102, it.Get().Value)
})

t.Run("查找不存在的节点", func(t *testing.T) {
t.Parallel()
rbt := NewRBTree[int, int](ekit.ComparatorRealNumber[int])
assert.Nil(t, rbt.Add(1, 101))
assert.Nil(t, rbt.Add(-100, 102))
assert.Nil(t, rbt.Add(100, 103))
it, err := rbt.FindIt(2)
assert.Equal(t, ErrRBTreeNotRBNode, err)
assert.Nil(t, it)
})

t.Run("查找存在的节点,删除后不存在", func(t *testing.T) {
t.Parallel()
rbt := NewRBTree[int, int](ekit.ComparatorRealNumber[int])
assert.Nil(t, rbt.Add(1, 101))
assert.Nil(t, rbt.Add(-100, 102))
assert.Nil(t, rbt.Add(100, 103))
it, err := rbt.FindIt(-100)
assert.Nil(t, err)
assert.Equal(t, 102, it.Get().Value)
it.Delete()
assert.Nil(t, it.Err())
it, err = rbt.FindIt(-100)
assert.Equal(t, ErrRBTreeNotRBNode, err)
assert.Nil(t, it)
})

t.Run("查找不存在的节点,增加后存在", func(t *testing.T) {
t.Parallel()
rbt := NewRBTree[int, int](ekit.ComparatorRealNumber[int])
assert.Nil(t, rbt.Add(1, 101))
assert.Nil(t, rbt.Add(-100, 102))
assert.Nil(t, rbt.Add(100, 103))
it, err := rbt.FindIt(2)
assert.Equal(t, ErrRBTreeNotRBNode, err)
assert.Nil(t, it)
assert.Nil(t, rbt.Add(2, 104))
it, err = rbt.FindIt(2)
assert.Nil(t, err)
assert.Equal(t, 104, it.Get().Value)
})
}

func TestIteratorDelete(t *testing.T) {
t.Run("重复删除某个节点", func(t *testing.T) {
t.Parallel()
rbt := NewRBTree[int, int](ekit.ComparatorRealNumber[int])
assert.Nil(t, rbt.Add(1, 101))
assert.Nil(t, rbt.Add(-100, 102))
assert.Nil(t, rbt.Add(100, 103))
it, err := rbt.FindIt(-100)
assert.Nil(t, err)
it.Delete()
assert.Equal(t, nil, it.Err())
it.Delete()
assert.Equal(t, ErrRBTreeIteratorInvalid, it.Err())
})
t.Run("删除节点后正常遍历", func(t *testing.T) {
t.Parallel()
rbt := NewRBTree[int, int](ekit.ComparatorRealNumber[int])
assert.Nil(t, rbt.Add(1, 101))
assert.Nil(t, rbt.Add(-100, 102))
assert.Nil(t, rbt.Add(100, 103))
assert.Nil(t, rbt.Add(101, 104))
assert.Nil(t, rbt.Add(102, 105))

result := make([]int, 0)
for it := rbt.Begin(); it.Valid(); it.Next() {
key := it.Get().Key
if key == 100 {
it.Delete()
assert.Nil(t, it.Err())
continue
}
result = append(result, key)
}
assert.EqualValues(t, []int{-100, 1, 101, 102}, result)
})
}

func generateArray(n int) []int {
res := make([]int, n)
for i := 0; i < n; i++ {
res[i] = i
}
rand.Shuffle(n, func(i, j int) {
res[i], res[j] = res[j], res[i]
})
return res
}
Loading