From f83efa96d01719a525905767716bce40062f54c4 Mon Sep 17 00:00:00 2001 From: uzziahlin <120019273+uzziahlin@users.noreply.github.com> Date: Fri, 9 Dec 2022 17:09:28 +0800 Subject: [PATCH 1/9] =?UTF-8?q?queue:=20=E5=9F=BA=E4=BA=8Esemaphore?= =?UTF-8?q?=E7=9A=84=E5=B9=B6=E5=8F=91=E9=98=BB=E5=A1=9E=E9=98=9F=E5=88=97?= =?UTF-8?q?=E5=AE=9E=E7=8E=B0=20(#129)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * queue: 基于semaphore的并发阻塞队列实现 -主要改造了concurrent_array_blocking_queue的实现,将原来基于cond的实现改造成基于semaphore的实现 Co-authored-by: uzziahlin --- .CHANGELOG.md | 1 + go.mod | 2 +- go.sum | 4 +- queue/concurrent_array_blocking_queue.go | 107 +++++++++++++++-------- 4 files changed, 73 insertions(+), 41 deletions(-) diff --git a/.CHANGELOG.md b/.CHANGELOG.md index 126713b1..616a5d8e 100644 --- a/.CHANGELOG.md +++ b/.CHANGELOG.md @@ -1,4 +1,5 @@ # 开发中 +- [queue: 基于semaphore的并发阻塞队列实现](https://github.com/gotomicro/ekit/pull/129) # v0.0.5 - [atomicx: 泛型封装 atomic.Value](https://github.com/gotomicro/ekit/pull/101) diff --git a/go.mod b/go.mod index 074ddb77..655d0665 100644 --- a/go.mod +++ b/go.mod @@ -5,7 +5,7 @@ go 1.18 require ( github.com/mattn/go-sqlite3 v1.14.15 github.com/stretchr/testify v1.7.1 - golang.org/x/sync v0.0.0-20220819030929-7fc1605a5dde + golang.org/x/sync v0.1.0 ) require ( diff --git a/go.sum b/go.sum index 62e805fe..f45795a7 100644 --- a/go.sum +++ b/go.sum @@ -7,8 +7,8 @@ github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZN github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.7.1 h1:5TQK59W5E3v0r2duFAb7P95B6hEeOyEnHRa8MjYSMTY= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= -golang.org/x/sync v0.0.0-20220819030929-7fc1605a5dde h1:ejfdSekXMDxDLbRrJMwUk6KnSLZ2McaUCVcIKM+N6jc= -golang.org/x/sync v0.0.0-20220819030929-7fc1605a5dde/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o= +golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= diff --git a/queue/concurrent_array_blocking_queue.go b/queue/concurrent_array_blocking_queue.go index 3b4e62da..2ed80882 100644 --- a/queue/concurrent_array_blocking_queue.go +++ b/queue/concurrent_array_blocking_queue.go @@ -17,6 +17,8 @@ package queue import ( "context" "sync" + + "golang.org/x/sync/semaphore" ) // ConcurrentArrayBlockingQueue 有界并发阻塞队列 @@ -31,8 +33,8 @@ type ConcurrentArrayBlockingQueue[T any] struct { // 包含多少个元素 count int - notEmpty *cond - notFull *cond + enqueueCap *semaphore.Weighted + dequeueCap *semaphore.Weighted // zero 不能作为返回值返回,防止用户篡改 zero T @@ -43,73 +45,102 @@ type ConcurrentArrayBlockingQueue[T any] struct { // capacity 必须为正数 func NewConcurrentArrayBlockingQueue[T any](capacity int) *ConcurrentArrayBlockingQueue[T] { mutex := &sync.RWMutex{} + + semaForEnqueue := semaphore.NewWeighted(int64(capacity)) + semaForDequeue := semaphore.NewWeighted(int64(capacity)) + + // error暂时不处理,因为目前没办法处理,只能考虑panic掉 + // 相当于将信号量置空 + _ = semaForDequeue.Acquire(context.TODO(), int64(capacity)) + res := &ConcurrentArrayBlockingQueue[T]{ - data: make([]T, capacity), - mutex: mutex, - notEmpty: newCond(mutex), - notFull: newCond(mutex), + data: make([]T, capacity), + mutex: mutex, + enqueueCap: semaForEnqueue, + dequeueCap: semaForDequeue, } return res } // Enqueue 入队 -// 注意:目前我们已经通过broadcast实现了超时控制 +// 通过sema来控制容量、超时、阻塞问题 func (c *ConcurrentArrayBlockingQueue[T]) Enqueue(ctx context.Context, t T) error { - if ctx.Err() != nil { - return ctx.Err() + + // 能拿到,说明队列还有空位,可以入队,拿不到则阻塞 + err := c.enqueueCap.Acquire(ctx, 1) + + if err != nil { + return err } + c.mutex.Lock() - for c.count == len(c.data) { - signal := c.notFull.signalCh() - select { - case <-ctx.Done(): - return ctx.Err() - case <-signal: - // 收到信号要重新加锁 - c.mutex.Lock() - } + defer c.mutex.Unlock() + + // 拿到锁,先判断是否超时,防止在抢锁时已经超时 + if ctx.Err() != nil { + + // 超时应该主动归还信号量,避免容量泄露 + c.enqueueCap.Release(1) + + return ctx.Err() } + c.data[c.tail] = t c.tail++ c.count++ + // c.tail 已经是最后一个了,重置下标 if c.tail == cap(c.data) { c.tail = 0 } - // 这里会释放锁 - c.notEmpty.broadcast() + + // 往出队的sema放入一个元素,出队的goroutine可以拿到并出队 + c.dequeueCap.Release(1) + return nil + } // Dequeue 出队 -// 注意:目前我们已经通过broadcast实现了超时控制 +// 通过sema来控制容量、超时、阻塞问题 func (c *ConcurrentArrayBlockingQueue[T]) Dequeue(ctx context.Context) (T, error) { - if ctx.Err() != nil { - var t T - return t, ctx.Err() + + // 能拿到,说明队列有元素可以取,可以出队,拿不到则阻塞 + err := c.dequeueCap.Acquire(ctx, 1) + + var res T + + if err != nil { + return res, err } + c.mutex.Lock() - for c.count == 0 { - signal := c.notEmpty.signalCh() - select { - case <-ctx.Done(): - var t T - return t, ctx.Err() - case <-signal: - c.mutex.Lock() - } + defer c.mutex.Unlock() + + // 拿到锁,先判断是否超时,防止在抢锁时已经超时 + if ctx.Err() != nil { + + // 超时应该主动归还信号量,有元素消费不到 + c.dequeueCap.Release(1) + + return res, ctx.Err() } - val := c.data[c.head] + + res = c.data[c.head] // 为了释放内存,GC c.data[c.head] = c.zero - c.count-- + c.head++ - // 重置下标 + c.count-- if c.head == cap(c.data) { c.head = 0 } - c.notFull.broadcast() - return val, nil + + // 往入队的sema放入一个元素,入队的goroutine可以拿到并入队 + c.enqueueCap.Release(1) + + return res, nil + } func (c *ConcurrentArrayBlockingQueue[T]) Len() int { From 2cc2690e476cf3e084788b58e2a27b8a53cf19c5 Mon Sep 17 00:00:00 2001 From: Ming Deng Date: Sat, 24 Dec 2022 20:30:22 +0800 Subject: [PATCH 2/9] =?UTF-8?q?mapx:=20=E6=B7=BB=E5=8A=A0=20Keys=20?= =?UTF-8?q?=E6=96=B9=E6=B3=95=20(#134)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Updates #134 --- .CHANGELOG.md | 1 + mapx/map.go | 25 +++++++++++++++++++ mapx/map_test.go | 62 ++++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 88 insertions(+) create mode 100644 mapx/map.go create mode 100644 mapx/map_test.go diff --git a/.CHANGELOG.md b/.CHANGELOG.md index 616a5d8e..e62c809a 100644 --- a/.CHANGELOG.md +++ b/.CHANGELOG.md @@ -1,5 +1,6 @@ # 开发中 - [queue: 基于semaphore的并发阻塞队列实现](https://github.com/gotomicro/ekit/pull/129) +- [mapx: 添加 Keys 方法](https://github.com/gotomicro/ekit/pull/134) # v0.0.5 - [atomicx: 泛型封装 atomic.Value](https://github.com/gotomicro/ekit/pull/101) diff --git a/mapx/map.go b/mapx/map.go new file mode 100644 index 00000000..4abad187 --- /dev/null +++ b/mapx/map.go @@ -0,0 +1,25 @@ +// Copyright 2021 gotomicro +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mapx + +// Keys 返回 map 里面的所有的 key。 +// 需要注意:这些 key 的顺序是随机。 +func Keys[K comparable, V any](m map[K]V) []K { + res := make([]K, 0, len(m)) + for k := range m { + res = append(res, k) + } + return res +} diff --git a/mapx/map_test.go b/mapx/map_test.go new file mode 100644 index 00000000..f79a6f34 --- /dev/null +++ b/mapx/map_test.go @@ -0,0 +1,62 @@ +// Copyright 2021 gotomicro +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mapx + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestKeys(t *testing.T) { + testCases := []struct { + name string + input map[int]int + wantRes []int + }{ + { + name: "nil", + input: nil, + wantRes: []int{}, + }, + { + name: "empty", + input: map[int]int{}, + wantRes: []int{}, + }, + { + name: "single", + input: map[int]int{ + 1: 11, + }, + wantRes: []int{1}, + }, + { + name: "multiple", + input: map[int]int{ + 1: 11, + 2: 12, + }, + wantRes: []int{1, 2}, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + res := Keys[int, int](tc.input) + assert.ElementsMatch(t, tc.wantRes, res) + }) + } +} From 088412ce7dd95a848cb201f9ee27938657e6b9ca Mon Sep 17 00:00:00 2001 From: juniaoshaonian <73632785+juniaoshaonian@users.noreply.github.com> Date: Sun, 1 Jan 2023 13:59:06 +0800 Subject: [PATCH 3/9] =?UTF-8?q?hashmap=E7=9A=84=E7=AE=80=E5=8D=95=E5=AE=9E?= =?UTF-8?q?=E7=8E=B0=20(#132)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Ming Deng Co-authored-by: Ming Deng --- .CHANGELOG.md | 1 + mapx/hashmap.go | 95 +++++++++++++++++++ mapx/hashmap_test.go | 211 +++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 307 insertions(+) create mode 100644 mapx/hashmap.go create mode 100644 mapx/hashmap_test.go diff --git a/.CHANGELOG.md b/.CHANGELOG.md index e62c809a..8de25e08 100644 --- a/.CHANGELOG.md +++ b/.CHANGELOG.md @@ -1,5 +1,6 @@ # 开发中 - [queue: 基于semaphore的并发阻塞队列实现](https://github.com/gotomicro/ekit/pull/129) +- [mapx: hashmap实现](https://github.com/gotomicro/ekit/pull/132) - [mapx: 添加 Keys 方法](https://github.com/gotomicro/ekit/pull/134) # v0.0.5 diff --git a/mapx/hashmap.go b/mapx/hashmap.go new file mode 100644 index 00000000..da458d15 --- /dev/null +++ b/mapx/hashmap.go @@ -0,0 +1,95 @@ +// Copyright 2021 gotomicro +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mapx + +import "github.com/gotomicro/ekit/syncx" + +type node[T Hashable, ValType any] struct { + key Hashable + value ValType + next *node[T, ValType] +} + +func (m *HashMap[T, ValType]) newNode(key Hashable, val ValType) *node[T, ValType] { + newNode := m.nodePool.Get() + newNode.value = val + newNode.key = key + return newNode +} + +type Hashable interface { + Code() uint64 + Equals(key any) bool +} + +type HashMap[T Hashable, ValType any] struct { + hashmap map[uint64]*node[T, ValType] + nodePool *syncx.Pool[*node[T, ValType]] +} + +func (m *HashMap[T, ValType]) Put(key T, val ValType) error { + hash := key.Code() + root, ok := m.hashmap[hash] + if !ok { + hash = key.Code() + new_node := m.newNode(key, val) + m.hashmap[hash] = new_node + return nil + } + pre := root + for root != nil { + if root.key.Equals(key) { + root.value = val + return nil + } + pre = root + root = root.next + } + new_node := m.newNode(key, val) + pre.next = new_node + return nil +} + +func (m *HashMap[T, ValType]) Get(key T) (ValType, bool) { + hash := key.Code() + root, ok := m.hashmap[hash] + var val ValType + if !ok { + return val, false + } + for root != nil { + if root.key.Equals(key) { + return root.value, true + } + root = root.next + } + return val, false +} + +func NewHashMap[T Hashable, ValType any](size int) *HashMap[T, ValType] { + return &HashMap[T, ValType]{ + nodePool: syncx.NewPool[*node[T, ValType]](func() *node[T, ValType] { + return &node[T, ValType]{} + }), + hashmap: make(map[uint64]*node[T, ValType], size), + } +} + +type mapi[T any, ValType any] interface { + Put(key T, val ValType) error + Get(key T) (ValType, bool) +} + +var _ mapi[Hashable, any] = (*HashMap[Hashable, any])(nil) diff --git a/mapx/hashmap_test.go b/mapx/hashmap_test.go new file mode 100644 index 00000000..25993ce8 --- /dev/null +++ b/mapx/hashmap_test.go @@ -0,0 +1,211 @@ +// Copyright 2021 gotomicro +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mapx + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestHashMap(t *testing.T) { + testKV := []struct { + key testData + val int + }{ + { + key: testData{ + id: 1, + }, + val: 1, + }, + { + key: testData{ + id: 2, + }, + val: 2, + }, + { + key: testData{ + id: 3, + }, + val: 3, + }, + { + key: testData{ + id: 11, + }, + val: 11, + }, + { + key: testData{ + id: 1, + }, + val: 101, + }, + } + myhashmap := NewHashMap[testData, int](10) + for _, kv := range testKV { + err := myhashmap.Put(kv.key, kv.val) + if err != nil { + panic(err) + } + } + + wantHashMap := NewHashMap[testData, int](10) + wantHashMap.hashmap = map[uint64]*node[testData, int]{ + 1: &node[testData, int]{ + key: testData{id: 1}, + value: 101, + next: &node[testData, int]{ + key: testData{id: 11}, + value: 11, + }, + }, + 2: wantHashMap.newNode(newTestData(2), 2), + 3: wantHashMap.newNode(newTestData(3), 3), + } + + assert.Equal(t, wantHashMap.hashmap, myhashmap.hashmap) + getTestCases := []struct { + name string + key testData + wantVal any + isFound bool + }{ + { + name: "get normal val", + key: testData{ + id: 1, + }, + wantVal: 101, + isFound: true, + }, + { + name: "hash conflicts", + key: testData{ + id: 11, + }, + wantVal: 11, + isFound: true, + }, + { + name: "hash not Found", + key: testData{ + id: 8, + }, + isFound: false, + }, + { + name: "val not Found", + key: testData{ + id: 21, + }, + isFound: false, + }, + } + for _, tc := range getTestCases { + t.Run(tc.name, func(t *testing.T) { + val, ok := myhashmap.Get(tc.key) + assert.Equal(t, tc.isFound, ok) + if !ok { + return + } + assert.Equal(t, tc.wantVal, val) + }) + } + +} + +type testData struct { + id int +} + +func (t testData) Code() uint64 { + hash := t.id % 10 + return uint64(hash) +} + +func (t testData) Equals(key any) bool { + val, ok := key.(testData) + if !ok { + return false + } + if t.id != val.id { + return false + } + return true +} + +func newTestData(id int) testData { + return testData{ + id: id, + } +} + +type hashInt uint64 + +func (h hashInt) Code() uint64 { + return uint64(h) +} + +func (h hashInt) Equals(key any) bool { + switch keyVal := key.(type) { + case hashInt: + return keyVal == h + default: + return false + } +} + +func newHashInt(i int) hashInt { + return hashInt(i) +} + +// goos: linux +// goarch: amd64 +// pkg: github.com/gotomicro/ekit/mapx +// cpu: Intel(R) Core(TM) i7-6700HQ CPU @ 2.60GHz +// BenchmarkMyHashMap/hashmap_put-8 4985634 374.1 ns/op 53 B/op 1 allocs/op +// BenchmarkMyHashMap/map_put-8 5465565 235.5 ns/op 49 B/op 0 allocs/op +// BenchmarkMyHashMap/hashmap_get-8 7080890 143.9 ns/op 5 B/op 0 allocs/op +// BenchmarkMyHashMap/map_get-8 18534306 86.94 ns/op 0 B/op 0 allocs/op + +func BenchmarkMyHashMap(b *testing.B) { + hashmap := NewHashMap[hashInt, int](10) + m := make(map[uint64]int, 10) + b.Run("hashmap_put", func(b *testing.B) { + for i := 0; i < b.N; i++ { + _ = hashmap.Put(newHashInt(i), i) + } + }) + b.Run("map_put", func(b *testing.B) { + for i := 0; i < b.N; i++ { + m[uint64(i)] = i + } + }) + b.Run("hashmap_get", func(b *testing.B) { + for i := 0; i < b.N; i++ { + _, _ = hashmap.Get(newHashInt(i)) + } + }) + + b.Run("map_get", func(b *testing.B) { + for i := 0; i < b.N; i++ { + _ = m[uint64(i)] + } + }) + +} From b91c08f8d6f1926ae4762c36dcda61ff25446306 Mon Sep 17 00:00:00 2001 From: yzicheng <786797661@qq.com> Date: Tue, 3 Jan 2023 14:49:02 +0800 Subject: [PATCH 4/9] =?UTF-8?q?mapx:=20Values=20=E5=92=8C=20KeysValues=20?= =?UTF-8?q?=E6=96=B9=E6=B3=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- mapx/map.go | 22 ++++++++++++ mapx/map_test.go | 87 ++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 109 insertions(+) diff --git a/mapx/map.go b/mapx/map.go index 4abad187..94bbbf19 100644 --- a/mapx/map.go +++ b/mapx/map.go @@ -23,3 +23,25 @@ func Keys[K comparable, V any](m map[K]V) []K { } return res } + +// Values 返回 map 里面的所有的 value。 +// 需要注意:这些 value 的顺序是随机。 +func Values[K comparable, V any](m map[K]V) []V { + res := make([]V, 0, len(m)) + for k := range m { + res = append(res, m[k]) + } + return res +} + +// KeysValues 返回 map 里面的所有的 key,value。 +// 需要注意:这些 (key,value) 的顺序是随机,相对顺序是一致的。 +func KeysValues[K comparable, V any](m map[K]V) ([]K, []V) { + keys := make([]K, 0, len(m)) + values := make([]V, 0, len(m)) + for k := range m { + keys = append(keys, k) + values = append(values, m[k]) + } + return keys, values +} diff --git a/mapx/map_test.go b/mapx/map_test.go index f79a6f34..d467e089 100644 --- a/mapx/map_test.go +++ b/mapx/map_test.go @@ -60,3 +60,90 @@ func TestKeys(t *testing.T) { }) } } +func TestValues(t *testing.T) { + testCases := []struct { + name string + input map[int]int + wantRes []int + }{ + { + name: "nil", + input: nil, + wantRes: []int{}, + }, + { + name: "empty", + input: map[int]int{}, + wantRes: []int{}, + }, + { + name: "single", + input: map[int]int{ + 1: 11, + }, + wantRes: []int{11}, + }, + { + name: "multiple", + input: map[int]int{ + 1: 11, + 2: 12, + }, + wantRes: []int{11, 12}, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + res := Values[int, int](tc.input) + assert.ElementsMatch(t, tc.wantRes, res) + }) + } +} + +func TestKeysValues(t *testing.T) { + testCases := []struct { + name string + input map[int]int + wantKeys []int + wantValues []int + }{ + { + name: "nil", + input: nil, + wantKeys: []int{}, + wantValues: []int{}, + }, + { + name: "empty", + input: map[int]int{}, + wantKeys: []int{}, + wantValues: []int{}, + }, + { + name: "single", + input: map[int]int{ + 1: 11, + }, + wantKeys: []int{1}, + wantValues: []int{11}, + }, + { + name: "multiple", + input: map[int]int{ + 1: 11, + 2: 12, + }, + wantKeys: []int{1, 2}, + wantValues: []int{11, 12}, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + keys, values := KeysValues[int, int](tc.input) + assert.ElementsMatch(t, tc.wantKeys, keys) + assert.ElementsMatch(t, tc.wantValues, values) + }) + } +} From 215438f4b69d5922ee062e14e0ed6b97a84664d1 Mon Sep 17 00:00:00 2001 From: juniaoshaonian <73632785+juniaoshaonian@users.noreply.github.com> Date: Sun, 8 Jan 2023 22:49:21 +0800 Subject: [PATCH 5/9] =?UTF-8?q?hashmap=20delete=E5=8A=9F=E8=83=BD=E5=AE=9E?= =?UTF-8?q?=E7=8E=B0=20(#138)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .CHANGELOG.md | 1 + mapx/hashmap.go | 39 ++++++++ mapx/hashmap_test.go | 214 ++++++++++++++++++++++++++++++++++++++++++- 3 files changed, 253 insertions(+), 1 deletion(-) diff --git a/.CHANGELOG.md b/.CHANGELOG.md index 8de25e08..7be3272a 100644 --- a/.CHANGELOG.md +++ b/.CHANGELOG.md @@ -2,6 +2,7 @@ - [queue: 基于semaphore的并发阻塞队列实现](https://github.com/gotomicro/ekit/pull/129) - [mapx: hashmap实现](https://github.com/gotomicro/ekit/pull/132) - [mapx: 添加 Keys 方法](https://github.com/gotomicro/ekit/pull/134) +- [mapx: hashmap添加刪除功能](https://github.com/gotomicro/ekit/pull/138) # v0.0.5 - [atomicx: 泛型封装 atomic.Value](https://github.com/gotomicro/ekit/pull/101) diff --git a/mapx/hashmap.go b/mapx/hashmap.go index da458d15..e6fa3799 100644 --- a/mapx/hashmap.go +++ b/mapx/hashmap.go @@ -93,3 +93,42 @@ type mapi[T any, ValType any] interface { } var _ mapi[Hashable, any] = (*HashMap[Hashable, any])(nil) + +// Delete 第一个返回值为删除key的值,第二个是hashmap是否真的有这个key +func (m *HashMap[T, ValType]) Delete(key T) (ValType, bool) { + root, ok := m.hashmap[key.Code()] + if !ok { + var t ValType + return t, false + } + pre := root + num := 0 + for root != nil { + if root.key.Equals(key) { + if num == 0 && root.next == nil { + delete(m.hashmap, key.Code()) + } else if num == 0 && root.next != nil { + m.hashmap[key.Code()] = root.next + } else { + pre.next = root.next + } + val := root.value + root.formatting() + m.nodePool.Put(root) + return val, true + } + num++ + pre = root + root = root.next + } + var t ValType + return t, false +} + +func (n *node[T, ValType]) formatting() { + var val ValType + var t T + n.key = t + n.value = val + n.next = nil +} diff --git a/mapx/hashmap_test.go b/mapx/hashmap_test.go index 25993ce8..2a86cdeb 100644 --- a/mapx/hashmap_test.go +++ b/mapx/hashmap_test.go @@ -17,6 +17,8 @@ package mapx import ( "testing" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/assert" ) @@ -60,7 +62,7 @@ func TestHashMap(t *testing.T) { for _, kv := range testKV { err := myhashmap.Put(kv.key, kv.val) if err != nil { - panic(err) + require.NoError(t, err) } } @@ -128,6 +130,216 @@ func TestHashMap(t *testing.T) { } } +func TestHashMap_Delete(t *testing.T) { + testcases := []struct { + name string + key testData + setHashMap func() map[uint64]*node[testData, int] + wantHashMap func() map[uint64]*node[testData, int] + wantVal int + isFound bool + }{ + { + name: "hash not found", + setHashMap: func() map[uint64]*node[testData, int] { + return map[uint64]*node[testData, int]{} + }, + isFound: false, + key: testData{ + id: 1, + }, + }, + { + name: "key not found", + setHashMap: func() map[uint64]*node[testData, int] { + return map[uint64]*node[testData, int]{ + 1: &node[testData, int]{ + key: testData{ + id: 1, + }, + value: 1, + }, + } + }, + isFound: false, + key: testData{ + id: 11, + }, + }, + { + name: "many link elements delete first", + setHashMap: func() map[uint64]*node[testData, int] { + return map[uint64]*node[testData, int]{ + 1: &node[testData, int]{ + key: testData{ + id: 1, + }, + value: 1, + next: &node[testData, int]{ + key: testData{ + id: 11, + }, + value: 11, + next: &node[testData, int]{ + key: testData{ + id: 21, + }, + value: 21, + }, + }, + }, + } + }, + isFound: true, + key: testData{ + id: 1, + }, + wantVal: 1, + wantHashMap: func() map[uint64]*node[testData, int] { + return map[uint64]*node[testData, int]{ + 1: &node[testData, int]{ + key: testData{ + id: 11, + }, + value: 11, + next: &node[testData, int]{ + key: testData{ + id: 21, + }, + value: 21, + }, + }, + } + }, + }, + { + name: "delete only one link element", + key: testData{ + id: 1, + }, + setHashMap: func() map[uint64]*node[testData, int] { + return map[uint64]*node[testData, int]{ + 1: &node[testData, int]{ + key: testData{ + id: 1, + }, + value: 1, + }, + } + }, + wantHashMap: func() map[uint64]*node[testData, int] { + return map[uint64]*node[testData, int]{} + }, + wantVal: 1, + isFound: true, + }, + { + name: "many link elements delete middle", + key: testData{ + id: 11, + }, + setHashMap: func() map[uint64]*node[testData, int] { + return map[uint64]*node[testData, int]{ + 1: &node[testData, int]{ + key: testData{ + id: 1, + }, + value: 1, + next: &node[testData, int]{ + key: testData{ + id: 11, + }, + value: 11, + next: &node[testData, int]{ + key: testData{ + id: 21, + }, + value: 21, + }, + }, + }, + } + }, + wantHashMap: func() map[uint64]*node[testData, int] { + return map[uint64]*node[testData, int]{ + 1: &node[testData, int]{ + key: testData{ + id: 1, + }, + value: 1, + next: &node[testData, int]{ + key: testData{ + id: 21, + }, + value: 21, + }, + }, + } + }, + isFound: true, + wantVal: 11, + }, + { + name: "many link elements delete end", + key: testData{ + id: 21, + }, + setHashMap: func() map[uint64]*node[testData, int] { + return map[uint64]*node[testData, int]{ + 1: &node[testData, int]{ + key: testData{ + id: 1, + }, + value: 1, + next: &node[testData, int]{ + key: testData{ + id: 11, + }, + value: 11, + next: &node[testData, int]{ + key: testData{ + id: 21, + }, + value: 21, + }, + }, + }, + } + }, + wantHashMap: func() map[uint64]*node[testData, int] { + return map[uint64]*node[testData, int]{ + 1: &node[testData, int]{ + key: testData{ + id: 1, + }, + value: 1, + next: &node[testData, int]{ + key: testData{ + id: 11, + }, + value: 11, + }, + }, + } + }, + isFound: true, + wantVal: 21, + }, + } + for _, tc := range testcases { + t.Run(tc.name, func(t *testing.T) { + h := NewHashMap[testData, int](10) + h.hashmap = tc.setHashMap() + val, ok := h.Delete(tc.key) + assert.Equal(t, tc.isFound, ok) + if !ok { + return + } + assert.Equal(t, tc.wantVal, val) + assert.Equal(t, tc.wantHashMap(), h.hashmap) + }) + } +} type testData struct { id int From 39c238f47ef5ba2e990ee19cf2ed73ea82cca078 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9F=A6=E4=BD=B3=E6=A0=8B?= <353470469@qq.com> Date: Tue, 10 Jan 2023 11:38:02 +0800 Subject: [PATCH 6/9] =?UTF-8?q?ekit:=20AnyValue=20=E5=A2=9E=E5=8A=A0bool?= =?UTF-8?q?=E7=B1=BB=E5=9E=8B=E6=94=AF=E6=8C=81=20(#135)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .CHANGELOG.md | 1 + value.go | 62 +++++++++++++++++++++++++------------ value_test.go | 86 +++++++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 129 insertions(+), 20 deletions(-) diff --git a/.CHANGELOG.md b/.CHANGELOG.md index 7be3272a..089d180d 100644 --- a/.CHANGELOG.md +++ b/.CHANGELOG.md @@ -2,6 +2,7 @@ - [queue: 基于semaphore的并发阻塞队列实现](https://github.com/gotomicro/ekit/pull/129) - [mapx: hashmap实现](https://github.com/gotomicro/ekit/pull/132) - [mapx: 添加 Keys 方法](https://github.com/gotomicro/ekit/pull/134) +- [ekit: 修改代码风格,增加bool类型支持](https://github.com/gotomicro/ekit/pull/135) - [mapx: hashmap添加刪除功能](https://github.com/gotomicro/ekit/pull/138) # v0.0.5 diff --git a/value.go b/value.go index e21c6c96..1022e73f 100644 --- a/value.go +++ b/value.go @@ -20,6 +20,7 @@ import ( "github.com/gotomicro/ekit/internal/errs" ) +// AnyValue 类型转换结构定义 type AnyValue struct { Val any Err error @@ -38,8 +39,8 @@ func (av AnyValue) Int() (int, error) { } // IntOrDefault 返回 int 数据,或者默认值 -func (a AnyValue) IntOrDefault(def int) int { - val, err := a.Int() +func (av AnyValue) IntOrDefault(def int) int { + val, err := av.Int() if err != nil { return def } @@ -59,8 +60,8 @@ func (av AnyValue) Uint() (uint, error) { } // UintOrDefault 返回 uint 数据,或者默认值 -func (a AnyValue) UintOrDefault(def uint) uint { - val, err := a.Uint() +func (av AnyValue) UintOrDefault(def uint) uint { + val, err := av.Uint() if err != nil { return def } @@ -80,8 +81,8 @@ func (av AnyValue) Int32() (int32, error) { } // Int32OrDefault 返回 int32 数据,或者默认值 -func (a AnyValue) Int32OrDefault(def int32) int32 { - val, err := a.Int32() +func (av AnyValue) Int32OrDefault(def int32) int32 { + val, err := av.Int32() if err != nil { return def } @@ -101,8 +102,8 @@ func (av AnyValue) Uint32() (uint32, error) { } // Uint32OrDefault 返回 uint32 数据,或者默认值 -func (a AnyValue) Uint32OrDefault(def uint32) uint32 { - val, err := a.Uint32() +func (av AnyValue) Uint32OrDefault(def uint32) uint32 { + val, err := av.Uint32() if err != nil { return def } @@ -122,8 +123,8 @@ func (av AnyValue) Int64() (int64, error) { } // Int64OrDefault 返回 int64 数据,或者默认值 -func (a AnyValue) Int64OrDefault(def int64) int64 { - val, err := a.Int64() +func (av AnyValue) Int64OrDefault(def int64) int64 { + val, err := av.Int64() if err != nil { return def } @@ -143,8 +144,8 @@ func (av AnyValue) Uint64() (uint64, error) { } // Uint64OrDefault 返回 uint64 数据,或者默认值 -func (a AnyValue) Uint64OrDefault(def uint64) uint64 { - val, err := a.Uint64() +func (av AnyValue) Uint64OrDefault(def uint64) uint64 { + val, err := av.Uint64() if err != nil { return def } @@ -164,8 +165,8 @@ func (av AnyValue) Float32() (float32, error) { } // Float32OrDefault 返回 float32 数据,或者默认值 -func (a AnyValue) Float32OrDefault(def float32) float32 { - val, err := a.Float32() +func (av AnyValue) Float32OrDefault(def float32) float32 { + val, err := av.Float32() if err != nil { return def } @@ -185,8 +186,8 @@ func (av AnyValue) Float64() (float64, error) { } // Float64OrDefault 返回 float64 数据,或者默认值 -func (a AnyValue) Float64OrDefault(def float64) float64 { - val, err := a.Float64() +func (av AnyValue) Float64OrDefault(def float64) float64 { + val, err := av.Float64() if err != nil { return def } @@ -206,8 +207,8 @@ func (av AnyValue) String() (string, error) { } // StringOrDefault 返回 string 数据,或者默认值 -func (a AnyValue) StringOrDefault(def string) string { - val, err := a.String() +func (av AnyValue) StringOrDefault(def string) string { + val, err := av.String() if err != nil { return def } @@ -227,8 +228,29 @@ func (av AnyValue) Bytes() ([]byte, error) { } // BytesOrDefault 返回 []byte 数据,或者默认值 -func (a AnyValue) BytesOrDefault(def []byte) []byte { - val, err := a.Bytes() +func (av AnyValue) BytesOrDefault(def []byte) []byte { + val, err := av.Bytes() + if err != nil { + return def + } + return val +} + +// Bool 返回 bool 数据 +func (av AnyValue) Bool() (bool, error) { + if av.Err != nil { + return false, av.Err + } + val, ok := av.Val.(bool) + if !ok { + return false, errs.NewErrInvalidType("bool", reflect.TypeOf(av.Val).String()) + } + return val, nil +} + +// BoolOrDefault 返回 bool 数据,或者默认值 +func (av AnyValue) BoolOrDefault(def bool) bool { + val, err := av.Bool() if err != nil { return def } diff --git a/value_test.go b/value_test.go index b7788220..e1bad249 100644 --- a/value_test.go +++ b/value_test.go @@ -884,3 +884,89 @@ func TestAnyValue_BytesOrDefault(t *testing.T) { }) } } + +func TestAnyValue_Bool(t *testing.T) { + tests := []struct { + name string + val AnyValue + want bool + err error + }{ + { + name: "normal case:", + val: AnyValue{ + Val: true, + }, + want: true, + err: nil, + }, + { + name: "error case:", + val: AnyValue{ + Err: errors.New("error"), + }, + err: errors.New("error"), + }, + { + name: "type error case:", + val: AnyValue{ + Val: 1, + }, + err: errs.NewErrInvalidType("bool", reflect.TypeOf(1).String()), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + av := AnyValue{ + Val: tt.val.Val, + Err: tt.val.Err, + } + got, err := av.Bool() + assert.Equal(t, err, tt.err) + assert.Equal(t, got, tt.want) + }) + } +} + +func TestAnyValue_BoolOrDefault(t *testing.T) { + tests := []struct { + name string + val AnyValue + def bool + want bool + }{ + { + name: "normal case:", + val: AnyValue{ + Val: true, + }, + want: true, + }, + { + name: "default case:", + val: AnyValue{ + Val: true, + Err: errors.New("error"), + }, + def: false, + want: false, + }, + { + name: "type error case:", + val: AnyValue{ + Val: 1, + }, + def: true, + want: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + av := AnyValue{ + Val: tt.val.Val, + Err: tt.val.Err, + } + assert.Equal(t, av.BoolOrDefault(tt.def), tt.want) + }) + } +} From 276e47b18921489890eed4c9a71cf21ad724a24f Mon Sep 17 00:00:00 2001 From: heroyf <13788978852@163.com> Date: Tue, 10 Jan 2023 15:05:25 +0800 Subject: [PATCH 7/9] =?UTF-8?q?mapx:=20HashMap=20=E5=A2=9E=E5=8A=A0=20Keys?= =?UTF-8?q?=20=E5=92=8C=20Values=20=E6=96=B9=E6=B3=95=20(#141)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .CHANGELOG.md | 1 + mapx/hashmap.go | 36 ++++++++++-- mapx/hashmap_test.go | 130 +++++++++++++++++++++++++++++++++++++++++-- 3 files changed, 159 insertions(+), 8 deletions(-) diff --git a/.CHANGELOG.md b/.CHANGELOG.md index 089d180d..95fd24ee 100644 --- a/.CHANGELOG.md +++ b/.CHANGELOG.md @@ -4,6 +4,7 @@ - [mapx: 添加 Keys 方法](https://github.com/gotomicro/ekit/pull/134) - [ekit: 修改代码风格,增加bool类型支持](https://github.com/gotomicro/ekit/pull/135) - [mapx: hashmap添加刪除功能](https://github.com/gotomicro/ekit/pull/138) +- [mapx: HashMap 增加 Keys 和 Values 方法](https://github.com/gotomicro/ekit/pull/141) # v0.0.5 - [atomicx: 泛型封装 atomic.Value](https://github.com/gotomicro/ekit/pull/101) diff --git a/mapx/hashmap.go b/mapx/hashmap.go index e6fa3799..556af16b 100644 --- a/mapx/hashmap.go +++ b/mapx/hashmap.go @@ -44,8 +44,8 @@ func (m *HashMap[T, ValType]) Put(key T, val ValType) error { root, ok := m.hashmap[hash] if !ok { hash = key.Code() - new_node := m.newNode(key, val) - m.hashmap[hash] = new_node + newNode := m.newNode(key, val) + m.hashmap[hash] = newNode return nil } pre := root @@ -57,8 +57,8 @@ func (m *HashMap[T, ValType]) Put(key T, val ValType) error { pre = root root = root.next } - new_node := m.newNode(key, val) - pre.next = new_node + newNode := m.newNode(key, val) + pre.next = newNode return nil } @@ -78,6 +78,34 @@ func (m *HashMap[T, ValType]) Get(key T) (ValType, bool) { return val, false } +// Keys 返回 Hashmap 里面的所有的 key。 +// 注意:key 的顺序是随机的。 +func (m *HashMap[T, ValType]) Keys() []Hashable { + res := make([]Hashable, 0) + for _, bucketNode := range m.hashmap { + curNode := bucketNode + for curNode != nil { + res = append(res, curNode.key) + curNode = curNode.next + } + } + return res +} + +// Values 返回 Hashmap 里面的所有的 value。 +// 注意:value 的顺序是随机的。 +func (m *HashMap[T, ValType]) Values() []ValType { + res := make([]ValType, 0) + for _, bucketNode := range m.hashmap { + curNode := bucketNode + for curNode != nil { + res = append(res, curNode.value) + curNode = curNode.next + } + } + return res +} + func NewHashMap[T Hashable, ValType any](size int) *HashMap[T, ValType] { return &HashMap[T, ValType]{ nodePool: syncx.NewPool[*node[T, ValType]](func() *node[T, ValType] { diff --git a/mapx/hashmap_test.go b/mapx/hashmap_test.go index 2a86cdeb..26cd80f5 100644 --- a/mapx/hashmap_test.go +++ b/mapx/hashmap_test.go @@ -341,6 +341,131 @@ func TestHashMap_Delete(t *testing.T) { } } +func TestHashMap_Keys_Values(t *testing.T) { + testCases := []struct { + name string + genHashMap func() *HashMap[testData, int] + wantKeys []Hashable + wantValues []int + }{ + { + name: "empty", + genHashMap: func() *HashMap[testData, int] { + return &HashMap[testData, int]{} + }, + wantKeys: []Hashable{}, + wantValues: []int{}, + }, + { + name: "size is zero empty", + genHashMap: func() *HashMap[testData, int] { + return NewHashMap[testData, int](0) + }, + wantKeys: []Hashable{}, + wantValues: []int{}, + }, + { + name: "single key", + genHashMap: func() *HashMap[testData, int] { + testHashMap := NewHashMap[testData, int](10) + err := testHashMap.Put(newTestData(1), 1) + require.NoError(t, err) + return testHashMap + }, + wantKeys: []Hashable{newTestData(1)}, + wantValues: []int{1}, + }, + { + name: "multiple keys", + genHashMap: func() *HashMap[testData, int] { + testHashMap := NewHashMap[testData, int](10) + for _, val := range []int{1, 2} { + err := testHashMap.Put(newTestData(val), val) + require.NoError(t, err) + } + return testHashMap + }, + wantKeys: []Hashable{newTestData(1), newTestData(2)}, + wantValues: []int{1, 2}, + }, + { + name: "same key", + genHashMap: func() *HashMap[testData, int] { + testHashMap := NewHashMap[testData, int](10) + err := testHashMap.Put(newTestData(1), 1) + require.NoError(t, err) + // 验证id相同,覆盖的场景 + err = testHashMap.Put(newTestData(1), 11) + require.NoError(t, err) + return testHashMap + }, + wantKeys: []Hashable{newTestData(1)}, + wantValues: []int{11}, + }, + { + name: "multi with same key", + genHashMap: func() *HashMap[testData, int] { + testHashMap := NewHashMap[testData, int](10) + for _, val := range []int{1, 2} { + // val为1、2 + err := testHashMap.Put(newTestData(val), val*10) + require.NoError(t, err) + } + err := testHashMap.Put(newTestData(1), 11) + require.NoError(t, err) + return testHashMap + }, + wantKeys: []Hashable{newTestData(1), newTestData(2)}, + wantValues: []int{11, 20}, + }, + { + name: "single key collision", + genHashMap: func() *HashMap[testData, int] { + testHashMap := NewHashMap[testData, int](10) + err := testHashMap.Put(newTestData(1), 11) + require.NoError(t, err) + // 验证id不同,但是code一致,进入同一个bucket中,会取出bucket中所有的value + err = testHashMap.Put(newTestData(11), 111) + require.NoError(t, err) + err = testHashMap.Put(newTestData(111), 1111) + require.NoError(t, err) + return testHashMap + }, + wantKeys: []Hashable{newTestData(1), newTestData(11), newTestData(111)}, + wantValues: []int{11, 1111, 111}, + }, + { + name: "multiple keys collision", + genHashMap: func() *HashMap[testData, int] { + testHashMap := NewHashMap[testData, int](10) + for _, val := range []int{1, 2} { + err := testHashMap.Put(newTestData(val), val) + require.NoError(t, err) + err = testHashMap.Put(newTestData(val*10+val), val*10) + require.NoError(t, err) + } + return testHashMap + }, + wantKeys: []Hashable{newTestData(1), newTestData(11), newTestData(22), newTestData(2)}, + wantValues: []int{1, 10, 2, 20}, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + actualKeys := tc.genHashMap().Keys() + actualValues := tc.genHashMap().Values() + // 断言key的数量一致 + assert.Equal(t, len(tc.wantKeys), len(actualKeys)) + // 断言value的数量一致 + assert.Equal(t, len(tc.wantValues), len(actualValues)) + // 断言keys的元素一致 + assert.ElementsMatch(t, tc.wantKeys, actualKeys) + // 断言value的元素一致 + assert.ElementsMatch(t, tc.wantValues, actualValues) + }) + } +} + type testData struct { id int } @@ -355,10 +480,7 @@ func (t testData) Equals(key any) bool { if !ok { return false } - if t.id != val.id { - return false - } - return true + return t.id == val.id } func newTestData(id int) testData { From 8160bae2ada5cfef556505767f55dd5ea32ff38a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=A8ZC?= <120188372+yzicheng@users.noreply.github.com> Date: Tue, 17 Jan 2023 19:46:27 +0800 Subject: [PATCH 8/9] =?UTF-8?q?=E5=9F=BA=E4=BA=8E=E7=BA=A2=E9=BB=91?= =?UTF-8?q?=E6=A0=91=E7=9A=84=20TreeMap=20=E5=AE=9E=E7=8E=B0=20(#142)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .CHANGELOG.md | 1 + internal/tree/red_black_tree.go | 511 ++++++++++ internal/tree/red_black_tree_test.go | 1400 ++++++++++++++++++++++++++ mapx/treemap.go | 85 ++ mapx/treemap_test.go | 383 +++++++ 5 files changed, 2380 insertions(+) create mode 100644 internal/tree/red_black_tree.go create mode 100644 internal/tree/red_black_tree_test.go create mode 100644 mapx/treemap.go create mode 100644 mapx/treemap_test.go diff --git a/.CHANGELOG.md b/.CHANGELOG.md index 95fd24ee..0f66fa12 100644 --- a/.CHANGELOG.md +++ b/.CHANGELOG.md @@ -5,6 +5,7 @@ - [ekit: 修改代码风格,增加bool类型支持](https://github.com/gotomicro/ekit/pull/135) - [mapx: hashmap添加刪除功能](https://github.com/gotomicro/ekit/pull/138) - [mapx: HashMap 增加 Keys 和 Values 方法](https://github.com/gotomicro/ekit/pull/141) +- [mapx: TreeMap](https://github.com/gotomicro/ekit/pull/142) # v0.0.5 - [atomicx: 泛型封装 atomic.Value](https://github.com/gotomicro/ekit/pull/101) diff --git a/internal/tree/red_black_tree.go b/internal/tree/red_black_tree.go new file mode 100644 index 00000000..45517f74 --- /dev/null +++ b/internal/tree/red_black_tree.go @@ -0,0 +1,511 @@ +// Copyright 2021 gotomicro +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tree + +import ( + "errors" + + "github.com/gotomicro/ekit" +) + +type color bool + +const ( + Red color = false + Black color = true +) + +var ( + ErrRBTreeSameRBNode = errors.New("ekit: RBTree不能添加重复节点Key") + ErrRBTreeNotRBNode = errors.New("ekit: RBTree不存在节点Key") + // errRBTreeCantRepaceNil = errors.New("ekit: RBTree不能将节点替换为nil") +) + +type RBTree[K any, V any] struct { + root *rbNode[K, V] + compare ekit.Comparator[K] + size int +} + +func (rb *RBTree[K, V]) Size() int { + if rb == nil { + return 0 + } + return rb.size +} + +type rbNode[K any, V any] struct { + color color + key K + value V + left, right, parent *rbNode[K, V] +} + +func (node *rbNode[K, V]) setNode(v V) { + if node == nil { + return + } + node.value = v +} + +// NewRBTree 构建红黑树 +func NewRBTree[K any, V any](compare ekit.Comparator[K]) *RBTree[K, V] { + return &RBTree[K, V]{ + compare: compare, + root: nil, + } +} + +func newRBNode[K any, V any](key K, value V) *rbNode[K, V] { + return &rbNode[K, V]{ + key: key, + value: value, + color: Red, + left: nil, + right: nil, + parent: nil, + } +} + +// Add 增加节点 +func (rb *RBTree[K, V]) Add(key K, value V) error { + return rb.addNode(newRBNode(key, value)) +} + +// Delete 删除节点 +func (rb *RBTree[K, V]) Delete(key K) { + if node := rb.findNode(key); node != nil { + rb.deleteNode(node) + } +} + +// Find 查找节点 +func (rb *RBTree[K, V]) Find(key K) (V, error) { + var v V + if node := rb.findNode(key); node != nil { + return node.value, nil + } + return v, ErrRBTreeNotRBNode +} +func (rb *RBTree[K, V]) Set(key K, value V) error { + if node := rb.findNode(key); node != nil { + node.setNode(value) + return nil + } + return ErrRBTreeNotRBNode +} + +// addNode 插入新节点 +func (rb *RBTree[K, V]) addNode(node *rbNode[K, V]) error { + var fixNode *rbNode[K, V] + if rb.root == nil { + rb.root = newRBNode[K, V](node.key, node.value) + fixNode = rb.root + } else { + t := rb.root + cmp := 0 + parent := &rbNode[K, V]{} + for t != nil { + parent = t + cmp = rb.compare(node.key, t.key) + if cmp < 0 { + t = t.left + } else if cmp > 0 { + t = t.right + } else if cmp == 0 { + return ErrRBTreeSameRBNode + } + } + fixNode = &rbNode[K, V]{ + key: node.key, + parent: parent, + value: node.value, + color: Red, + } + if cmp < 0 { + parent.left = fixNode + } else { + parent.right = fixNode + } + } + rb.size++ + rb.fixAfterAdd(fixNode) + return nil +} + +// deleteNode 红黑树删除方法 +// 删除分两步,第一步取出后继节点,第二部着色旋转 +// 取后继节点 +// case1:node左右非空子节点,通过getSuccessor获取后继节点 +// case2:node左右只有一个非空子节点 +// case3:node左右均为空节点 +// 着色旋转 +// case1:当删除节点非空且为黑色时,会违反红黑树任何路径黑节点个数相同的约束,所以需要重新平衡 +// case2:当删除红色节点时,不会破坏任何约束,所以不需要平衡 +func (rb *RBTree[K, V]) deleteNode(node *rbNode[K, V]) { + // node左右非空,取后继节点 + if node.left != nil && node.right != nil { + s := rb.findSuccessor(node) + node.key = s.key + node.value = s.value + node = s + } + var replacement *rbNode[K, V] + // node节点只有一个非空子节点 + if node.left != nil { + replacement = node.left + } else { + replacement = node.right + } + if replacement != nil { + replacement.parent = node.parent + if node.parent == nil { + rb.root = replacement + } else if node == node.parent.left { + node.parent.left = replacement + } else { + node.parent.right = replacement + } + node.left = nil + node.right = nil + node.parent = nil + if node.getColor() { + rb.fixAfterDelete(replacement) + } + } else if node.parent == nil { + // 如果node节点无父节点,说明node为root节点 + rb.root = nil + } else { + // node子节点均为空 + if node.getColor() { + rb.fixAfterDelete(node) + } + if node.parent != nil { + if node == node.parent.left { + node.parent.left = nil + } else if node == node.parent.right { + node.parent.right = nil + } + node.parent = nil + } + } + rb.size-- +} + +// findSuccessor 寻找后继节点 +// case1: node节点存在右子节点,则右子树的最小节点是node的后继节点 +// case2: node节点不存在右子节点,则其第一个为左节点的祖先的父节点为node的后继节点 +func (rb *RBTree[K, V]) findSuccessor(node *rbNode[K, V]) *rbNode[K, V] { + if node == nil { + return nil + } else if node.right != nil { + p := node.right + for p.left != nil { + p = p.left + } + return p + } else { + p := node.parent + ch := node + for p != nil && ch == p.right { + ch = p + p = p.parent + } + return p + } + +} + +func (rb *RBTree[K, V]) findNode(key K) *rbNode[K, V] { + node := rb.root + for node != nil { + cmp := rb.compare(key, node.key) + if cmp < 0 { + node = node.left + } else if cmp > 0 { + node = node.right + } else { + return node + } + } + return nil +} + +// fixAfterAdd 插入时着色旋转 +// 如果是空节点、root节点、父节点是黑无需构建 +// 可分为3种情况 +// fixUncleRed 叔叔节点是红色右节点 +// fixAddLeftBlack 叔叔节点是黑色右节点 +// fixAddRightBlack 叔叔节点是黑色左节点 +func (rb *RBTree[K, V]) fixAfterAdd(x *rbNode[K, V]) { + x.color = Red + for x != nil && x != rb.root && x.getParent().getColor() == Red { + uncle := x.getUncle() + if uncle.getColor() == Red { + x = rb.fixUncleRed(x, uncle) + continue + } + if x.getParent() == x.getGrandParent().getLeft() { + x = rb.fixAddLeftBlack(x) + continue + } + x = rb.fixAddRightBlack(x) + } + rb.root.setColor(Black) +} + +// fixAddLeftRed 叔叔节点是红色右节点,由于不能存在连续红色节点,此时祖父节点x.getParent().getParent()必为黑。另x为红所以叔父节点需要变黑,祖父变红,此时红黑树完成 +// +// b(b) b(r) +// / \ / \ +// a(r) y(r) -> a(b) y(b) +// / \ / \ / \ / \ +// x(r) nil nil nil x (r) nil nil nil +// / \ / \ +// nil nil nil nil +func (rb *RBTree[K, V]) fixUncleRed(x *rbNode[K, V], y *rbNode[K, V]) *rbNode[K, V] { + x.getParent().setColor(Black) + y.setColor(Black) + x.getGrandParent().setColor(Red) + x = x.getGrandParent() + return x +} + +// fixAddLeftBlack 叔叔节点是黑色右节点.x节点是父节点左节点,执行左旋,此时x节点变为原x节点的父节点a,也就是左子节点。的接着将x的父节点和爷爷节点的颜色对换。然后对爷爷节点进行右旋转,此时红黑树完成 +// 如果x为左节点则跳过左旋操作 +// +// b(b) b(b) b(r) +// / \ / \ / \ +// a(r) y(b) -> a(r) y(b) -> a(b) y(b) +// / \ / \ / \ / \ / \ / \ +// nil x (r) nil nil x(r) nil nil nil x(r) nil nil nil +// / \ / \ / \ +// nil nil nil nil nil nil +func (rb *RBTree[K, V]) fixAddLeftBlack(x *rbNode[K, V]) *rbNode[K, V] { + if x == x.getParent().getRight() { + x = x.getParent() + rb.rotateLeft(x) + } + x.getParent().setColor(Black) + x.getGrandParent().setColor(Red) + rb.rotateRight(x.getGrandParent()) + return x +} + +// fixAddRightBlack 叔叔节点是黑色左节点.x节点是父节点右节点,执行右旋,此时x节点变为原x节点的父节点a,也就是右子节点。接着将x的父节点和爷爷节点的颜色对换。然后对爷爷节点进行右旋转,此时红黑树完成 +// 如果x为右节点则跳过右旋操作 +// +// b(b) b(b) b(r) +// / \ / \ / \ +// y(b) a(r) -> y(b) a(r) -> y(b) a(b) +// / \ / \ / \ / \ / \ / \ +// nil nil x(r) nil nil nil nil x(r) nil nil nil x(r) +// / \ / \ / \ +// nil nil nil nil nil nil +func (rb *RBTree[K, V]) fixAddRightBlack(x *rbNode[K, V]) *rbNode[K, V] { + if x == x.getParent().getLeft() { + x = x.getParent() + rb.rotateRight(x) + } + x.getParent().setColor(Black) + x.getGrandParent().setColor(Red) + rb.rotateLeft(x.getGrandParent()) + return x +} + +// fixAfterDelete 删除时着色旋转 +// 根据x是节点位置分为fixAfterDeleteLeft,fixAfterDeleteRight两种情况 +func (rb *RBTree[K, V]) fixAfterDelete(x *rbNode[K, V]) { + for x != rb.root && x.getColor() == Black { + if x == x.parent.getLeft() { + x = rb.fixAfterDeleteLeft(x) + } else { + x = rb.fixAfterDeleteRight(x) + } + } + x.setColor(Black) +} + +// fixAfterDeleteLeft 处理x为左子节点时的平衡处理 +func (rb *RBTree[K, V]) fixAfterDeleteLeft(x *rbNode[K, V]) *rbNode[K, V] { + sib := x.getParent().getRight() + if sib.getColor() == Red { + sib.setColor(Black) + sib.getParent().setColor(Red) + rb.rotateLeft(x.getParent()) + sib = x.getParent().getRight() + } + if sib.getLeft().getColor() == Black && sib.getRight().getColor() == Black { + sib.setColor(Red) + x = x.getParent() + } else { + if sib.getRight().getColor() == Black { + sib.getLeft().setColor(Black) + sib.setColor(Red) + rb.rotateRight(sib) + sib = x.getParent().getRight() + } + sib.setColor(x.getParent().getColor()) + x.getParent().setColor(Black) + sib.getRight().setColor(Black) + rb.rotateLeft(x.getParent()) + x = rb.root + } + return x +} + +// fixAfterDeleteRight 处理x为右子节点时的平衡处理 +func (rb *RBTree[K, V]) fixAfterDeleteRight(x *rbNode[K, V]) *rbNode[K, V] { + sib := x.getParent().getLeft() + if sib.getColor() == Red { + sib.setColor(Black) + x.getParent().setColor(Red) + rb.rotateRight(x.getParent()) + sib = x.getBrother() + } + if sib.getRight().getColor() == Black && sib.getLeft().getColor() == Black { + sib.setColor(Red) + x = x.getParent() + } else { + if sib.getLeft().getColor() == Black { + sib.getRight().setColor(Black) + sib.setColor(Red) + rb.rotateLeft(sib) + sib = x.getParent().getLeft() + } + sib.setColor(x.getParent().getColor()) + x.getParent().setColor(Black) + sib.getLeft().setColor(Black) + rb.rotateRight(x.getParent()) + x = rb.root + } + return x +} + +// rotateLeft 左旋转 +// +// b a +// / \ / \ +// c a -> b y +// / \ / \ +// x y c x + +func (rb *RBTree[K, V]) rotateLeft(node *rbNode[K, V]) { + if node == nil || node.getRight() == nil { + return + } + r := node.right + node.right = r.left + if r.left != nil { + r.left.parent = node + } + r.parent = node.parent + if node.parent == nil { + rb.root = r + } else if node.parent.left == node { + node.parent.left = r + } else { + node.parent.right = r + } + r.left = node + node.parent = r + +} + +// rotateRight 右旋转 +// +// b c +// / \ / \ +// c a -> x b +// / \ / \ +// x y y a +func (rb *RBTree[K, V]) rotateRight(node *rbNode[K, V]) { + if node == nil || node.getLeft() == nil { + return + } + l := node.left + node.left = l.right + if l.right != nil { + l.right.parent = node + } + l.parent = node.parent + if node.parent == nil { + rb.root = l + } else if node.parent.right == node { + node.parent.right = l + } else { + node.parent.left = l + } + l.right = node + node.parent = l + +} + +func (node *rbNode[K, V]) getColor() color { + if node == nil { + return Black + } + return node.color +} + +func (node *rbNode[K, V]) setColor(color color) { + if node == nil { + return + } + node.color = color +} + +func (node *rbNode[K, V]) getParent() *rbNode[K, V] { + if node == nil { + return nil + } + return node.parent +} + +func (node *rbNode[K, V]) getLeft() *rbNode[K, V] { + if node == nil { + return nil + } + return node.left +} + +func (node *rbNode[K, V]) getRight() *rbNode[K, V] { + if node == nil { + return nil + } + return node.right +} + +func (node *rbNode[K, V]) getUncle() *rbNode[K, V] { + if node == nil { + return nil + } + return node.getParent().getBrother() +} +func (node *rbNode[K, V]) getGrandParent() *rbNode[K, V] { + if node == nil { + return nil + } + return node.getParent().getParent() +} +func (node *rbNode[K, V]) getBrother() *rbNode[K, V] { + if node == nil { + return nil + } + if node == node.getParent().getLeft() { + return node.getParent().getRight() + } + return node.getParent().getLeft() +} diff --git a/internal/tree/red_black_tree_test.go b/internal/tree/red_black_tree_test.go new file mode 100644 index 00000000..96419f47 --- /dev/null +++ b/internal/tree/red_black_tree_test.go @@ -0,0 +1,1400 @@ +// Copyright 2021 gotomicro +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tree + +import ( + "errors" + "testing" + + "github.com/gotomicro/ekit" + "github.com/stretchr/testify/assert" +) + +func TestNewRBTree(t *testing.T) { + tests := []struct { + name string + compare ekit.Comparator[int] + wantV bool + }{ + { + name: "int", + compare: compare(), + wantV: true, + }, + { + name: "nil", + compare: nil, + wantV: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + redBlackTree := NewRBTree[int, string](compare()) + assert.Equal(t, tt.wantV, IsRedBlackTree[int, string](redBlackTree.root)) + }) + } +} + +func compare() ekit.Comparator[int] { + return ekit.ComparatorRealNumber[int] +} + +func TestRBTree_Add(t *testing.T) { + IsRedBlackTreeCase := []struct { + name string + node *rbNode[int, string] + want bool + }{ + { + name: "nil", + node: nil, + want: true, + }, + { + name: "node-nil", + node: nil, + want: true, + }, + { + name: "root", + node: &rbNode[int, string]{left: nil, right: nil, color: Black}, + want: true, + }, + { + name: "root", + node: &rbNode[int, string]{left: nil, right: nil, color: Red}, + want: false, + }, + // root(b) + // / + // a(b) + { + name: "root-oneChild", + node: &rbNode[int, string]{ + left: &rbNode[int, string]{ + right: nil, + left: nil, + color: Red, + }, + right: nil, + color: Black, + }, + want: true, + }, + // root(b) + // / \ + // a(r) b(b) + { + name: "root-twoChild", + node: &rbNode[int, string]{ + left: &rbNode[int, string]{ + right: nil, + left: nil, + color: Red, + }, + right: &rbNode[int, string]{ + right: nil, + left: nil, + color: Black, + }, + color: Black, + }, + want: false, + }, + // root(b) + // / \ + // a(b) b(b) + // / \ / \ + // nil c(r) d(r) nil + // / \ / \ + // nil nil nil nil + { + name: "blackNodeNotSame", + node: &rbNode[int, string]{ + left: &rbNode[int, string]{ + right: &rbNode[int, string]{ + right: nil, + left: nil, + color: Red, + }, + left: nil, + color: Black, + }, + right: &rbNode[int, string]{ + right: nil, + left: &rbNode[int, string]{ + right: nil, + left: nil, + color: Red, + }, + color: Black, + }, + color: Black, + }, + want: true, + }, + { + name: "root-grandson", + node: &rbNode[int, string]{ + parent: nil, + key: 7, + left: &rbNode[int, string]{ + key: 5, + color: Black, + left: &rbNode[int, string]{ + key: 4, + color: Red, + }, + right: &rbNode[int, string]{ + key: 6, + color: Red, + }, + }, + right: &rbNode[int, string]{ + key: 10, + color: Red, + left: &rbNode[int, string]{ + key: 9, + color: Black, + left: &rbNode[int, string]{ + key: 8, + color: Red, + }, + }, + right: &rbNode[int, string]{ + key: 12, + color: Black, + left: &rbNode[int, string]{ + key: 11, + color: Red, + }, + }, + }, + color: Black, + }, + want: true, + }, + } + for _, tt := range IsRedBlackTreeCase { + t.Run(tt.name, func(t *testing.T) { + res := IsRedBlackTree[int](tt.node) + assert.Equal(t, tt.want, res) + + }) + } + tests := []struct { + name string + k []int + want bool + wantErr error + size int + wantKey int + }{ + { + name: "nil", + k: nil, + want: true, + size: 0, + }, + { + name: "one", + k: []int{1}, + want: true, + size: 1, + }, + { + name: "one", + k: []int{1, 2}, + want: true, + size: 2, + wantKey: 1, + }, + { + name: "normal", + k: []int{1, 2, 3, 4}, + want: true, + size: 4, + wantKey: 3, + }, + { + name: "same", + k: []int{0, 0, 1, 2, 2, 3}, + want: true, + size: 0, + wantErr: errors.New("ekit: RBTree不能添加重复节点Key"), + }, + { + name: "disorder", + k: []int{1, 2, 0, 3, 5, 4}, + want: true, + wantErr: nil, + size: 6, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + redBlackTree := NewRBTree[int, int](compare()) + for i := 0; i < len(tt.k); i++ { + err := redBlackTree.Add(tt.k[i], i) + if err != nil { + assert.Equal(t, tt.wantErr, err) + return + } + } + res := IsRedBlackTree[int, int](redBlackTree.root) + assert.Equal(t, tt.want, res) + assert.Equal(t, tt.size, redBlackTree.Size()) + }) + } +} + +func TestRBTree_Delete(t *testing.T) { + tcase := []struct { + name string + delKey int + key []int + want bool + size int + }{ + { + name: "nil", + delKey: 0, + key: nil, + want: true, + size: 0, + }, + { + name: "node-empty", + delKey: 0, + key: []int{4, 5, 6, 7, 8, 9, 10, 11, 12}, + want: true, + size: 9, + }, + { + name: "左右非空子节点,删除节点为黑色", + delKey: 11, + key: []int{4, 5, 6, 7, 8, 9, 10, 11, 12}, + want: true, + size: 8, + }, + { + name: "左右只有一个非空子节点,删除节点为黑色", + delKey: 11, + key: []int{4, 5, 6, 7, 8, 9, 11, 12}, + want: true, + size: 7, + }, + { + name: "左右均为空节点,删除节点为黑色", + delKey: 12, + key: []int{4, 5, 6, 7, 8, 9, 12}, + want: true, + size: 6, + }, { + name: "左右非空子节点,删除节点为红色", + delKey: 5, + key: []int{4, 5, 6, 7, 8, 9, 10, 11, 12}, + want: true, + size: 8, + }, + // 此状态无法构造出正确的红黑树 + // { + // name: "左右只有一个非空子节点,删除节点为红色", + // delKey: 5, + // key: []int{4, 5, 6, 7, 8, 9, 11, 12}, + // want: true, + // }, + } + for _, tt := range tcase { + t.Run(tt.name, func(t *testing.T) { + rbTree := NewRBTree[int, int](compare()) + for i := 0; i < len(tt.key); i++ { + err := rbTree.Add(tt.key[i], i) + if err != nil { + panic(err) + } + } + assert.Equal(t, tt.want, IsRedBlackTree[int](rbTree.root)) + rbTree.Delete(tt.delKey) + assert.Equal(t, tt.want, IsRedBlackTree[int](rbTree.root)) + assert.Equal(t, tt.size, rbTree.Size()) + }) + } +} + +func TestRBTree_Find(t *testing.T) { + tcase := []struct { + name string + find int + k []int + wantKey int + wantError error + }{ + { + name: "nil", + find: 0, + k: nil, + wantError: errors.New("未找到0节点"), + }, + { + name: "node-empty", + find: 0, + k: []int{4, 5, 6, 7, 8, 9, 10, 11, 12}, + wantError: errors.New("未找到0节点"), + }, + { + name: "find", + find: 11, + k: []int{4, 5, 6, 7, 8, 9, 10, 11, 12}, + wantKey: 11, + }, { + name: "find", + find: 12, + k: []int{4, 5, 6, 7, 8, 9, 10, 11, 12}, + wantKey: 12, + }, { + name: "find", + find: 7, + k: []int{4, 5, 6, 7, 8, 9, 10, 11, 12}, + wantKey: 7, + }, + } + for _, tt := range tcase { + t.Run(tt.name, func(t *testing.T) { + rbTree := NewRBTree[int, int](compare()) + for i := 0; i < len(tt.k); i++ { + err := rbTree.Add(tt.k[i], tt.k[i]) + if err != nil { + panic(err) + } + } + assert.Equal(t, true, IsRedBlackTree[int](rbTree.root)) + findNode, err := rbTree.Find(tt.find) + if err != nil { + assert.Equal(t, tt.wantError, errors.New("未找到0节点")) + } else { + assert.Equal(t, tt.find, findNode) + } + }) + } +} + +func TestRBTree_addNode(t *testing.T) { + tests := []struct { + name string + k []int + want bool + wantErr error + }{ + { + name: "nil", + k: nil, + want: true, + }, + { + name: "case1", + k: []int{1, 2, 3, 4}, + want: true, + }, + { + name: "same", + k: []int{0, 0, 1, 2, 2, 3}, + want: true, + wantErr: errors.New("ekit: RBTree不能添加重复节点Key"), + }, + { + name: "disorder", + k: []int{1, 2, 0, 3, 5, 4}, + want: true, + wantErr: nil, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + redBlackTree := NewRBTree[int, string](compare()) + for i := 0; i < len(tt.k); i++ { + err := redBlackTree.addNode(&rbNode[int, string]{ + key: tt.k[i], + }) + if err != nil { + assert.Equal(t, tt.wantErr, err) + } + } + res := IsRedBlackTree[int](redBlackTree.root) + assert.Equal(t, tt.want, res) + + }) + } +} + +func TestRBTree_deleteNode(t *testing.T) { + tcase := []struct { + name string + delKey int + key []int + want bool + wantError error + }{ + { + name: "nil", + delKey: 0, + key: nil, + wantError: errors.New("未找到节点0"), + }, + { + name: "node-empty", + delKey: 0, + key: []int{4, 5, 6, 7, 8, 9, 10, 11, 12}, + wantError: errors.New("未找到节点0"), + }, + { + name: "本身为右节点,左右非空子节点,删除节点为黑色", + delKey: 11, + key: []int{2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, + want: true, + }, + { + name: "本身为右节点,右只有一个非空子节点,删除节点为黑色", + delKey: 11, + key: []int{2, 3, 4, 5, 6, 7, 8, 9, 11, 12}, + want: true, + }, + { + name: "本身为右节点,左右均为空节点,删除节点为黑色", + delKey: 6, + key: []int{2, 3, 4, 5, 6, 7, 8, 9, 11, 12}, + want: true, + }, + { + name: "本身为左节点,左右非空子节点,删除节点为黑色", + delKey: 3, + key: []int{2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, + want: true, + }, + { + name: "本身为左节点,左右只有一个非空子节点,删除节点为黑色", + delKey: 3, + key: []int{2, 3, 5, 6, 7, 8, 9, 11, 12}, + want: true, + }, + { + name: "本身为左节点,左右均为空节点,删除节点为黑色", + delKey: 8, + key: []int{2, 3, 5, 6, 7, 8, 9, 11, 12}, + want: true, + }, + { + name: "本身是左节点,只有左边子节点,删除节点为黑色", + delKey: 3, + key: []int{5, 3, 4, 6, 2}, + want: true, + }, + // name: "本身为左节点,左右只有一个非空子节点,删除节点为红色(无法正确构造)" + { + name: "本身为左节点,左右均为空节点,删除节点为红色", + delKey: 2, + key: []int{2, 3, 5, 6, 7, 8, 9, 11, 12}, + want: true, + }, + { + name: "删除root节点", + delKey: 7, + key: []int{4, 5, 6, 7, 8, 9, 10, 11, 12}, + want: true, + }, + { + name: "删除root节点", + delKey: 7, + key: []int{7}, + want: true, + }, + { + name: "root", + delKey: 2, + key: []int{2, 1}, + want: true, + }, + } + for _, tt := range tcase { + t.Run(tt.name, func(t *testing.T) { + rbTree := NewRBTree[int, int](compare()) + for i := 0; i < len(tt.key); i++ { + err := rbTree.Add(tt.key[i], i) + if err != nil { + panic(err) + } + } + delNode := rbTree.findNode(tt.delKey) + if delNode == nil { + assert.Equal(t, tt.wantError, errors.New("未找到节点0")) + } else { + rbTree.deleteNode(delNode) + assert.Equal(t, tt.want, IsRedBlackTree[int](rbTree.root)) + } + }) + } +} + +func TestRBTree_findNode(t *testing.T) { + tcase := []struct { + name string + findKey int + key []int + wantKey int + wantError error + }{ + { + name: "nil", + findKey: 0, + key: nil, + wantError: errors.New("未找到0节点"), + }, + { + name: "node-empty", + findKey: 0, + key: []int{4, 5, 6, 7, 8, 9, 10, 11, 12}, + wantError: errors.New("未找到0节点"), + }, + { + name: "find", + findKey: 11, + key: []int{4, 5, 6, 7, 8, 9, 10, 11, 12}, + wantKey: 11, + }, { + name: "find", + findKey: 12, + key: []int{4, 5, 6, 7, 8, 9, 10, 11, 12}, + wantKey: 12, + }, { + name: "find", + findKey: 7, + key: []int{4, 5, 6, 7, 8, 9, 10, 11, 12}, + wantKey: 7, + }, + } + for _, tt := range tcase { + t.Run(tt.name, func(t *testing.T) { + rbTree := NewRBTree[int, int](compare()) + for i := 0; i < len(tt.key); i++ { + err := rbTree.Add(tt.key[i], i) + if err != nil { + panic(err) + } + } + assert.Equal(t, true, IsRedBlackTree[int](rbTree.root)) + findNode := rbTree.findNode(tt.findKey) + if findNode == nil { + assert.Equal(t, tt.wantError, errors.New("未找到0节点")) + } else { + assert.Equal(t, tt.wantKey, findNode.key) + } + }) + } +} + +func TestRBTree_rotateLeft(t *testing.T) { + tcase := []struct { + name string + key []int + wantParent int + wantLeftChild int + wantRightChild int + rotaNode int + }{ + { + name: "only-root", + key: []int{1}, + wantParent: 1, + rotaNode: 1, + }, + { + name: "节点有2个子节点,并且本身是右节点", + key: []int{4, 5, 6, 7, 8, 9, 10, 11, 12}, + rotaNode: 9, + wantParent: 11, + wantLeftChild: 8, + wantRightChild: 10, + }, + { + name: "节点有2个子节点,并且本身是左节点", + key: []int{4, 5, 6, 7, 8, 9, 10, 11, 12}, + rotaNode: 5, + wantParent: 6, + wantLeftChild: 4, + }, + { + name: "节点有1个子节点", + key: []int{1, 2, 3, 4}, + rotaNode: 2, + wantParent: 3, + wantLeftChild: 1, + }, + { + name: "节点没有子节点", + key: []int{1, 2, 3}, + rotaNode: 2, + wantParent: 3, + wantLeftChild: 1, + }, + } + for _, tt := range tcase { + t.Run(tt.name, func(t *testing.T) { + rbTree := NewRBTree[int, int](compare()) + for i := 0; i < len(tt.key); i++ { + err := rbTree.Add(tt.key[i], i) + if err != nil { + panic(err) + } + } + rotaNode := rbTree.findNode(tt.rotaNode) + rbTree.rotateLeft(rotaNode) + if rotaNode.getParent() != nil { + assert.Equal(t, tt.wantParent, rotaNode.getParent().key) + if rotaNode.getLeft() != nil { + assert.Equal(t, tt.wantLeftChild, rotaNode.getLeft().key) + } + if rotaNode.getRight() != nil { + assert.Equal(t, tt.wantRightChild, rotaNode.getRight().key) + } + } + }) + } +} + +func TestRBTree_rotateRight(t *testing.T) { + tcase := []struct { + name string + key []int + wantParent int + wantLeftChild int + wantRightChild int + rotaNode int + }{ + { + name: "only-root", + key: []int{1}, + wantParent: 1, + rotaNode: 1, + }, + { + name: "节点有2个子节点,并且本身是右节点", + key: []int{4, 5, 6, 7, 8, 9, 10, 11, 12}, + rotaNode: 9, + wantParent: 8, + wantRightChild: 11, + }, + { + name: "节点有2个子节点,并且本身是左节点", + key: []int{4, 5, 6, 7, 8, 9, 10, 11, 12}, + rotaNode: 5, + wantParent: 4, + wantRightChild: 6, + }, + { + name: "有一个子节点", + key: []int{4, 5, 3, 2}, + rotaNode: 4, + wantParent: 3, + wantRightChild: 5, + }, + { + name: "没有子节点", + key: []int{4, 5, 3}, + rotaNode: 4, + wantParent: 3, + wantRightChild: 5, + }, + } + for _, tt := range tcase { + t.Run(tt.name, func(t *testing.T) { + rbTree := NewRBTree[int, int](compare()) + for i := 0; i < len(tt.key); i++ { + err := rbTree.Add(tt.key[i], i) + if err != nil { + panic(err) + } + } + rotaNode := rbTree.findNode(tt.rotaNode) + rbTree.rotateRight(rotaNode) + if rotaNode.getParent() != nil { + assert.Equal(t, tt.wantParent, rotaNode.getParent().key) + if rotaNode.getLeft() != nil { + assert.Equal(t, tt.wantLeftChild, rotaNode.getLeft().key) + } + if rotaNode.getRight() != nil { + assert.Equal(t, tt.wantRightChild, rotaNode.getRight().key) + } + } + }) + } +} + +func TestRBNode_getColor(t *testing.T) { + tcase := []struct { + name string + node *rbNode[int, int] + wantColor color + }{ + { + name: "nod-nil", + node: nil, + wantColor: true, + }, + { + name: "new node", + node: newRBNode[int, int](1, 1), + wantColor: false, + }, + } + for _, tt := range tcase { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.wantColor, tt.node.getColor()) + }) + } +} + +func TestRBNode_getLeft(t *testing.T) { + tcase := []struct { + name string + node *rbNode[int, int] + wantNode *rbNode[int, int] + }{ + { + name: "nod-nil", + node: nil, + wantNode: nil, + }, + { + name: "new node", + node: newRBNode[int, int](1, 1), + wantNode: nil, + }, + { + name: "new node have left-child", + node: &rbNode[int, int]{ + key: 2, + left: &rbNode[int, int]{ + key: 1, + }, + }, + wantNode: &rbNode[int, int]{ + key: 1, + }, + }, + } + for _, tt := range tcase { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.wantNode, tt.node.getLeft()) + }) + } +} + +func TestRBNode_getRight(t *testing.T) { + tcase := []struct { + name string + node *rbNode[int, int] + wantNode *rbNode[int, int] + }{ + { + name: "nod-nil", + node: nil, + wantNode: nil, + }, + { + name: "new node", + node: newRBNode[int, int](1, 1), + wantNode: nil, + }, + { + name: "new node have right-child", + node: &rbNode[int, int]{ + key: 1, + right: &rbNode[int, int]{ + key: 2, + }, + }, + wantNode: &rbNode[int, int]{ + key: 2, + }, + }, + } + for _, tt := range tcase { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.wantNode, tt.node.getRight()) + }) + } +} + +func TestRBNode_getParent(t *testing.T) { + tcase := []struct { + name string + node *rbNode[int, int] + wantNode *rbNode[int, int] + }{ + { + name: "nod-nil", + node: nil, + wantNode: nil, + }, + { + name: "new node", + node: newRBNode[int, int](1, 1), + wantNode: nil, + }, + { + name: "new node have parent", + node: &rbNode[int, int]{ + key: 2, + parent: &rbNode[int, int]{ + key: 3, + }, + }, + wantNode: &rbNode[int, int]{ + key: 3, + }, + }, + } + for _, tt := range tcase { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.wantNode, tt.node.getParent()) + }) + } +} + +func TestRBNode_setColor(t *testing.T) { + tcase := []struct { + name string + node *rbNode[int, int] + color color + wantColor color + }{ + { + name: "nod-nil", + node: nil, + color: false, + wantColor: Black, + }, + { + name: "new node", + node: newRBNode[int, int](1, 1), + color: true, + wantColor: Black, + }, + } + for _, tt := range tcase { + t.Run(tt.name, func(t *testing.T) { + tt.node.setColor(tt.color) + assert.Equal(t, tt.wantColor, tt.node.getColor()) + }) + } +} + +func TestNewRBNode(t *testing.T) { + tcase := []struct { + name string + key int + value int + wantNode *rbNode[int, int] + }{ + { + name: "new node", + key: 1, + value: 1, + wantNode: &rbNode[int, int]{ + key: 1, + value: 1, + left: nil, + right: nil, + parent: nil, + color: Red, + }, + }, + } + for _, tt := range tcase { + t.Run(tt.name, func(t *testing.T) { + node := newRBNode[int, int](tt.key, tt.value) + assert.Equal(t, tt.wantNode, node) + }) + } +} + +func TestRBNode_getBrother(t *testing.T) { + tests := []struct { + name string + k []int + nodeKye int + want int + }{ + { + name: "nil", + k: nil, + }, + { + name: "no-brother", + nodeKye: 1, + k: []int{1}, + }, + { + name: "no-brother", + nodeKye: 1, + k: []int{1, 2}, + }, + { + name: "have brother", + k: []int{1, 2, 3, 4}, + nodeKye: 1, + want: 3, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + redBlackTree := NewRBTree[int, int](compare()) + for i := 0; i < len(tt.k); i++ { + err := redBlackTree.Add(tt.k[i], i) + if err != nil { + panic(err) + } + } + tagNode := redBlackTree.findNode(tt.nodeKye) + brNode := tagNode.getBrother() + if brNode == nil { + return + } + assert.Equal(t, tt.want, brNode.key) + + }) + } +} + +func TestRBNode_getGrandParent(t *testing.T) { + tests := []struct { + name string + k []int + nodeKye int + want int + }{ + { + name: "nil", + k: nil, + }, + { + name: "no-grandpa", + nodeKye: 1, + k: []int{1}, + }, + { + name: "no-grandpa", + nodeKye: 1, + k: []int{1, 2}, + }, + { + name: "have grandpa", + k: []int{1, 2, 3, 4}, + nodeKye: 4, + want: 2, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + redBlackTree := NewRBTree[int, int](compare()) + for i := 0; i < len(tt.k); i++ { + err := redBlackTree.Add(tt.k[i], i) + if err != nil { + panic(err) + } + } + tagNode := redBlackTree.findNode(tt.nodeKye) + brNode := tagNode.getGrandParent() + if brNode == nil { + return + } + assert.Equal(t, tt.want, brNode.key) + + }) + } +} + +func TestRBNode_getUncle(t *testing.T) { + tests := []struct { + name string + k []int + nodeKye int + want int + }{ + { + name: "nil", + k: nil, + }, + { + name: "no-uncle", + nodeKye: 1, + k: []int{1}, + }, + { + name: "no-uncle", + nodeKye: 1, + k: []int{1, 2}, + }, + { + name: "have uncle", + k: []int{1, 2, 3, 4}, + nodeKye: 4, + want: 1, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + redBlackTree := NewRBTree[int, int](compare()) + for i := 0; i < len(tt.k); i++ { + err := redBlackTree.Add(tt.k[i], i) + if err != nil { + panic(err) + } + } + tagNode := redBlackTree.findNode(tt.nodeKye) + brNode := tagNode.getUncle() + if brNode == nil { + return + } + assert.Equal(t, tt.want, brNode.key) + + }) + } +} + +func TestRBNode_set(t *testing.T) { + tcase := []struct { + name string + node *rbNode[int, int] + value int + wantNode *rbNode[int, int] + }{ + { + name: "nil", + node: nil, + value: 1, + wantNode: nil, + }, + { + name: "new node", + node: &rbNode[int, int]{ + key: 1, + value: 0, + left: nil, + right: nil, + parent: nil, + color: Red, + }, + value: 1, + wantNode: &rbNode[int, int]{ + key: 1, + value: 1, + left: nil, + right: nil, + parent: nil, + color: Red, + }, + }, + } + for _, tt := range tcase { + t.Run(tt.name, func(t *testing.T) { + tt.node.setNode(tt.value) + assert.Equal(t, tt.wantNode, tt.node) + }) + } +} + +func TestRBTree_findSuccessor(t *testing.T) { + tests := []struct { + name string + k []int + successor int + wantKey int + }{ + { + name: "nil-successor", + k: nil, + successor: 8, + }, + { + name: "have no successor", + k: []int{2}, + successor: 2, + }, + { + name: "have right successor", + k: []int{5, 4, 6, 3, 2}, + successor: 3, + wantKey: 4, + }, + { + name: "have right successor", + k: []int{5, 4, 7, 6, 3, 2}, + successor: 5, + wantKey: 6, + }, + { + name: "have no-right successor", + k: []int{5, 4, 6, 3, 2}, + successor: 4, + wantKey: 5, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + redBlackTree := NewRBTree[int, int](compare()) + for i := 0; i < len(tt.k); i++ { + err := redBlackTree.Add(tt.k[i], i) + if err != nil { + return + } + } + tagNode := redBlackTree.findNode(tt.successor) + successorNode := redBlackTree.findSuccessor(tagNode) + if successorNode == nil { + return + } + assert.Equal(t, tt.wantKey, successorNode.key) + }) + } +} + +func TestRBTree_fixAddLeftBlack(t *testing.T) { + tests := []struct { + name string + k []int + addNode int + want int + }{ + { + name: "nod is right", + k: []int{2, 1, 3}, + addNode: 3, + want: 2, + }, + { + name: "node is left", + k: []int{2, 1}, + addNode: 1, + want: 1, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + redBlackTree := NewRBTree[int, int](compare()) + for i := 0; i < len(tt.k); i++ { + err := redBlackTree.Add(tt.k[i], i) + if err != nil { + return + } + } + node := redBlackTree.findNode(tt.addNode) + x := redBlackTree.fixAddLeftBlack(node) + assert.Equal(t, tt.want, x.key) + }) + + } +} + +func TestRBTree_fixAddRightBlack(t *testing.T) { + tests := []struct { + name string + k []int + addNode int + want int + }{ + { + name: "nod is left", + k: []int{2, 1}, + addNode: 1, + want: 2, + }, + { + name: "node is right", + k: []int{2, 1, 3}, + addNode: 3, + want: 3, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + redBlackTree := NewRBTree[int, int](compare()) + for i := 0; i < len(tt.k); i++ { + err := redBlackTree.Add(tt.k[i], i) + if err != nil { + return + } + } + node := redBlackTree.findNode(tt.addNode) + x := redBlackTree.fixAddRightBlack(node) + assert.Equal(t, tt.want, x.key) + }) + + } +} + +func TestRBTree_fixAfterDeleteLeft(t *testing.T) { + tcase := []struct { + name string + delKey int + key []int + want int + wantError error + }{ + { + name: "兄弟节点是红色并且兄弟左节点左侧是黑色", + delKey: 10, + key: []int{2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, + want: 11, + }, + { + name: "兄弟节点是黑色,兄弟节点左测是黑色", + delKey: 2, + key: []int{2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, + want: 3, + }, + { + name: "兄弟节点是黑色,兄弟节点左测不是黑色", + delKey: 8, + key: []int{2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, + want: 5, + }, + { + name: "兄弟节点是红色,兄弟节点左测是黑色", + delKey: 1, + key: []int{2, 1, 3}, + want: 2, + }, + { + name: "节点左旋之后兄弟节点是红色", + delKey: 21, + key: []int{15, 20, 10, 16, 21, 8, 14, 7}, + want: 15, + }, + } + for _, tt := range tcase { + t.Run(tt.name, func(t *testing.T) { + rbTree := NewRBTree[int, int](compare()) + for i := 0; i < len(tt.key); i++ { + err := rbTree.Add(tt.key[i], i) + if err != nil { + panic(err) + } + + } + delNode := rbTree.findNode(tt.delKey) + if delNode == nil { + assert.Equal(t, tt.wantError, errors.New("未找到节点0")) + } else { + x := rbTree.fixAfterDeleteLeft(delNode) + assert.Equal(t, tt.want, x.key) + } + }) + } +} + +func TestRBTree_fixAfterDeleteRight(t *testing.T) { + tcase := []struct { + name string + delKey int + key []int + want int + wantError error + }{ + { + name: "兄弟节点是红色并且兄弟左节点左侧是黑色", + delKey: 12, + key: []int{2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, + want: 11, + }, + { + name: "兄弟节点是黑色,兄弟节点左测是黑色", + delKey: 11, + key: []int{2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 1, 0}, + want: 9, + }, + { + name: "兄弟节点是黑色,兄弟节点左测不是黑色", + delKey: 4, + key: []int{2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 1, 0}, + want: 5, + }, + { + name: "兄弟节点是红色,兄弟节点左测是黑色", + delKey: 3, + key: []int{2, 1, 3}, + want: 2, + }, + } + for _, tt := range tcase { + t.Run(tt.name, func(t *testing.T) { + rbTree := NewRBTree[int, int](compare()) + for i := 0; i < len(tt.key); i++ { + err := rbTree.Add(tt.key[i], i) + if err != nil { + panic(err) + } + + } + delNode := rbTree.findNode(tt.delKey) + if delNode == nil { + assert.Equal(t, tt.wantError, errors.New("未找到节点0")) + } else { + x := rbTree.fixAfterDeleteRight(delNode) + assert.Equal(t, tt.want, x.key) + } + }) + } +} + +// IsRedBlackTree 检测是否满足红黑树 +func IsRedBlackTree[K any, V any](root *rbNode[K, V]) bool { + // 检测节点是否黑色 + if !root.getColor() { + return false + } + // count 取最左树的黑色节点作为对照 + count := 0 + num := 0 + node := root + for node != nil { + if node.getColor() { + count++ + } + node = node.getLeft() + } + return nodeCheck[K](root, count, num) +} + +// nodeCheck 节点检测 +// 1、是否有连续的红色节点 +// 2、每条路径的黑色节点是否一致 +func nodeCheck[K any, V any](node *rbNode[K, V], count int, num int) bool { + if node == nil { + return true + } + if !node.getColor() && !node.parent.getColor() { + return false + } + if node.getColor() { + num++ + } + if node.getLeft() == nil && node.getRight() == nil { + if num != count { + return false + } + } + return nodeCheck(node.left, count, num) && nodeCheck(node.right, count, num) +} diff --git a/mapx/treemap.go b/mapx/treemap.go new file mode 100644 index 00000000..91ed6018 --- /dev/null +++ b/mapx/treemap.go @@ -0,0 +1,85 @@ +// Copyright 2021 gotomicro +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mapx + +import ( + "errors" + + "github.com/gotomicro/ekit" + "github.com/gotomicro/ekit/internal/tree" +) + +var ( + errTreeMapComparatorIsNull = errors.New("ekit: Comparator不能为nil") +) + +// TreeMap 是基于红黑树实现的Map +type TreeMap[K any, V any] struct { + *tree.RBTree[K, V] +} + +// NewTreeMapWithMap TreeMap构造方法 +// 支持通过传入的map构造生成TreeMap +func NewTreeMapWithMap[K comparable, V any](compare ekit.Comparator[K], m map[K]V) (*TreeMap[K, V], error) { + treeMap, err := NewTreeMap[K, V](compare) + if err != nil { + return treeMap, err + } + putAll(treeMap, m) + return treeMap, nil +} + +// NewTreeMap TreeMap构造方法,创建一个的TreeMap +// 需注意比较器compare不能为nil +func NewTreeMap[K any, V any](compare ekit.Comparator[K]) (*TreeMap[K, V], error) { + if compare == nil { + return nil, errTreeMapComparatorIsNull + } + return &TreeMap[K, V]{ + RBTree: tree.NewRBTree[K, V](compare), + }, nil +} + +// putAll 将map传入TreeMap +// 需注意如果map中的key已存在,value将被替换 +func putAll[K comparable, V any](treeMap *TreeMap[K, V], m map[K]V) { + for k, v := range m { + _ = treeMap.Put(k, v) + } +} + +// Put 在TreeMap插入指定值 +// 需注意如果TreeMap已存在该Key那么原值会被替换 +func (treeMap *TreeMap[K, V]) Put(key K, value V) error { + err := treeMap.Add(key, value) + if err == tree.ErrRBTreeSameRBNode { + return treeMap.Set(key, value) + } + return nil +} + +// Get 在TreeMap找到指定Key的节点,返回Val +// TreeMap未找到指定节点将会返回false +func (treeMap *TreeMap[K, V]) Get(key K) (V, bool) { + v, err := treeMap.Find(key) + return v, err == nil +} + +// Remove TreeMap中删除指定key的节点 +func (treeMap *TreeMap[T, V]) Remove(k T) { + treeMap.Delete(k) +} + +var _ mapi[any, any] = (*TreeMap[any, any])(nil) diff --git a/mapx/treemap_test.go b/mapx/treemap_test.go new file mode 100644 index 00000000..b6fa6a25 --- /dev/null +++ b/mapx/treemap_test.go @@ -0,0 +1,383 @@ +// Copyright 2021 gotomicro +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mapx + +import ( + "errors" + "testing" + + "github.com/gotomicro/ekit" + "github.com/stretchr/testify/assert" +) + +func TestNewTreeMapWithMap(t *testing.T) { + tests := []struct { + name string + m map[int]int + comparable ekit.Comparator[int] + wantKey []int + wantVal []int + wantErr error + }{ + { + name: "nil", + m: nil, + comparable: nil, + wantKey: nil, + wantVal: nil, + wantErr: errors.New("ekit: Comparator不能为nil"), + }, + { + name: "empty", + m: map[int]int{}, + comparable: compare(), + wantKey: nil, + wantVal: nil, + wantErr: nil, + }, + { + name: "single", + m: map[int]int{ + 0: 0, + }, + comparable: compare(), + wantKey: []int{0}, + wantVal: []int{0}, + wantErr: nil, + }, + { + name: "multiple", + m: map[int]int{ + 0: 0, + 1: 1, + 2: 2, + }, + comparable: compare(), + wantKey: []int{0, 1, 2}, + wantVal: []int{0, 1, 2}, + wantErr: nil, + }, + { + name: "disorder", + m: map[int]int{ + 1: 1, + 2: 2, + 0: 0, + 3: 3, + 5: 5, + 4: 4, + }, + comparable: compare(), + wantKey: []int{0, 1, 2, 3, 5, 4}, + wantVal: []int{0, 1, 2, 3, 5, 4}, + wantErr: nil, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + treeMap, err := NewTreeMapWithMap[int, int](tt.comparable, tt.m) + if err != nil { + assert.Equal(t, tt.wantErr, err) + return + } + for k, v := range tt.m { + value, _ := treeMap.Get(k) + assert.Equal(t, true, v == value) + } + + }) + + } +} + +func TestTreeMap_Get(t *testing.T) { + var tests = []struct { + name string + m map[int]int + findKey int + wantVal int + wantBool bool + }{ + { + name: "empty-TreeMap", + m: map[int]int{}, + findKey: 0, + wantVal: 0, + wantBool: false, + }, + { + name: "find", + m: map[int]int{ + 1: 1, + 2: 2, + 0: 0, + 3: 3, + 5: 5, + 4: 4, + }, + findKey: 2, + wantVal: 2, + wantBool: true, + }, + { + name: "not-find", + m: map[int]int{ + 1: 1, + 2: 2, + 0: 0, + 3: 3, + 5: 5, + 4: 4, + }, + findKey: 6, + wantVal: 0, + wantBool: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + treeMap, _ := NewTreeMap[int, int](compare()) + putAll(treeMap, tt.m) + val, b := treeMap.Get(tt.findKey) + assert.Equal(t, tt.wantBool, b) + assert.Equal(t, tt.wantVal, val) + }) + } +} + +func TestTreeMap_Put(t *testing.T) { + + tests := []struct { + name string + k []int + v []string + wantKey []int + wantVal []string + wantErr error + }{ + { + name: "single", + k: []int{0}, + v: []string{"0"}, + wantKey: []int{0}, + wantVal: []string{"0"}, + wantErr: nil, + }, + { + name: "multiple", + k: []int{0, 1, 2}, + v: []string{"0", "1", "2"}, + wantKey: []int{0, 1, 2}, + wantVal: []string{"0", "1", "2"}, + wantErr: nil, + }, + { + name: "same", + k: []int{0, 0, 1, 2, 2, 3}, + v: []string{"0", "999", "1", "998", "2", "3"}, + wantKey: []int{0, 1, 2, 3}, + wantVal: []string{"999", "1", "2", "3"}, + wantErr: nil, + }, + { + name: "same", + k: []int{0, 0}, + v: []string{"0", "999"}, + wantKey: []int{0}, + wantVal: []string{"999"}, + wantErr: nil, + }, + { + name: "disorder", + k: []int{1, 2, 0, 3, 5, 4}, + v: []string{"1", "2", "0", "3", "5", "4"}, + wantKey: []int{0, 1, 2, 3, 4, 5}, + wantVal: []string{"0", "1", "2", "3", "4", "5"}, + wantErr: nil, + }, + { + name: "disorder-same", + k: []int{1, 3, 2, 0, 2, 3}, + v: []string{"1", "2", "998", "0", "3", "997"}, + wantKey: []int{0, 1, 2, 3}, + wantVal: []string{"0", "1", "3", "997"}, + wantErr: nil, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + treeMap, _ := NewTreeMap[int, string](compare()) + for i := 0; i < len(tt.k); i++ { + err := treeMap.Put(tt.k[i], tt.v[i]) + if err != nil { + assert.Equal(t, tt.wantErr, err) + return + } + } + for i := 0; i < len(tt.wantKey); i++ { + v, b := treeMap.Get(tt.wantKey[i]) + assert.Equal(t, true, b) + assert.Equal(t, tt.wantVal[i], v) + } + + }) + } + subTests := []struct { + name string + k []int + v []string + wantKey []int + wantVal []string + wantErr error + }{ + { + name: "nil", + k: []int{0}, + v: nil, + wantKey: []int{0}, + wantVal: []string(nil), + }, + { + name: "nil", + k: []int{0}, + v: []string{"0"}, + wantKey: []int{0}, + wantVal: []string{"0"}, + }, + } + for _, tt := range subTests { + t.Run(tt.name, func(t *testing.T) { + treeMap, _ := NewTreeMap[int, []string](compare()) + for i := 0; i < len(tt.k); i++ { + err := treeMap.Put(tt.k[i], tt.v) + if err != nil { + assert.Equal(t, tt.wantErr, err) + return + } + } + for i := 0; i < len(tt.wantKey); i++ { + v, b := treeMap.Get(tt.wantKey[i]) + assert.Equal(t, true, b) + assert.Equal(t, tt.wantVal, v) + } + + }) + } +} + +func TestTreeMap_Remove(t *testing.T) { + var tests = []struct { + name string + m map[int]int + delKey int + wantVal int + wantBool bool + }{ + { + name: "empty-TreeMap", + m: map[int]int{}, + delKey: 0, + wantVal: 0, + }, + { + name: "find", + m: map[int]int{ + 1: 1, + 2: 2, + 0: 0, + 3: 3, + 5: 5, + 4: 4, + }, + delKey: 2, + wantVal: 0, + }, + { + name: "not-find", + m: map[int]int{ + 1: 1, + 2: 2, + 0: 0, + 3: 3, + 5: 5, + 4: 4, + }, + delKey: 6, + wantVal: 0, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + treeMap, _ := NewTreeMap[int, int](compare()) + treeMap.Remove(tt.delKey) + val, err := treeMap.Get(tt.delKey) + assert.Equal(t, tt.wantBool, err) + assert.Equal(t, tt.wantVal, val) + }) + } +} + +func compare() ekit.Comparator[int] { + return ekit.ComparatorRealNumber[int] +} + +// goos: windows +// goarch: amd64 +// pkg: github.com/gotomicro/ekit/mapx +// cpu: Intel(R) Core(TM) i5-7500 CPU @ 3.40GHz +// BenchmarkTreeMap/treeMap_put-4 10000 250.6 ns/op 95 B/op 1 allocs/op +// BenchmarkTreeMap/map_put-4 10000 103.0 ns/op 68 B/op 0 allocs/op +// BenchmarkTreeMap/hashMap_put-4 10000 250.6 ns/op 107 B/op 1 allocs/op +// BenchmarkTreeMap/treeMap_get-4 10000 52.16 ns/op 0 B/op 0 allocs/op +// BenchmarkTreeMap/map_get-4 10000 0 B/op 0 allocs/op +// BenchmarkTreeMap/hashMap_get-4 10000 52.89 ns/op 7 B/op 0 allocs/op +// PASS +// ok github.com/gotomicro/ekit/mapx 0.797s +func BenchmarkTreeMap(b *testing.B) { + hashmap := NewHashMap[hashInt, int](10) + treeMap, _ := NewTreeMap[uint64, int](ekit.ComparatorRealNumber[uint64]) + m := make(map[uint64]int, 10) + b.Run("treeMap_put", func(b *testing.B) { + for i := 0; i < b.N; i++ { + _ = treeMap.Put(uint64(i), i) + } + }) + b.Run("map_put", func(b *testing.B) { + for i := 0; i < b.N; i++ { + m[uint64(i)] = i + } + }) + b.Run("hashMap_put", func(b *testing.B) { + for i := 0; i < b.N; i++ { + _ = hashmap.Put(hashInt(uint64(i)), i) + } + }) + b.Run("treeMap_get", func(b *testing.B) { + for i := 0; i < b.N; i++ { + _, _ = treeMap.Get(uint64(i)) + } + }) + b.Run("map_get", func(b *testing.B) { + for i := 0; i < b.N; i++ { + _ = m[uint64(i)] + } + }) + b.Run("hashMap_get", func(b *testing.B) { + for i := 0; i < b.N; i++ { + _, _ = hashmap.Get(hashInt(uint64(i))) + } + }) + +} From 4c1af75c0998832ad773be28e09bba5508f8604a Mon Sep 17 00:00:00 2001 From: Ming Deng Date: Tue, 31 Jan 2023 13:59:43 +0800 Subject: [PATCH 9/9] =?UTF-8?q?=E6=B7=BB=E5=8A=A0=20mapx=20=E7=9A=84?= =?UTF-8?q?=E4=BE=8B=E5=AD=90=EF=BC=8C=E5=87=86=E5=A4=87=20v0.0.6=20(#143)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .CHANGELOG.md | 2 ++ mapx/hashmap.go | 3 ++ mapx/map_example_test.go | 64 ++++++++++++++++++++++++++++++++++++ mapx/treemap_example_test.go | 31 +++++++++++++++++ 4 files changed, 100 insertions(+) create mode 100644 mapx/map_example_test.go create mode 100644 mapx/treemap_example_test.go diff --git a/.CHANGELOG.md b/.CHANGELOG.md index 0f66fa12..5b9d6628 100644 --- a/.CHANGELOG.md +++ b/.CHANGELOG.md @@ -1,4 +1,6 @@ # 开发中 + +# v0.0.6 - [queue: 基于semaphore的并发阻塞队列实现](https://github.com/gotomicro/ekit/pull/129) - [mapx: hashmap实现](https://github.com/gotomicro/ekit/pull/132) - [mapx: 添加 Keys 方法](https://github.com/gotomicro/ekit/pull/134) diff --git a/mapx/hashmap.go b/mapx/hashmap.go index 556af16b..2f29cd3b 100644 --- a/mapx/hashmap.go +++ b/mapx/hashmap.go @@ -30,7 +30,10 @@ func (m *HashMap[T, ValType]) newNode(key Hashable, val ValType) *node[T, ValTyp } type Hashable interface { + // Code 返回该元素的哈希值 + // 注意:哈希值应该尽可能的均匀以避免冲突 Code() uint64 + // Equals 比较两个元素是否相等。如果返回 true,那么我们会认为两个键是一样的。 Equals(key any) bool } diff --git a/mapx/map_example_test.go b/mapx/map_example_test.go new file mode 100644 index 00000000..791e2cfd --- /dev/null +++ b/mapx/map_example_test.go @@ -0,0 +1,64 @@ +// Copyright 2021 gotomicro +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mapx_test + +import ( + "fmt" + + "github.com/gotomicro/ekit/mapx" +) + +func ExampleNewHashMap() { + m := mapx.NewHashMap[MockKey, int](10) + _ = m.Put(MockKey{}, 123) + val, _ := m.Get(MockKey{}) + fmt.Println(val) + // Output: + // 123 +} + +type MockKey struct { + values []int +} + +func (m MockKey) Code() uint64 { + res := 3 + for _, v := range m.values { + res += v * 7 + } + return uint64(res) +} + +func (m MockKey) Equals(key any) bool { + k, ok := key.(MockKey) + if !ok { + return false + } + if len(k.values) != len(m.values) { + return false + } + if k.values == nil && m.values != nil { + return false + } + if k.values != nil && m.values == nil { + return false + } + for i, v := range m.values { + if v != k.values[i] { + return false + } + } + return true +} diff --git a/mapx/treemap_example_test.go b/mapx/treemap_example_test.go new file mode 100644 index 00000000..bcaf57c1 --- /dev/null +++ b/mapx/treemap_example_test.go @@ -0,0 +1,31 @@ +// Copyright 2021 gotomicro +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mapx_test + +import ( + "fmt" + + "github.com/gotomicro/ekit" + "github.com/gotomicro/ekit/mapx" +) + +func ExampleNewTreeMap() { + m, _ := mapx.NewTreeMap[int, int](ekit.ComparatorRealNumber[int]) + _ = m.Put(1, 11) + val, _ := m.Get(1) + fmt.Println(val) + // Output: + // 11 +}