summaryrefslogtreecommitdiff
path: root/lib/containers
diff options
context:
space:
mode:
authorLuke Shumaker <lukeshu@lukeshu.com>2023-02-05 00:31:29 -0700
committerLuke Shumaker <lukeshu@lukeshu.com>2023-02-12 02:43:16 -0700
commit696a7d192e5eefa53230168a4b200ec0560c8a10 (patch)
tree039c3c549414c21b15d58c3d695ee87c3feb1402 /lib/containers
parentb608e4cf9c9e6e5bf5a333e8d78b2800ffcb0c91 (diff)
containers: Rethink the RBTree interface to be simpler
Diffstat (limited to 'lib/containers')
-rw-r--r--lib/containers/intervaltree.go132
-rw-r--r--lib/containers/intervaltree_test.go28
-rw-r--r--lib/containers/rbtree.go161
-rw-r--r--lib/containers/rbtree_test.go66
-rw-r--r--lib/containers/sortedmap.go76
5 files changed, 209 insertions, 254 deletions
diff --git a/lib/containers/intervaltree.go b/lib/containers/intervaltree.go
index b7ff866..7b96526 100644
--- a/lib/containers/intervaltree.go
+++ b/lib/containers/intervaltree.go
@@ -4,81 +4,80 @@
package containers
-type intervalKey[K Ordered[K]] struct {
+type interval[K Ordered[K]] struct {
Min, Max K
}
-func (ival intervalKey[K]) ContainsFn(fn func(K) int) bool {
- return fn(ival.Min) >= 0 && fn(ival.Max) <= 0
-}
-
-func (a intervalKey[K]) Compare(b intervalKey[K]) int {
+// Compare implements Ordered.
+func (a interval[K]) Compare(b interval[K]) int {
if d := a.Min.Compare(b.Min); d != 0 {
return d
}
return a.Max.Compare(b.Max)
}
+// ContainsFn returns whether this interval contains the range matched
+// by the given function.
+func (ival interval[K]) ContainsFn(fn func(K) int) bool {
+ return fn(ival.Min) >= 0 && fn(ival.Max) <= 0
+}
+
type intervalValue[K Ordered[K], V any] struct {
- Val V
- SpanOfChildren intervalKey[K]
+ Val V
+ ValSpan interval[K]
+ ChildSpan interval[K]
+}
+
+// Compare implements Ordered.
+func (a intervalValue[K, V]) Compare(b intervalValue[K, V]) int {
+ return a.ValSpan.Compare(b.ValSpan)
}
type IntervalTree[K Ordered[K], V any] struct {
MinFn func(V) K
MaxFn func(V) K
- inner RBTree[intervalKey[K], intervalValue[K, V]]
-}
-
-func (t *IntervalTree[K, V]) keyFn(v intervalValue[K, V]) intervalKey[K] {
- return intervalKey[K]{
- Min: t.MinFn(v.Val),
- Max: t.MaxFn(v.Val),
- }
+ inner RBTree[intervalValue[K, V]]
}
func (t *IntervalTree[K, V]) attrFn(node *RBNode[intervalValue[K, V]]) {
- max := t.MaxFn(node.Value.Val)
- if node.Left != nil && node.Left.Value.SpanOfChildren.Max.Compare(max) > 0 {
- max = node.Left.Value.SpanOfChildren.Max
+ max := node.Value.ValSpan.Max
+ if node.Left != nil && node.Left.Value.ChildSpan.Max.Compare(max) > 0 {
+ max = node.Left.Value.ChildSpan.Max
}
- if node.Right != nil && node.Right.Value.SpanOfChildren.Max.Compare(max) > 0 {
- max = node.Right.Value.SpanOfChildren.Max
+ if node.Right != nil && node.Right.Value.ChildSpan.Max.Compare(max) > 0 {
+ max = node.Right.Value.ChildSpan.Max
}
- node.Value.SpanOfChildren.Max = max
+ node.Value.ChildSpan.Max = max
- min := t.MinFn(node.Value.Val)
- if node.Left != nil && node.Left.Value.SpanOfChildren.Min.Compare(min) < 0 {
- min = node.Left.Value.SpanOfChildren.Min
+ min := node.Value.ValSpan.Min
+ if node.Left != nil && node.Left.Value.ChildSpan.Min.Compare(min) < 0 {
+ min = node.Left.Value.ChildSpan.Min
}
- if node.Right != nil && node.Right.Value.SpanOfChildren.Min.Compare(min) < 0 {
- min = node.Right.Value.SpanOfChildren.Min
+ if node.Right != nil && node.Right.Value.ChildSpan.Min.Compare(min) < 0 {
+ min = node.Right.Value.ChildSpan.Min
}
- node.Value.SpanOfChildren.Min = min
+ node.Value.ChildSpan.Min = min
}
func (t *IntervalTree[K, V]) init() {
- if t.inner.KeyFn == nil {
- t.inner.KeyFn = t.keyFn
+ if t.inner.AttrFn == nil {
t.inner.AttrFn = t.attrFn
}
}
-func (t *IntervalTree[K, V]) Delete(min, max K) {
- t.init()
- t.inner.Delete(intervalKey[K]{
- Min: min,
- Max: max,
- })
-}
-
func (t *IntervalTree[K, V]) Equal(u *IntervalTree[K, V]) bool {
return t.inner.Equal(&u.inner)
}
func (t *IntervalTree[K, V]) Insert(val V) {
t.init()
- t.inner.Insert(intervalValue[K, V]{Val: val})
+ t.inner.Insert(intervalValue[K, V]{
+ Val: val,
+ ValSpan: interval[K]{
+ Min: t.MinFn(val),
+ Max: t.MaxFn(val),
+ },
+ })
}
func (t *IntervalTree[K, V]) Min() (K, bool) {
@@ -86,7 +85,7 @@ func (t *IntervalTree[K, V]) Min() (K, bool) {
var zero K
return zero, false
}
- return t.inner.root.Value.SpanOfChildren.Min, true
+ return t.inner.root.Value.ChildSpan.Min, true
}
func (t *IntervalTree[K, V]) Max() (K, bool) {
@@ -94,22 +93,18 @@ func (t *IntervalTree[K, V]) Max() (K, bool) {
var zero K
return zero, false
}
- return t.inner.root.Value.SpanOfChildren.Max, true
-}
-
-func (t *IntervalTree[K, V]) Lookup(k K) (V, bool) {
- return t.Search(k.Compare)
+ return t.inner.root.Value.ChildSpan.Max, true
}
func (t *IntervalTree[K, V]) Search(fn func(K) int) (V, bool) {
node := t.inner.root
for node != nil {
switch {
- case t.keyFn(node.Value).ContainsFn(fn):
+ case node.Value.ValSpan.ContainsFn(fn):
return node.Value.Val, true
- case node.Left != nil && node.Left.Value.SpanOfChildren.ContainsFn(fn):
+ case node.Left != nil && node.Left.Value.ChildSpan.ContainsFn(fn):
node = node.Left
- case node.Right != nil && node.Right.Value.SpanOfChildren.ContainsFn(fn):
+ case node.Right != nil && node.Right.Value.ChildSpan.ContainsFn(fn):
node = node.Right
default:
node = nil
@@ -119,24 +114,33 @@ func (t *IntervalTree[K, V]) Search(fn func(K) int) (V, bool) {
return zero, false
}
-func (t *IntervalTree[K, V]) searchAll(fn func(K) int, node *RBNode[intervalValue[K, V]], ret *[]V) {
+func (t *IntervalTree[K, V]) Range(fn func(V) bool) {
+ t.inner.Range(func(node *RBNode[intervalValue[K, V]]) bool {
+ return fn(node.Value.Val)
+ })
+}
+
+func (t *IntervalTree[K, V]) Subrange(rangeFn func(K) int, handleFn func(V) bool) {
+ t.subrange(t.inner.root, rangeFn, handleFn)
+}
+
+func (t *IntervalTree[K, V]) subrange(node *RBNode[intervalValue[K, V]], rangeFn func(K) int, handleFn func(V) bool) bool {
if node == nil {
- return
+ return true
}
- if !node.Value.SpanOfChildren.ContainsFn(fn) {
- return
+ if !node.Value.ChildSpan.ContainsFn(rangeFn) {
+ return true
}
- t.searchAll(fn, node.Left, ret)
- if t.keyFn(node.Value).ContainsFn(fn) {
- *ret = append(*ret, node.Value.Val)
+ if !t.subrange(node.Left, rangeFn, handleFn) {
+ return false
}
- t.searchAll(fn, node.Right, ret)
-}
-
-func (t *IntervalTree[K, V]) SearchAll(fn func(K) int) []V {
- var ret []V
- t.searchAll(fn, t.inner.root, &ret)
- return ret
+ if node.Value.ValSpan.ContainsFn(rangeFn) {
+ if !handleFn(node.Value.Val) {
+ return false
+ }
+ }
+ if !t.subrange(node.Right, rangeFn, handleFn) {
+ return false
+ }
+ return true
}
-
-// TODO: func (t *IntervalTree[K, V]) Walk(fn func(*RBNode[V]) error) error
diff --git a/lib/containers/intervaltree_test.go b/lib/containers/intervaltree_test.go
index 45743a2..30885bd 100644
--- a/lib/containers/intervaltree_test.go
+++ b/lib/containers/intervaltree_test.go
@@ -18,8 +18,8 @@ func (t *IntervalTree[K, V]) ASCIIArt() string {
func (v intervalValue[K, V]) String() string {
return fmt.Sprintf("%v) ([%v,%v]",
v.Val,
- v.SpanOfChildren.Min,
- v.SpanOfChildren.Max)
+ v.ChildSpan.Min,
+ v.ChildSpan.Max)
}
func (v NativeOrdered[T]) String() string {
@@ -60,15 +60,21 @@ func TestIntervalTree(t *testing.T) {
t.Log("\n" + tree.ASCIIArt())
// find intervals that touch [9,20]
- intervals := tree.SearchAll(func(k NativeOrdered[int]) int {
- if k.Val < 9 {
- return 1
- }
- if k.Val > 20 {
- return -1
- }
- return 0
- })
+ var intervals []SimpleInterval
+ tree.Subrange(
+ func(k NativeOrdered[int]) int {
+ if k.Val < 9 {
+ return 1
+ }
+ if k.Val > 20 {
+ return -1
+ }
+ return 0
+ },
+ func(v SimpleInterval) bool {
+ intervals = append(intervals, v)
+ return true
+ })
assert.Equal(t,
[]SimpleInterval{
{6, 10},
diff --git a/lib/containers/rbtree.go b/lib/containers/rbtree.go
index 1fdb799..4125847 100644
--- a/lib/containers/rbtree.go
+++ b/lib/containers/rbtree.go
@@ -7,8 +7,6 @@ package containers
import (
"fmt"
"reflect"
-
- "git.lukeshu.com/btrfs-progs-ng/lib/slices"
)
type Color bool
@@ -18,50 +16,49 @@ const (
Red = Color(true)
)
-type RBNode[V any] struct {
- Parent, Left, Right *RBNode[V]
+type RBNode[T Ordered[T]] struct {
+ Parent, Left, Right *RBNode[T]
Color Color
- Value V
+ Value T
}
-func (node *RBNode[V]) getColor() Color {
+func (node *RBNode[T]) getColor() Color {
if node == nil {
return Black
}
return node.Color
}
-type RBTree[K Ordered[K], V any] struct {
- KeyFn func(V) K
- AttrFn func(*RBNode[V])
- root *RBNode[V]
+type RBTree[T Ordered[T]] struct {
+ AttrFn func(*RBNode[T])
+ root *RBNode[T]
len int
}
-func (t *RBTree[K, V]) Len() int {
+func (t *RBTree[T]) Len() int {
return t.len
}
-func (t *RBTree[K, V]) Walk(fn func(*RBNode[V]) error) error {
- return t.root.walk(fn)
+func (t *RBTree[T]) Range(fn func(*RBNode[T]) bool) {
+ t.root._range(fn)
}
-func (node *RBNode[V]) walk(fn func(*RBNode[V]) error) error {
+func (node *RBNode[T]) _range(fn func(*RBNode[T]) bool) bool {
if node == nil {
- return nil
+ return true
}
- if err := node.Left.walk(fn); err != nil {
- return err
+ if !node.Left._range(fn) {
+ return false
}
- if err := fn(node); err != nil {
- return err
+ if !fn(node) {
+ return false
}
- if err := node.Right.walk(fn); err != nil {
- return err
+ if !node.Right._range(fn) {
+ return false
}
- return nil
+ return true
}
// Search the tree for a value that satisfied the given callbackk
@@ -81,16 +78,13 @@ func (node *RBNode[V]) walk(fn func(*RBNode[V]) error) error {
// +---+ +---+
//
// Returns nil if no such value is found.
-//
-// Search is good for advanced lookup, like when a range of values is
-// acceptable. For simple exact-value lookup, use Lookup.
-func (t *RBTree[K, V]) Search(fn func(V) int) *RBNode[V] {
+func (t *RBTree[T]) Search(fn func(T) int) *RBNode[T] {
ret, _ := t.root.search(fn)
return ret
}
-func (node *RBNode[V]) search(fn func(V) int) (exact, nearest *RBNode[V]) {
- var prev *RBNode[V]
+func (node *RBNode[T]) search(fn func(T) int) (exact, nearest *RBNode[T]) {
+ var prev *RBNode[T]
for {
if node == nil {
return nil, prev
@@ -108,26 +102,13 @@ func (node *RBNode[V]) search(fn func(V) int) (exact, nearest *RBNode[V]) {
}
}
-func (t *RBTree[K, V]) exactKey(key K) func(V) int {
- return func(val V) int {
- valKey := t.KeyFn(val)
- return key.Compare(valKey)
- }
-}
-
-// Lookup looks up the value for an exact key. If no such value
-// exists, nil is returned.
-func (t *RBTree[K, V]) Lookup(key K) *RBNode[V] {
- return t.Search(t.exactKey(key))
-}
-
// Min returns the minimum value stored in the tree, or nil if the
// tree is empty.
-func (t *RBTree[K, V]) Min() *RBNode[V] {
+func (t *RBTree[T]) Min() *RBNode[T] {
return t.root.min()
}
-func (node *RBNode[V]) min() *RBNode[V] {
+func (node *RBNode[T]) min() *RBNode[T] {
if node == nil {
return nil
}
@@ -141,11 +122,11 @@ func (node *RBNode[V]) min() *RBNode[V] {
// Max returns the maximum value stored in the tree, or nil if the
// tree is empty.
-func (t *RBTree[K, V]) Max() *RBNode[V] {
+func (t *RBTree[T]) Max() *RBNode[T] {
return t.root.max()
}
-func (node *RBNode[V]) max() *RBNode[V] {
+func (node *RBNode[T]) max() *RBNode[T] {
if node == nil {
return nil
}
@@ -157,11 +138,7 @@ func (node *RBNode[V]) max() *RBNode[V] {
}
}
-func (t *RBTree[K, V]) Next(cur *RBNode[V]) *RBNode[V] {
- return cur.next()
-}
-
-func (cur *RBNode[V]) next() *RBNode[V] {
+func (cur *RBNode[T]) Next() *RBNode[T] {
if cur.Right != nil {
return cur.Right.min()
}
@@ -172,11 +149,7 @@ func (cur *RBNode[V]) next() *RBNode[V] {
return parent
}
-func (t *RBTree[K, V]) Prev(cur *RBNode[V]) *RBNode[V] {
- return cur.prev()
-}
-
-func (cur *RBNode[V]) prev() *RBNode[V] {
+func (cur *RBNode[T]) Prev() *RBNode[T] {
if cur.Left != nil {
return cur.Left.max()
}
@@ -187,48 +160,56 @@ func (cur *RBNode[V]) prev() *RBNode[V] {
return parent
}
-// SearchRange is like Search, but returns all nodes that match the
-// function; assuming that they are contiguous.
-func (t *RBTree[K, V]) SearchRange(fn func(V) int) []V {
- middle := t.Search(fn)
- if middle == nil {
- return nil
- }
- ret := []V{middle.Value}
- for node := t.Prev(middle); node != nil && fn(node.Value) == 0; node = t.Prev(node) {
- ret = append(ret, node.Value)
+// Subrange is like Search, but for when there may be more than one
+// result.
+func (t *RBTree[T]) Subrange(rangeFn func(T) int, handleFn func(*RBNode[T]) bool) {
+ // Find the left-most acceptable node.
+ _, node := t.root.search(func(v T) int {
+ if rangeFn(v) <= 0 {
+ return -1
+ } else {
+ return 1
+ }
+ })
+ for node != nil && rangeFn(node.Value) > 0 {
+ node = node.Next()
}
- slices.Reverse(ret)
- for node := t.Next(middle); node != nil && fn(node.Value) == 0; node = t.Next(node) {
- ret = append(ret, node.Value)
+ // Now walk forward until we hit the end.
+ for node != nil && rangeFn(node.Value) == 0 {
+ if keepGoing := handleFn(node); !keepGoing {
+ return
+ }
+ node = node.Next()
}
- return ret
}
-func (t *RBTree[K, V]) Equal(u *RBTree[K, V]) bool {
+func (t *RBTree[T]) Equal(u *RBTree[T]) bool {
if (t == nil) != (u == nil) {
return false
}
if t == nil {
return true
}
+ if t.len != u.len {
+ return false
+ }
- var tSlice []V
- _ = t.Walk(func(node *RBNode[V]) error {
+ tSlice := make([]T, 0, t.len)
+ t.Range(func(node *RBNode[T]) bool {
tSlice = append(tSlice, node.Value)
- return nil
+ return true
})
- var uSlice []V
- _ = u.Walk(func(node *RBNode[V]) error {
+ uSlice := make([]T, 0, u.len)
+ u.Range(func(node *RBNode[T]) bool {
uSlice = append(uSlice, node.Value)
- return nil
+ return true
})
return reflect.DeepEqual(tSlice, uSlice)
}
-func (t *RBTree[K, V]) parentChild(node *RBNode[V]) **RBNode[V] {
+func (t *RBTree[T]) parentChild(node *RBNode[T]) **RBNode[T] {
switch {
case node.Parent == nil:
return &t.root
@@ -241,7 +222,7 @@ func (t *RBTree[K, V]) parentChild(node *RBNode[V]) **RBNode[V] {
}
}
-func (t *RBTree[K, V]) updateAttr(node *RBNode[V]) {
+func (t *RBTree[T]) updateAttr(node *RBNode[T]) {
if t.AttrFn == nil {
return
}
@@ -251,7 +232,7 @@ func (t *RBTree[K, V]) updateAttr(node *RBNode[V]) {
}
}
-func (t *RBTree[K, V]) leftRotate(x *RBNode[V]) {
+func (t *RBTree[T]) leftRotate(x *RBNode[T]) {
// p p
// | |
// +---+ +---+
@@ -286,7 +267,7 @@ func (t *RBTree[K, V]) leftRotate(x *RBNode[V]) {
t.updateAttr(x)
}
-func (t *RBTree[K, V]) rightRotate(y *RBNode[V]) {
+func (t *RBTree[T]) rightRotate(y *RBNode[T]) {
//nolint:dupword
//
// | |
@@ -322,18 +303,17 @@ func (t *RBTree[K, V]) rightRotate(y *RBNode[V]) {
t.updateAttr(y)
}
-func (t *RBTree[K, V]) Insert(val V) {
+func (t *RBTree[T]) Insert(val T) {
// Naive-insert
- key := t.KeyFn(val)
- exact, parent := t.root.search(t.exactKey(key))
+ exact, parent := t.root.search(val.Compare)
if exact != nil {
exact.Value = val
return
}
t.len++
- node := &RBNode[V]{
+ node := &RBNode[T]{
Color: Red,
Parent: parent,
Value: val,
@@ -341,7 +321,7 @@ func (t *RBTree[K, V]) Insert(val V) {
switch {
case parent == nil:
t.root = node
- case key.Compare(t.KeyFn(parent.Value)) < 0:
+ case val.Compare(parent.Value) < 0:
parent.Left = node
default:
parent.Right = node
@@ -391,15 +371,14 @@ func (t *RBTree[K, V]) Insert(val V) {
t.root.Color = Black
}
-func (t *RBTree[K, V]) transplant(oldNode, newNode *RBNode[V]) {
+func (t *RBTree[T]) transplant(oldNode, newNode *RBNode[T]) {
*t.parentChild(oldNode) = newNode
if newNode != nil {
newNode.Parent = oldNode.Parent
}
}
-func (t *RBTree[K, V]) Delete(key K) {
- nodeToDelete := t.Lookup(key)
+func (t *RBTree[T]) Delete(nodeToDelete *RBNode[T]) {
if nodeToDelete == nil {
return
}
@@ -410,8 +389,8 @@ func (t *RBTree[K, V]) Delete(key K) {
// phase 1
- var nodeToRebalance *RBNode[V]
- var nodeToRebalanceParent *RBNode[V] // in case 'nodeToRebalance' is nil, which it can be
+ var nodeToRebalance *RBNode[T]
+ var nodeToRebalanceParent *RBNode[T] // in case 'nodeToRebalance' is nil, which it can be
needsRebalance := nodeToDelete.Color == Black
switch {
@@ -427,7 +406,7 @@ func (t *RBTree[K, V]) Delete(key K) {
// The node being deleted has a child on both sides,
// so we've go to reshuffle the parents a bit to make
// room for those children.
- next := nodeToDelete.next()
+ next := nodeToDelete.Next()
if next.Parent == nodeToDelete {
// p p
// | |
diff --git a/lib/containers/rbtree_test.go b/lib/containers/rbtree_test.go
index e42410e..d2fe931 100644
--- a/lib/containers/rbtree_test.go
+++ b/lib/containers/rbtree_test.go
@@ -7,21 +7,22 @@ package containers
import (
"fmt"
"io"
- "sort"
"strings"
"testing"
"github.com/stretchr/testify/require"
"golang.org/x/exp/constraints"
+
+ "git.lukeshu.com/btrfs-progs-ng/lib/slices"
)
-func (t *RBTree[K, V]) ASCIIArt() string {
+func (t *RBTree[T]) ASCIIArt() string {
var out strings.Builder
t.root.asciiArt(&out, "", "", "")
return out.String()
}
-func (node *RBNode[V]) String() string {
+func (node *RBNode[T]) String() string {
switch {
case node == nil:
return "nil"
@@ -32,7 +33,7 @@ func (node *RBNode[V]) String() string {
}
}
-func (node *RBNode[V]) asciiArt(w io.Writer, u, m, l string) {
+func (node *RBNode[T]) asciiArt(w io.Writer, u, m, l string) {
if node == nil {
fmt.Fprintf(w, "%snil\n", m)
return
@@ -43,7 +44,7 @@ func (node *RBNode[V]) asciiArt(w io.Writer, u, m, l string) {
node.Left.asciiArt(w, l+" | ", l+" `--", l+" ")
}
-func checkRBTree[K constraints.Ordered, V any](t *testing.T, expectedSet Set[K], tree *RBTree[NativeOrdered[K], V]) {
+func checkRBTree[T constraints.Ordered](t *testing.T, expectedSet Set[T], tree *RBTree[NativeOrdered[T]]) {
// 1. Every node is either red or black
// 2. The root is black.
@@ -52,19 +53,19 @@ func checkRBTree[K constraints.Ordered, V any](t *testing.T, expectedSet Set[K],
// 3. Every nil is black.
// 4. If a node is red, then both its children are black.
- require.NoError(t, tree.Walk(func(node *RBNode[V]) error {
+ tree.Range(func(node *RBNode[NativeOrdered[T]]) bool {
if node.getColor() == Red {
require.Equal(t, Black, node.Left.getColor())
require.Equal(t, Black, node.Right.getColor())
}
- return nil
- }))
+ return true
+ })
// 5. For each node, all simple paths from the node to
// descendent leaves contain the same number of black
// nodes.
- var walkCnt func(node *RBNode[V], cnt int, leafFn func(int))
- walkCnt = func(node *RBNode[V], cnt int, leafFn func(int)) {
+ var walkCnt func(node *RBNode[NativeOrdered[T]], cnt int, leafFn func(int))
+ walkCnt = func(node *RBNode[NativeOrdered[T]], cnt int, leafFn func(int)) {
if node.getColor() == Black {
cnt++
}
@@ -75,7 +76,7 @@ func checkRBTree[K constraints.Ordered, V any](t *testing.T, expectedSet Set[K],
walkCnt(node.Left, cnt, leafFn)
walkCnt(node.Right, cnt, leafFn)
}
- require.NoError(t, tree.Walk(func(node *RBNode[V]) error {
+ tree.Range(func(node *RBNode[NativeOrdered[T]]) bool {
var cnts []int
walkCnt(node, 0, func(cnt int) {
cnts = append(cnts, cnt)
@@ -86,27 +87,24 @@ func checkRBTree[K constraints.Ordered, V any](t *testing.T, expectedSet Set[K],
break
}
}
- return nil
- }))
+ return true
+ })
// expected contents
- expectedOrder := make([]K, 0, len(expectedSet))
- for k := range expectedSet {
- expectedOrder = append(expectedOrder, k)
- node := tree.Lookup(NativeOrdered[K]{Val: k})
+ expectedOrder := make([]T, 0, len(expectedSet))
+ for v := range expectedSet {
+ expectedOrder = append(expectedOrder, v)
+ node := tree.Search(NativeOrdered[T]{Val: v}.Compare)
require.NotNil(t, tree, node)
- require.Equal(t, k, tree.KeyFn(node.Value).Val)
}
- sort.Slice(expectedOrder, func(i, j int) bool {
- return expectedOrder[i] < expectedOrder[j]
+ slices.Sort(expectedOrder)
+ actOrder := make([]T, 0, len(expectedSet))
+ tree.Range(func(node *RBNode[NativeOrdered[T]]) bool {
+ actOrder = append(actOrder, node.Value.Val)
+ return true
})
- actOrder := make([]K, 0, len(expectedSet))
- require.NoError(t, tree.Walk(func(node *RBNode[V]) error {
- actOrder = append(actOrder, tree.KeyFn(node.Value).Val)
- return nil
- }))
require.Equal(t, expectedOrder, actOrder)
- require.Equal(t, tree.Len(), len(expectedSet))
+ require.Equal(t, len(expectedSet), tree.Len())
}
func FuzzRBTree(f *testing.F) {
@@ -132,11 +130,7 @@ func FuzzRBTree(f *testing.F) {
})
f.Fuzz(func(t *testing.T, dat []uint8) {
- tree := &RBTree[NativeOrdered[uint8], uint8]{
- KeyFn: func(x uint8) NativeOrdered[uint8] {
- return NativeOrdered[uint8]{Val: x}
- },
- }
+ tree := new(RBTree[NativeOrdered[uint8]])
set := make(Set[uint8])
checkRBTree(t, set, tree)
t.Logf("\n%s\n", tree.ASCIIArt())
@@ -145,18 +139,18 @@ func FuzzRBTree(f *testing.F) {
val := (b & 0b0011_1111)
if ins {
t.Logf("Insert(%v)", val)
- tree.Insert(val)
+ tree.Insert(NativeOrdered[uint8]{Val: val})
set.Insert(val)
t.Logf("\n%s\n", tree.ASCIIArt())
- node := tree.Lookup(NativeOrdered[uint8]{Val: val})
+ node := tree.Search(NativeOrdered[uint8]{Val: val}.Compare)
require.NotNil(t, node)
- require.Equal(t, val, node.Value)
+ require.Equal(t, val, node.Value.Val)
} else {
t.Logf("Delete(%v)", val)
- tree.Delete(NativeOrdered[uint8]{Val: val})
+ tree.Delete(tree.Search(NativeOrdered[uint8]{Val: val}.Compare))
delete(set, val)
t.Logf("\n%s\n", tree.ASCIIArt())
- require.Nil(t, tree.Lookup(NativeOrdered[uint8]{Val: val}))
+ require.Nil(t, tree.Search(NativeOrdered[uint8]{Val: val}.Compare))
}
checkRBTree(t, set, tree)
}
diff --git a/lib/containers/sortedmap.go b/lib/containers/sortedmap.go
index 52308c9..d104274 100644
--- a/lib/containers/sortedmap.go
+++ b/lib/containers/sortedmap.go
@@ -1,40 +1,32 @@
-// Copyright (C) 2022 Luke Shumaker <lukeshu@lukeshu.com>
+// Copyright (C) 2022-2023 Luke Shumaker <lukeshu@lukeshu.com>
//
// SPDX-License-Identifier: GPL-2.0-or-later
package containers
-import (
- "errors"
-)
-
-type OrderedKV[K Ordered[K], V any] struct {
+type orderedKV[K Ordered[K], V any] struct {
K K
V V
}
-type SortedMap[K Ordered[K], V any] struct {
- inner RBTree[K, OrderedKV[K, V]]
-}
-
-func (m *SortedMap[K, V]) init() {
- if m.inner.KeyFn == nil {
- m.inner.KeyFn = m.keyFn
- }
+func (a orderedKV[K, V]) Compare(b orderedKV[K, V]) int {
+ return a.K.Compare(b.K)
}
-func (m *SortedMap[K, V]) keyFn(kv OrderedKV[K, V]) K {
- return kv.K
+type SortedMap[K Ordered[K], V any] struct {
+ inner RBTree[orderedKV[K, V]]
}
func (m *SortedMap[K, V]) Delete(key K) {
- m.init()
- m.inner.Delete(key)
+ m.inner.Delete(m.inner.Search(func(kv orderedKV[K, V]) int {
+ return key.Compare(kv.K)
+ }))
}
func (m *SortedMap[K, V]) Load(key K) (value V, ok bool) {
- m.init()
- node := m.inner.Lookup(key)
+ node := m.inner.Search(func(kv orderedKV[K, V]) int {
+ return key.Compare(kv.K)
+ })
if node == nil {
var zero V
return zero, false
@@ -42,41 +34,27 @@ func (m *SortedMap[K, V]) Load(key K) (value V, ok bool) {
return node.Value.V, true
}
-var errStop = errors.New("stop")
-
-func (m *SortedMap[K, V]) Range(f func(key K, value V) bool) {
- m.init()
- _ = m.inner.Walk(func(node *RBNode[OrderedKV[K, V]]) error {
- if f(node.Value.K, node.Value.V) {
- return nil
- } else {
- return errStop
- }
+func (m *SortedMap[K, V]) Store(key K, value V) {
+ m.inner.Insert(orderedKV[K, V]{
+ K: key,
+ V: value,
})
}
-func (m *SortedMap[K, V]) Subrange(rangeFn func(K, V) int, handleFn func(K, V) bool) {
- m.init()
- kvs := m.inner.SearchRange(func(kv OrderedKV[K, V]) int {
- return rangeFn(kv.K, kv.V)
+func (m *SortedMap[K, V]) Range(fn func(key K, value V) bool) {
+ m.inner.Range(func(node *RBNode[orderedKV[K, V]]) bool {
+ return fn(node.Value.K, node.Value.V)
})
- for _, kv := range kvs {
- if !handleFn(kv.K, kv.V) {
- break
- }
- }
}
-func (m *SortedMap[K, V]) Store(key K, value V) {
- m.init()
- m.inner.Insert(OrderedKV[K, V]{
- K: key,
- V: value,
- })
+func (m *SortedMap[K, V]) Subrange(rangeFn func(K, V) int, handleFn func(K, V) bool) {
+ m.inner.Subrange(
+ func(kv orderedKV[K, V]) int { return rangeFn(kv.K, kv.V) },
+ func(node *RBNode[orderedKV[K, V]]) bool { return handleFn(node.Value.K, node.Value.V) })
}
func (m *SortedMap[K, V]) Search(fn func(K, V) int) (K, V, bool) {
- node := m.inner.Search(func(kv OrderedKV[K, V]) int {
+ node := m.inner.Search(func(kv orderedKV[K, V]) int {
return fn(kv.K, kv.V)
})
if node == nil {
@@ -87,12 +65,6 @@ func (m *SortedMap[K, V]) Search(fn func(K, V) int) (K, V, bool) {
return node.Value.K, node.Value.V, true
}
-func (m *SortedMap[K, V]) SearchAll(fn func(K, V) int) []OrderedKV[K, V] {
- return m.inner.SearchRange(func(kv OrderedKV[K, V]) int {
- return fn(kv.K, kv.V)
- })
-}
-
func (m *SortedMap[K, V]) Len() int {
return m.inner.Len()
}