summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--lib/btrfs/btrfsvol/chunk.go5
-rw-r--r--lib/btrfs/btrfsvol/devext.go5
-rw-r--r--lib/btrfs/btrfsvol/lvm.go67
-rw-r--r--lib/btrfsprogs/btrfsinspect/rebuildmappings/logicalsums.go14
-rw-r--r--lib/btrfsprogs/btrfsinspect/rebuildnodes/rebuild.go56
-rw-r--r--lib/btrfsprogs/btrfsinspect/scandevices.go6
-rw-r--r--lib/btrfsprogs/btrfsutil/broken_btree.go53
-rw-r--r--lib/containers/intervaltree.go132
-rw-r--r--lib/containers/intervaltree_test.go28
-rw-r--r--lib/containers/rbtree.go161
-rw-r--r--lib/containers/rbtree_test.go66
-rw-r--r--lib/containers/sortedmap.go76
12 files changed, 320 insertions, 349 deletions
diff --git a/lib/btrfs/btrfsvol/chunk.go b/lib/btrfs/btrfsvol/chunk.go
index 08a0b2b..a112fd3 100644
--- a/lib/btrfs/btrfsvol/chunk.go
+++ b/lib/btrfs/btrfsvol/chunk.go
@@ -22,7 +22,10 @@ type chunkMapping struct {
Flags containers.Optional[BlockGroupFlags]
}
-type ChunkMapping = chunkMapping
+// Compare implements containers.Ordered.
+func (a chunkMapping) Compare(b chunkMapping) int {
+ return containers.NativeCompare(a.LAddr, b.LAddr)
+}
// return -1 if 'a' is wholly to the left of 'b'
// return 0 if there is some overlap between 'a' and 'b'
diff --git a/lib/btrfs/btrfsvol/devext.go b/lib/btrfs/btrfsvol/devext.go
index 037021c..3324476 100644
--- a/lib/btrfs/btrfsvol/devext.go
+++ b/lib/btrfs/btrfsvol/devext.go
@@ -20,6 +20,11 @@ type devextMapping struct {
Flags containers.Optional[BlockGroupFlags]
}
+// Compare implements containers.Ordered.
+func (a devextMapping) Compare(b devextMapping) int {
+ return containers.NativeCompare(a.PAddr, b.PAddr)
+}
+
// return -1 if 'a' is wholly to the left of 'b'
// return 0 if there is some overlap between 'a' and 'b'
// return 1 if 'a is wholly to the right of 'b'
diff --git a/lib/btrfs/btrfsvol/lvm.go b/lib/btrfs/btrfsvol/lvm.go
index 59ca609..93ec438 100644
--- a/lib/btrfs/btrfsvol/lvm.go
+++ b/lib/btrfs/btrfsvol/lvm.go
@@ -22,8 +22,8 @@ type LogicalVolume[PhysicalVolume diskio.File[PhysicalAddr]] struct {
id2pv map[DeviceID]PhysicalVolume
- logical2physical *containers.RBTree[containers.NativeOrdered[LogicalAddr], chunkMapping]
- physical2logical map[DeviceID]*containers.RBTree[containers.NativeOrdered[PhysicalAddr], devextMapping]
+ logical2physical *containers.RBTree[chunkMapping]
+ physical2logical map[DeviceID]*containers.RBTree[devextMapping]
}
var _ diskio.File[LogicalAddr] = (*LogicalVolume[diskio.File[PhysicalAddr]])(nil)
@@ -33,22 +33,14 @@ func (lv *LogicalVolume[PhysicalVolume]) init() {
lv.id2pv = make(map[DeviceID]PhysicalVolume)
}
if lv.logical2physical == nil {
- lv.logical2physical = &containers.RBTree[containers.NativeOrdered[LogicalAddr], chunkMapping]{
- KeyFn: func(chunk chunkMapping) containers.NativeOrdered[LogicalAddr] {
- return containers.NativeOrdered[LogicalAddr]{Val: chunk.LAddr}
- },
- }
+ lv.logical2physical = new(containers.RBTree[chunkMapping])
}
if lv.physical2logical == nil {
- lv.physical2logical = make(map[DeviceID]*containers.RBTree[containers.NativeOrdered[PhysicalAddr], devextMapping], len(lv.id2pv))
+ lv.physical2logical = make(map[DeviceID]*containers.RBTree[devextMapping], len(lv.id2pv))
}
for devid := range lv.id2pv {
if _, ok := lv.physical2logical[devid]; !ok {
- lv.physical2logical[devid] = &containers.RBTree[containers.NativeOrdered[PhysicalAddr], devextMapping]{
- KeyFn: func(ext devextMapping) containers.NativeOrdered[PhysicalAddr] {
- return containers.NativeOrdered[PhysicalAddr]{Val: ext.PAddr}
- },
- }
+ lv.physical2logical[devid] = new(containers.RBTree[devextMapping])
}
}
}
@@ -90,11 +82,7 @@ func (lv *LogicalVolume[PhysicalVolume]) AddPhysicalVolume(id DeviceID, dev Phys
lv, dev.Name(), other.Name(), id)
}
lv.id2pv[id] = dev
- lv.physical2logical[id] = &containers.RBTree[containers.NativeOrdered[PhysicalAddr], devextMapping]{
- KeyFn: func(ext devextMapping) containers.NativeOrdered[PhysicalAddr] {
- return containers.NativeOrdered[PhysicalAddr]{Val: ext.PAddr}
- },
- }
+ lv.physical2logical[id] = new(containers.RBTree[devextMapping])
return nil
}
@@ -143,11 +131,13 @@ func (lv *LogicalVolume[PhysicalVolume]) addMapping(m Mapping, dryRun bool) erro
SizeLocked: m.SizeLocked,
Flags: m.Flags,
}
- logicalOverlaps := lv.logical2physical.SearchRange(newChunk.compareRange)
+ var logicalOverlaps []chunkMapping
numOverlappingStripes := 0
- for _, chunk := range logicalOverlaps {
- numOverlappingStripes += len(chunk.PAddrs)
- }
+ lv.logical2physical.Subrange(newChunk.compareRange, func(node *containers.RBNode[chunkMapping]) bool {
+ logicalOverlaps = append(logicalOverlaps, node.Value)
+ numOverlappingStripes += len(node.Value.PAddrs)
+ return true
+ })
var err error
newChunk, err = newChunk.union(logicalOverlaps...)
if err != nil {
@@ -162,7 +152,11 @@ func (lv *LogicalVolume[PhysicalVolume]) addMapping(m Mapping, dryRun bool) erro
SizeLocked: m.SizeLocked,
Flags: m.Flags,
}
- physicalOverlaps := lv.physical2logical[m.PAddr.Dev].SearchRange(newExt.compareRange)
+ var physicalOverlaps []devextMapping
+ lv.physical2logical[m.PAddr.Dev].Subrange(newExt.compareRange, func(node *containers.RBNode[devextMapping]) bool {
+ physicalOverlaps = append(physicalOverlaps, node.Value)
+ return true
+ })
newExt, err = newExt.union(physicalOverlaps...)
if err != nil {
return fmt.Errorf("(%p).AddMapping: %w", lv, err)
@@ -202,13 +196,13 @@ func (lv *LogicalVolume[PhysicalVolume]) addMapping(m Mapping, dryRun bool) erro
// logical2physical
for _, chunk := range logicalOverlaps {
- lv.logical2physical.Delete(containers.NativeOrdered[LogicalAddr]{Val: chunk.LAddr})
+ lv.logical2physical.Delete(lv.logical2physical.Search(chunk.Compare))
}
lv.logical2physical.Insert(newChunk)
// physical2logical
for _, ext := range physicalOverlaps {
- lv.physical2logical[m.PAddr.Dev].Delete(containers.NativeOrdered[PhysicalAddr]{Val: ext.PAddr})
+ lv.physical2logical[m.PAddr.Dev].Delete(lv.physical2logical[m.PAddr.Dev].Search(ext.Compare))
}
lv.physical2logical[m.PAddr.Dev].Insert(newExt)
@@ -227,20 +221,18 @@ func (lv *LogicalVolume[PhysicalVolume]) addMapping(m Mapping, dryRun bool) erro
}
func (lv *LogicalVolume[PhysicalVolume]) fsck() error {
- physical2logical := make(map[DeviceID]*containers.RBTree[containers.NativeOrdered[PhysicalAddr], devextMapping])
- if err := lv.logical2physical.Walk(func(node *containers.RBNode[chunkMapping]) error {
+ physical2logical := make(map[DeviceID]*containers.RBTree[devextMapping])
+ var err error
+ lv.logical2physical.Range(func(node *containers.RBNode[chunkMapping]) bool {
chunk := node.Value
for _, stripe := range chunk.PAddrs {
if _, devOK := lv.id2pv[stripe.Dev]; !devOK {
- return fmt.Errorf("(%p).fsck: chunk references physical volume %v which does not exist",
+ err = fmt.Errorf("(%p).fsck: chunk references physical volume %v which does not exist",
lv, stripe.Dev)
+ return false
}
if _, exists := physical2logical[stripe.Dev]; !exists {
- physical2logical[stripe.Dev] = &containers.RBTree[containers.NativeOrdered[PhysicalAddr], devextMapping]{
- KeyFn: func(ext devextMapping) containers.NativeOrdered[PhysicalAddr] {
- return containers.NativeOrdered[PhysicalAddr]{Val: ext.PAddr}
- },
- }
+ physical2logical[stripe.Dev] = new(containers.RBTree[devextMapping])
}
physical2logical[stripe.Dev].Insert(devextMapping{
PAddr: stripe.Addr,
@@ -249,8 +241,9 @@ func (lv *LogicalVolume[PhysicalVolume]) fsck() error {
Flags: chunk.Flags,
})
}
- return nil
- }); err != nil {
+ return true
+ })
+ if err != nil {
return err
}
@@ -270,7 +263,7 @@ func (lv *LogicalVolume[PhysicalVolume]) fsck() error {
func (lv *LogicalVolume[PhysicalVolume]) Mappings() []Mapping {
var ret []Mapping
- _ = lv.logical2physical.Walk(func(node *containers.RBNode[chunkMapping]) error {
+ lv.logical2physical.Range(func(node *containers.RBNode[chunkMapping]) bool {
chunk := node.Value
for _, slice := range chunk.PAddrs {
ret = append(ret, Mapping{
@@ -280,7 +273,7 @@ func (lv *LogicalVolume[PhysicalVolume]) Mappings() []Mapping {
Flags: chunk.Flags,
})
}
- return nil
+ return true
})
return ret
}
diff --git a/lib/btrfsprogs/btrfsinspect/rebuildmappings/logicalsums.go b/lib/btrfsprogs/btrfsinspect/rebuildmappings/logicalsums.go
index 69d14c7..7c02d05 100644
--- a/lib/btrfsprogs/btrfsinspect/rebuildmappings/logicalsums.go
+++ b/lib/btrfsprogs/btrfsinspect/rebuildmappings/logicalsums.go
@@ -53,11 +53,7 @@ func ExtractLogicalSums(ctx context.Context, scanResults btrfsinspect.ScanDevice
// "AAAAAAA" shouldn't be present, and if we just discard "BBBBBBBB"
// because it conflicts with "CCCCCCC", then we would erroneously
// include "AAAAAAA".
- addrspace := &containers.RBTree[containers.NativeOrdered[btrfsvol.LogicalAddr], btrfsinspect.SysExtentCSum]{
- KeyFn: func(item btrfsinspect.SysExtentCSum) containers.NativeOrdered[btrfsvol.LogicalAddr] {
- return containers.NativeOrdered[btrfsvol.LogicalAddr]{Val: item.Sums.Addr}
- },
- }
+ addrspace := new(containers.RBTree[btrfsinspect.SysExtentCSum])
for _, newRecord := range records {
for {
conflict := addrspace.Search(func(oldRecord btrfsinspect.SysExtentCSum) int {
@@ -85,7 +81,7 @@ func ExtractLogicalSums(ctx context.Context, scanResults btrfsinspect.ScanDevice
}
if oldRecord.Generation < newRecord.Generation {
// Newer generation wins.
- addrspace.Delete(containers.NativeOrdered[btrfsvol.LogicalAddr]{Val: oldRecord.Sums.Addr})
+ addrspace.Delete(conflict)
// loop around to check for more conflicts
continue
}
@@ -142,7 +138,7 @@ func ExtractLogicalSums(ctx context.Context, scanResults btrfsinspect.ScanDevice
},
},
}
- addrspace.Delete(containers.NativeOrdered[btrfsvol.LogicalAddr]{Val: oldRecord.Sums.Addr})
+ addrspace.Delete(conflict)
newRecord = unionRecord
// loop around to check for more conflicts
}
@@ -152,7 +148,7 @@ func ExtractLogicalSums(ctx context.Context, scanResults btrfsinspect.ScanDevice
var flattened SumRunWithGaps[btrfsvol.LogicalAddr]
var curAddr btrfsvol.LogicalAddr
var curSums strings.Builder
- _ = addrspace.Walk(func(node *containers.RBNode[btrfsinspect.SysExtentCSum]) error {
+ addrspace.Range(func(node *containers.RBNode[btrfsinspect.SysExtentCSum]) bool {
curEnd := curAddr + (btrfsvol.LogicalAddr(curSums.Len()/sumSize) * btrfssum.BlockSize)
if node.Value.Sums.Addr != curEnd {
if curSums.Len() > 0 {
@@ -166,7 +162,7 @@ func ExtractLogicalSums(ctx context.Context, scanResults btrfsinspect.ScanDevice
curSums.Reset()
}
curSums.WriteString(string(node.Value.Sums.Sums))
- return nil
+ return true
})
if curSums.Len() > 0 {
flattened.Runs = append(flattened.Runs, btrfssum.SumRun[btrfsvol.LogicalAddr]{
diff --git a/lib/btrfsprogs/btrfsinspect/rebuildnodes/rebuild.go b/lib/btrfsprogs/btrfsinspect/rebuildnodes/rebuild.go
index ee5950d..bd29278 100644
--- a/lib/btrfsprogs/btrfsinspect/rebuildnodes/rebuild.go
+++ b/lib/btrfsprogs/btrfsinspect/rebuildnodes/rebuild.go
@@ -770,6 +770,16 @@ func (o *rebuilder) _walkRange(
})
}
+type gap struct {
+ // range is [Beg,End)
+ Beg, End uint64
+}
+
+// Compare implements containers.Ordered.
+func (a gap) Compare(b gap) int {
+ return containers.NativeCompare(a.Beg, b.Beg)
+}
+
func (o *rebuilder) _wantRange(
ctx context.Context,
treeID btrfsprim.ObjID, objID btrfsprim.ObjID, typ btrfsprim.ItemType,
@@ -787,15 +797,7 @@ func (o *rebuilder) _wantRange(
//
// Start with a gap of the whole range, then subtract each run
// from it.
- type gap struct {
- // range is [Beg,End)
- Beg, End uint64
- }
- gaps := &containers.RBTree[containers.NativeOrdered[uint64], gap]{
- KeyFn: func(gap gap) containers.NativeOrdered[uint64] {
- return containers.NativeOrdered[uint64]{Val: gap.Beg}
- },
- }
+ gaps := new(containers.RBTree[gap])
gaps.Insert(gap{
Beg: beg,
End: end,
@@ -805,23 +807,29 @@ func (o *rebuilder) _wantRange(
o.rebuilt.Tree(ctx, treeID).Items(ctx),
treeID, objID, typ, beg, end,
func(runKey btrfsprim.Key, _ keyio.ItemPtr, runBeg, runEnd uint64) {
- overlappingGaps := gaps.SearchRange(func(gap gap) int {
- switch {
- case gap.End <= runBeg:
- return 1
- case runEnd <= gap.Beg:
- return -1
- default:
- return 0
- }
- })
+ var overlappingGaps []*containers.RBNode[gap]
+ gaps.Subrange(
+ func(gap gap) int {
+ switch {
+ case gap.End <= runBeg:
+ return 1
+ case runEnd <= gap.Beg:
+ return -1
+ default:
+ return 0
+ }
+ },
+ func(node *containers.RBNode[gap]) bool {
+ overlappingGaps = append(overlappingGaps, node)
+ return true
+ })
if len(overlappingGaps) == 0 {
return
}
- gapsBeg := overlappingGaps[0].Beg
- gapsEnd := overlappingGaps[len(overlappingGaps)-1].End
+ gapsBeg := overlappingGaps[0].Value.Beg
+ gapsEnd := overlappingGaps[len(overlappingGaps)-1].Value.End
for _, gap := range overlappingGaps {
- gaps.Delete(containers.NativeOrdered[uint64]{Val: gap.Beg})
+ gaps.Delete(gap)
}
if gapsBeg < runBeg {
gaps.Insert(gap{
@@ -842,7 +850,7 @@ func (o *rebuilder) _wantRange(
return
}
potentialItems := o.rebuilt.Tree(ctx, treeID).PotentialItems(ctx)
- _ = gaps.Walk(func(rbNode *containers.RBNode[gap]) error {
+ gaps.Range(func(rbNode *containers.RBNode[gap]) bool {
gap := rbNode.Value
last := gap.Beg
o._walkRange(
@@ -874,7 +882,7 @@ func (o *rebuilder) _wantRange(
o.wantAugment(wantCtx, treeID, wantKey, nil)
}
}
- return nil
+ return true
})
}
diff --git a/lib/btrfsprogs/btrfsinspect/scandevices.go b/lib/btrfsprogs/btrfsinspect/scandevices.go
index 7668a83..9b8360c 100644
--- a/lib/btrfsprogs/btrfsinspect/scandevices.go
+++ b/lib/btrfsprogs/btrfsinspect/scandevices.go
@@ -22,6 +22,7 @@ import (
"git.lukeshu.com/btrfs-progs-ng/lib/btrfs/btrfssum"
"git.lukeshu.com/btrfs-progs-ng/lib/btrfs/btrfstree"
"git.lukeshu.com/btrfs-progs-ng/lib/btrfs/btrfsvol"
+ "git.lukeshu.com/btrfs-progs-ng/lib/containers"
"git.lukeshu.com/btrfs-progs-ng/lib/textui"
)
@@ -79,6 +80,11 @@ type SysExtentCSum struct {
Sums btrfsitem.ExtentCSum
}
+// Compare implements containers.Ordered.
+func (a SysExtentCSum) Compare(b SysExtentCSum) int {
+ return containers.NativeCompare(a.Sums.Addr, b.Sums.Addr)
+}
+
type scanStats struct {
textui.Portion[btrfsvol.PhysicalAddr]
diff --git a/lib/btrfsprogs/btrfsutil/broken_btree.go b/lib/btrfsprogs/btrfsutil/broken_btree.go
index 8261119..7ea31ce 100644
--- a/lib/btrfsprogs/btrfsutil/broken_btree.go
+++ b/lib/btrfsprogs/btrfsutil/broken_btree.go
@@ -23,7 +23,7 @@ import (
type treeIndex struct {
TreeRootErr error
- Items *containers.RBTree[btrfsprim.Key, treeIndexValue]
+ Items *containers.RBTree[treeIndexValue]
Errors *containers.IntervalTree[btrfsprim.Key, treeIndexError]
}
@@ -38,13 +38,14 @@ type treeIndexValue struct {
ItemSize uint32
}
+// Compare implements containers.Ordered.
+func (a treeIndexValue) Compare(b treeIndexValue) int {
+ return a.Key.Compare(b.Key)
+}
+
func newTreeIndex(arena *SkinnyPathArena) treeIndex {
return treeIndex{
- Items: &containers.RBTree[btrfsprim.Key, treeIndexValue]{
- KeyFn: func(iv treeIndexValue) btrfsprim.Key {
- return iv.Key
- },
- },
+ Items: new(containers.RBTree[treeIndexValue]),
Errors: &containers.IntervalTree[btrfsprim.Key, treeIndexError]{
MinFn: func(err treeIndexError) btrfsprim.Key {
return arena.Inflate(err.Path).Node(-1).ToKey
@@ -173,7 +174,7 @@ func (bt *brokenTrees) rawTreeWalk(root btrfstree.TreeRoot, cacheEntry treeIndex
},
btrfstree.TreeWalkHandler{
Item: func(path btrfstree.TreePath, item btrfstree.Item) error {
- if cacheEntry.Items.Lookup(item.Key) != nil {
+ if cacheEntry.Items.Search(func(v treeIndexValue) int { return item.Key.Compare(v.Key) }) != nil {
// This is a panic because I'm not really sure what the best way to
// handle this is, and so if this happens I want the program to crash
// and force me to figure out how to handle it.
@@ -203,15 +204,15 @@ func (bt *brokenTrees) TreeLookup(treeID btrfsprim.ObjID, key btrfsprim.Key) (bt
func (bt *brokenTrees) addErrs(index treeIndex, fn func(btrfsprim.Key, uint32) int, err error) error {
var errs derror.MultiError
- if _errs := index.Errors.SearchAll(func(k btrfsprim.Key) int { return fn(k, 0) }); len(_errs) > 0 {
- errs = make(derror.MultiError, len(_errs))
- for i := range _errs {
- errs[i] = &btrfstree.TreeError{
- Path: bt.arena.Inflate(_errs[i].Path),
- Err: _errs[i].Err,
- }
- }
- }
+ index.Errors.Subrange(
+ func(k btrfsprim.Key) int { return fn(k, 0) },
+ func(v treeIndexError) bool {
+ errs = append(errs, &btrfstree.TreeError{
+ Path: bt.arena.Inflate(v.Path),
+ Err: v.Err,
+ })
+ return true
+ })
if len(errs) == 0 {
return err
}
@@ -253,9 +254,13 @@ func (bt *brokenTrees) TreeSearchAll(treeID btrfsprim.ObjID, fn func(btrfsprim.K
return nil, index.TreeRootErr
}
- indexItems := index.Items.SearchRange(func(indexItem treeIndexValue) int {
- return fn(indexItem.Key, indexItem.ItemSize)
- })
+ var indexItems []treeIndexValue
+ index.Items.Subrange(
+ func(indexItem treeIndexValue) int { return fn(indexItem.Key, indexItem.ItemSize) },
+ func(node *containers.RBNode[treeIndexValue]) bool {
+ indexItems = append(indexItems, node.Value)
+ return true
+ })
if len(indexItems) == 0 {
return nil, bt.addErrs(index, fn, iofs.ErrNotExist)
}
@@ -290,12 +295,12 @@ func (bt *brokenTrees) TreeWalk(ctx context.Context, treeID btrfsprim.ObjID, err
return
}
var node *diskio.Ref[btrfsvol.LogicalAddr, btrfstree.Node]
- _ = index.Items.Walk(func(indexItem *containers.RBNode[treeIndexValue]) error {
+ index.Items.Range(func(indexItem *containers.RBNode[treeIndexValue]) bool {
if ctx.Err() != nil {
- return ctx.Err()
+ return false
}
if bt.ctx.Err() != nil {
- return bt.ctx.Err()
+ return false
}
if cbs.Item != nil {
itemPath := bt.arena.Inflate(indexItem.Value.Path)
@@ -304,7 +309,7 @@ func (bt *brokenTrees) TreeWalk(ctx context.Context, treeID btrfsprim.ObjID, err
node, err = bt.inner.ReadNode(itemPath.Parent())
if err != nil {
errHandle(&btrfstree.TreeError{Path: itemPath, Err: err})
- return nil //nolint:nilerr // We already called errHandle().
+ return true
}
}
item := node.Data.BodyLeaf[itemPath.Node(-1).FromItemIdx]
@@ -312,7 +317,7 @@ func (bt *brokenTrees) TreeWalk(ctx context.Context, treeID btrfsprim.ObjID, err
errHandle(&btrfstree.TreeError{Path: itemPath, Err: err})
}
}
- return nil
+ return true
})
}
diff --git a/lib/containers/intervaltree.go b/lib/containers/intervaltree.go
index b7ff866..7b96526 100644
--- a/lib/containers/intervaltree.go
+++ b/lib/containers/intervaltree.go
@@ -4,81 +4,80 @@
package containers
-type intervalKey[K Ordered[K]] struct {
+type interval[K Ordered[K]] struct {
Min, Max K
}
-func (ival intervalKey[K]) ContainsFn(fn func(K) int) bool {
- return fn(ival.Min) >= 0 && fn(ival.Max) <= 0
-}
-
-func (a intervalKey[K]) Compare(b intervalKey[K]) int {
+// Compare implements Ordered.
+func (a interval[K]) Compare(b interval[K]) int {
if d := a.Min.Compare(b.Min); d != 0 {
return d
}
return a.Max.Compare(b.Max)
}
+// ContainsFn returns whether this interval contains the range matched
+// by the given function.
+func (ival interval[K]) ContainsFn(fn func(K) int) bool {
+ return fn(ival.Min) >= 0 && fn(ival.Max) <= 0
+}
+
type intervalValue[K Ordered[K], V any] struct {
- Val V
- SpanOfChildren intervalKey[K]
+ Val V
+ ValSpan interval[K]
+ ChildSpan interval[K]
+}
+
+// Compare implements Ordered.
+func (a intervalValue[K, V]) Compare(b intervalValue[K, V]) int {
+ return a.ValSpan.Compare(b.ValSpan)
}
type IntervalTree[K Ordered[K], V any] struct {
MinFn func(V) K
MaxFn func(V) K
- inner RBTree[intervalKey[K], intervalValue[K, V]]
-}
-
-func (t *IntervalTree[K, V]) keyFn(v intervalValue[K, V]) intervalKey[K] {
- return intervalKey[K]{
- Min: t.MinFn(v.Val),
- Max: t.MaxFn(v.Val),
- }
+ inner RBTree[intervalValue[K, V]]
}
func (t *IntervalTree[K, V]) attrFn(node *RBNode[intervalValue[K, V]]) {
- max := t.MaxFn(node.Value.Val)
- if node.Left != nil && node.Left.Value.SpanOfChildren.Max.Compare(max) > 0 {
- max = node.Left.Value.SpanOfChildren.Max
+ max := node.Value.ValSpan.Max
+ if node.Left != nil && node.Left.Value.ChildSpan.Max.Compare(max) > 0 {
+ max = node.Left.Value.ChildSpan.Max
}
- if node.Right != nil && node.Right.Value.SpanOfChildren.Max.Compare(max) > 0 {
- max = node.Right.Value.SpanOfChildren.Max
+ if node.Right != nil && node.Right.Value.ChildSpan.Max.Compare(max) > 0 {
+ max = node.Right.Value.ChildSpan.Max
}
- node.Value.SpanOfChildren.Max = max
+ node.Value.ChildSpan.Max = max
- min := t.MinFn(node.Value.Val)
- if node.Left != nil && node.Left.Value.SpanOfChildren.Min.Compare(min) < 0 {
- min = node.Left.Value.SpanOfChildren.Min
+ min := node.Value.ValSpan.Min
+ if node.Left != nil && node.Left.Value.ChildSpan.Min.Compare(min) < 0 {
+ min = node.Left.Value.ChildSpan.Min
}
- if node.Right != nil && node.Right.Value.SpanOfChildren.Min.Compare(min) < 0 {
- min = node.Right.Value.SpanOfChildren.Min
+ if node.Right != nil && node.Right.Value.ChildSpan.Min.Compare(min) < 0 {
+ min = node.Right.Value.ChildSpan.Min
}
- node.Value.SpanOfChildren.Min = min
+ node.Value.ChildSpan.Min = min
}
func (t *IntervalTree[K, V]) init() {
- if t.inner.KeyFn == nil {
- t.inner.KeyFn = t.keyFn
+ if t.inner.AttrFn == nil {
t.inner.AttrFn = t.attrFn
}
}
-func (t *IntervalTree[K, V]) Delete(min, max K) {
- t.init()
- t.inner.Delete(intervalKey[K]{
- Min: min,
- Max: max,
- })
-}
-
func (t *IntervalTree[K, V]) Equal(u *IntervalTree[K, V]) bool {
return t.inner.Equal(&u.inner)
}
func (t *IntervalTree[K, V]) Insert(val V) {
t.init()
- t.inner.Insert(intervalValue[K, V]{Val: val})
+ t.inner.Insert(intervalValue[K, V]{
+ Val: val,
+ ValSpan: interval[K]{
+ Min: t.MinFn(val),
+ Max: t.MaxFn(val),
+ },
+ })
}
func (t *IntervalTree[K, V]) Min() (K, bool) {
@@ -86,7 +85,7 @@ func (t *IntervalTree[K, V]) Min() (K, bool) {
var zero K
return zero, false
}
- return t.inner.root.Value.SpanOfChildren.Min, true
+ return t.inner.root.Value.ChildSpan.Min, true
}
func (t *IntervalTree[K, V]) Max() (K, bool) {
@@ -94,22 +93,18 @@ func (t *IntervalTree[K, V]) Max() (K, bool) {
var zero K
return zero, false
}
- return t.inner.root.Value.SpanOfChildren.Max, true
-}
-
-func (t *IntervalTree[K, V]) Lookup(k K) (V, bool) {
- return t.Search(k.Compare)
+ return t.inner.root.Value.ChildSpan.Max, true
}
func (t *IntervalTree[K, V]) Search(fn func(K) int) (V, bool) {
node := t.inner.root
for node != nil {
switch {
- case t.keyFn(node.Value).ContainsFn(fn):
+ case node.Value.ValSpan.ContainsFn(fn):
return node.Value.Val, true
- case node.Left != nil && node.Left.Value.SpanOfChildren.ContainsFn(fn):
+ case node.Left != nil && node.Left.Value.ChildSpan.ContainsFn(fn):
node = node.Left
- case node.Right != nil && node.Right.Value.SpanOfChildren.ContainsFn(fn):
+ case node.Right != nil && node.Right.Value.ChildSpan.ContainsFn(fn):
node = node.Right
default:
node = nil
@@ -119,24 +114,33 @@ func (t *IntervalTree[K, V]) Search(fn func(K) int) (V, bool) {
return zero, false
}
-func (t *IntervalTree[K, V]) searchAll(fn func(K) int, node *RBNode[intervalValue[K, V]], ret *[]V) {
+func (t *IntervalTree[K, V]) Range(fn func(V) bool) {
+ t.inner.Range(func(node *RBNode[intervalValue[K, V]]) bool {
+ return fn(node.Value.Val)
+ })
+}
+
+func (t *IntervalTree[K, V]) Subrange(rangeFn func(K) int, handleFn func(V) bool) {
+ t.subrange(t.inner.root, rangeFn, handleFn)
+}
+
+func (t *IntervalTree[K, V]) subrange(node *RBNode[intervalValue[K, V]], rangeFn func(K) int, handleFn func(V) bool) bool {
if node == nil {
- return
+ return true
}
- if !node.Value.SpanOfChildren.ContainsFn(fn) {
- return
+ if !node.Value.ChildSpan.ContainsFn(rangeFn) {
+ return true
}
- t.searchAll(fn, node.Left, ret)
- if t.keyFn(node.Value).ContainsFn(fn) {
- *ret = append(*ret, node.Value.Val)
+ if !t.subrange(node.Left, rangeFn, handleFn) {
+ return false
}
- t.searchAll(fn, node.Right, ret)
-}
-
-func (t *IntervalTree[K, V]) SearchAll(fn func(K) int) []V {
- var ret []V
- t.searchAll(fn, t.inner.root, &ret)
- return ret
+ if node.Value.ValSpan.ContainsFn(rangeFn) {
+ if !handleFn(node.Value.Val) {
+ return false
+ }
+ }
+ if !t.subrange(node.Right, rangeFn, handleFn) {
+ return false
+ }
+ return true
}
-
-// TODO: func (t *IntervalTree[K, V]) Walk(fn func(*RBNode[V]) error) error
diff --git a/lib/containers/intervaltree_test.go b/lib/containers/intervaltree_test.go
index 45743a2..30885bd 100644
--- a/lib/containers/intervaltree_test.go
+++ b/lib/containers/intervaltree_test.go
@@ -18,8 +18,8 @@ func (t *IntervalTree[K, V]) ASCIIArt() string {
func (v intervalValue[K, V]) String() string {
return fmt.Sprintf("%v) ([%v,%v]",
v.Val,
- v.SpanOfChildren.Min,
- v.SpanOfChildren.Max)
+ v.ChildSpan.Min,
+ v.ChildSpan.Max)
}
func (v NativeOrdered[T]) String() string {
@@ -60,15 +60,21 @@ func TestIntervalTree(t *testing.T) {
t.Log("\n" + tree.ASCIIArt())
// find intervals that touch [9,20]
- intervals := tree.SearchAll(func(k NativeOrdered[int]) int {
- if k.Val < 9 {
- return 1
- }
- if k.Val > 20 {
- return -1
- }
- return 0
- })
+ var intervals []SimpleInterval
+ tree.Subrange(
+ func(k NativeOrdered[int]) int {
+ if k.Val < 9 {
+ return 1
+ }
+ if k.Val > 20 {
+ return -1
+ }
+ return 0
+ },
+ func(v SimpleInterval) bool {
+ intervals = append(intervals, v)
+ return true
+ })
assert.Equal(t,
[]SimpleInterval{
{6, 10},
diff --git a/lib/containers/rbtree.go b/lib/containers/rbtree.go
index 1fdb799..4125847 100644
--- a/lib/containers/rbtree.go
+++ b/lib/containers/rbtree.go
@@ -7,8 +7,6 @@ package containers
import (
"fmt"
"reflect"
-
- "git.lukeshu.com/btrfs-progs-ng/lib/slices"
)
type Color bool
@@ -18,50 +16,49 @@ const (
Red = Color(true)
)
-type RBNode[V any] struct {
- Parent, Left, Right *RBNode[V]
+type RBNode[T Ordered[T]] struct {
+ Parent, Left, Right *RBNode[T]
Color Color
- Value V
+ Value T
}
-func (node *RBNode[V]) getColor() Color {
+func (node *RBNode[T]) getColor() Color {
if node == nil {
return Black
}
return node.Color
}
-type RBTree[K Ordered[K], V any] struct {
- KeyFn func(V) K
- AttrFn func(*RBNode[V])
- root *RBNode[V]
+type RBTree[T Ordered[T]] struct {
+ AttrFn func(*RBNode[T])
+ root *RBNode[T]
len int
}
-func (t *RBTree[K, V]) Len() int {
+func (t *RBTree[T]) Len() int {
return t.len
}
-func (t *RBTree[K, V]) Walk(fn func(*RBNode[V]) error) error {
- return t.root.walk(fn)
+func (t *RBTree[T]) Range(fn func(*RBNode[T]) bool) {
+ t.root._range(fn)
}
-func (node *RBNode[V]) walk(fn func(*RBNode[V]) error) error {
+func (node *RBNode[T]) _range(fn func(*RBNode[T]) bool) bool {
if node == nil {
- return nil
+ return true
}
- if err := node.Left.walk(fn); err != nil {
- return err
+ if !node.Left._range(fn) {
+ return false
}
- if err := fn(node); err != nil {
- return err
+ if !fn(node) {
+ return false
}
- if err := node.Right.walk(fn); err != nil {
- return err
+ if !node.Right._range(fn) {
+ return false
}
- return nil
+ return true
}
// Search the tree for a value that satisfied the given callbackk
@@ -81,16 +78,13 @@ func (node *RBNode[V]) walk(fn func(*RBNode[V]) error) error {
// +---+ +---+
//
// Returns nil if no such value is found.
-//
-// Search is good for advanced lookup, like when a range of values is
-// acceptable. For simple exact-value lookup, use Lookup.
-func (t *RBTree[K, V]) Search(fn func(V) int) *RBNode[V] {
+func (t *RBTree[T]) Search(fn func(T) int) *RBNode[T] {
ret, _ := t.root.search(fn)
return ret
}
-func (node *RBNode[V]) search(fn func(V) int) (exact, nearest *RBNode[V]) {
- var prev *RBNode[V]
+func (node *RBNode[T]) search(fn func(T) int) (exact, nearest *RBNode[T]) {
+ var prev *RBNode[T]
for {
if node == nil {
return nil, prev
@@ -108,26 +102,13 @@ func (node *RBNode[V]) search(fn func(V) int) (exact, nearest *RBNode[V]) {
}
}
-func (t *RBTree[K, V]) exactKey(key K) func(V) int {
- return func(val V) int {
- valKey := t.KeyFn(val)
- return key.Compare(valKey)
- }
-}
-
-// Lookup looks up the value for an exact key. If no such value
-// exists, nil is returned.
-func (t *RBTree[K, V]) Lookup(key K) *RBNode[V] {
- return t.Search(t.exactKey(key))
-}
-
// Min returns the minimum value stored in the tree, or nil if the
// tree is empty.
-func (t *RBTree[K, V]) Min() *RBNode[V] {
+func (t *RBTree[T]) Min() *RBNode[T] {
return t.root.min()
}
-func (node *RBNode[V]) min() *RBNode[V] {
+func (node *RBNode[T]) min() *RBNode[T] {
if node == nil {
return nil
}
@@ -141,11 +122,11 @@ func (node *RBNode[V]) min() *RBNode[V] {
// Max returns the maximum value stored in the tree, or nil if the
// tree is empty.
-func (t *RBTree[K, V]) Max() *RBNode[V] {
+func (t *RBTree[T]) Max() *RBNode[T] {
return t.root.max()
}
-func (node *RBNode[V]) max() *RBNode[V] {
+func (node *RBNode[T]) max() *RBNode[T] {
if node == nil {
return nil
}
@@ -157,11 +138,7 @@ func (node *RBNode[V]) max() *RBNode[V] {
}
}
-func (t *RBTree[K, V]) Next(cur *RBNode[V]) *RBNode[V] {
- return cur.next()
-}
-
-func (cur *RBNode[V]) next() *RBNode[V] {
+func (cur *RBNode[T]) Next() *RBNode[T] {
if cur.Right != nil {
return cur.Right.min()
}
@@ -172,11 +149,7 @@ func (cur *RBNode[V]) next() *RBNode[V] {
return parent
}
-func (t *RBTree[K, V]) Prev(cur *RBNode[V]) *RBNode[V] {
- return cur.prev()
-}
-
-func (cur *RBNode[V]) prev() *RBNode[V] {
+func (cur *RBNode[T]) Prev() *RBNode[T] {
if cur.Left != nil {
return cur.Left.max()
}
@@ -187,48 +160,56 @@ func (cur *RBNode[V]) prev() *RBNode[V] {
return parent
}
-// SearchRange is like Search, but returns all nodes that match the
-// function; assuming that they are contiguous.
-func (t *RBTree[K, V]) SearchRange(fn func(V) int) []V {
- middle := t.Search(fn)
- if middle == nil {
- return nil
- }
- ret := []V{middle.Value}
- for node := t.Prev(middle); node != nil && fn(node.Value) == 0; node = t.Prev(node) {
- ret = append(ret, node.Value)
+// Subrange is like Search, but for when there may be more than one
+// result.
+func (t *RBTree[T]) Subrange(rangeFn func(T) int, handleFn func(*RBNode[T]) bool) {
+ // Find the left-most acceptable node.
+ _, node := t.root.search(func(v T) int {
+ if rangeFn(v) <= 0 {
+ return -1
+ } else {
+ return 1
+ }
+ })
+ for node != nil && rangeFn(node.Value) > 0 {
+ node = node.Next()
}
- slices.Reverse(ret)
- for node := t.Next(middle); node != nil && fn(node.Value) == 0; node = t.Next(node) {
- ret = append(ret, node.Value)
+ // Now walk forward until we hit the end.
+ for node != nil && rangeFn(node.Value) == 0 {
+ if keepGoing := handleFn(node); !keepGoing {
+ return
+ }
+ node = node.Next()
}
- return ret
}
-func (t *RBTree[K, V]) Equal(u *RBTree[K, V]) bool {
+func (t *RBTree[T]) Equal(u *RBTree[T]) bool {
if (t == nil) != (u == nil) {
return false
}
if t == nil {
return true
}
+ if t.len != u.len {
+ return false
+ }
- var tSlice []V
- _ = t.Walk(func(node *RBNode[V]) error {
+ tSlice := make([]T, 0, t.len)
+ t.Range(func(node *RBNode[T]) bool {
tSlice = append(tSlice, node.Value)
- return nil
+ return true
})
- var uSlice []V
- _ = u.Walk(func(node *RBNode[V]) error {
+ uSlice := make([]T, 0, u.len)
+ u.Range(func(node *RBNode[T]) bool {
uSlice = append(uSlice, node.Value)
- return nil
+ return true
})
return reflect.DeepEqual(tSlice, uSlice)
}
-func (t *RBTree[K, V]) parentChild(node *RBNode[V]) **RBNode[V] {
+func (t *RBTree[T]) parentChild(node *RBNode[T]) **RBNode[T] {
switch {
case node.Parent == nil:
return &t.root
@@ -241,7 +222,7 @@ func (t *RBTree[K, V]) parentChild(node *RBNode[V]) **RBNode[V] {
}
}
-func (t *RBTree[K, V]) updateAttr(node *RBNode[V]) {
+func (t *RBTree[T]) updateAttr(node *RBNode[T]) {
if t.AttrFn == nil {
return
}
@@ -251,7 +232,7 @@ func (t *RBTree[K, V]) updateAttr(node *RBNode[V]) {
}
}
-func (t *RBTree[K, V]) leftRotate(x *RBNode[V]) {
+func (t *RBTree[T]) leftRotate(x *RBNode[T]) {
// p p
// | |
// +---+ +---+
@@ -286,7 +267,7 @@ func (t *RBTree[K, V]) leftRotate(x *RBNode[V]) {
t.updateAttr(x)
}
-func (t *RBTree[K, V]) rightRotate(y *RBNode[V]) {
+func (t *RBTree[T]) rightRotate(y *RBNode[T]) {
//nolint:dupword
//
// | |
@@ -322,18 +303,17 @@ func (t *RBTree[K, V]) rightRotate(y *RBNode[V]) {
t.updateAttr(y)
}
-func (t *RBTree[K, V]) Insert(val V) {
+func (t *RBTree[T]) Insert(val T) {
// Naive-insert
- key := t.KeyFn(val)
- exact, parent := t.root.search(t.exactKey(key))
+ exact, parent := t.root.search(val.Compare)
if exact != nil {
exact.Value = val
return
}
t.len++
- node := &RBNode[V]{
+ node := &RBNode[T]{
Color: Red,
Parent: parent,
Value: val,
@@ -341,7 +321,7 @@ func (t *RBTree[K, V]) Insert(val V) {
switch {
case parent == nil:
t.root = node
- case key.Compare(t.KeyFn(parent.Value)) < 0:
+ case val.Compare(parent.Value) < 0:
parent.Left = node
default:
parent.Right = node
@@ -391,15 +371,14 @@ func (t *RBTree[K, V]) Insert(val V) {
t.root.Color = Black
}
-func (t *RBTree[K, V]) transplant(oldNode, newNode *RBNode[V]) {
+func (t *RBTree[T]) transplant(oldNode, newNode *RBNode[T]) {
*t.parentChild(oldNode) = newNode
if newNode != nil {
newNode.Parent = oldNode.Parent
}
}
-func (t *RBTree[K, V]) Delete(key K) {
- nodeToDelete := t.Lookup(key)
+func (t *RBTree[T]) Delete(nodeToDelete *RBNode[T]) {
if nodeToDelete == nil {
return
}
@@ -410,8 +389,8 @@ func (t *RBTree[K, V]) Delete(key K) {
// phase 1
- var nodeToRebalance *RBNode[V]
- var nodeToRebalanceParent *RBNode[V] // in case 'nodeToRebalance' is nil, which it can be
+ var nodeToRebalance *RBNode[T]
+ var nodeToRebalanceParent *RBNode[T] // in case 'nodeToRebalance' is nil, which it can be
needsRebalance := nodeToDelete.Color == Black
switch {
@@ -427,7 +406,7 @@ func (t *RBTree[K, V]) Delete(key K) {
// The node being deleted has a child on both sides,
// so we've go to reshuffle the parents a bit to make
// room for those children.
- next := nodeToDelete.next()
+ next := nodeToDelete.Next()
if next.Parent == nodeToDelete {
// p p
// | |
diff --git a/lib/containers/rbtree_test.go b/lib/containers/rbtree_test.go
index e42410e..d2fe931 100644
--- a/lib/containers/rbtree_test.go
+++ b/lib/containers/rbtree_test.go
@@ -7,21 +7,22 @@ package containers
import (
"fmt"
"io"
- "sort"
"strings"
"testing"
"github.com/stretchr/testify/require"
"golang.org/x/exp/constraints"
+
+ "git.lukeshu.com/btrfs-progs-ng/lib/slices"
)
-func (t *RBTree[K, V]) ASCIIArt() string {
+func (t *RBTree[T]) ASCIIArt() string {
var out strings.Builder
t.root.asciiArt(&out, "", "", "")
return out.String()
}
-func (node *RBNode[V]) String() string {
+func (node *RBNode[T]) String() string {
switch {
case node == nil:
return "nil"
@@ -32,7 +33,7 @@ func (node *RBNode[V]) String() string {
}
}
-func (node *RBNode[V]) asciiArt(w io.Writer, u, m, l string) {
+func (node *RBNode[T]) asciiArt(w io.Writer, u, m, l string) {
if node == nil {
fmt.Fprintf(w, "%snil\n", m)
return
@@ -43,7 +44,7 @@ func (node *RBNode[V]) asciiArt(w io.Writer, u, m, l string) {
node.Left.asciiArt(w, l+" | ", l+" `--", l+" ")
}
-func checkRBTree[K constraints.Ordered, V any](t *testing.T, expectedSet Set[K], tree *RBTree[NativeOrdered[K], V]) {
+func checkRBTree[T constraints.Ordered](t *testing.T, expectedSet Set[T], tree *RBTree[NativeOrdered[T]]) {
// 1. Every node is either red or black
// 2. The root is black.
@@ -52,19 +53,19 @@ func checkRBTree[K constraints.Ordered, V any](t *testing.T, expectedSet Set[K],
// 3. Every nil is black.
// 4. If a node is red, then both its children are black.
- require.NoError(t, tree.Walk(func(node *RBNode[V]) error {
+ tree.Range(func(node *RBNode[NativeOrdered[T]]) bool {
if node.getColor() == Red {
require.Equal(t, Black, node.Left.getColor())
require.Equal(t, Black, node.Right.getColor())
}
- return nil
- }))
+ return true
+ })
// 5. For each node, all simple paths from the node to
// descendent leaves contain the same number of black
// nodes.
- var walkCnt func(node *RBNode[V], cnt int, leafFn func(int))
- walkCnt = func(node *RBNode[V], cnt int, leafFn func(int)) {
+ var walkCnt func(node *RBNode[NativeOrdered[T]], cnt int, leafFn func(int))
+ walkCnt = func(node *RBNode[NativeOrdered[T]], cnt int, leafFn func(int)) {
if node.getColor() == Black {
cnt++
}
@@ -75,7 +76,7 @@ func checkRBTree[K constraints.Ordered, V any](t *testing.T, expectedSet Set[K],
walkCnt(node.Left, cnt, leafFn)
walkCnt(node.Right, cnt, leafFn)
}
- require.NoError(t, tree.Walk(func(node *RBNode[V]) error {
+ tree.Range(func(node *RBNode[NativeOrdered[T]]) bool {
var cnts []int
walkCnt(node, 0, func(cnt int) {
cnts = append(cnts, cnt)
@@ -86,27 +87,24 @@ func checkRBTree[K constraints.Ordered, V any](t *testing.T, expectedSet Set[K],
break
}
}
- return nil
- }))
+ return true
+ })
// expected contents
- expectedOrder := make([]K, 0, len(expectedSet))
- for k := range expectedSet {
- expectedOrder = append(expectedOrder, k)
- node := tree.Lookup(NativeOrdered[K]{Val: k})
+ expectedOrder := make([]T, 0, len(expectedSet))
+ for v := range expectedSet {
+ expectedOrder = append(expectedOrder, v)
+ node := tree.Search(NativeOrdered[T]{Val: v}.Compare)
require.NotNil(t, tree, node)
- require.Equal(t, k, tree.KeyFn(node.Value).Val)
}
- sort.Slice(expectedOrder, func(i, j int) bool {
- return expectedOrder[i] < expectedOrder[j]
+ slices.Sort(expectedOrder)
+ actOrder := make([]T, 0, len(expectedSet))
+ tree.Range(func(node *RBNode[NativeOrdered[T]]) bool {
+ actOrder = append(actOrder, node.Value.Val)
+ return true
})
- actOrder := make([]K, 0, len(expectedSet))
- require.NoError(t, tree.Walk(func(node *RBNode[V]) error {
- actOrder = append(actOrder, tree.KeyFn(node.Value).Val)
- return nil
- }))
require.Equal(t, expectedOrder, actOrder)
- require.Equal(t, tree.Len(), len(expectedSet))
+ require.Equal(t, len(expectedSet), tree.Len())
}
func FuzzRBTree(f *testing.F) {
@@ -132,11 +130,7 @@ func FuzzRBTree(f *testing.F) {
})
f.Fuzz(func(t *testing.T, dat []uint8) {
- tree := &RBTree[NativeOrdered[uint8], uint8]{
- KeyFn: func(x uint8) NativeOrdered[uint8] {
- return NativeOrdered[uint8]{Val: x}
- },
- }
+ tree := new(RBTree[NativeOrdered[uint8]])
set := make(Set[uint8])
checkRBTree(t, set, tree)
t.Logf("\n%s\n", tree.ASCIIArt())
@@ -145,18 +139,18 @@ func FuzzRBTree(f *testing.F) {
val := (b & 0b0011_1111)
if ins {
t.Logf("Insert(%v)", val)
- tree.Insert(val)
+ tree.Insert(NativeOrdered[uint8]{Val: val})
set.Insert(val)
t.Logf("\n%s\n", tree.ASCIIArt())
- node := tree.Lookup(NativeOrdered[uint8]{Val: val})
+ node := tree.Search(NativeOrdered[uint8]{Val: val}.Compare)
require.NotNil(t, node)
- require.Equal(t, val, node.Value)
+ require.Equal(t, val, node.Value.Val)
} else {
t.Logf("Delete(%v)", val)
- tree.Delete(NativeOrdered[uint8]{Val: val})
+ tree.Delete(tree.Search(NativeOrdered[uint8]{Val: val}.Compare))
delete(set, val)
t.Logf("\n%s\n", tree.ASCIIArt())
- require.Nil(t, tree.Lookup(NativeOrdered[uint8]{Val: val}))
+ require.Nil(t, tree.Search(NativeOrdered[uint8]{Val: val}.Compare))
}
checkRBTree(t, set, tree)
}
diff --git a/lib/containers/sortedmap.go b/lib/containers/sortedmap.go
index 52308c9..d104274 100644
--- a/lib/containers/sortedmap.go
+++ b/lib/containers/sortedmap.go
@@ -1,40 +1,32 @@
-// Copyright (C) 2022 Luke Shumaker <lukeshu@lukeshu.com>
+// Copyright (C) 2022-2023 Luke Shumaker <lukeshu@lukeshu.com>
//
// SPDX-License-Identifier: GPL-2.0-or-later
package containers
-import (
- "errors"
-)
-
-type OrderedKV[K Ordered[K], V any] struct {
+type orderedKV[K Ordered[K], V any] struct {
K K
V V
}
-type SortedMap[K Ordered[K], V any] struct {
- inner RBTree[K, OrderedKV[K, V]]
-}
-
-func (m *SortedMap[K, V]) init() {
- if m.inner.KeyFn == nil {
- m.inner.KeyFn = m.keyFn
- }
+func (a orderedKV[K, V]) Compare(b orderedKV[K, V]) int {
+ return a.K.Compare(b.K)
}
-func (m *SortedMap[K, V]) keyFn(kv OrderedKV[K, V]) K {
- return kv.K
+type SortedMap[K Ordered[K], V any] struct {
+ inner RBTree[orderedKV[K, V]]
}
func (m *SortedMap[K, V]) Delete(key K) {
- m.init()
- m.inner.Delete(key)
+ m.inner.Delete(m.inner.Search(func(kv orderedKV[K, V]) int {
+ return key.Compare(kv.K)
+ }))
}
func (m *SortedMap[K, V]) Load(key K) (value V, ok bool) {
- m.init()
- node := m.inner.Lookup(key)
+ node := m.inner.Search(func(kv orderedKV[K, V]) int {
+ return key.Compare(kv.K)
+ })
if node == nil {
var zero V
return zero, false
@@ -42,41 +34,27 @@ func (m *SortedMap[K, V]) Load(key K) (value V, ok bool) {
return node.Value.V, true
}
-var errStop = errors.New("stop")
-
-func (m *SortedMap[K, V]) Range(f func(key K, value V) bool) {
- m.init()
- _ = m.inner.Walk(func(node *RBNode[OrderedKV[K, V]]) error {
- if f(node.Value.K, node.Value.V) {
- return nil
- } else {
- return errStop
- }
+func (m *SortedMap[K, V]) Store(key K, value V) {
+ m.inner.Insert(orderedKV[K, V]{
+ K: key,
+ V: value,
})
}
-func (m *SortedMap[K, V]) Subrange(rangeFn func(K, V) int, handleFn func(K, V) bool) {
- m.init()
- kvs := m.inner.SearchRange(func(kv OrderedKV[K, V]) int {
- return rangeFn(kv.K, kv.V)
+func (m *SortedMap[K, V]) Range(fn func(key K, value V) bool) {
+ m.inner.Range(func(node *RBNode[orderedKV[K, V]]) bool {
+ return fn(node.Value.K, node.Value.V)
})
- for _, kv := range kvs {
- if !handleFn(kv.K, kv.V) {
- break
- }
- }
}
-func (m *SortedMap[K, V]) Store(key K, value V) {
- m.init()
- m.inner.Insert(OrderedKV[K, V]{
- K: key,
- V: value,
- })
+func (m *SortedMap[K, V]) Subrange(rangeFn func(K, V) int, handleFn func(K, V) bool) {
+ m.inner.Subrange(
+ func(kv orderedKV[K, V]) int { return rangeFn(kv.K, kv.V) },
+ func(node *RBNode[orderedKV[K, V]]) bool { return handleFn(node.Value.K, node.Value.V) })
}
func (m *SortedMap[K, V]) Search(fn func(K, V) int) (K, V, bool) {
- node := m.inner.Search(func(kv OrderedKV[K, V]) int {
+ node := m.inner.Search(func(kv orderedKV[K, V]) int {
return fn(kv.K, kv.V)
})
if node == nil {
@@ -87,12 +65,6 @@ func (m *SortedMap[K, V]) Search(fn func(K, V) int) (K, V, bool) {
return node.Value.K, node.Value.V, true
}
-func (m *SortedMap[K, V]) SearchAll(fn func(K, V) int) []OrderedKV[K, V] {
- return m.inner.SearchRange(func(kv OrderedKV[K, V]) int {
- return fn(kv.K, kv.V)
- })
-}
-
func (m *SortedMap[K, V]) Len() int {
return m.inner.Len()
}