From 696a7d192e5eefa53230168a4b200ec0560c8a10 Mon Sep 17 00:00:00 2001 From: Luke Shumaker Date: Sun, 5 Feb 2023 00:31:29 -0700 Subject: containers: Rethink the RBTree interface to be simpler --- lib/containers/intervaltree.go | 132 +++++++++++++++-------------- lib/containers/intervaltree_test.go | 28 ++++--- lib/containers/rbtree.go | 161 ++++++++++++++++-------------------- lib/containers/rbtree_test.go | 66 +++++++-------- lib/containers/sortedmap.go | 76 ++++++----------- 5 files changed, 209 insertions(+), 254 deletions(-) (limited to 'lib/containers') diff --git a/lib/containers/intervaltree.go b/lib/containers/intervaltree.go index b7ff866..7b96526 100644 --- a/lib/containers/intervaltree.go +++ b/lib/containers/intervaltree.go @@ -4,81 +4,80 @@ package containers -type intervalKey[K Ordered[K]] struct { +type interval[K Ordered[K]] struct { Min, Max K } -func (ival intervalKey[K]) ContainsFn(fn func(K) int) bool { - return fn(ival.Min) >= 0 && fn(ival.Max) <= 0 -} - -func (a intervalKey[K]) Compare(b intervalKey[K]) int { +// Compare implements Ordered. +func (a interval[K]) Compare(b interval[K]) int { if d := a.Min.Compare(b.Min); d != 0 { return d } return a.Max.Compare(b.Max) } +// ContainsFn returns whether this interval contains the range matched +// by the given function. +func (ival interval[K]) ContainsFn(fn func(K) int) bool { + return fn(ival.Min) >= 0 && fn(ival.Max) <= 0 +} + type intervalValue[K Ordered[K], V any] struct { - Val V - SpanOfChildren intervalKey[K] + Val V + ValSpan interval[K] + ChildSpan interval[K] +} + +// Compare implements Ordered. +func (a intervalValue[K, V]) Compare(b intervalValue[K, V]) int { + return a.ValSpan.Compare(b.ValSpan) } type IntervalTree[K Ordered[K], V any] struct { MinFn func(V) K MaxFn func(V) K - inner RBTree[intervalKey[K], intervalValue[K, V]] -} - -func (t *IntervalTree[K, V]) keyFn(v intervalValue[K, V]) intervalKey[K] { - return intervalKey[K]{ - Min: t.MinFn(v.Val), - Max: t.MaxFn(v.Val), - } + inner RBTree[intervalValue[K, V]] } func (t *IntervalTree[K, V]) attrFn(node *RBNode[intervalValue[K, V]]) { - max := t.MaxFn(node.Value.Val) - if node.Left != nil && node.Left.Value.SpanOfChildren.Max.Compare(max) > 0 { - max = node.Left.Value.SpanOfChildren.Max + max := node.Value.ValSpan.Max + if node.Left != nil && node.Left.Value.ChildSpan.Max.Compare(max) > 0 { + max = node.Left.Value.ChildSpan.Max } - if node.Right != nil && node.Right.Value.SpanOfChildren.Max.Compare(max) > 0 { - max = node.Right.Value.SpanOfChildren.Max + if node.Right != nil && node.Right.Value.ChildSpan.Max.Compare(max) > 0 { + max = node.Right.Value.ChildSpan.Max } - node.Value.SpanOfChildren.Max = max + node.Value.ChildSpan.Max = max - min := t.MinFn(node.Value.Val) - if node.Left != nil && node.Left.Value.SpanOfChildren.Min.Compare(min) < 0 { - min = node.Left.Value.SpanOfChildren.Min + min := node.Value.ValSpan.Min + if node.Left != nil && node.Left.Value.ChildSpan.Min.Compare(min) < 0 { + min = node.Left.Value.ChildSpan.Min } - if node.Right != nil && node.Right.Value.SpanOfChildren.Min.Compare(min) < 0 { - min = node.Right.Value.SpanOfChildren.Min + if node.Right != nil && node.Right.Value.ChildSpan.Min.Compare(min) < 0 { + min = node.Right.Value.ChildSpan.Min } - node.Value.SpanOfChildren.Min = min + node.Value.ChildSpan.Min = min } func (t *IntervalTree[K, V]) init() { - if t.inner.KeyFn == nil { - t.inner.KeyFn = t.keyFn + if t.inner.AttrFn == nil { t.inner.AttrFn = t.attrFn } } -func (t *IntervalTree[K, V]) Delete(min, max K) { - t.init() - t.inner.Delete(intervalKey[K]{ - Min: min, - Max: max, - }) -} - func (t *IntervalTree[K, V]) Equal(u *IntervalTree[K, V]) bool { return t.inner.Equal(&u.inner) } func (t *IntervalTree[K, V]) Insert(val V) { t.init() - t.inner.Insert(intervalValue[K, V]{Val: val}) + t.inner.Insert(intervalValue[K, V]{ + Val: val, + ValSpan: interval[K]{ + Min: t.MinFn(val), + Max: t.MaxFn(val), + }, + }) } func (t *IntervalTree[K, V]) Min() (K, bool) { @@ -86,7 +85,7 @@ func (t *IntervalTree[K, V]) Min() (K, bool) { var zero K return zero, false } - return t.inner.root.Value.SpanOfChildren.Min, true + return t.inner.root.Value.ChildSpan.Min, true } func (t *IntervalTree[K, V]) Max() (K, bool) { @@ -94,22 +93,18 @@ func (t *IntervalTree[K, V]) Max() (K, bool) { var zero K return zero, false } - return t.inner.root.Value.SpanOfChildren.Max, true -} - -func (t *IntervalTree[K, V]) Lookup(k K) (V, bool) { - return t.Search(k.Compare) + return t.inner.root.Value.ChildSpan.Max, true } func (t *IntervalTree[K, V]) Search(fn func(K) int) (V, bool) { node := t.inner.root for node != nil { switch { - case t.keyFn(node.Value).ContainsFn(fn): + case node.Value.ValSpan.ContainsFn(fn): return node.Value.Val, true - case node.Left != nil && node.Left.Value.SpanOfChildren.ContainsFn(fn): + case node.Left != nil && node.Left.Value.ChildSpan.ContainsFn(fn): node = node.Left - case node.Right != nil && node.Right.Value.SpanOfChildren.ContainsFn(fn): + case node.Right != nil && node.Right.Value.ChildSpan.ContainsFn(fn): node = node.Right default: node = nil @@ -119,24 +114,33 @@ func (t *IntervalTree[K, V]) Search(fn func(K) int) (V, bool) { return zero, false } -func (t *IntervalTree[K, V]) searchAll(fn func(K) int, node *RBNode[intervalValue[K, V]], ret *[]V) { +func (t *IntervalTree[K, V]) Range(fn func(V) bool) { + t.inner.Range(func(node *RBNode[intervalValue[K, V]]) bool { + return fn(node.Value.Val) + }) +} + +func (t *IntervalTree[K, V]) Subrange(rangeFn func(K) int, handleFn func(V) bool) { + t.subrange(t.inner.root, rangeFn, handleFn) +} + +func (t *IntervalTree[K, V]) subrange(node *RBNode[intervalValue[K, V]], rangeFn func(K) int, handleFn func(V) bool) bool { if node == nil { - return + return true } - if !node.Value.SpanOfChildren.ContainsFn(fn) { - return + if !node.Value.ChildSpan.ContainsFn(rangeFn) { + return true } - t.searchAll(fn, node.Left, ret) - if t.keyFn(node.Value).ContainsFn(fn) { - *ret = append(*ret, node.Value.Val) + if !t.subrange(node.Left, rangeFn, handleFn) { + return false } - t.searchAll(fn, node.Right, ret) -} - -func (t *IntervalTree[K, V]) SearchAll(fn func(K) int) []V { - var ret []V - t.searchAll(fn, t.inner.root, &ret) - return ret + if node.Value.ValSpan.ContainsFn(rangeFn) { + if !handleFn(node.Value.Val) { + return false + } + } + if !t.subrange(node.Right, rangeFn, handleFn) { + return false + } + return true } - -// TODO: func (t *IntervalTree[K, V]) Walk(fn func(*RBNode[V]) error) error diff --git a/lib/containers/intervaltree_test.go b/lib/containers/intervaltree_test.go index 45743a2..30885bd 100644 --- a/lib/containers/intervaltree_test.go +++ b/lib/containers/intervaltree_test.go @@ -18,8 +18,8 @@ func (t *IntervalTree[K, V]) ASCIIArt() string { func (v intervalValue[K, V]) String() string { return fmt.Sprintf("%v) ([%v,%v]", v.Val, - v.SpanOfChildren.Min, - v.SpanOfChildren.Max) + v.ChildSpan.Min, + v.ChildSpan.Max) } func (v NativeOrdered[T]) String() string { @@ -60,15 +60,21 @@ func TestIntervalTree(t *testing.T) { t.Log("\n" + tree.ASCIIArt()) // find intervals that touch [9,20] - intervals := tree.SearchAll(func(k NativeOrdered[int]) int { - if k.Val < 9 { - return 1 - } - if k.Val > 20 { - return -1 - } - return 0 - }) + var intervals []SimpleInterval + tree.Subrange( + func(k NativeOrdered[int]) int { + if k.Val < 9 { + return 1 + } + if k.Val > 20 { + return -1 + } + return 0 + }, + func(v SimpleInterval) bool { + intervals = append(intervals, v) + return true + }) assert.Equal(t, []SimpleInterval{ {6, 10}, diff --git a/lib/containers/rbtree.go b/lib/containers/rbtree.go index 1fdb799..4125847 100644 --- a/lib/containers/rbtree.go +++ b/lib/containers/rbtree.go @@ -7,8 +7,6 @@ package containers import ( "fmt" "reflect" - - "git.lukeshu.com/btrfs-progs-ng/lib/slices" ) type Color bool @@ -18,50 +16,49 @@ const ( Red = Color(true) ) -type RBNode[V any] struct { - Parent, Left, Right *RBNode[V] +type RBNode[T Ordered[T]] struct { + Parent, Left, Right *RBNode[T] Color Color - Value V + Value T } -func (node *RBNode[V]) getColor() Color { +func (node *RBNode[T]) getColor() Color { if node == nil { return Black } return node.Color } -type RBTree[K Ordered[K], V any] struct { - KeyFn func(V) K - AttrFn func(*RBNode[V]) - root *RBNode[V] +type RBTree[T Ordered[T]] struct { + AttrFn func(*RBNode[T]) + root *RBNode[T] len int } -func (t *RBTree[K, V]) Len() int { +func (t *RBTree[T]) Len() int { return t.len } -func (t *RBTree[K, V]) Walk(fn func(*RBNode[V]) error) error { - return t.root.walk(fn) +func (t *RBTree[T]) Range(fn func(*RBNode[T]) bool) { + t.root._range(fn) } -func (node *RBNode[V]) walk(fn func(*RBNode[V]) error) error { +func (node *RBNode[T]) _range(fn func(*RBNode[T]) bool) bool { if node == nil { - return nil + return true } - if err := node.Left.walk(fn); err != nil { - return err + if !node.Left._range(fn) { + return false } - if err := fn(node); err != nil { - return err + if !fn(node) { + return false } - if err := node.Right.walk(fn); err != nil { - return err + if !node.Right._range(fn) { + return false } - return nil + return true } // Search the tree for a value that satisfied the given callbackk @@ -81,16 +78,13 @@ func (node *RBNode[V]) walk(fn func(*RBNode[V]) error) error { // +---+ +---+ // // Returns nil if no such value is found. -// -// Search is good for advanced lookup, like when a range of values is -// acceptable. For simple exact-value lookup, use Lookup. -func (t *RBTree[K, V]) Search(fn func(V) int) *RBNode[V] { +func (t *RBTree[T]) Search(fn func(T) int) *RBNode[T] { ret, _ := t.root.search(fn) return ret } -func (node *RBNode[V]) search(fn func(V) int) (exact, nearest *RBNode[V]) { - var prev *RBNode[V] +func (node *RBNode[T]) search(fn func(T) int) (exact, nearest *RBNode[T]) { + var prev *RBNode[T] for { if node == nil { return nil, prev @@ -108,26 +102,13 @@ func (node *RBNode[V]) search(fn func(V) int) (exact, nearest *RBNode[V]) { } } -func (t *RBTree[K, V]) exactKey(key K) func(V) int { - return func(val V) int { - valKey := t.KeyFn(val) - return key.Compare(valKey) - } -} - -// Lookup looks up the value for an exact key. If no such value -// exists, nil is returned. -func (t *RBTree[K, V]) Lookup(key K) *RBNode[V] { - return t.Search(t.exactKey(key)) -} - // Min returns the minimum value stored in the tree, or nil if the // tree is empty. -func (t *RBTree[K, V]) Min() *RBNode[V] { +func (t *RBTree[T]) Min() *RBNode[T] { return t.root.min() } -func (node *RBNode[V]) min() *RBNode[V] { +func (node *RBNode[T]) min() *RBNode[T] { if node == nil { return nil } @@ -141,11 +122,11 @@ func (node *RBNode[V]) min() *RBNode[V] { // Max returns the maximum value stored in the tree, or nil if the // tree is empty. -func (t *RBTree[K, V]) Max() *RBNode[V] { +func (t *RBTree[T]) Max() *RBNode[T] { return t.root.max() } -func (node *RBNode[V]) max() *RBNode[V] { +func (node *RBNode[T]) max() *RBNode[T] { if node == nil { return nil } @@ -157,11 +138,7 @@ func (node *RBNode[V]) max() *RBNode[V] { } } -func (t *RBTree[K, V]) Next(cur *RBNode[V]) *RBNode[V] { - return cur.next() -} - -func (cur *RBNode[V]) next() *RBNode[V] { +func (cur *RBNode[T]) Next() *RBNode[T] { if cur.Right != nil { return cur.Right.min() } @@ -172,11 +149,7 @@ func (cur *RBNode[V]) next() *RBNode[V] { return parent } -func (t *RBTree[K, V]) Prev(cur *RBNode[V]) *RBNode[V] { - return cur.prev() -} - -func (cur *RBNode[V]) prev() *RBNode[V] { +func (cur *RBNode[T]) Prev() *RBNode[T] { if cur.Left != nil { return cur.Left.max() } @@ -187,48 +160,56 @@ func (cur *RBNode[V]) prev() *RBNode[V] { return parent } -// SearchRange is like Search, but returns all nodes that match the -// function; assuming that they are contiguous. -func (t *RBTree[K, V]) SearchRange(fn func(V) int) []V { - middle := t.Search(fn) - if middle == nil { - return nil - } - ret := []V{middle.Value} - for node := t.Prev(middle); node != nil && fn(node.Value) == 0; node = t.Prev(node) { - ret = append(ret, node.Value) +// Subrange is like Search, but for when there may be more than one +// result. +func (t *RBTree[T]) Subrange(rangeFn func(T) int, handleFn func(*RBNode[T]) bool) { + // Find the left-most acceptable node. + _, node := t.root.search(func(v T) int { + if rangeFn(v) <= 0 { + return -1 + } else { + return 1 + } + }) + for node != nil && rangeFn(node.Value) > 0 { + node = node.Next() } - slices.Reverse(ret) - for node := t.Next(middle); node != nil && fn(node.Value) == 0; node = t.Next(node) { - ret = append(ret, node.Value) + // Now walk forward until we hit the end. + for node != nil && rangeFn(node.Value) == 0 { + if keepGoing := handleFn(node); !keepGoing { + return + } + node = node.Next() } - return ret } -func (t *RBTree[K, V]) Equal(u *RBTree[K, V]) bool { +func (t *RBTree[T]) Equal(u *RBTree[T]) bool { if (t == nil) != (u == nil) { return false } if t == nil { return true } + if t.len != u.len { + return false + } - var tSlice []V - _ = t.Walk(func(node *RBNode[V]) error { + tSlice := make([]T, 0, t.len) + t.Range(func(node *RBNode[T]) bool { tSlice = append(tSlice, node.Value) - return nil + return true }) - var uSlice []V - _ = u.Walk(func(node *RBNode[V]) error { + uSlice := make([]T, 0, u.len) + u.Range(func(node *RBNode[T]) bool { uSlice = append(uSlice, node.Value) - return nil + return true }) return reflect.DeepEqual(tSlice, uSlice) } -func (t *RBTree[K, V]) parentChild(node *RBNode[V]) **RBNode[V] { +func (t *RBTree[T]) parentChild(node *RBNode[T]) **RBNode[T] { switch { case node.Parent == nil: return &t.root @@ -241,7 +222,7 @@ func (t *RBTree[K, V]) parentChild(node *RBNode[V]) **RBNode[V] { } } -func (t *RBTree[K, V]) updateAttr(node *RBNode[V]) { +func (t *RBTree[T]) updateAttr(node *RBNode[T]) { if t.AttrFn == nil { return } @@ -251,7 +232,7 @@ func (t *RBTree[K, V]) updateAttr(node *RBNode[V]) { } } -func (t *RBTree[K, V]) leftRotate(x *RBNode[V]) { +func (t *RBTree[T]) leftRotate(x *RBNode[T]) { // p p // | | // +---+ +---+ @@ -286,7 +267,7 @@ func (t *RBTree[K, V]) leftRotate(x *RBNode[V]) { t.updateAttr(x) } -func (t *RBTree[K, V]) rightRotate(y *RBNode[V]) { +func (t *RBTree[T]) rightRotate(y *RBNode[T]) { //nolint:dupword // // | | @@ -322,18 +303,17 @@ func (t *RBTree[K, V]) rightRotate(y *RBNode[V]) { t.updateAttr(y) } -func (t *RBTree[K, V]) Insert(val V) { +func (t *RBTree[T]) Insert(val T) { // Naive-insert - key := t.KeyFn(val) - exact, parent := t.root.search(t.exactKey(key)) + exact, parent := t.root.search(val.Compare) if exact != nil { exact.Value = val return } t.len++ - node := &RBNode[V]{ + node := &RBNode[T]{ Color: Red, Parent: parent, Value: val, @@ -341,7 +321,7 @@ func (t *RBTree[K, V]) Insert(val V) { switch { case parent == nil: t.root = node - case key.Compare(t.KeyFn(parent.Value)) < 0: + case val.Compare(parent.Value) < 0: parent.Left = node default: parent.Right = node @@ -391,15 +371,14 @@ func (t *RBTree[K, V]) Insert(val V) { t.root.Color = Black } -func (t *RBTree[K, V]) transplant(oldNode, newNode *RBNode[V]) { +func (t *RBTree[T]) transplant(oldNode, newNode *RBNode[T]) { *t.parentChild(oldNode) = newNode if newNode != nil { newNode.Parent = oldNode.Parent } } -func (t *RBTree[K, V]) Delete(key K) { - nodeToDelete := t.Lookup(key) +func (t *RBTree[T]) Delete(nodeToDelete *RBNode[T]) { if nodeToDelete == nil { return } @@ -410,8 +389,8 @@ func (t *RBTree[K, V]) Delete(key K) { // phase 1 - var nodeToRebalance *RBNode[V] - var nodeToRebalanceParent *RBNode[V] // in case 'nodeToRebalance' is nil, which it can be + var nodeToRebalance *RBNode[T] + var nodeToRebalanceParent *RBNode[T] // in case 'nodeToRebalance' is nil, which it can be needsRebalance := nodeToDelete.Color == Black switch { @@ -427,7 +406,7 @@ func (t *RBTree[K, V]) Delete(key K) { // The node being deleted has a child on both sides, // so we've go to reshuffle the parents a bit to make // room for those children. - next := nodeToDelete.next() + next := nodeToDelete.Next() if next.Parent == nodeToDelete { // p p // | | diff --git a/lib/containers/rbtree_test.go b/lib/containers/rbtree_test.go index e42410e..d2fe931 100644 --- a/lib/containers/rbtree_test.go +++ b/lib/containers/rbtree_test.go @@ -7,21 +7,22 @@ package containers import ( "fmt" "io" - "sort" "strings" "testing" "github.com/stretchr/testify/require" "golang.org/x/exp/constraints" + + "git.lukeshu.com/btrfs-progs-ng/lib/slices" ) -func (t *RBTree[K, V]) ASCIIArt() string { +func (t *RBTree[T]) ASCIIArt() string { var out strings.Builder t.root.asciiArt(&out, "", "", "") return out.String() } -func (node *RBNode[V]) String() string { +func (node *RBNode[T]) String() string { switch { case node == nil: return "nil" @@ -32,7 +33,7 @@ func (node *RBNode[V]) String() string { } } -func (node *RBNode[V]) asciiArt(w io.Writer, u, m, l string) { +func (node *RBNode[T]) asciiArt(w io.Writer, u, m, l string) { if node == nil { fmt.Fprintf(w, "%snil\n", m) return @@ -43,7 +44,7 @@ func (node *RBNode[V]) asciiArt(w io.Writer, u, m, l string) { node.Left.asciiArt(w, l+" | ", l+" `--", l+" ") } -func checkRBTree[K constraints.Ordered, V any](t *testing.T, expectedSet Set[K], tree *RBTree[NativeOrdered[K], V]) { +func checkRBTree[T constraints.Ordered](t *testing.T, expectedSet Set[T], tree *RBTree[NativeOrdered[T]]) { // 1. Every node is either red or black // 2. The root is black. @@ -52,19 +53,19 @@ func checkRBTree[K constraints.Ordered, V any](t *testing.T, expectedSet Set[K], // 3. Every nil is black. // 4. If a node is red, then both its children are black. - require.NoError(t, tree.Walk(func(node *RBNode[V]) error { + tree.Range(func(node *RBNode[NativeOrdered[T]]) bool { if node.getColor() == Red { require.Equal(t, Black, node.Left.getColor()) require.Equal(t, Black, node.Right.getColor()) } - return nil - })) + return true + }) // 5. For each node, all simple paths from the node to // descendent leaves contain the same number of black // nodes. - var walkCnt func(node *RBNode[V], cnt int, leafFn func(int)) - walkCnt = func(node *RBNode[V], cnt int, leafFn func(int)) { + var walkCnt func(node *RBNode[NativeOrdered[T]], cnt int, leafFn func(int)) + walkCnt = func(node *RBNode[NativeOrdered[T]], cnt int, leafFn func(int)) { if node.getColor() == Black { cnt++ } @@ -75,7 +76,7 @@ func checkRBTree[K constraints.Ordered, V any](t *testing.T, expectedSet Set[K], walkCnt(node.Left, cnt, leafFn) walkCnt(node.Right, cnt, leafFn) } - require.NoError(t, tree.Walk(func(node *RBNode[V]) error { + tree.Range(func(node *RBNode[NativeOrdered[T]]) bool { var cnts []int walkCnt(node, 0, func(cnt int) { cnts = append(cnts, cnt) @@ -86,27 +87,24 @@ func checkRBTree[K constraints.Ordered, V any](t *testing.T, expectedSet Set[K], break } } - return nil - })) + return true + }) // expected contents - expectedOrder := make([]K, 0, len(expectedSet)) - for k := range expectedSet { - expectedOrder = append(expectedOrder, k) - node := tree.Lookup(NativeOrdered[K]{Val: k}) + expectedOrder := make([]T, 0, len(expectedSet)) + for v := range expectedSet { + expectedOrder = append(expectedOrder, v) + node := tree.Search(NativeOrdered[T]{Val: v}.Compare) require.NotNil(t, tree, node) - require.Equal(t, k, tree.KeyFn(node.Value).Val) } - sort.Slice(expectedOrder, func(i, j int) bool { - return expectedOrder[i] < expectedOrder[j] + slices.Sort(expectedOrder) + actOrder := make([]T, 0, len(expectedSet)) + tree.Range(func(node *RBNode[NativeOrdered[T]]) bool { + actOrder = append(actOrder, node.Value.Val) + return true }) - actOrder := make([]K, 0, len(expectedSet)) - require.NoError(t, tree.Walk(func(node *RBNode[V]) error { - actOrder = append(actOrder, tree.KeyFn(node.Value).Val) - return nil - })) require.Equal(t, expectedOrder, actOrder) - require.Equal(t, tree.Len(), len(expectedSet)) + require.Equal(t, len(expectedSet), tree.Len()) } func FuzzRBTree(f *testing.F) { @@ -132,11 +130,7 @@ func FuzzRBTree(f *testing.F) { }) f.Fuzz(func(t *testing.T, dat []uint8) { - tree := &RBTree[NativeOrdered[uint8], uint8]{ - KeyFn: func(x uint8) NativeOrdered[uint8] { - return NativeOrdered[uint8]{Val: x} - }, - } + tree := new(RBTree[NativeOrdered[uint8]]) set := make(Set[uint8]) checkRBTree(t, set, tree) t.Logf("\n%s\n", tree.ASCIIArt()) @@ -145,18 +139,18 @@ func FuzzRBTree(f *testing.F) { val := (b & 0b0011_1111) if ins { t.Logf("Insert(%v)", val) - tree.Insert(val) + tree.Insert(NativeOrdered[uint8]{Val: val}) set.Insert(val) t.Logf("\n%s\n", tree.ASCIIArt()) - node := tree.Lookup(NativeOrdered[uint8]{Val: val}) + node := tree.Search(NativeOrdered[uint8]{Val: val}.Compare) require.NotNil(t, node) - require.Equal(t, val, node.Value) + require.Equal(t, val, node.Value.Val) } else { t.Logf("Delete(%v)", val) - tree.Delete(NativeOrdered[uint8]{Val: val}) + tree.Delete(tree.Search(NativeOrdered[uint8]{Val: val}.Compare)) delete(set, val) t.Logf("\n%s\n", tree.ASCIIArt()) - require.Nil(t, tree.Lookup(NativeOrdered[uint8]{Val: val})) + require.Nil(t, tree.Search(NativeOrdered[uint8]{Val: val}.Compare)) } checkRBTree(t, set, tree) } diff --git a/lib/containers/sortedmap.go b/lib/containers/sortedmap.go index 52308c9..d104274 100644 --- a/lib/containers/sortedmap.go +++ b/lib/containers/sortedmap.go @@ -1,40 +1,32 @@ -// Copyright (C) 2022 Luke Shumaker +// Copyright (C) 2022-2023 Luke Shumaker // // SPDX-License-Identifier: GPL-2.0-or-later package containers -import ( - "errors" -) - -type OrderedKV[K Ordered[K], V any] struct { +type orderedKV[K Ordered[K], V any] struct { K K V V } -type SortedMap[K Ordered[K], V any] struct { - inner RBTree[K, OrderedKV[K, V]] -} - -func (m *SortedMap[K, V]) init() { - if m.inner.KeyFn == nil { - m.inner.KeyFn = m.keyFn - } +func (a orderedKV[K, V]) Compare(b orderedKV[K, V]) int { + return a.K.Compare(b.K) } -func (m *SortedMap[K, V]) keyFn(kv OrderedKV[K, V]) K { - return kv.K +type SortedMap[K Ordered[K], V any] struct { + inner RBTree[orderedKV[K, V]] } func (m *SortedMap[K, V]) Delete(key K) { - m.init() - m.inner.Delete(key) + m.inner.Delete(m.inner.Search(func(kv orderedKV[K, V]) int { + return key.Compare(kv.K) + })) } func (m *SortedMap[K, V]) Load(key K) (value V, ok bool) { - m.init() - node := m.inner.Lookup(key) + node := m.inner.Search(func(kv orderedKV[K, V]) int { + return key.Compare(kv.K) + }) if node == nil { var zero V return zero, false @@ -42,41 +34,27 @@ func (m *SortedMap[K, V]) Load(key K) (value V, ok bool) { return node.Value.V, true } -var errStop = errors.New("stop") - -func (m *SortedMap[K, V]) Range(f func(key K, value V) bool) { - m.init() - _ = m.inner.Walk(func(node *RBNode[OrderedKV[K, V]]) error { - if f(node.Value.K, node.Value.V) { - return nil - } else { - return errStop - } +func (m *SortedMap[K, V]) Store(key K, value V) { + m.inner.Insert(orderedKV[K, V]{ + K: key, + V: value, }) } -func (m *SortedMap[K, V]) Subrange(rangeFn func(K, V) int, handleFn func(K, V) bool) { - m.init() - kvs := m.inner.SearchRange(func(kv OrderedKV[K, V]) int { - return rangeFn(kv.K, kv.V) +func (m *SortedMap[K, V]) Range(fn func(key K, value V) bool) { + m.inner.Range(func(node *RBNode[orderedKV[K, V]]) bool { + return fn(node.Value.K, node.Value.V) }) - for _, kv := range kvs { - if !handleFn(kv.K, kv.V) { - break - } - } } -func (m *SortedMap[K, V]) Store(key K, value V) { - m.init() - m.inner.Insert(OrderedKV[K, V]{ - K: key, - V: value, - }) +func (m *SortedMap[K, V]) Subrange(rangeFn func(K, V) int, handleFn func(K, V) bool) { + m.inner.Subrange( + func(kv orderedKV[K, V]) int { return rangeFn(kv.K, kv.V) }, + func(node *RBNode[orderedKV[K, V]]) bool { return handleFn(node.Value.K, node.Value.V) }) } func (m *SortedMap[K, V]) Search(fn func(K, V) int) (K, V, bool) { - node := m.inner.Search(func(kv OrderedKV[K, V]) int { + node := m.inner.Search(func(kv orderedKV[K, V]) int { return fn(kv.K, kv.V) }) if node == nil { @@ -87,12 +65,6 @@ func (m *SortedMap[K, V]) Search(fn func(K, V) int) (K, V, bool) { return node.Value.K, node.Value.V, true } -func (m *SortedMap[K, V]) SearchAll(fn func(K, V) int) []OrderedKV[K, V] { - return m.inner.SearchRange(func(kv OrderedKV[K, V]) int { - return fn(kv.K, kv.V) - }) -} - func (m *SortedMap[K, V]) Len() int { return m.inner.Len() } -- cgit v1.2.3-54-g00ecf