From 696a7d192e5eefa53230168a4b200ec0560c8a10 Mon Sep 17 00:00:00 2001 From: Luke Shumaker Date: Sun, 5 Feb 2023 00:31:29 -0700 Subject: containers: Rethink the RBTree interface to be simpler --- lib/btrfs/btrfsvol/chunk.go | 5 +- lib/btrfs/btrfsvol/devext.go | 5 + lib/btrfs/btrfsvol/lvm.go | 67 ++++----- .../btrfsinspect/rebuildmappings/logicalsums.go | 14 +- .../btrfsinspect/rebuildnodes/rebuild.go | 56 ++++--- lib/btrfsprogs/btrfsinspect/scandevices.go | 6 + lib/btrfsprogs/btrfsutil/broken_btree.go | 53 ++++--- lib/containers/intervaltree.go | 132 +++++++++-------- lib/containers/intervaltree_test.go | 28 ++-- lib/containers/rbtree.go | 161 +++++++++------------ lib/containers/rbtree_test.go | 66 ++++----- lib/containers/sortedmap.go | 76 +++------- 12 files changed, 320 insertions(+), 349 deletions(-) diff --git a/lib/btrfs/btrfsvol/chunk.go b/lib/btrfs/btrfsvol/chunk.go index 08a0b2b..a112fd3 100644 --- a/lib/btrfs/btrfsvol/chunk.go +++ b/lib/btrfs/btrfsvol/chunk.go @@ -22,7 +22,10 @@ type chunkMapping struct { Flags containers.Optional[BlockGroupFlags] } -type ChunkMapping = chunkMapping +// Compare implements containers.Ordered. +func (a chunkMapping) Compare(b chunkMapping) int { + return containers.NativeCompare(a.LAddr, b.LAddr) +} // return -1 if 'a' is wholly to the left of 'b' // return 0 if there is some overlap between 'a' and 'b' diff --git a/lib/btrfs/btrfsvol/devext.go b/lib/btrfs/btrfsvol/devext.go index 037021c..3324476 100644 --- a/lib/btrfs/btrfsvol/devext.go +++ b/lib/btrfs/btrfsvol/devext.go @@ -20,6 +20,11 @@ type devextMapping struct { Flags containers.Optional[BlockGroupFlags] } +// Compare implements containers.Ordered. +func (a devextMapping) Compare(b devextMapping) int { + return containers.NativeCompare(a.PAddr, b.PAddr) +} + // return -1 if 'a' is wholly to the left of 'b' // return 0 if there is some overlap between 'a' and 'b' // return 1 if 'a is wholly to the right of 'b' diff --git a/lib/btrfs/btrfsvol/lvm.go b/lib/btrfs/btrfsvol/lvm.go index 59ca609..93ec438 100644 --- a/lib/btrfs/btrfsvol/lvm.go +++ b/lib/btrfs/btrfsvol/lvm.go @@ -22,8 +22,8 @@ type LogicalVolume[PhysicalVolume diskio.File[PhysicalAddr]] struct { id2pv map[DeviceID]PhysicalVolume - logical2physical *containers.RBTree[containers.NativeOrdered[LogicalAddr], chunkMapping] - physical2logical map[DeviceID]*containers.RBTree[containers.NativeOrdered[PhysicalAddr], devextMapping] + logical2physical *containers.RBTree[chunkMapping] + physical2logical map[DeviceID]*containers.RBTree[devextMapping] } var _ diskio.File[LogicalAddr] = (*LogicalVolume[diskio.File[PhysicalAddr]])(nil) @@ -33,22 +33,14 @@ func (lv *LogicalVolume[PhysicalVolume]) init() { lv.id2pv = make(map[DeviceID]PhysicalVolume) } if lv.logical2physical == nil { - lv.logical2physical = &containers.RBTree[containers.NativeOrdered[LogicalAddr], chunkMapping]{ - KeyFn: func(chunk chunkMapping) containers.NativeOrdered[LogicalAddr] { - return containers.NativeOrdered[LogicalAddr]{Val: chunk.LAddr} - }, - } + lv.logical2physical = new(containers.RBTree[chunkMapping]) } if lv.physical2logical == nil { - lv.physical2logical = make(map[DeviceID]*containers.RBTree[containers.NativeOrdered[PhysicalAddr], devextMapping], len(lv.id2pv)) + lv.physical2logical = make(map[DeviceID]*containers.RBTree[devextMapping], len(lv.id2pv)) } for devid := range lv.id2pv { if _, ok := lv.physical2logical[devid]; !ok { - lv.physical2logical[devid] = &containers.RBTree[containers.NativeOrdered[PhysicalAddr], devextMapping]{ - KeyFn: func(ext devextMapping) containers.NativeOrdered[PhysicalAddr] { - return containers.NativeOrdered[PhysicalAddr]{Val: ext.PAddr} - }, - } + lv.physical2logical[devid] = new(containers.RBTree[devextMapping]) } } } @@ -90,11 +82,7 @@ func (lv *LogicalVolume[PhysicalVolume]) AddPhysicalVolume(id DeviceID, dev Phys lv, dev.Name(), other.Name(), id) } lv.id2pv[id] = dev - lv.physical2logical[id] = &containers.RBTree[containers.NativeOrdered[PhysicalAddr], devextMapping]{ - KeyFn: func(ext devextMapping) containers.NativeOrdered[PhysicalAddr] { - return containers.NativeOrdered[PhysicalAddr]{Val: ext.PAddr} - }, - } + lv.physical2logical[id] = new(containers.RBTree[devextMapping]) return nil } @@ -143,11 +131,13 @@ func (lv *LogicalVolume[PhysicalVolume]) addMapping(m Mapping, dryRun bool) erro SizeLocked: m.SizeLocked, Flags: m.Flags, } - logicalOverlaps := lv.logical2physical.SearchRange(newChunk.compareRange) + var logicalOverlaps []chunkMapping numOverlappingStripes := 0 - for _, chunk := range logicalOverlaps { - numOverlappingStripes += len(chunk.PAddrs) - } + lv.logical2physical.Subrange(newChunk.compareRange, func(node *containers.RBNode[chunkMapping]) bool { + logicalOverlaps = append(logicalOverlaps, node.Value) + numOverlappingStripes += len(node.Value.PAddrs) + return true + }) var err error newChunk, err = newChunk.union(logicalOverlaps...) if err != nil { @@ -162,7 +152,11 @@ func (lv *LogicalVolume[PhysicalVolume]) addMapping(m Mapping, dryRun bool) erro SizeLocked: m.SizeLocked, Flags: m.Flags, } - physicalOverlaps := lv.physical2logical[m.PAddr.Dev].SearchRange(newExt.compareRange) + var physicalOverlaps []devextMapping + lv.physical2logical[m.PAddr.Dev].Subrange(newExt.compareRange, func(node *containers.RBNode[devextMapping]) bool { + physicalOverlaps = append(physicalOverlaps, node.Value) + return true + }) newExt, err = newExt.union(physicalOverlaps...) if err != nil { return fmt.Errorf("(%p).AddMapping: %w", lv, err) @@ -202,13 +196,13 @@ func (lv *LogicalVolume[PhysicalVolume]) addMapping(m Mapping, dryRun bool) erro // logical2physical for _, chunk := range logicalOverlaps { - lv.logical2physical.Delete(containers.NativeOrdered[LogicalAddr]{Val: chunk.LAddr}) + lv.logical2physical.Delete(lv.logical2physical.Search(chunk.Compare)) } lv.logical2physical.Insert(newChunk) // physical2logical for _, ext := range physicalOverlaps { - lv.physical2logical[m.PAddr.Dev].Delete(containers.NativeOrdered[PhysicalAddr]{Val: ext.PAddr}) + lv.physical2logical[m.PAddr.Dev].Delete(lv.physical2logical[m.PAddr.Dev].Search(ext.Compare)) } lv.physical2logical[m.PAddr.Dev].Insert(newExt) @@ -227,20 +221,18 @@ func (lv *LogicalVolume[PhysicalVolume]) addMapping(m Mapping, dryRun bool) erro } func (lv *LogicalVolume[PhysicalVolume]) fsck() error { - physical2logical := make(map[DeviceID]*containers.RBTree[containers.NativeOrdered[PhysicalAddr], devextMapping]) - if err := lv.logical2physical.Walk(func(node *containers.RBNode[chunkMapping]) error { + physical2logical := make(map[DeviceID]*containers.RBTree[devextMapping]) + var err error + lv.logical2physical.Range(func(node *containers.RBNode[chunkMapping]) bool { chunk := node.Value for _, stripe := range chunk.PAddrs { if _, devOK := lv.id2pv[stripe.Dev]; !devOK { - return fmt.Errorf("(%p).fsck: chunk references physical volume %v which does not exist", + err = fmt.Errorf("(%p).fsck: chunk references physical volume %v which does not exist", lv, stripe.Dev) + return false } if _, exists := physical2logical[stripe.Dev]; !exists { - physical2logical[stripe.Dev] = &containers.RBTree[containers.NativeOrdered[PhysicalAddr], devextMapping]{ - KeyFn: func(ext devextMapping) containers.NativeOrdered[PhysicalAddr] { - return containers.NativeOrdered[PhysicalAddr]{Val: ext.PAddr} - }, - } + physical2logical[stripe.Dev] = new(containers.RBTree[devextMapping]) } physical2logical[stripe.Dev].Insert(devextMapping{ PAddr: stripe.Addr, @@ -249,8 +241,9 @@ func (lv *LogicalVolume[PhysicalVolume]) fsck() error { Flags: chunk.Flags, }) } - return nil - }); err != nil { + return true + }) + if err != nil { return err } @@ -270,7 +263,7 @@ func (lv *LogicalVolume[PhysicalVolume]) fsck() error { func (lv *LogicalVolume[PhysicalVolume]) Mappings() []Mapping { var ret []Mapping - _ = lv.logical2physical.Walk(func(node *containers.RBNode[chunkMapping]) error { + lv.logical2physical.Range(func(node *containers.RBNode[chunkMapping]) bool { chunk := node.Value for _, slice := range chunk.PAddrs { ret = append(ret, Mapping{ @@ -280,7 +273,7 @@ func (lv *LogicalVolume[PhysicalVolume]) Mappings() []Mapping { Flags: chunk.Flags, }) } - return nil + return true }) return ret } diff --git a/lib/btrfsprogs/btrfsinspect/rebuildmappings/logicalsums.go b/lib/btrfsprogs/btrfsinspect/rebuildmappings/logicalsums.go index 69d14c7..7c02d05 100644 --- a/lib/btrfsprogs/btrfsinspect/rebuildmappings/logicalsums.go +++ b/lib/btrfsprogs/btrfsinspect/rebuildmappings/logicalsums.go @@ -53,11 +53,7 @@ func ExtractLogicalSums(ctx context.Context, scanResults btrfsinspect.ScanDevice // "AAAAAAA" shouldn't be present, and if we just discard "BBBBBBBB" // because it conflicts with "CCCCCCC", then we would erroneously // include "AAAAAAA". - addrspace := &containers.RBTree[containers.NativeOrdered[btrfsvol.LogicalAddr], btrfsinspect.SysExtentCSum]{ - KeyFn: func(item btrfsinspect.SysExtentCSum) containers.NativeOrdered[btrfsvol.LogicalAddr] { - return containers.NativeOrdered[btrfsvol.LogicalAddr]{Val: item.Sums.Addr} - }, - } + addrspace := new(containers.RBTree[btrfsinspect.SysExtentCSum]) for _, newRecord := range records { for { conflict := addrspace.Search(func(oldRecord btrfsinspect.SysExtentCSum) int { @@ -85,7 +81,7 @@ func ExtractLogicalSums(ctx context.Context, scanResults btrfsinspect.ScanDevice } if oldRecord.Generation < newRecord.Generation { // Newer generation wins. - addrspace.Delete(containers.NativeOrdered[btrfsvol.LogicalAddr]{Val: oldRecord.Sums.Addr}) + addrspace.Delete(conflict) // loop around to check for more conflicts continue } @@ -142,7 +138,7 @@ func ExtractLogicalSums(ctx context.Context, scanResults btrfsinspect.ScanDevice }, }, } - addrspace.Delete(containers.NativeOrdered[btrfsvol.LogicalAddr]{Val: oldRecord.Sums.Addr}) + addrspace.Delete(conflict) newRecord = unionRecord // loop around to check for more conflicts } @@ -152,7 +148,7 @@ func ExtractLogicalSums(ctx context.Context, scanResults btrfsinspect.ScanDevice var flattened SumRunWithGaps[btrfsvol.LogicalAddr] var curAddr btrfsvol.LogicalAddr var curSums strings.Builder - _ = addrspace.Walk(func(node *containers.RBNode[btrfsinspect.SysExtentCSum]) error { + addrspace.Range(func(node *containers.RBNode[btrfsinspect.SysExtentCSum]) bool { curEnd := curAddr + (btrfsvol.LogicalAddr(curSums.Len()/sumSize) * btrfssum.BlockSize) if node.Value.Sums.Addr != curEnd { if curSums.Len() > 0 { @@ -166,7 +162,7 @@ func ExtractLogicalSums(ctx context.Context, scanResults btrfsinspect.ScanDevice curSums.Reset() } curSums.WriteString(string(node.Value.Sums.Sums)) - return nil + return true }) if curSums.Len() > 0 { flattened.Runs = append(flattened.Runs, btrfssum.SumRun[btrfsvol.LogicalAddr]{ diff --git a/lib/btrfsprogs/btrfsinspect/rebuildnodes/rebuild.go b/lib/btrfsprogs/btrfsinspect/rebuildnodes/rebuild.go index ee5950d..bd29278 100644 --- a/lib/btrfsprogs/btrfsinspect/rebuildnodes/rebuild.go +++ b/lib/btrfsprogs/btrfsinspect/rebuildnodes/rebuild.go @@ -770,6 +770,16 @@ func (o *rebuilder) _walkRange( }) } +type gap struct { + // range is [Beg,End) + Beg, End uint64 +} + +// Compare implements containers.Ordered. +func (a gap) Compare(b gap) int { + return containers.NativeCompare(a.Beg, b.Beg) +} + func (o *rebuilder) _wantRange( ctx context.Context, treeID btrfsprim.ObjID, objID btrfsprim.ObjID, typ btrfsprim.ItemType, @@ -787,15 +797,7 @@ func (o *rebuilder) _wantRange( // // Start with a gap of the whole range, then subtract each run // from it. - type gap struct { - // range is [Beg,End) - Beg, End uint64 - } - gaps := &containers.RBTree[containers.NativeOrdered[uint64], gap]{ - KeyFn: func(gap gap) containers.NativeOrdered[uint64] { - return containers.NativeOrdered[uint64]{Val: gap.Beg} - }, - } + gaps := new(containers.RBTree[gap]) gaps.Insert(gap{ Beg: beg, End: end, @@ -805,23 +807,29 @@ func (o *rebuilder) _wantRange( o.rebuilt.Tree(ctx, treeID).Items(ctx), treeID, objID, typ, beg, end, func(runKey btrfsprim.Key, _ keyio.ItemPtr, runBeg, runEnd uint64) { - overlappingGaps := gaps.SearchRange(func(gap gap) int { - switch { - case gap.End <= runBeg: - return 1 - case runEnd <= gap.Beg: - return -1 - default: - return 0 - } - }) + var overlappingGaps []*containers.RBNode[gap] + gaps.Subrange( + func(gap gap) int { + switch { + case gap.End <= runBeg: + return 1 + case runEnd <= gap.Beg: + return -1 + default: + return 0 + } + }, + func(node *containers.RBNode[gap]) bool { + overlappingGaps = append(overlappingGaps, node) + return true + }) if len(overlappingGaps) == 0 { return } - gapsBeg := overlappingGaps[0].Beg - gapsEnd := overlappingGaps[len(overlappingGaps)-1].End + gapsBeg := overlappingGaps[0].Value.Beg + gapsEnd := overlappingGaps[len(overlappingGaps)-1].Value.End for _, gap := range overlappingGaps { - gaps.Delete(containers.NativeOrdered[uint64]{Val: gap.Beg}) + gaps.Delete(gap) } if gapsBeg < runBeg { gaps.Insert(gap{ @@ -842,7 +850,7 @@ func (o *rebuilder) _wantRange( return } potentialItems := o.rebuilt.Tree(ctx, treeID).PotentialItems(ctx) - _ = gaps.Walk(func(rbNode *containers.RBNode[gap]) error { + gaps.Range(func(rbNode *containers.RBNode[gap]) bool { gap := rbNode.Value last := gap.Beg o._walkRange( @@ -874,7 +882,7 @@ func (o *rebuilder) _wantRange( o.wantAugment(wantCtx, treeID, wantKey, nil) } } - return nil + return true }) } diff --git a/lib/btrfsprogs/btrfsinspect/scandevices.go b/lib/btrfsprogs/btrfsinspect/scandevices.go index 7668a83..9b8360c 100644 --- a/lib/btrfsprogs/btrfsinspect/scandevices.go +++ b/lib/btrfsprogs/btrfsinspect/scandevices.go @@ -22,6 +22,7 @@ import ( "git.lukeshu.com/btrfs-progs-ng/lib/btrfs/btrfssum" "git.lukeshu.com/btrfs-progs-ng/lib/btrfs/btrfstree" "git.lukeshu.com/btrfs-progs-ng/lib/btrfs/btrfsvol" + "git.lukeshu.com/btrfs-progs-ng/lib/containers" "git.lukeshu.com/btrfs-progs-ng/lib/textui" ) @@ -79,6 +80,11 @@ type SysExtentCSum struct { Sums btrfsitem.ExtentCSum } +// Compare implements containers.Ordered. +func (a SysExtentCSum) Compare(b SysExtentCSum) int { + return containers.NativeCompare(a.Sums.Addr, b.Sums.Addr) +} + type scanStats struct { textui.Portion[btrfsvol.PhysicalAddr] diff --git a/lib/btrfsprogs/btrfsutil/broken_btree.go b/lib/btrfsprogs/btrfsutil/broken_btree.go index 8261119..7ea31ce 100644 --- a/lib/btrfsprogs/btrfsutil/broken_btree.go +++ b/lib/btrfsprogs/btrfsutil/broken_btree.go @@ -23,7 +23,7 @@ import ( type treeIndex struct { TreeRootErr error - Items *containers.RBTree[btrfsprim.Key, treeIndexValue] + Items *containers.RBTree[treeIndexValue] Errors *containers.IntervalTree[btrfsprim.Key, treeIndexError] } @@ -38,13 +38,14 @@ type treeIndexValue struct { ItemSize uint32 } +// Compare implements containers.Ordered. +func (a treeIndexValue) Compare(b treeIndexValue) int { + return a.Key.Compare(b.Key) +} + func newTreeIndex(arena *SkinnyPathArena) treeIndex { return treeIndex{ - Items: &containers.RBTree[btrfsprim.Key, treeIndexValue]{ - KeyFn: func(iv treeIndexValue) btrfsprim.Key { - return iv.Key - }, - }, + Items: new(containers.RBTree[treeIndexValue]), Errors: &containers.IntervalTree[btrfsprim.Key, treeIndexError]{ MinFn: func(err treeIndexError) btrfsprim.Key { return arena.Inflate(err.Path).Node(-1).ToKey @@ -173,7 +174,7 @@ func (bt *brokenTrees) rawTreeWalk(root btrfstree.TreeRoot, cacheEntry treeIndex }, btrfstree.TreeWalkHandler{ Item: func(path btrfstree.TreePath, item btrfstree.Item) error { - if cacheEntry.Items.Lookup(item.Key) != nil { + if cacheEntry.Items.Search(func(v treeIndexValue) int { return item.Key.Compare(v.Key) }) != nil { // This is a panic because I'm not really sure what the best way to // handle this is, and so if this happens I want the program to crash // and force me to figure out how to handle it. @@ -203,15 +204,15 @@ func (bt *brokenTrees) TreeLookup(treeID btrfsprim.ObjID, key btrfsprim.Key) (bt func (bt *brokenTrees) addErrs(index treeIndex, fn func(btrfsprim.Key, uint32) int, err error) error { var errs derror.MultiError - if _errs := index.Errors.SearchAll(func(k btrfsprim.Key) int { return fn(k, 0) }); len(_errs) > 0 { - errs = make(derror.MultiError, len(_errs)) - for i := range _errs { - errs[i] = &btrfstree.TreeError{ - Path: bt.arena.Inflate(_errs[i].Path), - Err: _errs[i].Err, - } - } - } + index.Errors.Subrange( + func(k btrfsprim.Key) int { return fn(k, 0) }, + func(v treeIndexError) bool { + errs = append(errs, &btrfstree.TreeError{ + Path: bt.arena.Inflate(v.Path), + Err: v.Err, + }) + return true + }) if len(errs) == 0 { return err } @@ -253,9 +254,13 @@ func (bt *brokenTrees) TreeSearchAll(treeID btrfsprim.ObjID, fn func(btrfsprim.K return nil, index.TreeRootErr } - indexItems := index.Items.SearchRange(func(indexItem treeIndexValue) int { - return fn(indexItem.Key, indexItem.ItemSize) - }) + var indexItems []treeIndexValue + index.Items.Subrange( + func(indexItem treeIndexValue) int { return fn(indexItem.Key, indexItem.ItemSize) }, + func(node *containers.RBNode[treeIndexValue]) bool { + indexItems = append(indexItems, node.Value) + return true + }) if len(indexItems) == 0 { return nil, bt.addErrs(index, fn, iofs.ErrNotExist) } @@ -290,12 +295,12 @@ func (bt *brokenTrees) TreeWalk(ctx context.Context, treeID btrfsprim.ObjID, err return } var node *diskio.Ref[btrfsvol.LogicalAddr, btrfstree.Node] - _ = index.Items.Walk(func(indexItem *containers.RBNode[treeIndexValue]) error { + index.Items.Range(func(indexItem *containers.RBNode[treeIndexValue]) bool { if ctx.Err() != nil { - return ctx.Err() + return false } if bt.ctx.Err() != nil { - return bt.ctx.Err() + return false } if cbs.Item != nil { itemPath := bt.arena.Inflate(indexItem.Value.Path) @@ -304,7 +309,7 @@ func (bt *brokenTrees) TreeWalk(ctx context.Context, treeID btrfsprim.ObjID, err node, err = bt.inner.ReadNode(itemPath.Parent()) if err != nil { errHandle(&btrfstree.TreeError{Path: itemPath, Err: err}) - return nil //nolint:nilerr // We already called errHandle(). + return true } } item := node.Data.BodyLeaf[itemPath.Node(-1).FromItemIdx] @@ -312,7 +317,7 @@ func (bt *brokenTrees) TreeWalk(ctx context.Context, treeID btrfsprim.ObjID, err errHandle(&btrfstree.TreeError{Path: itemPath, Err: err}) } } - return nil + return true }) } diff --git a/lib/containers/intervaltree.go b/lib/containers/intervaltree.go index b7ff866..7b96526 100644 --- a/lib/containers/intervaltree.go +++ b/lib/containers/intervaltree.go @@ -4,81 +4,80 @@ package containers -type intervalKey[K Ordered[K]] struct { +type interval[K Ordered[K]] struct { Min, Max K } -func (ival intervalKey[K]) ContainsFn(fn func(K) int) bool { - return fn(ival.Min) >= 0 && fn(ival.Max) <= 0 -} - -func (a intervalKey[K]) Compare(b intervalKey[K]) int { +// Compare implements Ordered. +func (a interval[K]) Compare(b interval[K]) int { if d := a.Min.Compare(b.Min); d != 0 { return d } return a.Max.Compare(b.Max) } +// ContainsFn returns whether this interval contains the range matched +// by the given function. +func (ival interval[K]) ContainsFn(fn func(K) int) bool { + return fn(ival.Min) >= 0 && fn(ival.Max) <= 0 +} + type intervalValue[K Ordered[K], V any] struct { - Val V - SpanOfChildren intervalKey[K] + Val V + ValSpan interval[K] + ChildSpan interval[K] +} + +// Compare implements Ordered. +func (a intervalValue[K, V]) Compare(b intervalValue[K, V]) int { + return a.ValSpan.Compare(b.ValSpan) } type IntervalTree[K Ordered[K], V any] struct { MinFn func(V) K MaxFn func(V) K - inner RBTree[intervalKey[K], intervalValue[K, V]] -} - -func (t *IntervalTree[K, V]) keyFn(v intervalValue[K, V]) intervalKey[K] { - return intervalKey[K]{ - Min: t.MinFn(v.Val), - Max: t.MaxFn(v.Val), - } + inner RBTree[intervalValue[K, V]] } func (t *IntervalTree[K, V]) attrFn(node *RBNode[intervalValue[K, V]]) { - max := t.MaxFn(node.Value.Val) - if node.Left != nil && node.Left.Value.SpanOfChildren.Max.Compare(max) > 0 { - max = node.Left.Value.SpanOfChildren.Max + max := node.Value.ValSpan.Max + if node.Left != nil && node.Left.Value.ChildSpan.Max.Compare(max) > 0 { + max = node.Left.Value.ChildSpan.Max } - if node.Right != nil && node.Right.Value.SpanOfChildren.Max.Compare(max) > 0 { - max = node.Right.Value.SpanOfChildren.Max + if node.Right != nil && node.Right.Value.ChildSpan.Max.Compare(max) > 0 { + max = node.Right.Value.ChildSpan.Max } - node.Value.SpanOfChildren.Max = max + node.Value.ChildSpan.Max = max - min := t.MinFn(node.Value.Val) - if node.Left != nil && node.Left.Value.SpanOfChildren.Min.Compare(min) < 0 { - min = node.Left.Value.SpanOfChildren.Min + min := node.Value.ValSpan.Min + if node.Left != nil && node.Left.Value.ChildSpan.Min.Compare(min) < 0 { + min = node.Left.Value.ChildSpan.Min } - if node.Right != nil && node.Right.Value.SpanOfChildren.Min.Compare(min) < 0 { - min = node.Right.Value.SpanOfChildren.Min + if node.Right != nil && node.Right.Value.ChildSpan.Min.Compare(min) < 0 { + min = node.Right.Value.ChildSpan.Min } - node.Value.SpanOfChildren.Min = min + node.Value.ChildSpan.Min = min } func (t *IntervalTree[K, V]) init() { - if t.inner.KeyFn == nil { - t.inner.KeyFn = t.keyFn + if t.inner.AttrFn == nil { t.inner.AttrFn = t.attrFn } } -func (t *IntervalTree[K, V]) Delete(min, max K) { - t.init() - t.inner.Delete(intervalKey[K]{ - Min: min, - Max: max, - }) -} - func (t *IntervalTree[K, V]) Equal(u *IntervalTree[K, V]) bool { return t.inner.Equal(&u.inner) } func (t *IntervalTree[K, V]) Insert(val V) { t.init() - t.inner.Insert(intervalValue[K, V]{Val: val}) + t.inner.Insert(intervalValue[K, V]{ + Val: val, + ValSpan: interval[K]{ + Min: t.MinFn(val), + Max: t.MaxFn(val), + }, + }) } func (t *IntervalTree[K, V]) Min() (K, bool) { @@ -86,7 +85,7 @@ func (t *IntervalTree[K, V]) Min() (K, bool) { var zero K return zero, false } - return t.inner.root.Value.SpanOfChildren.Min, true + return t.inner.root.Value.ChildSpan.Min, true } func (t *IntervalTree[K, V]) Max() (K, bool) { @@ -94,22 +93,18 @@ func (t *IntervalTree[K, V]) Max() (K, bool) { var zero K return zero, false } - return t.inner.root.Value.SpanOfChildren.Max, true -} - -func (t *IntervalTree[K, V]) Lookup(k K) (V, bool) { - return t.Search(k.Compare) + return t.inner.root.Value.ChildSpan.Max, true } func (t *IntervalTree[K, V]) Search(fn func(K) int) (V, bool) { node := t.inner.root for node != nil { switch { - case t.keyFn(node.Value).ContainsFn(fn): + case node.Value.ValSpan.ContainsFn(fn): return node.Value.Val, true - case node.Left != nil && node.Left.Value.SpanOfChildren.ContainsFn(fn): + case node.Left != nil && node.Left.Value.ChildSpan.ContainsFn(fn): node = node.Left - case node.Right != nil && node.Right.Value.SpanOfChildren.ContainsFn(fn): + case node.Right != nil && node.Right.Value.ChildSpan.ContainsFn(fn): node = node.Right default: node = nil @@ -119,24 +114,33 @@ func (t *IntervalTree[K, V]) Search(fn func(K) int) (V, bool) { return zero, false } -func (t *IntervalTree[K, V]) searchAll(fn func(K) int, node *RBNode[intervalValue[K, V]], ret *[]V) { +func (t *IntervalTree[K, V]) Range(fn func(V) bool) { + t.inner.Range(func(node *RBNode[intervalValue[K, V]]) bool { + return fn(node.Value.Val) + }) +} + +func (t *IntervalTree[K, V]) Subrange(rangeFn func(K) int, handleFn func(V) bool) { + t.subrange(t.inner.root, rangeFn, handleFn) +} + +func (t *IntervalTree[K, V]) subrange(node *RBNode[intervalValue[K, V]], rangeFn func(K) int, handleFn func(V) bool) bool { if node == nil { - return + return true } - if !node.Value.SpanOfChildren.ContainsFn(fn) { - return + if !node.Value.ChildSpan.ContainsFn(rangeFn) { + return true } - t.searchAll(fn, node.Left, ret) - if t.keyFn(node.Value).ContainsFn(fn) { - *ret = append(*ret, node.Value.Val) + if !t.subrange(node.Left, rangeFn, handleFn) { + return false } - t.searchAll(fn, node.Right, ret) -} - -func (t *IntervalTree[K, V]) SearchAll(fn func(K) int) []V { - var ret []V - t.searchAll(fn, t.inner.root, &ret) - return ret + if node.Value.ValSpan.ContainsFn(rangeFn) { + if !handleFn(node.Value.Val) { + return false + } + } + if !t.subrange(node.Right, rangeFn, handleFn) { + return false + } + return true } - -// TODO: func (t *IntervalTree[K, V]) Walk(fn func(*RBNode[V]) error) error diff --git a/lib/containers/intervaltree_test.go b/lib/containers/intervaltree_test.go index 45743a2..30885bd 100644 --- a/lib/containers/intervaltree_test.go +++ b/lib/containers/intervaltree_test.go @@ -18,8 +18,8 @@ func (t *IntervalTree[K, V]) ASCIIArt() string { func (v intervalValue[K, V]) String() string { return fmt.Sprintf("%v) ([%v,%v]", v.Val, - v.SpanOfChildren.Min, - v.SpanOfChildren.Max) + v.ChildSpan.Min, + v.ChildSpan.Max) } func (v NativeOrdered[T]) String() string { @@ -60,15 +60,21 @@ func TestIntervalTree(t *testing.T) { t.Log("\n" + tree.ASCIIArt()) // find intervals that touch [9,20] - intervals := tree.SearchAll(func(k NativeOrdered[int]) int { - if k.Val < 9 { - return 1 - } - if k.Val > 20 { - return -1 - } - return 0 - }) + var intervals []SimpleInterval + tree.Subrange( + func(k NativeOrdered[int]) int { + if k.Val < 9 { + return 1 + } + if k.Val > 20 { + return -1 + } + return 0 + }, + func(v SimpleInterval) bool { + intervals = append(intervals, v) + return true + }) assert.Equal(t, []SimpleInterval{ {6, 10}, diff --git a/lib/containers/rbtree.go b/lib/containers/rbtree.go index 1fdb799..4125847 100644 --- a/lib/containers/rbtree.go +++ b/lib/containers/rbtree.go @@ -7,8 +7,6 @@ package containers import ( "fmt" "reflect" - - "git.lukeshu.com/btrfs-progs-ng/lib/slices" ) type Color bool @@ -18,50 +16,49 @@ const ( Red = Color(true) ) -type RBNode[V any] struct { - Parent, Left, Right *RBNode[V] +type RBNode[T Ordered[T]] struct { + Parent, Left, Right *RBNode[T] Color Color - Value V + Value T } -func (node *RBNode[V]) getColor() Color { +func (node *RBNode[T]) getColor() Color { if node == nil { return Black } return node.Color } -type RBTree[K Ordered[K], V any] struct { - KeyFn func(V) K - AttrFn func(*RBNode[V]) - root *RBNode[V] +type RBTree[T Ordered[T]] struct { + AttrFn func(*RBNode[T]) + root *RBNode[T] len int } -func (t *RBTree[K, V]) Len() int { +func (t *RBTree[T]) Len() int { return t.len } -func (t *RBTree[K, V]) Walk(fn func(*RBNode[V]) error) error { - return t.root.walk(fn) +func (t *RBTree[T]) Range(fn func(*RBNode[T]) bool) { + t.root._range(fn) } -func (node *RBNode[V]) walk(fn func(*RBNode[V]) error) error { +func (node *RBNode[T]) _range(fn func(*RBNode[T]) bool) bool { if node == nil { - return nil + return true } - if err := node.Left.walk(fn); err != nil { - return err + if !node.Left._range(fn) { + return false } - if err := fn(node); err != nil { - return err + if !fn(node) { + return false } - if err := node.Right.walk(fn); err != nil { - return err + if !node.Right._range(fn) { + return false } - return nil + return true } // Search the tree for a value that satisfied the given callbackk @@ -81,16 +78,13 @@ func (node *RBNode[V]) walk(fn func(*RBNode[V]) error) error { // +---+ +---+ // // Returns nil if no such value is found. -// -// Search is good for advanced lookup, like when a range of values is -// acceptable. For simple exact-value lookup, use Lookup. -func (t *RBTree[K, V]) Search(fn func(V) int) *RBNode[V] { +func (t *RBTree[T]) Search(fn func(T) int) *RBNode[T] { ret, _ := t.root.search(fn) return ret } -func (node *RBNode[V]) search(fn func(V) int) (exact, nearest *RBNode[V]) { - var prev *RBNode[V] +func (node *RBNode[T]) search(fn func(T) int) (exact, nearest *RBNode[T]) { + var prev *RBNode[T] for { if node == nil { return nil, prev @@ -108,26 +102,13 @@ func (node *RBNode[V]) search(fn func(V) int) (exact, nearest *RBNode[V]) { } } -func (t *RBTree[K, V]) exactKey(key K) func(V) int { - return func(val V) int { - valKey := t.KeyFn(val) - return key.Compare(valKey) - } -} - -// Lookup looks up the value for an exact key. If no such value -// exists, nil is returned. -func (t *RBTree[K, V]) Lookup(key K) *RBNode[V] { - return t.Search(t.exactKey(key)) -} - // Min returns the minimum value stored in the tree, or nil if the // tree is empty. -func (t *RBTree[K, V]) Min() *RBNode[V] { +func (t *RBTree[T]) Min() *RBNode[T] { return t.root.min() } -func (node *RBNode[V]) min() *RBNode[V] { +func (node *RBNode[T]) min() *RBNode[T] { if node == nil { return nil } @@ -141,11 +122,11 @@ func (node *RBNode[V]) min() *RBNode[V] { // Max returns the maximum value stored in the tree, or nil if the // tree is empty. -func (t *RBTree[K, V]) Max() *RBNode[V] { +func (t *RBTree[T]) Max() *RBNode[T] { return t.root.max() } -func (node *RBNode[V]) max() *RBNode[V] { +func (node *RBNode[T]) max() *RBNode[T] { if node == nil { return nil } @@ -157,11 +138,7 @@ func (node *RBNode[V]) max() *RBNode[V] { } } -func (t *RBTree[K, V]) Next(cur *RBNode[V]) *RBNode[V] { - return cur.next() -} - -func (cur *RBNode[V]) next() *RBNode[V] { +func (cur *RBNode[T]) Next() *RBNode[T] { if cur.Right != nil { return cur.Right.min() } @@ -172,11 +149,7 @@ func (cur *RBNode[V]) next() *RBNode[V] { return parent } -func (t *RBTree[K, V]) Prev(cur *RBNode[V]) *RBNode[V] { - return cur.prev() -} - -func (cur *RBNode[V]) prev() *RBNode[V] { +func (cur *RBNode[T]) Prev() *RBNode[T] { if cur.Left != nil { return cur.Left.max() } @@ -187,48 +160,56 @@ func (cur *RBNode[V]) prev() *RBNode[V] { return parent } -// SearchRange is like Search, but returns all nodes that match the -// function; assuming that they are contiguous. -func (t *RBTree[K, V]) SearchRange(fn func(V) int) []V { - middle := t.Search(fn) - if middle == nil { - return nil - } - ret := []V{middle.Value} - for node := t.Prev(middle); node != nil && fn(node.Value) == 0; node = t.Prev(node) { - ret = append(ret, node.Value) +// Subrange is like Search, but for when there may be more than one +// result. +func (t *RBTree[T]) Subrange(rangeFn func(T) int, handleFn func(*RBNode[T]) bool) { + // Find the left-most acceptable node. + _, node := t.root.search(func(v T) int { + if rangeFn(v) <= 0 { + return -1 + } else { + return 1 + } + }) + for node != nil && rangeFn(node.Value) > 0 { + node = node.Next() } - slices.Reverse(ret) - for node := t.Next(middle); node != nil && fn(node.Value) == 0; node = t.Next(node) { - ret = append(ret, node.Value) + // Now walk forward until we hit the end. + for node != nil && rangeFn(node.Value) == 0 { + if keepGoing := handleFn(node); !keepGoing { + return + } + node = node.Next() } - return ret } -func (t *RBTree[K, V]) Equal(u *RBTree[K, V]) bool { +func (t *RBTree[T]) Equal(u *RBTree[T]) bool { if (t == nil) != (u == nil) { return false } if t == nil { return true } + if t.len != u.len { + return false + } - var tSlice []V - _ = t.Walk(func(node *RBNode[V]) error { + tSlice := make([]T, 0, t.len) + t.Range(func(node *RBNode[T]) bool { tSlice = append(tSlice, node.Value) - return nil + return true }) - var uSlice []V - _ = u.Walk(func(node *RBNode[V]) error { + uSlice := make([]T, 0, u.len) + u.Range(func(node *RBNode[T]) bool { uSlice = append(uSlice, node.Value) - return nil + return true }) return reflect.DeepEqual(tSlice, uSlice) } -func (t *RBTree[K, V]) parentChild(node *RBNode[V]) **RBNode[V] { +func (t *RBTree[T]) parentChild(node *RBNode[T]) **RBNode[T] { switch { case node.Parent == nil: return &t.root @@ -241,7 +222,7 @@ func (t *RBTree[K, V]) parentChild(node *RBNode[V]) **RBNode[V] { } } -func (t *RBTree[K, V]) updateAttr(node *RBNode[V]) { +func (t *RBTree[T]) updateAttr(node *RBNode[T]) { if t.AttrFn == nil { return } @@ -251,7 +232,7 @@ func (t *RBTree[K, V]) updateAttr(node *RBNode[V]) { } } -func (t *RBTree[K, V]) leftRotate(x *RBNode[V]) { +func (t *RBTree[T]) leftRotate(x *RBNode[T]) { // p p // | | // +---+ +---+ @@ -286,7 +267,7 @@ func (t *RBTree[K, V]) leftRotate(x *RBNode[V]) { t.updateAttr(x) } -func (t *RBTree[K, V]) rightRotate(y *RBNode[V]) { +func (t *RBTree[T]) rightRotate(y *RBNode[T]) { //nolint:dupword // // | | @@ -322,18 +303,17 @@ func (t *RBTree[K, V]) rightRotate(y *RBNode[V]) { t.updateAttr(y) } -func (t *RBTree[K, V]) Insert(val V) { +func (t *RBTree[T]) Insert(val T) { // Naive-insert - key := t.KeyFn(val) - exact, parent := t.root.search(t.exactKey(key)) + exact, parent := t.root.search(val.Compare) if exact != nil { exact.Value = val return } t.len++ - node := &RBNode[V]{ + node := &RBNode[T]{ Color: Red, Parent: parent, Value: val, @@ -341,7 +321,7 @@ func (t *RBTree[K, V]) Insert(val V) { switch { case parent == nil: t.root = node - case key.Compare(t.KeyFn(parent.Value)) < 0: + case val.Compare(parent.Value) < 0: parent.Left = node default: parent.Right = node @@ -391,15 +371,14 @@ func (t *RBTree[K, V]) Insert(val V) { t.root.Color = Black } -func (t *RBTree[K, V]) transplant(oldNode, newNode *RBNode[V]) { +func (t *RBTree[T]) transplant(oldNode, newNode *RBNode[T]) { *t.parentChild(oldNode) = newNode if newNode != nil { newNode.Parent = oldNode.Parent } } -func (t *RBTree[K, V]) Delete(key K) { - nodeToDelete := t.Lookup(key) +func (t *RBTree[T]) Delete(nodeToDelete *RBNode[T]) { if nodeToDelete == nil { return } @@ -410,8 +389,8 @@ func (t *RBTree[K, V]) Delete(key K) { // phase 1 - var nodeToRebalance *RBNode[V] - var nodeToRebalanceParent *RBNode[V] // in case 'nodeToRebalance' is nil, which it can be + var nodeToRebalance *RBNode[T] + var nodeToRebalanceParent *RBNode[T] // in case 'nodeToRebalance' is nil, which it can be needsRebalance := nodeToDelete.Color == Black switch { @@ -427,7 +406,7 @@ func (t *RBTree[K, V]) Delete(key K) { // The node being deleted has a child on both sides, // so we've go to reshuffle the parents a bit to make // room for those children. - next := nodeToDelete.next() + next := nodeToDelete.Next() if next.Parent == nodeToDelete { // p p // | | diff --git a/lib/containers/rbtree_test.go b/lib/containers/rbtree_test.go index e42410e..d2fe931 100644 --- a/lib/containers/rbtree_test.go +++ b/lib/containers/rbtree_test.go @@ -7,21 +7,22 @@ package containers import ( "fmt" "io" - "sort" "strings" "testing" "github.com/stretchr/testify/require" "golang.org/x/exp/constraints" + + "git.lukeshu.com/btrfs-progs-ng/lib/slices" ) -func (t *RBTree[K, V]) ASCIIArt() string { +func (t *RBTree[T]) ASCIIArt() string { var out strings.Builder t.root.asciiArt(&out, "", "", "") return out.String() } -func (node *RBNode[V]) String() string { +func (node *RBNode[T]) String() string { switch { case node == nil: return "nil" @@ -32,7 +33,7 @@ func (node *RBNode[V]) String() string { } } -func (node *RBNode[V]) asciiArt(w io.Writer, u, m, l string) { +func (node *RBNode[T]) asciiArt(w io.Writer, u, m, l string) { if node == nil { fmt.Fprintf(w, "%snil\n", m) return @@ -43,7 +44,7 @@ func (node *RBNode[V]) asciiArt(w io.Writer, u, m, l string) { node.Left.asciiArt(w, l+" | ", l+" `--", l+" ") } -func checkRBTree[K constraints.Ordered, V any](t *testing.T, expectedSet Set[K], tree *RBTree[NativeOrdered[K], V]) { +func checkRBTree[T constraints.Ordered](t *testing.T, expectedSet Set[T], tree *RBTree[NativeOrdered[T]]) { // 1. Every node is either red or black // 2. The root is black. @@ -52,19 +53,19 @@ func checkRBTree[K constraints.Ordered, V any](t *testing.T, expectedSet Set[K], // 3. Every nil is black. // 4. If a node is red, then both its children are black. - require.NoError(t, tree.Walk(func(node *RBNode[V]) error { + tree.Range(func(node *RBNode[NativeOrdered[T]]) bool { if node.getColor() == Red { require.Equal(t, Black, node.Left.getColor()) require.Equal(t, Black, node.Right.getColor()) } - return nil - })) + return true + }) // 5. For each node, all simple paths from the node to // descendent leaves contain the same number of black // nodes. - var walkCnt func(node *RBNode[V], cnt int, leafFn func(int)) - walkCnt = func(node *RBNode[V], cnt int, leafFn func(int)) { + var walkCnt func(node *RBNode[NativeOrdered[T]], cnt int, leafFn func(int)) + walkCnt = func(node *RBNode[NativeOrdered[T]], cnt int, leafFn func(int)) { if node.getColor() == Black { cnt++ } @@ -75,7 +76,7 @@ func checkRBTree[K constraints.Ordered, V any](t *testing.T, expectedSet Set[K], walkCnt(node.Left, cnt, leafFn) walkCnt(node.Right, cnt, leafFn) } - require.NoError(t, tree.Walk(func(node *RBNode[V]) error { + tree.Range(func(node *RBNode[NativeOrdered[T]]) bool { var cnts []int walkCnt(node, 0, func(cnt int) { cnts = append(cnts, cnt) @@ -86,27 +87,24 @@ func checkRBTree[K constraints.Ordered, V any](t *testing.T, expectedSet Set[K], break } } - return nil - })) + return true + }) // expected contents - expectedOrder := make([]K, 0, len(expectedSet)) - for k := range expectedSet { - expectedOrder = append(expectedOrder, k) - node := tree.Lookup(NativeOrdered[K]{Val: k}) + expectedOrder := make([]T, 0, len(expectedSet)) + for v := range expectedSet { + expectedOrder = append(expectedOrder, v) + node := tree.Search(NativeOrdered[T]{Val: v}.Compare) require.NotNil(t, tree, node) - require.Equal(t, k, tree.KeyFn(node.Value).Val) } - sort.Slice(expectedOrder, func(i, j int) bool { - return expectedOrder[i] < expectedOrder[j] + slices.Sort(expectedOrder) + actOrder := make([]T, 0, len(expectedSet)) + tree.Range(func(node *RBNode[NativeOrdered[T]]) bool { + actOrder = append(actOrder, node.Value.Val) + return true }) - actOrder := make([]K, 0, len(expectedSet)) - require.NoError(t, tree.Walk(func(node *RBNode[V]) error { - actOrder = append(actOrder, tree.KeyFn(node.Value).Val) - return nil - })) require.Equal(t, expectedOrder, actOrder) - require.Equal(t, tree.Len(), len(expectedSet)) + require.Equal(t, len(expectedSet), tree.Len()) } func FuzzRBTree(f *testing.F) { @@ -132,11 +130,7 @@ func FuzzRBTree(f *testing.F) { }) f.Fuzz(func(t *testing.T, dat []uint8) { - tree := &RBTree[NativeOrdered[uint8], uint8]{ - KeyFn: func(x uint8) NativeOrdered[uint8] { - return NativeOrdered[uint8]{Val: x} - }, - } + tree := new(RBTree[NativeOrdered[uint8]]) set := make(Set[uint8]) checkRBTree(t, set, tree) t.Logf("\n%s\n", tree.ASCIIArt()) @@ -145,18 +139,18 @@ func FuzzRBTree(f *testing.F) { val := (b & 0b0011_1111) if ins { t.Logf("Insert(%v)", val) - tree.Insert(val) + tree.Insert(NativeOrdered[uint8]{Val: val}) set.Insert(val) t.Logf("\n%s\n", tree.ASCIIArt()) - node := tree.Lookup(NativeOrdered[uint8]{Val: val}) + node := tree.Search(NativeOrdered[uint8]{Val: val}.Compare) require.NotNil(t, node) - require.Equal(t, val, node.Value) + require.Equal(t, val, node.Value.Val) } else { t.Logf("Delete(%v)", val) - tree.Delete(NativeOrdered[uint8]{Val: val}) + tree.Delete(tree.Search(NativeOrdered[uint8]{Val: val}.Compare)) delete(set, val) t.Logf("\n%s\n", tree.ASCIIArt()) - require.Nil(t, tree.Lookup(NativeOrdered[uint8]{Val: val})) + require.Nil(t, tree.Search(NativeOrdered[uint8]{Val: val}.Compare)) } checkRBTree(t, set, tree) } diff --git a/lib/containers/sortedmap.go b/lib/containers/sortedmap.go index 52308c9..d104274 100644 --- a/lib/containers/sortedmap.go +++ b/lib/containers/sortedmap.go @@ -1,40 +1,32 @@ -// Copyright (C) 2022 Luke Shumaker +// Copyright (C) 2022-2023 Luke Shumaker // // SPDX-License-Identifier: GPL-2.0-or-later package containers -import ( - "errors" -) - -type OrderedKV[K Ordered[K], V any] struct { +type orderedKV[K Ordered[K], V any] struct { K K V V } -type SortedMap[K Ordered[K], V any] struct { - inner RBTree[K, OrderedKV[K, V]] -} - -func (m *SortedMap[K, V]) init() { - if m.inner.KeyFn == nil { - m.inner.KeyFn = m.keyFn - } +func (a orderedKV[K, V]) Compare(b orderedKV[K, V]) int { + return a.K.Compare(b.K) } -func (m *SortedMap[K, V]) keyFn(kv OrderedKV[K, V]) K { - return kv.K +type SortedMap[K Ordered[K], V any] struct { + inner RBTree[orderedKV[K, V]] } func (m *SortedMap[K, V]) Delete(key K) { - m.init() - m.inner.Delete(key) + m.inner.Delete(m.inner.Search(func(kv orderedKV[K, V]) int { + return key.Compare(kv.K) + })) } func (m *SortedMap[K, V]) Load(key K) (value V, ok bool) { - m.init() - node := m.inner.Lookup(key) + node := m.inner.Search(func(kv orderedKV[K, V]) int { + return key.Compare(kv.K) + }) if node == nil { var zero V return zero, false @@ -42,41 +34,27 @@ func (m *SortedMap[K, V]) Load(key K) (value V, ok bool) { return node.Value.V, true } -var errStop = errors.New("stop") - -func (m *SortedMap[K, V]) Range(f func(key K, value V) bool) { - m.init() - _ = m.inner.Walk(func(node *RBNode[OrderedKV[K, V]]) error { - if f(node.Value.K, node.Value.V) { - return nil - } else { - return errStop - } +func (m *SortedMap[K, V]) Store(key K, value V) { + m.inner.Insert(orderedKV[K, V]{ + K: key, + V: value, }) } -func (m *SortedMap[K, V]) Subrange(rangeFn func(K, V) int, handleFn func(K, V) bool) { - m.init() - kvs := m.inner.SearchRange(func(kv OrderedKV[K, V]) int { - return rangeFn(kv.K, kv.V) +func (m *SortedMap[K, V]) Range(fn func(key K, value V) bool) { + m.inner.Range(func(node *RBNode[orderedKV[K, V]]) bool { + return fn(node.Value.K, node.Value.V) }) - for _, kv := range kvs { - if !handleFn(kv.K, kv.V) { - break - } - } } -func (m *SortedMap[K, V]) Store(key K, value V) { - m.init() - m.inner.Insert(OrderedKV[K, V]{ - K: key, - V: value, - }) +func (m *SortedMap[K, V]) Subrange(rangeFn func(K, V) int, handleFn func(K, V) bool) { + m.inner.Subrange( + func(kv orderedKV[K, V]) int { return rangeFn(kv.K, kv.V) }, + func(node *RBNode[orderedKV[K, V]]) bool { return handleFn(node.Value.K, node.Value.V) }) } func (m *SortedMap[K, V]) Search(fn func(K, V) int) (K, V, bool) { - node := m.inner.Search(func(kv OrderedKV[K, V]) int { + node := m.inner.Search(func(kv orderedKV[K, V]) int { return fn(kv.K, kv.V) }) if node == nil { @@ -87,12 +65,6 @@ func (m *SortedMap[K, V]) Search(fn func(K, V) int) (K, V, bool) { return node.Value.K, node.Value.V, true } -func (m *SortedMap[K, V]) SearchAll(fn func(K, V) int) []OrderedKV[K, V] { - return m.inner.SearchRange(func(kv OrderedKV[K, V]) int { - return fn(kv.K, kv.V) - }) -} - func (m *SortedMap[K, V]) Len() int { return m.inner.Len() } -- cgit v1.2.3-54-g00ecf