summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLuke Shumaker <lukeshu@lukeshu.com>2023-02-12 02:44:35 -0700
committerLuke Shumaker <lukeshu@lukeshu.com>2023-02-12 02:44:35 -0700
commit128e4d9aa876e14a1203ce98bfaa7ad399ad97c7 (patch)
tree039c3c549414c21b15d58c3d695ee87c3feb1402
parent53d7fbb73869eb5defa1ca5c52b26abd346b13b9 (diff)
parent696a7d192e5eefa53230168a4b200ec0560c8a10 (diff)
Merge branch 'lukeshu/containers'
-rw-r--r--lib/btrfs/btrfsprim/misc.go8
-rw-r--r--lib/btrfs/btrfsprim/uuid.go2
-rw-r--r--lib/btrfs/btrfstree/ops.go2
-rw-r--r--lib/btrfs/btrfstree/root.go2
-rw-r--r--lib/btrfs/btrfstree/types_node.go4
-rw-r--r--lib/btrfs/btrfsvol/addr.go4
-rw-r--r--lib/btrfs/btrfsvol/blockgroupflags.go5
-rw-r--r--lib/btrfs/btrfsvol/chunk.go13
-rw-r--r--lib/btrfs/btrfsvol/devext.go11
-rw-r--r--lib/btrfs/btrfsvol/lvm.go93
-rw-r--r--lib/btrfs/io4_fs.go2
-rw-r--r--lib/btrfsprogs/btrfsinspect/print_addrspace.go4
-rw-r--r--lib/btrfsprogs/btrfsinspect/rebuildmappings/fuzzymatchsums.go6
-rw-r--r--lib/btrfsprogs/btrfsinspect/rebuildmappings/logicalsums.go14
-rw-r--r--lib/btrfsprogs/btrfsinspect/rebuildmappings/physicalsums.go4
-rw-r--r--lib/btrfsprogs/btrfsinspect/rebuildnodes/rebuild.go78
-rw-r--r--lib/btrfsprogs/btrfsinspect/scandevices.go6
-rw-r--r--lib/btrfsprogs/btrfsutil/broken_btree.go55
-rw-r--r--lib/containers/intervaltree.go136
-rw-r--r--lib/containers/intervaltree_test.go28
-rw-r--r--lib/containers/ordered.go33
-rw-r--r--lib/containers/ordered_test.go13
-rw-r--r--lib/containers/rbtree.go161
-rw-r--r--lib/containers/rbtree_test.go66
-rw-r--r--lib/containers/set.go2
-rw-r--r--lib/containers/sortedmap.go76
26 files changed, 433 insertions, 395 deletions
diff --git a/lib/btrfs/btrfsprim/misc.go b/lib/btrfs/btrfsprim/misc.go
index 22939bf..da661f6 100644
--- a/lib/btrfs/btrfsprim/misc.go
+++ b/lib/btrfs/btrfsprim/misc.go
@@ -44,14 +44,14 @@ func (key Key) Mm() Key {
return key
}
-func (a Key) Cmp(b Key) int {
- if d := containers.NativeCmp(a.ObjectID, b.ObjectID); d != 0 {
+func (a Key) Compare(b Key) int {
+ if d := containers.NativeCompare(a.ObjectID, b.ObjectID); d != 0 {
return d
}
- if d := containers.NativeCmp(a.ItemType, b.ItemType); d != 0 {
+ if d := containers.NativeCompare(a.ItemType, b.ItemType); d != 0 {
return d
}
- return containers.NativeCmp(a.Offset, b.Offset)
+ return containers.NativeCompare(a.Offset, b.Offset)
}
var _ containers.Ordered[Key] = Key{}
diff --git a/lib/btrfs/btrfsprim/uuid.go b/lib/btrfs/btrfsprim/uuid.go
index 0103ee4..232ab58 100644
--- a/lib/btrfs/btrfsprim/uuid.go
+++ b/lib/btrfs/btrfsprim/uuid.go
@@ -47,7 +47,7 @@ func (uuid UUID) Format(f fmt.State, verb rune) {
fmtutil.FormatByteArrayStringer(uuid, uuid[:], f, verb)
}
-func (a UUID) Cmp(b UUID) int {
+func (a UUID) Compare(b UUID) int {
for i := range a {
if d := int(a[i]) - int(b[i]); d != 0 {
return d
diff --git a/lib/btrfs/btrfstree/ops.go b/lib/btrfs/btrfstree/ops.go
index 537773a..cdacef9 100644
--- a/lib/btrfs/btrfstree/ops.go
+++ b/lib/btrfs/btrfstree/ops.go
@@ -481,7 +481,7 @@ func KeySearch(fn func(btrfsprim.Key) int) func(btrfsprim.Key, uint32) int {
// TreeLookup implements the 'Trees' interface.
func (fs TreeOperatorImpl) TreeLookup(treeID btrfsprim.ObjID, key btrfsprim.Key) (Item, error) {
- item, err := fs.TreeSearch(treeID, KeySearch(key.Cmp))
+ item, err := fs.TreeSearch(treeID, KeySearch(key.Compare))
if err != nil {
err = fmt.Errorf("item with key=%v: %w", key, err)
}
diff --git a/lib/btrfs/btrfstree/root.go b/lib/btrfs/btrfstree/root.go
index 6ec45b5..319904b 100644
--- a/lib/btrfs/btrfstree/root.go
+++ b/lib/btrfs/btrfstree/root.go
@@ -30,7 +30,7 @@ func RootItemSearchFn(treeID btrfsprim.ObjID) func(btrfsprim.Key, uint32) int {
ObjectID: treeID,
ItemType: btrfsitem.ROOT_ITEM_KEY,
Offset: 0,
- }.Cmp(key)
+ }.Compare(key)
}
}
diff --git a/lib/btrfs/btrfstree/types_node.go b/lib/btrfs/btrfstree/types_node.go
index a26215b..d9d7118 100644
--- a/lib/btrfs/btrfstree/types_node.go
+++ b/lib/btrfs/btrfstree/types_node.go
@@ -507,11 +507,11 @@ func ReadNode[Addr ~int64](fs diskio.File[Addr], sb Superblock, addr Addr, exp N
if nodeRef.Data.Head.NumItems == 0 {
errs = append(errs, fmt.Errorf("has no items"))
} else {
- if minItem, _ := nodeRef.Data.MinItem(); exp.MinItem.OK && exp.MinItem.Val.Cmp(minItem) > 0 {
+ if minItem, _ := nodeRef.Data.MinItem(); exp.MinItem.OK && exp.MinItem.Val.Compare(minItem) > 0 {
errs = append(errs, fmt.Errorf("expected minItem>=%v but node has minItem=%v",
exp.MinItem, minItem))
}
- if maxItem, _ := nodeRef.Data.MaxItem(); exp.MaxItem.OK && exp.MaxItem.Val.Cmp(maxItem) < 0 {
+ if maxItem, _ := nodeRef.Data.MaxItem(); exp.MaxItem.OK && exp.MaxItem.Val.Compare(maxItem) < 0 {
errs = append(errs, fmt.Errorf("expected maxItem<=%v but node has maxItem=%v",
exp.MaxItem, maxItem))
}
diff --git a/lib/btrfs/btrfsvol/addr.go b/lib/btrfs/btrfsvol/addr.go
index 94320ef..655f4e9 100644
--- a/lib/btrfs/btrfsvol/addr.go
+++ b/lib/btrfs/btrfsvol/addr.go
@@ -1,4 +1,4 @@
-// Copyright (C) 2022 Luke Shumaker <lukeshu@lukeshu.com>
+// Copyright (C) 2022-2023 Luke Shumaker <lukeshu@lukeshu.com>
//
// SPDX-License-Identifier: GPL-2.0-or-later
@@ -50,7 +50,7 @@ func (a QualifiedPhysicalAddr) Add(b AddrDelta) QualifiedPhysicalAddr {
}
}
-func (a QualifiedPhysicalAddr) Cmp(b QualifiedPhysicalAddr) int {
+func (a QualifiedPhysicalAddr) Compare(b QualifiedPhysicalAddr) int {
if d := int(a.Dev - b.Dev); d != 0 {
return d
}
diff --git a/lib/btrfs/btrfsvol/blockgroupflags.go b/lib/btrfs/btrfsvol/blockgroupflags.go
index 4ca5544..0125664 100644
--- a/lib/btrfs/btrfsvol/blockgroupflags.go
+++ b/lib/btrfs/btrfsvol/blockgroupflags.go
@@ -23,6 +23,11 @@ const (
BLOCK_GROUP_RAID1C3
BLOCK_GROUP_RAID1C4
+ // BLOCK_GROUP_RAID_MASK is the set of bits that mean that
+ // mean the logical:physical relationship is a one:many
+ // relationship rather than a one:one relationship.
+ //
+ // Notably, this does not include BLOCK_GROUP_RAID0.
BLOCK_GROUP_RAID_MASK = (BLOCK_GROUP_RAID1 | BLOCK_GROUP_DUP | BLOCK_GROUP_RAID10 | BLOCK_GROUP_RAID5 | BLOCK_GROUP_RAID6 | BLOCK_GROUP_RAID1C3 | BLOCK_GROUP_RAID1C4)
)
diff --git a/lib/btrfs/btrfsvol/chunk.go b/lib/btrfs/btrfsvol/chunk.go
index 5f1baa8..a112fd3 100644
--- a/lib/btrfs/btrfsvol/chunk.go
+++ b/lib/btrfs/btrfsvol/chunk.go
@@ -1,4 +1,4 @@
-// Copyright (C) 2022 Luke Shumaker <lukeshu@lukeshu.com>
+// Copyright (C) 2022-2023 Luke Shumaker <lukeshu@lukeshu.com>
//
// SPDX-License-Identifier: GPL-2.0-or-later
@@ -22,12 +22,15 @@ type chunkMapping struct {
Flags containers.Optional[BlockGroupFlags]
}
-type ChunkMapping = chunkMapping
+// Compare implements containers.Ordered.
+func (a chunkMapping) Compare(b chunkMapping) int {
+ return containers.NativeCompare(a.LAddr, b.LAddr)
+}
// return -1 if 'a' is wholly to the left of 'b'
// return 0 if there is some overlap between 'a' and 'b'
// return 1 if 'a is wholly to the right of 'b'
-func (a chunkMapping) cmpRange(b chunkMapping) int {
+func (a chunkMapping) compareRange(b chunkMapping) int {
switch {
case a.LAddr.Add(a.Size) <= b.LAddr:
// 'a' is wholly to the left of 'b'.
@@ -44,7 +47,7 @@ func (a chunkMapping) cmpRange(b chunkMapping) int {
func (a chunkMapping) union(rest ...chunkMapping) (chunkMapping, error) {
// sanity check
for _, chunk := range rest {
- if a.cmpRange(chunk) != 0 {
+ if a.compareRange(chunk) != 0 {
return chunkMapping{}, fmt.Errorf("chunks don't overlap")
}
}
@@ -79,7 +82,7 @@ func (a chunkMapping) union(rest ...chunkMapping) (chunkMapping, error) {
}
ret.PAddrs = maps.Keys(paddrs)
sort.Slice(ret.PAddrs, func(i, j int) bool {
- return ret.PAddrs[i].Cmp(ret.PAddrs[j]) < 0
+ return ret.PAddrs[i].Compare(ret.PAddrs[j]) < 0
})
// figure out the flags (.Flags)
for _, chunk := range chunks {
diff --git a/lib/btrfs/btrfsvol/devext.go b/lib/btrfs/btrfsvol/devext.go
index e8e5446..3324476 100644
--- a/lib/btrfs/btrfsvol/devext.go
+++ b/lib/btrfs/btrfsvol/devext.go
@@ -1,4 +1,4 @@
-// Copyright (C) 2022 Luke Shumaker <lukeshu@lukeshu.com>
+// Copyright (C) 2022-2023 Luke Shumaker <lukeshu@lukeshu.com>
//
// SPDX-License-Identifier: GPL-2.0-or-later
@@ -20,10 +20,15 @@ type devextMapping struct {
Flags containers.Optional[BlockGroupFlags]
}
+// Compare implements containers.Ordered.
+func (a devextMapping) Compare(b devextMapping) int {
+ return containers.NativeCompare(a.PAddr, b.PAddr)
+}
+
// return -1 if 'a' is wholly to the left of 'b'
// return 0 if there is some overlap between 'a' and 'b'
// return 1 if 'a is wholly to the right of 'b'
-func (a devextMapping) cmpRange(b devextMapping) int {
+func (a devextMapping) compareRange(b devextMapping) int {
switch {
case a.PAddr.Add(a.Size) <= b.PAddr:
// 'a' is wholly to the left of 'b'.
@@ -40,7 +45,7 @@ func (a devextMapping) cmpRange(b devextMapping) int {
func (a devextMapping) union(rest ...devextMapping) (devextMapping, error) {
// sanity check
for _, ext := range rest {
- if a.cmpRange(ext) != 0 {
+ if a.compareRange(ext) != 0 {
return devextMapping{}, fmt.Errorf("devexts don't overlap")
}
}
diff --git a/lib/btrfs/btrfsvol/lvm.go b/lib/btrfs/btrfsvol/lvm.go
index 1cb1ded..93ec438 100644
--- a/lib/btrfs/btrfsvol/lvm.go
+++ b/lib/btrfs/btrfsvol/lvm.go
@@ -22,8 +22,8 @@ type LogicalVolume[PhysicalVolume diskio.File[PhysicalAddr]] struct {
id2pv map[DeviceID]PhysicalVolume
- logical2physical *containers.RBTree[containers.NativeOrdered[LogicalAddr], chunkMapping]
- physical2logical map[DeviceID]*containers.RBTree[containers.NativeOrdered[PhysicalAddr], devextMapping]
+ logical2physical *containers.RBTree[chunkMapping]
+ physical2logical map[DeviceID]*containers.RBTree[devextMapping]
}
var _ diskio.File[LogicalAddr] = (*LogicalVolume[diskio.File[PhysicalAddr]])(nil)
@@ -33,22 +33,14 @@ func (lv *LogicalVolume[PhysicalVolume]) init() {
lv.id2pv = make(map[DeviceID]PhysicalVolume)
}
if lv.logical2physical == nil {
- lv.logical2physical = &containers.RBTree[containers.NativeOrdered[LogicalAddr], chunkMapping]{
- KeyFn: func(chunk chunkMapping) containers.NativeOrdered[LogicalAddr] {
- return containers.NativeOrdered[LogicalAddr]{Val: chunk.LAddr}
- },
- }
+ lv.logical2physical = new(containers.RBTree[chunkMapping])
}
if lv.physical2logical == nil {
- lv.physical2logical = make(map[DeviceID]*containers.RBTree[containers.NativeOrdered[PhysicalAddr], devextMapping], len(lv.id2pv))
+ lv.physical2logical = make(map[DeviceID]*containers.RBTree[devextMapping], len(lv.id2pv))
}
for devid := range lv.id2pv {
if _, ok := lv.physical2logical[devid]; !ok {
- lv.physical2logical[devid] = &containers.RBTree[containers.NativeOrdered[PhysicalAddr], devextMapping]{
- KeyFn: func(ext devextMapping) containers.NativeOrdered[PhysicalAddr] {
- return containers.NativeOrdered[PhysicalAddr]{Val: ext.PAddr}
- },
- }
+ lv.physical2logical[devid] = new(containers.RBTree[devextMapping])
}
}
}
@@ -90,11 +82,7 @@ func (lv *LogicalVolume[PhysicalVolume]) AddPhysicalVolume(id DeviceID, dev Phys
lv, dev.Name(), other.Name(), id)
}
lv.id2pv[id] = dev
- lv.physical2logical[id] = &containers.RBTree[containers.NativeOrdered[PhysicalAddr], devextMapping]{
- KeyFn: func(ext devextMapping) containers.NativeOrdered[PhysicalAddr] {
- return containers.NativeOrdered[PhysicalAddr]{Val: ext.PAddr}
- },
- }
+ lv.physical2logical[id] = new(containers.RBTree[devextMapping])
return nil
}
@@ -143,7 +131,13 @@ func (lv *LogicalVolume[PhysicalVolume]) addMapping(m Mapping, dryRun bool) erro
SizeLocked: m.SizeLocked,
Flags: m.Flags,
}
- logicalOverlaps := lv.logical2physical.SearchRange(newChunk.cmpRange)
+ var logicalOverlaps []chunkMapping
+ numOverlappingStripes := 0
+ lv.logical2physical.Subrange(newChunk.compareRange, func(node *containers.RBNode[chunkMapping]) bool {
+ logicalOverlaps = append(logicalOverlaps, node.Value)
+ numOverlappingStripes += len(node.Value.PAddrs)
+ return true
+ })
var err error
newChunk, err = newChunk.union(logicalOverlaps...)
if err != nil {
@@ -158,12 +152,38 @@ func (lv *LogicalVolume[PhysicalVolume]) addMapping(m Mapping, dryRun bool) erro
SizeLocked: m.SizeLocked,
Flags: m.Flags,
}
- physicalOverlaps := lv.physical2logical[m.PAddr.Dev].SearchRange(newExt.cmpRange)
+ var physicalOverlaps []devextMapping
+ lv.physical2logical[m.PAddr.Dev].Subrange(newExt.compareRange, func(node *containers.RBNode[devextMapping]) bool {
+ physicalOverlaps = append(physicalOverlaps, node.Value)
+ return true
+ })
newExt, err = newExt.union(physicalOverlaps...)
if err != nil {
return fmt.Errorf("(%p).AddMapping: %w", lv, err)
}
+ if newChunk.Flags != newExt.Flags {
+ // If these don't match up, it's a bug in this code.
+ panic(fmt.Errorf("should not happen: newChunk.Flags:%+v != newExt.Flags:%+v",
+ newChunk.Flags, newExt.Flags))
+ }
+ switch {
+ case len(physicalOverlaps) == numOverlappingStripes:
+ // normal case
+ case len(physicalOverlaps) < numOverlappingStripes:
+ // .Flags = DUP or RAID{X}
+ if newChunk.Flags.OK && newChunk.Flags.Val&BLOCK_GROUP_RAID_MASK == 0 {
+ return fmt.Errorf("multiple stripes but flags=%v does not allow multiple stripes",
+ newChunk.Flags.Val)
+ }
+ case len(physicalOverlaps) > numOverlappingStripes:
+ // This should not happen because calling .AddMapping
+ // should update the two in lockstep; if these don't
+ // match up, it's a bug in this code.
+ panic(fmt.Errorf("should not happen: len(physicalOverlaps):%d != numOverlappingStripes:%d",
+ len(physicalOverlaps), numOverlappingStripes))
+ }
+
if dryRun {
return nil
}
@@ -176,13 +196,13 @@ func (lv *LogicalVolume[PhysicalVolume]) addMapping(m Mapping, dryRun bool) erro
// logical2physical
for _, chunk := range logicalOverlaps {
- lv.logical2physical.Delete(containers.NativeOrdered[LogicalAddr]{Val: chunk.LAddr})
+ lv.logical2physical.Delete(lv.logical2physical.Search(chunk.Compare))
}
lv.logical2physical.Insert(newChunk)
// physical2logical
for _, ext := range physicalOverlaps {
- lv.physical2logical[m.PAddr.Dev].Delete(containers.NativeOrdered[PhysicalAddr]{Val: ext.PAddr})
+ lv.physical2logical[m.PAddr.Dev].Delete(lv.physical2logical[m.PAddr.Dev].Search(ext.Compare))
}
lv.physical2logical[m.PAddr.Dev].Insert(newExt)
@@ -201,20 +221,18 @@ func (lv *LogicalVolume[PhysicalVolume]) addMapping(m Mapping, dryRun bool) erro
}
func (lv *LogicalVolume[PhysicalVolume]) fsck() error {
- physical2logical := make(map[DeviceID]*containers.RBTree[containers.NativeOrdered[PhysicalAddr], devextMapping])
- if err := lv.logical2physical.Walk(func(node *containers.RBNode[chunkMapping]) error {
+ physical2logical := make(map[DeviceID]*containers.RBTree[devextMapping])
+ var err error
+ lv.logical2physical.Range(func(node *containers.RBNode[chunkMapping]) bool {
chunk := node.Value
for _, stripe := range chunk.PAddrs {
if _, devOK := lv.id2pv[stripe.Dev]; !devOK {
- return fmt.Errorf("(%p).fsck: chunk references physical volume %v which does not exist",
+ err = fmt.Errorf("(%p).fsck: chunk references physical volume %v which does not exist",
lv, stripe.Dev)
+ return false
}
if _, exists := physical2logical[stripe.Dev]; !exists {
- physical2logical[stripe.Dev] = &containers.RBTree[containers.NativeOrdered[PhysicalAddr], devextMapping]{
- KeyFn: func(ext devextMapping) containers.NativeOrdered[PhysicalAddr] {
- return containers.NativeOrdered[PhysicalAddr]{Val: ext.PAddr}
- },
- }
+ physical2logical[stripe.Dev] = new(containers.RBTree[devextMapping])
}
physical2logical[stripe.Dev].Insert(devextMapping{
PAddr: stripe.Addr,
@@ -223,8 +241,9 @@ func (lv *LogicalVolume[PhysicalVolume]) fsck() error {
Flags: chunk.Flags,
})
}
- return nil
- }); err != nil {
+ return true
+ })
+ if err != nil {
return err
}
@@ -244,7 +263,7 @@ func (lv *LogicalVolume[PhysicalVolume]) fsck() error {
func (lv *LogicalVolume[PhysicalVolume]) Mappings() []Mapping {
var ret []Mapping
- _ = lv.logical2physical.Walk(func(node *containers.RBNode[chunkMapping]) error {
+ lv.logical2physical.Range(func(node *containers.RBNode[chunkMapping]) bool {
chunk := node.Value
for _, slice := range chunk.PAddrs {
ret = append(ret, Mapping{
@@ -254,14 +273,14 @@ func (lv *LogicalVolume[PhysicalVolume]) Mappings() []Mapping {
Flags: chunk.Flags,
})
}
- return nil
+ return true
})
return ret
}
func (lv *LogicalVolume[PhysicalVolume]) Resolve(laddr LogicalAddr) (paddrs containers.Set[QualifiedPhysicalAddr], maxlen AddrDelta) {
node := lv.logical2physical.Search(func(chunk chunkMapping) int {
- return chunkMapping{LAddr: laddr, Size: 1}.cmpRange(chunk)
+ return chunkMapping{LAddr: laddr, Size: 1}.compareRange(chunk)
})
if node == nil {
return nil, 0
@@ -281,7 +300,7 @@ func (lv *LogicalVolume[PhysicalVolume]) Resolve(laddr LogicalAddr) (paddrs cont
func (lv *LogicalVolume[PhysicalVolume]) ResolveAny(laddr LogicalAddr, size AddrDelta) (LogicalAddr, QualifiedPhysicalAddr) {
node := lv.logical2physical.Search(func(chunk chunkMapping) int {
- return chunkMapping{LAddr: laddr, Size: size}.cmpRange(chunk)
+ return chunkMapping{LAddr: laddr, Size: size}.compareRange(chunk)
})
if node == nil {
return -1, QualifiedPhysicalAddr{0, -1}
@@ -291,7 +310,7 @@ func (lv *LogicalVolume[PhysicalVolume]) ResolveAny(laddr LogicalAddr, size Addr
func (lv *LogicalVolume[PhysicalVolume]) UnResolve(paddr QualifiedPhysicalAddr) LogicalAddr {
node := lv.physical2logical[paddr.Dev].Search(func(ext devextMapping) int {
- return devextMapping{PAddr: paddr.Addr, Size: 1}.cmpRange(ext)
+ return devextMapping{PAddr: paddr.Addr, Size: 1}.compareRange(ext)
})
if node == nil {
return -1
diff --git a/lib/btrfs/io4_fs.go b/lib/btrfs/io4_fs.go
index 799c865..fce9c76 100644
--- a/lib/btrfs/io4_fs.go
+++ b/lib/btrfs/io4_fs.go
@@ -152,7 +152,7 @@ func (sv *Subvolume) LoadFullInode(inode btrfsprim.ObjID) (*FullInode, error) {
XAttrs: make(map[string]string),
}
items, err := sv.FS.TreeSearchAll(sv.TreeID, func(key btrfsprim.Key, _ uint32) int {
- return containers.NativeCmp(inode, key.ObjectID)
+ return containers.NativeCompare(inode, key.ObjectID)
})
if err != nil {
val.Errs = append(val.Errs, err)
diff --git a/lib/btrfsprogs/btrfsinspect/print_addrspace.go b/lib/btrfsprogs/btrfsinspect/print_addrspace.go
index a8b992e..e85e055 100644
--- a/lib/btrfsprogs/btrfsinspect/print_addrspace.go
+++ b/lib/btrfsprogs/btrfsinspect/print_addrspace.go
@@ -1,4 +1,4 @@
-// Copyright (C) 2022 Luke Shumaker <lukeshu@lukeshu.com>
+// Copyright (C) 2022-2023 Luke Shumaker <lukeshu@lukeshu.com>
//
// SPDX-License-Identifier: GPL-2.0-or-later
@@ -46,7 +46,7 @@ func PrintLogicalSpace(out io.Writer, fs *btrfs.FS) {
func PrintPhysicalSpace(out io.Writer, fs *btrfs.FS) {
mappings := fs.LV.Mappings()
sort.Slice(mappings, func(i, j int) bool {
- return mappings[i].PAddr.Cmp(mappings[j].PAddr) < 0
+ return mappings[i].PAddr.Compare(mappings[j].PAddr) < 0
})
var prevDev btrfsvol.DeviceID = 0
diff --git a/lib/btrfsprogs/btrfsinspect/rebuildmappings/fuzzymatchsums.go b/lib/btrfsprogs/btrfsinspect/rebuildmappings/fuzzymatchsums.go
index 6b75d84..4724c12 100644
--- a/lib/btrfsprogs/btrfsinspect/rebuildmappings/fuzzymatchsums.go
+++ b/lib/btrfsprogs/btrfsinspect/rebuildmappings/fuzzymatchsums.go
@@ -26,7 +26,7 @@ type fuzzyRecord struct {
N int
}
-func (a fuzzyRecord) Cmp(b fuzzyRecord) int {
+func (a fuzzyRecord) Compare(b fuzzyRecord) int {
switch {
case a.N < b.N:
return -1
@@ -148,12 +148,12 @@ func (l *lowestN[T]) Insert(v T) {
switch {
case len(l.Dat) < l.N:
l.Dat = append(l.Dat, v)
- case v.Cmp(l.Dat[0]) < 0:
+ case v.Compare(l.Dat[0]) < 0:
l.Dat[0] = v
default:
return
}
sort.Slice(l.Dat, func(i, j int) bool {
- return l.Dat[i].Cmp(l.Dat[j]) < 0
+ return l.Dat[i].Compare(l.Dat[j]) < 0
})
}
diff --git a/lib/btrfsprogs/btrfsinspect/rebuildmappings/logicalsums.go b/lib/btrfsprogs/btrfsinspect/rebuildmappings/logicalsums.go
index 69d14c7..7c02d05 100644
--- a/lib/btrfsprogs/btrfsinspect/rebuildmappings/logicalsums.go
+++ b/lib/btrfsprogs/btrfsinspect/rebuildmappings/logicalsums.go
@@ -53,11 +53,7 @@ func ExtractLogicalSums(ctx context.Context, scanResults btrfsinspect.ScanDevice
// "AAAAAAA" shouldn't be present, and if we just discard "BBBBBBBB"
// because it conflicts with "CCCCCCC", then we would erroneously
// include "AAAAAAA".
- addrspace := &containers.RBTree[containers.NativeOrdered[btrfsvol.LogicalAddr], btrfsinspect.SysExtentCSum]{
- KeyFn: func(item btrfsinspect.SysExtentCSum) containers.NativeOrdered[btrfsvol.LogicalAddr] {
- return containers.NativeOrdered[btrfsvol.LogicalAddr]{Val: item.Sums.Addr}
- },
- }
+ addrspace := new(containers.RBTree[btrfsinspect.SysExtentCSum])
for _, newRecord := range records {
for {
conflict := addrspace.Search(func(oldRecord btrfsinspect.SysExtentCSum) int {
@@ -85,7 +81,7 @@ func ExtractLogicalSums(ctx context.Context, scanResults btrfsinspect.ScanDevice
}
if oldRecord.Generation < newRecord.Generation {
// Newer generation wins.
- addrspace.Delete(containers.NativeOrdered[btrfsvol.LogicalAddr]{Val: oldRecord.Sums.Addr})
+ addrspace.Delete(conflict)
// loop around to check for more conflicts
continue
}
@@ -142,7 +138,7 @@ func ExtractLogicalSums(ctx context.Context, scanResults btrfsinspect.ScanDevice
},
},
}
- addrspace.Delete(containers.NativeOrdered[btrfsvol.LogicalAddr]{Val: oldRecord.Sums.Addr})
+ addrspace.Delete(conflict)
newRecord = unionRecord
// loop around to check for more conflicts
}
@@ -152,7 +148,7 @@ func ExtractLogicalSums(ctx context.Context, scanResults btrfsinspect.ScanDevice
var flattened SumRunWithGaps[btrfsvol.LogicalAddr]
var curAddr btrfsvol.LogicalAddr
var curSums strings.Builder
- _ = addrspace.Walk(func(node *containers.RBNode[btrfsinspect.SysExtentCSum]) error {
+ addrspace.Range(func(node *containers.RBNode[btrfsinspect.SysExtentCSum]) bool {
curEnd := curAddr + (btrfsvol.LogicalAddr(curSums.Len()/sumSize) * btrfssum.BlockSize)
if node.Value.Sums.Addr != curEnd {
if curSums.Len() > 0 {
@@ -166,7 +162,7 @@ func ExtractLogicalSums(ctx context.Context, scanResults btrfsinspect.ScanDevice
curSums.Reset()
}
curSums.WriteString(string(node.Value.Sums.Sums))
- return nil
+ return true
})
if curSums.Len() > 0 {
flattened.Runs = append(flattened.Runs, btrfssum.SumRun[btrfsvol.LogicalAddr]{
diff --git a/lib/btrfsprogs/btrfsinspect/rebuildmappings/physicalsums.go b/lib/btrfsprogs/btrfsinspect/rebuildmappings/physicalsums.go
index 0806a63..da22fbf 100644
--- a/lib/btrfsprogs/btrfsinspect/rebuildmappings/physicalsums.go
+++ b/lib/btrfsprogs/btrfsinspect/rebuildmappings/physicalsums.go
@@ -1,4 +1,4 @@
-// Copyright (C) 2022 Luke Shumaker <lukeshu@lukeshu.com>
+// Copyright (C) 2022-2023 Luke Shumaker <lukeshu@lukeshu.com>
//
// SPDX-License-Identifier: GPL-2.0-or-later
@@ -34,7 +34,7 @@ func ListUnmappedPhysicalRegions(fs *btrfs.FS) map[btrfsvol.DeviceID][]PhysicalR
pos := make(map[btrfsvol.DeviceID]btrfsvol.PhysicalAddr)
mappings := fs.LV.Mappings()
sort.Slice(mappings, func(i, j int) bool {
- return mappings[i].PAddr.Cmp(mappings[j].PAddr) < 0
+ return mappings[i].PAddr.Compare(mappings[j].PAddr) < 0
})
for _, mapping := range mappings {
if pos[mapping.PAddr.Dev] < mapping.PAddr.Addr {
diff --git a/lib/btrfsprogs/btrfsinspect/rebuildnodes/rebuild.go b/lib/btrfsprogs/btrfsinspect/rebuildnodes/rebuild.go
index ebca2bd..bd29278 100644
--- a/lib/btrfsprogs/btrfsinspect/rebuildnodes/rebuild.go
+++ b/lib/btrfsprogs/btrfsinspect/rebuildnodes/rebuild.go
@@ -36,11 +36,11 @@ type keyAndTree struct {
TreeID btrfsprim.ObjID
}
-func (a keyAndTree) Cmp(b keyAndTree) int {
- if d := containers.NativeCmp(a.TreeID, b.TreeID); d != 0 {
+func (a keyAndTree) Compare(b keyAndTree) int {
+ if d := containers.NativeCompare(a.TreeID, b.TreeID); d != 0 {
return d
}
- return a.Key.Cmp(b.Key)
+ return a.Key.Compare(b.Key)
}
func (o keyAndTree) String() string {
@@ -155,7 +155,7 @@ func (o *rebuilder) Rebuild(_ctx context.Context) error {
itemQueue := maps.Keys(o.itemQueue)
o.itemQueue = make(containers.Set[keyAndTree])
sort.Slice(itemQueue, func(i, j int) bool {
- return itemQueue[i].Cmp(itemQueue[j]) < 0
+ return itemQueue[i].Compare(itemQueue[j]) < 0
})
var progress itemStats
progress.D = len(itemQueue)
@@ -596,7 +596,7 @@ func (o *rebuilder) _want(ctx context.Context, treeID btrfsprim.ObjID, wantKey s
}
if key, _, ok := o.rebuilt.Tree(ctx, treeID).Items(ctx).Search(func(key btrfsprim.Key, _ keyio.ItemPtr) int {
key.Offset = 0
- return tgt.Cmp(key)
+ return tgt.Compare(key)
}); ok {
return key, true
}
@@ -608,7 +608,7 @@ func (o *rebuilder) _want(ctx context.Context, treeID btrfsprim.ObjID, wantKey s
}
wants := make(containers.Set[btrfsvol.LogicalAddr])
o.rebuilt.Tree(ctx, treeID).PotentialItems(ctx).Subrange(
- func(k btrfsprim.Key, _ keyio.ItemPtr) int { k.Offset = 0; return tgt.Cmp(k) },
+ func(k btrfsprim.Key, _ keyio.ItemPtr) int { k.Offset = 0; return tgt.Compare(k) },
func(_ btrfsprim.Key, v keyio.ItemPtr) bool {
wants.InsertFrom(o.rebuilt.Tree(ctx, treeID).LeafToRoots(ctx, v.Node))
return true
@@ -649,7 +649,7 @@ func (o *rebuilder) _wantOff(ctx context.Context, treeID btrfsprim.ObjID, wantKe
}
wants := make(containers.Set[btrfsvol.LogicalAddr])
o.rebuilt.Tree(ctx, treeID).PotentialItems(ctx).Subrange(
- func(k btrfsprim.Key, _ keyio.ItemPtr) int { return tgt.Cmp(k) },
+ func(k btrfsprim.Key, _ keyio.ItemPtr) int { return tgt.Compare(k) },
func(_ btrfsprim.Key, v keyio.ItemPtr) bool {
wants.InsertFrom(o.rebuilt.Tree(ctx, treeID).LeafToRoots(ctx, v.Node))
return true
@@ -674,7 +674,7 @@ func (o *rebuilder) _wantFunc(ctx context.Context, treeID btrfsprim.ObjID, wantK
o.rebuilt.Tree(ctx, treeID).Items(ctx).Subrange(
func(key btrfsprim.Key, _ keyio.ItemPtr) int {
key.Offset = 0
- return tgt.Cmp(key)
+ return tgt.Compare(key)
},
func(_ btrfsprim.Key, itemPtr keyio.ItemPtr) bool {
if fn(itemPtr) {
@@ -693,7 +693,7 @@ func (o *rebuilder) _wantFunc(ctx context.Context, treeID btrfsprim.ObjID, wantK
}
wants := make(containers.Set[btrfsvol.LogicalAddr])
o.rebuilt.Tree(ctx, treeID).PotentialItems(ctx).Subrange(
- func(k btrfsprim.Key, _ keyio.ItemPtr) int { k.Offset = 0; return tgt.Cmp(k) },
+ func(k btrfsprim.Key, _ keyio.ItemPtr) int { k.Offset = 0; return tgt.Compare(k) },
func(k btrfsprim.Key, v keyio.ItemPtr) bool {
if fn(v) {
wants.InsertFrom(o.rebuilt.Tree(ctx, treeID).LeafToRoots(ctx, v.Node))
@@ -735,9 +735,9 @@ func (o *rebuilder) _walkRange(
items.Subrange(
func(runKey btrfsprim.Key, _ keyio.ItemPtr) int {
switch {
- case min.Cmp(runKey) < 0:
+ case min.Compare(runKey) < 0:
return 1
- case max.Cmp(runKey) > 0:
+ case max.Compare(runKey) > 0:
return -1
default:
return 0
@@ -770,6 +770,16 @@ func (o *rebuilder) _walkRange(
})
}
+type gap struct {
+ // range is [Beg,End)
+ Beg, End uint64
+}
+
+// Compare implements containers.Ordered.
+func (a gap) Compare(b gap) int {
+ return containers.NativeCompare(a.Beg, b.Beg)
+}
+
func (o *rebuilder) _wantRange(
ctx context.Context,
treeID btrfsprim.ObjID, objID btrfsprim.ObjID, typ btrfsprim.ItemType,
@@ -787,15 +797,7 @@ func (o *rebuilder) _wantRange(
//
// Start with a gap of the whole range, then subtract each run
// from it.
- type gap struct {
- // range is [Beg,End)
- Beg, End uint64
- }
- gaps := &containers.RBTree[containers.NativeOrdered[uint64], gap]{
- KeyFn: func(gap gap) containers.NativeOrdered[uint64] {
- return containers.NativeOrdered[uint64]{Val: gap.Beg}
- },
- }
+ gaps := new(containers.RBTree[gap])
gaps.Insert(gap{
Beg: beg,
End: end,
@@ -805,23 +807,29 @@ func (o *rebuilder) _wantRange(
o.rebuilt.Tree(ctx, treeID).Items(ctx),
treeID, objID, typ, beg, end,
func(runKey btrfsprim.Key, _ keyio.ItemPtr, runBeg, runEnd uint64) {
- overlappingGaps := gaps.SearchRange(func(gap gap) int {
- switch {
- case gap.End <= runBeg:
- return 1
- case runEnd <= gap.Beg:
- return -1
- default:
- return 0
- }
- })
+ var overlappingGaps []*containers.RBNode[gap]
+ gaps.Subrange(
+ func(gap gap) int {
+ switch {
+ case gap.End <= runBeg:
+ return 1
+ case runEnd <= gap.Beg:
+ return -1
+ default:
+ return 0
+ }
+ },
+ func(node *containers.RBNode[gap]) bool {
+ overlappingGaps = append(overlappingGaps, node)
+ return true
+ })
if len(overlappingGaps) == 0 {
return
}
- gapsBeg := overlappingGaps[0].Beg
- gapsEnd := overlappingGaps[len(overlappingGaps)-1].End
+ gapsBeg := overlappingGaps[0].Value.Beg
+ gapsEnd := overlappingGaps[len(overlappingGaps)-1].Value.End
for _, gap := range overlappingGaps {
- gaps.Delete(containers.NativeOrdered[uint64]{Val: gap.Beg})
+ gaps.Delete(gap)
}
if gapsBeg < runBeg {
gaps.Insert(gap{
@@ -842,7 +850,7 @@ func (o *rebuilder) _wantRange(
return
}
potentialItems := o.rebuilt.Tree(ctx, treeID).PotentialItems(ctx)
- _ = gaps.Walk(func(rbNode *containers.RBNode[gap]) error {
+ gaps.Range(func(rbNode *containers.RBNode[gap]) bool {
gap := rbNode.Value
last := gap.Beg
o._walkRange(
@@ -874,7 +882,7 @@ func (o *rebuilder) _wantRange(
o.wantAugment(wantCtx, treeID, wantKey, nil)
}
}
- return nil
+ return true
})
}
diff --git a/lib/btrfsprogs/btrfsinspect/scandevices.go b/lib/btrfsprogs/btrfsinspect/scandevices.go
index 7668a83..9b8360c 100644
--- a/lib/btrfsprogs/btrfsinspect/scandevices.go
+++ b/lib/btrfsprogs/btrfsinspect/scandevices.go
@@ -22,6 +22,7 @@ import (
"git.lukeshu.com/btrfs-progs-ng/lib/btrfs/btrfssum"
"git.lukeshu.com/btrfs-progs-ng/lib/btrfs/btrfstree"
"git.lukeshu.com/btrfs-progs-ng/lib/btrfs/btrfsvol"
+ "git.lukeshu.com/btrfs-progs-ng/lib/containers"
"git.lukeshu.com/btrfs-progs-ng/lib/textui"
)
@@ -79,6 +80,11 @@ type SysExtentCSum struct {
Sums btrfsitem.ExtentCSum
}
+// Compare implements containers.Ordered.
+func (a SysExtentCSum) Compare(b SysExtentCSum) int {
+ return containers.NativeCompare(a.Sums.Addr, b.Sums.Addr)
+}
+
type scanStats struct {
textui.Portion[btrfsvol.PhysicalAddr]
diff --git a/lib/btrfsprogs/btrfsutil/broken_btree.go b/lib/btrfsprogs/btrfsutil/broken_btree.go
index f0b298e..7ea31ce 100644
--- a/lib/btrfsprogs/btrfsutil/broken_btree.go
+++ b/lib/btrfsprogs/btrfsutil/broken_btree.go
@@ -23,7 +23,7 @@ import (
type treeIndex struct {
TreeRootErr error
- Items *containers.RBTree[btrfsprim.Key, treeIndexValue]
+ Items *containers.RBTree[treeIndexValue]
Errors *containers.IntervalTree[btrfsprim.Key, treeIndexError]
}
@@ -38,13 +38,14 @@ type treeIndexValue struct {
ItemSize uint32
}
+// Compare implements containers.Ordered.
+func (a treeIndexValue) Compare(b treeIndexValue) int {
+ return a.Key.Compare(b.Key)
+}
+
func newTreeIndex(arena *SkinnyPathArena) treeIndex {
return treeIndex{
- Items: &containers.RBTree[btrfsprim.Key, treeIndexValue]{
- KeyFn: func(iv treeIndexValue) btrfsprim.Key {
- return iv.Key
- },
- },
+ Items: new(containers.RBTree[treeIndexValue]),
Errors: &containers.IntervalTree[btrfsprim.Key, treeIndexError]{
MinFn: func(err treeIndexError) btrfsprim.Key {
return arena.Inflate(err.Path).Node(-1).ToKey
@@ -173,7 +174,7 @@ func (bt *brokenTrees) rawTreeWalk(root btrfstree.TreeRoot, cacheEntry treeIndex
},
btrfstree.TreeWalkHandler{
Item: func(path btrfstree.TreePath, item btrfstree.Item) error {
- if cacheEntry.Items.Lookup(item.Key) != nil {
+ if cacheEntry.Items.Search(func(v treeIndexValue) int { return item.Key.Compare(v.Key) }) != nil {
// This is a panic because I'm not really sure what the best way to
// handle this is, and so if this happens I want the program to crash
// and force me to figure out how to handle it.
@@ -194,7 +195,7 @@ func (bt *brokenTrees) rawTreeWalk(root btrfstree.TreeRoot, cacheEntry treeIndex
}
func (bt *brokenTrees) TreeLookup(treeID btrfsprim.ObjID, key btrfsprim.Key) (btrfstree.Item, error) {
- item, err := bt.TreeSearch(treeID, btrfstree.KeySearch(key.Cmp))
+ item, err := bt.TreeSearch(treeID, btrfstree.KeySearch(key.Compare))
if err != nil {
err = fmt.Errorf("item with key=%v: %w", key, err)
}
@@ -203,15 +204,15 @@ func (bt *brokenTrees) TreeLookup(treeID btrfsprim.ObjID, key btrfsprim.Key) (bt
func (bt *brokenTrees) addErrs(index treeIndex, fn func(btrfsprim.Key, uint32) int, err error) error {
var errs derror.MultiError
- if _errs := index.Errors.SearchAll(func(k btrfsprim.Key) int { return fn(k, 0) }); len(_errs) > 0 {
- errs = make(derror.MultiError, len(_errs))
- for i := range _errs {
- errs[i] = &btrfstree.TreeError{
- Path: bt.arena.Inflate(_errs[i].Path),
- Err: _errs[i].Err,
- }
- }
- }
+ index.Errors.Subrange(
+ func(k btrfsprim.Key) int { return fn(k, 0) },
+ func(v treeIndexError) bool {
+ errs = append(errs, &btrfstree.TreeError{
+ Path: bt.arena.Inflate(v.Path),
+ Err: v.Err,
+ })
+ return true
+ })
if len(errs) == 0 {
return err
}
@@ -253,9 +254,13 @@ func (bt *brokenTrees) TreeSearchAll(treeID btrfsprim.ObjID, fn func(btrfsprim.K
return nil, index.TreeRootErr
}
- indexItems := index.Items.SearchRange(func(indexItem treeIndexValue) int {
- return fn(indexItem.Key, indexItem.ItemSize)
- })
+ var indexItems []treeIndexValue
+ index.Items.Subrange(
+ func(indexItem treeIndexValue) int { return fn(indexItem.Key, indexItem.ItemSize) },
+ func(node *containers.RBNode[treeIndexValue]) bool {
+ indexItems = append(indexItems, node.Value)
+ return true
+ })
if len(indexItems) == 0 {
return nil, bt.addErrs(index, fn, iofs.ErrNotExist)
}
@@ -290,12 +295,12 @@ func (bt *brokenTrees) TreeWalk(ctx context.Context, treeID btrfsprim.ObjID, err
return
}
var node *diskio.Ref[btrfsvol.LogicalAddr, btrfstree.Node]
- _ = index.Items.Walk(func(indexItem *containers.RBNode[treeIndexValue]) error {
+ index.Items.Range(func(indexItem *containers.RBNode[treeIndexValue]) bool {
if ctx.Err() != nil {
- return ctx.Err()
+ return false
}
if bt.ctx.Err() != nil {
- return bt.ctx.Err()
+ return false
}
if cbs.Item != nil {
itemPath := bt.arena.Inflate(indexItem.Value.Path)
@@ -304,7 +309,7 @@ func (bt *brokenTrees) TreeWalk(ctx context.Context, treeID btrfsprim.ObjID, err
node, err = bt.inner.ReadNode(itemPath.Parent())
if err != nil {
errHandle(&btrfstree.TreeError{Path: itemPath, Err: err})
- return nil //nolint:nilerr // We already called errHandle().
+ return true
}
}
item := node.Data.BodyLeaf[itemPath.Node(-1).FromItemIdx]
@@ -312,7 +317,7 @@ func (bt *brokenTrees) TreeWalk(ctx context.Context, treeID btrfsprim.ObjID, err
errHandle(&btrfstree.TreeError{Path: itemPath, Err: err})
}
}
- return nil
+ return true
})
}
diff --git a/lib/containers/intervaltree.go b/lib/containers/intervaltree.go
index 16bc916..7b96526 100644
--- a/lib/containers/intervaltree.go
+++ b/lib/containers/intervaltree.go
@@ -4,81 +4,80 @@
package containers
-type intervalKey[K Ordered[K]] struct {
+type interval[K Ordered[K]] struct {
Min, Max K
}
-func (ival intervalKey[K]) ContainsFn(fn func(K) int) bool {
- return fn(ival.Min) >= 0 && fn(ival.Max) <= 0
-}
-
-func (a intervalKey[K]) Cmp(b intervalKey[K]) int {
- if d := a.Min.Cmp(b.Min); d != 0 {
+// Compare implements Ordered.
+func (a interval[K]) Compare(b interval[K]) int {
+ if d := a.Min.Compare(b.Min); d != 0 {
return d
}
- return a.Max.Cmp(b.Max)
+ return a.Max.Compare(b.Max)
+}
+
+// ContainsFn returns whether this interval contains the range matched
+// by the given function.
+func (ival interval[K]) ContainsFn(fn func(K) int) bool {
+ return fn(ival.Min) >= 0 && fn(ival.Max) <= 0
}
type intervalValue[K Ordered[K], V any] struct {
- Val V
- SpanOfChildren intervalKey[K]
+ Val V
+ ValSpan interval[K]
+ ChildSpan interval[K]
+}
+
+// Compare implements Ordered.
+func (a intervalValue[K, V]) Compare(b intervalValue[K, V]) int {
+ return a.ValSpan.Compare(b.ValSpan)
}
type IntervalTree[K Ordered[K], V any] struct {
MinFn func(V) K
MaxFn func(V) K
- inner RBTree[intervalKey[K], intervalValue[K, V]]
-}
-
-func (t *IntervalTree[K, V]) keyFn(v intervalValue[K, V]) intervalKey[K] {
- return intervalKey[K]{
- Min: t.MinFn(v.Val),
- Max: t.MaxFn(v.Val),
- }
+ inner RBTree[intervalValue[K, V]]
}
func (t *IntervalTree[K, V]) attrFn(node *RBNode[intervalValue[K, V]]) {
- max := t.MaxFn(node.Value.Val)
- if node.Left != nil && node.Left.Value.SpanOfChildren.Max.Cmp(max) > 0 {
- max = node.Left.Value.SpanOfChildren.Max
+ max := node.Value.ValSpan.Max
+ if node.Left != nil && node.Left.Value.ChildSpan.Max.Compare(max) > 0 {
+ max = node.Left.Value.ChildSpan.Max
}
- if node.Right != nil && node.Right.Value.SpanOfChildren.Max.Cmp(max) > 0 {
- max = node.Right.Value.SpanOfChildren.Max
+ if node.Right != nil && node.Right.Value.ChildSpan.Max.Compare(max) > 0 {
+ max = node.Right.Value.ChildSpan.Max
}
- node.Value.SpanOfChildren.Max = max
+ node.Value.ChildSpan.Max = max
- min := t.MinFn(node.Value.Val)
- if node.Left != nil && node.Left.Value.SpanOfChildren.Min.Cmp(min) < 0 {
- min = node.Left.Value.SpanOfChildren.Min
+ min := node.Value.ValSpan.Min
+ if node.Left != nil && node.Left.Value.ChildSpan.Min.Compare(min) < 0 {
+ min = node.Left.Value.ChildSpan.Min
}
- if node.Right != nil && node.Right.Value.SpanOfChildren.Min.Cmp(min) < 0 {
- min = node.Right.Value.SpanOfChildren.Min
+ if node.Right != nil && node.Right.Value.ChildSpan.Min.Compare(min) < 0 {
+ min = node.Right.Value.ChildSpan.Min
}
- node.Value.SpanOfChildren.Min = min
+ node.Value.ChildSpan.Min = min
}
func (t *IntervalTree[K, V]) init() {
- if t.inner.KeyFn == nil {
- t.inner.KeyFn = t.keyFn
+ if t.inner.AttrFn == nil {
t.inner.AttrFn = t.attrFn
}
}
-func (t *IntervalTree[K, V]) Delete(min, max K) {
- t.init()
- t.inner.Delete(intervalKey[K]{
- Min: min,
- Max: max,
- })
-}
-
func (t *IntervalTree[K, V]) Equal(u *IntervalTree[K, V]) bool {
return t.inner.Equal(&u.inner)
}
func (t *IntervalTree[K, V]) Insert(val V) {
t.init()
- t.inner.Insert(intervalValue[K, V]{Val: val})
+ t.inner.Insert(intervalValue[K, V]{
+ Val: val,
+ ValSpan: interval[K]{
+ Min: t.MinFn(val),
+ Max: t.MaxFn(val),
+ },
+ })
}
func (t *IntervalTree[K, V]) Min() (K, bool) {
@@ -86,7 +85,7 @@ func (t *IntervalTree[K, V]) Min() (K, bool) {
var zero K
return zero, false
}
- return t.inner.root.Value.SpanOfChildren.Min, true
+ return t.inner.root.Value.ChildSpan.Min, true
}
func (t *IntervalTree[K, V]) Max() (K, bool) {
@@ -94,22 +93,18 @@ func (t *IntervalTree[K, V]) Max() (K, bool) {
var zero K
return zero, false
}
- return t.inner.root.Value.SpanOfChildren.Max, true
-}
-
-func (t *IntervalTree[K, V]) Lookup(k K) (V, bool) {
- return t.Search(k.Cmp)
+ return t.inner.root.Value.ChildSpan.Max, true
}
func (t *IntervalTree[K, V]) Search(fn func(K) int) (V, bool) {
node := t.inner.root
for node != nil {
switch {
- case t.keyFn(node.Value).ContainsFn(fn):
+ case node.Value.ValSpan.ContainsFn(fn):
return node.Value.Val, true
- case node.Left != nil && node.Left.Value.SpanOfChildren.ContainsFn(fn):
+ case node.Left != nil && node.Left.Value.ChildSpan.ContainsFn(fn):
node = node.Left
- case node.Right != nil && node.Right.Value.SpanOfChildren.ContainsFn(fn):
+ case node.Right != nil && node.Right.Value.ChildSpan.ContainsFn(fn):
node = node.Right
default:
node = nil
@@ -119,24 +114,33 @@ func (t *IntervalTree[K, V]) Search(fn func(K) int) (V, bool) {
return zero, false
}
-func (t *IntervalTree[K, V]) searchAll(fn func(K) int, node *RBNode[intervalValue[K, V]], ret *[]V) {
+func (t *IntervalTree[K, V]) Range(fn func(V) bool) {
+ t.inner.Range(func(node *RBNode[intervalValue[K, V]]) bool {
+ return fn(node.Value.Val)
+ })
+}
+
+func (t *IntervalTree[K, V]) Subrange(rangeFn func(K) int, handleFn func(V) bool) {
+ t.subrange(t.inner.root, rangeFn, handleFn)
+}
+
+func (t *IntervalTree[K, V]) subrange(node *RBNode[intervalValue[K, V]], rangeFn func(K) int, handleFn func(V) bool) bool {
if node == nil {
- return
+ return true
}
- if !node.Value.SpanOfChildren.ContainsFn(fn) {
- return
+ if !node.Value.ChildSpan.ContainsFn(rangeFn) {
+ return true
}
- t.searchAll(fn, node.Left, ret)
- if t.keyFn(node.Value).ContainsFn(fn) {
- *ret = append(*ret, node.Value.Val)
+ if !t.subrange(node.Left, rangeFn, handleFn) {
+ return false
}
- t.searchAll(fn, node.Right, ret)
-}
-
-func (t *IntervalTree[K, V]) SearchAll(fn func(K) int) []V {
- var ret []V
- t.searchAll(fn, t.inner.root, &ret)
- return ret
+ if node.Value.ValSpan.ContainsFn(rangeFn) {
+ if !handleFn(node.Value.Val) {
+ return false
+ }
+ }
+ if !t.subrange(node.Right, rangeFn, handleFn) {
+ return false
+ }
+ return true
}
-
-// TODO: func (t *IntervalTree[K, V]) Walk(fn func(*RBNode[V]) error) error
diff --git a/lib/containers/intervaltree_test.go b/lib/containers/intervaltree_test.go
index 45743a2..30885bd 100644
--- a/lib/containers/intervaltree_test.go
+++ b/lib/containers/intervaltree_test.go
@@ -18,8 +18,8 @@ func (t *IntervalTree[K, V]) ASCIIArt() string {
func (v intervalValue[K, V]) String() string {
return fmt.Sprintf("%v) ([%v,%v]",
v.Val,
- v.SpanOfChildren.Min,
- v.SpanOfChildren.Max)
+ v.ChildSpan.Min,
+ v.ChildSpan.Max)
}
func (v NativeOrdered[T]) String() string {
@@ -60,15 +60,21 @@ func TestIntervalTree(t *testing.T) {
t.Log("\n" + tree.ASCIIArt())
// find intervals that touch [9,20]
- intervals := tree.SearchAll(func(k NativeOrdered[int]) int {
- if k.Val < 9 {
- return 1
- }
- if k.Val > 20 {
- return -1
- }
- return 0
- })
+ var intervals []SimpleInterval
+ tree.Subrange(
+ func(k NativeOrdered[int]) int {
+ if k.Val < 9 {
+ return 1
+ }
+ if k.Val > 20 {
+ return -1
+ }
+ return 0
+ },
+ func(v SimpleInterval) bool {
+ intervals = append(intervals, v)
+ return true
+ })
assert.Equal(t,
[]SimpleInterval{
{6, 10},
diff --git a/lib/containers/ordered.go b/lib/containers/ordered.go
index d918149..1ebc17e 100644
--- a/lib/containers/ordered.go
+++ b/lib/containers/ordered.go
@@ -1,4 +1,4 @@
-// Copyright (C) 2022 Luke Shumaker <lukeshu@lukeshu.com>
+// Copyright (C) 2022-2023 Luke Shumaker <lukeshu@lukeshu.com>
//
// SPDX-License-Identifier: GPL-2.0-or-later
@@ -9,16 +9,38 @@ import (
)
type _Ordered[T any] interface {
- Cmp(T) int
+ Compare(T) int
}
+// An Ordered[T] is a type that has a
+//
+// func (a T) Compare(b T) int
+//
+// method that returns a value <1 if a is "less than" b, >1 if a is
+// "greater than" b, or 0 if a is "equal to" b.
+//
+// You can conceptualize as subtraction:
+//
+// func (a T) Compare(b T) int {
+// return a - b
+// }
+//
+// Be careful to avoid integer overflow if actually implementing it as
+// subtraction.
type Ordered[T _Ordered[T]] _Ordered[T]
+// NativeOrdered takes a type that is natively-ordered (integer types,
+// float types, and string types), and wraps them such that they
+// implement the Ordered interface.
type NativeOrdered[T constraints.Ordered] struct {
Val T
}
-func NativeCmp[T constraints.Ordered](a, b T) int {
+// NativeCompare[T] implements the Ordered[T] Compare operation for
+// natively-ordered (integer types, float types, and string types).
+// While this operation be conceptualized as subtration,
+// NativeCompare[T] is careful to avoid integer overflow.
+func NativeCompare[T constraints.Ordered](a, b T) int {
switch {
case a < b:
return -1
@@ -29,8 +51,9 @@ func NativeCmp[T constraints.Ordered](a, b T) int {
}
}
-func (a NativeOrdered[T]) Cmp(b NativeOrdered[T]) int {
- return NativeCmp(a.Val, b.Val)
+// Compare implements Ordered[T].
+func (a NativeOrdered[T]) Compare(b NativeOrdered[T]) int {
+ return NativeCompare(a.Val, b.Val)
}
var _ Ordered[NativeOrdered[int]] = NativeOrdered[int]{}
diff --git a/lib/containers/ordered_test.go b/lib/containers/ordered_test.go
new file mode 100644
index 0000000..4f0b2e6
--- /dev/null
+++ b/lib/containers/ordered_test.go
@@ -0,0 +1,13 @@
+// Copyright (C) 2023 Luke Shumaker <lukeshu@lukeshu.com>
+//
+// SPDX-License-Identifier: GPL-2.0-or-later
+
+package containers_test
+
+import (
+ "net/netip"
+
+ "git.lukeshu.com/btrfs-progs-ng/lib/containers"
+)
+
+var _ containers.Ordered[netip.Addr] = netip.Addr{}
diff --git a/lib/containers/rbtree.go b/lib/containers/rbtree.go
index 0430123..4125847 100644
--- a/lib/containers/rbtree.go
+++ b/lib/containers/rbtree.go
@@ -7,8 +7,6 @@ package containers
import (
"fmt"
"reflect"
-
- "git.lukeshu.com/btrfs-progs-ng/lib/slices"
)
type Color bool
@@ -18,50 +16,49 @@ const (
Red = Color(true)
)
-type RBNode[V any] struct {
- Parent, Left, Right *RBNode[V]
+type RBNode[T Ordered[T]] struct {
+ Parent, Left, Right *RBNode[T]
Color Color
- Value V
+ Value T
}
-func (node *RBNode[V]) getColor() Color {
+func (node *RBNode[T]) getColor() Color {
if node == nil {
return Black
}
return node.Color
}
-type RBTree[K Ordered[K], V any] struct {
- KeyFn func(V) K
- AttrFn func(*RBNode[V])
- root *RBNode[V]
+type RBTree[T Ordered[T]] struct {
+ AttrFn func(*RBNode[T])
+ root *RBNode[T]
len int
}
-func (t *RBTree[K, V]) Len() int {
+func (t *RBTree[T]) Len() int {
return t.len
}
-func (t *RBTree[K, V]) Walk(fn func(*RBNode[V]) error) error {
- return t.root.walk(fn)
+func (t *RBTree[T]) Range(fn func(*RBNode[T]) bool) {
+ t.root._range(fn)
}
-func (node *RBNode[V]) walk(fn func(*RBNode[V]) error) error {
+func (node *RBNode[T]) _range(fn func(*RBNode[T]) bool) bool {
if node == nil {
- return nil
+ return true
}
- if err := node.Left.walk(fn); err != nil {
- return err
+ if !node.Left._range(fn) {
+ return false
}
- if err := fn(node); err != nil {
- return err
+ if !fn(node) {
+ return false
}
- if err := node.Right.walk(fn); err != nil {
- return err
+ if !node.Right._range(fn) {
+ return false
}
- return nil
+ return true
}
// Search the tree for a value that satisfied the given callbackk
@@ -81,16 +78,13 @@ func (node *RBNode[V]) walk(fn func(*RBNode[V]) error) error {
// +---+ +---+
//
// Returns nil if no such value is found.
-//
-// Search is good for advanced lookup, like when a range of values is
-// acceptable. For simple exact-value lookup, use Lookup.
-func (t *RBTree[K, V]) Search(fn func(V) int) *RBNode[V] {
+func (t *RBTree[T]) Search(fn func(T) int) *RBNode[T] {
ret, _ := t.root.search(fn)
return ret
}
-func (node *RBNode[V]) search(fn func(V) int) (exact, nearest *RBNode[V]) {
- var prev *RBNode[V]
+func (node *RBNode[T]) search(fn func(T) int) (exact, nearest *RBNode[T]) {
+ var prev *RBNode[T]
for {
if node == nil {
return nil, prev
@@ -108,26 +102,13 @@ func (node *RBNode[V]) search(fn func(V) int) (exact, nearest *RBNode[V]) {
}
}
-func (t *RBTree[K, V]) exactKey(key K) func(V) int {
- return func(val V) int {
- valKey := t.KeyFn(val)
- return key.Cmp(valKey)
- }
-}
-
-// Lookup looks up the value for an exact key. If no such value
-// exists, nil is returned.
-func (t *RBTree[K, V]) Lookup(key K) *RBNode[V] {
- return t.Search(t.exactKey(key))
-}
-
// Min returns the minimum value stored in the tree, or nil if the
// tree is empty.
-func (t *RBTree[K, V]) Min() *RBNode[V] {
+func (t *RBTree[T]) Min() *RBNode[T] {
return t.root.min()
}
-func (node *RBNode[V]) min() *RBNode[V] {
+func (node *RBNode[T]) min() *RBNode[T] {
if node == nil {
return nil
}
@@ -141,11 +122,11 @@ func (node *RBNode[V]) min() *RBNode[V] {
// Max returns the maximum value stored in the tree, or nil if the
// tree is empty.
-func (t *RBTree[K, V]) Max() *RBNode[V] {
+func (t *RBTree[T]) Max() *RBNode[T] {
return t.root.max()
}
-func (node *RBNode[V]) max() *RBNode[V] {
+func (node *RBNode[T]) max() *RBNode[T] {
if node == nil {
return nil
}
@@ -157,11 +138,7 @@ func (node *RBNode[V]) max() *RBNode[V] {
}
}
-func (t *RBTree[K, V]) Next(cur *RBNode[V]) *RBNode[V] {
- return cur.next()
-}
-
-func (cur *RBNode[V]) next() *RBNode[V] {
+func (cur *RBNode[T]) Next() *RBNode[T] {
if cur.Right != nil {
return cur.Right.min()
}
@@ -172,11 +149,7 @@ func (cur *RBNode[V]) next() *RBNode[V] {
return parent
}
-func (t *RBTree[K, V]) Prev(cur *RBNode[V]) *RBNode[V] {
- return cur.prev()
-}
-
-func (cur *RBNode[V]) prev() *RBNode[V] {
+func (cur *RBNode[T]) Prev() *RBNode[T] {
if cur.Left != nil {
return cur.Left.max()
}
@@ -187,48 +160,56 @@ func (cur *RBNode[V]) prev() *RBNode[V] {
return parent
}
-// SearchRange is like Search, but returns all nodes that match the
-// function; assuming that they are contiguous.
-func (t *RBTree[K, V]) SearchRange(fn func(V) int) []V {
- middle := t.Search(fn)
- if middle == nil {
- return nil
- }
- ret := []V{middle.Value}
- for node := t.Prev(middle); node != nil && fn(node.Value) == 0; node = t.Prev(node) {
- ret = append(ret, node.Value)
+// Subrange is like Search, but for when there may be more than one
+// result.
+func (t *RBTree[T]) Subrange(rangeFn func(T) int, handleFn func(*RBNode[T]) bool) {
+ // Find the left-most acceptable node.
+ _, node := t.root.search(func(v T) int {
+ if rangeFn(v) <= 0 {
+ return -1
+ } else {
+ return 1
+ }
+ })
+ for node != nil && rangeFn(node.Value) > 0 {
+ node = node.Next()
}
- slices.Reverse(ret)
- for node := t.Next(middle); node != nil && fn(node.Value) == 0; node = t.Next(node) {
- ret = append(ret, node.Value)
+ // Now walk forward until we hit the end.
+ for node != nil && rangeFn(node.Value) == 0 {
+ if keepGoing := handleFn(node); !keepGoing {
+ return
+ }
+ node = node.Next()
}
- return ret
}
-func (t *RBTree[K, V]) Equal(u *RBTree[K, V]) bool {
+func (t *RBTree[T]) Equal(u *RBTree[T]) bool {
if (t == nil) != (u == nil) {
return false
}
if t == nil {
return true
}
+ if t.len != u.len {
+ return false
+ }
- var tSlice []V
- _ = t.Walk(func(node *RBNode[V]) error {
+ tSlice := make([]T, 0, t.len)
+ t.Range(func(node *RBNode[T]) bool {
tSlice = append(tSlice, node.Value)
- return nil
+ return true
})
- var uSlice []V
- _ = u.Walk(func(node *RBNode[V]) error {
+ uSlice := make([]T, 0, u.len)
+ u.Range(func(node *RBNode[T]) bool {
uSlice = append(uSlice, node.Value)
- return nil
+ return true
})
return reflect.DeepEqual(tSlice, uSlice)
}
-func (t *RBTree[K, V]) parentChild(node *RBNode[V]) **RBNode[V] {
+func (t *RBTree[T]) parentChild(node *RBNode[T]) **RBNode[T] {
switch {
case node.Parent == nil:
return &t.root
@@ -241,7 +222,7 @@ func (t *RBTree[K, V]) parentChild(node *RBNode[V]) **RBNode[V] {
}
}
-func (t *RBTree[K, V]) updateAttr(node *RBNode[V]) {
+func (t *RBTree[T]) updateAttr(node *RBNode[T]) {
if t.AttrFn == nil {
return
}
@@ -251,7 +232,7 @@ func (t *RBTree[K, V]) updateAttr(node *RBNode[V]) {
}
}
-func (t *RBTree[K, V]) leftRotate(x *RBNode[V]) {
+func (t *RBTree[T]) leftRotate(x *RBNode[T]) {
// p p
// | |
// +---+ +---+
@@ -286,7 +267,7 @@ func (t *RBTree[K, V]) leftRotate(x *RBNode[V]) {
t.updateAttr(x)
}
-func (t *RBTree[K, V]) rightRotate(y *RBNode[V]) {
+func (t *RBTree[T]) rightRotate(y *RBNode[T]) {
//nolint:dupword
//
// | |
@@ -322,18 +303,17 @@ func (t *RBTree[K, V]) rightRotate(y *RBNode[V]) {
t.updateAttr(y)
}
-func (t *RBTree[K, V]) Insert(val V) {
+func (t *RBTree[T]) Insert(val T) {
// Naive-insert
- key := t.KeyFn(val)
- exact, parent := t.root.search(t.exactKey(key))
+ exact, parent := t.root.search(val.Compare)
if exact != nil {
exact.Value = val
return
}
t.len++
- node := &RBNode[V]{
+ node := &RBNode[T]{
Color: Red,
Parent: parent,
Value: val,
@@ -341,7 +321,7 @@ func (t *RBTree[K, V]) Insert(val V) {
switch {
case parent == nil:
t.root = node
- case key.Cmp(t.KeyFn(parent.Value)) < 0:
+ case val.Compare(parent.Value) < 0:
parent.Left = node
default:
parent.Right = node
@@ -391,15 +371,14 @@ func (t *RBTree[K, V]) Insert(val V) {
t.root.Color = Black
}
-func (t *RBTree[K, V]) transplant(oldNode, newNode *RBNode[V]) {
+func (t *RBTree[T]) transplant(oldNode, newNode *RBNode[T]) {
*t.parentChild(oldNode) = newNode
if newNode != nil {
newNode.Parent = oldNode.Parent
}
}
-func (t *RBTree[K, V]) Delete(key K) {
- nodeToDelete := t.Lookup(key)
+func (t *RBTree[T]) Delete(nodeToDelete *RBNode[T]) {
if nodeToDelete == nil {
return
}
@@ -410,8 +389,8 @@ func (t *RBTree[K, V]) Delete(key K) {
// phase 1
- var nodeToRebalance *RBNode[V]
- var nodeToRebalanceParent *RBNode[V] // in case 'nodeToRebalance' is nil, which it can be
+ var nodeToRebalance *RBNode[T]
+ var nodeToRebalanceParent *RBNode[T] // in case 'nodeToRebalance' is nil, which it can be
needsRebalance := nodeToDelete.Color == Black
switch {
@@ -427,7 +406,7 @@ func (t *RBTree[K, V]) Delete(key K) {
// The node being deleted has a child on both sides,
// so we've go to reshuffle the parents a bit to make
// room for those children.
- next := nodeToDelete.next()
+ next := nodeToDelete.Next()
if next.Parent == nodeToDelete {
// p p
// | |
diff --git a/lib/containers/rbtree_test.go b/lib/containers/rbtree_test.go
index e42410e..d2fe931 100644
--- a/lib/containers/rbtree_test.go
+++ b/lib/containers/rbtree_test.go
@@ -7,21 +7,22 @@ package containers
import (
"fmt"
"io"
- "sort"
"strings"
"testing"
"github.com/stretchr/testify/require"
"golang.org/x/exp/constraints"
+
+ "git.lukeshu.com/btrfs-progs-ng/lib/slices"
)
-func (t *RBTree[K, V]) ASCIIArt() string {
+func (t *RBTree[T]) ASCIIArt() string {
var out strings.Builder
t.root.asciiArt(&out, "", "", "")
return out.String()
}
-func (node *RBNode[V]) String() string {
+func (node *RBNode[T]) String() string {
switch {
case node == nil:
return "nil"
@@ -32,7 +33,7 @@ func (node *RBNode[V]) String() string {
}
}
-func (node *RBNode[V]) asciiArt(w io.Writer, u, m, l string) {
+func (node *RBNode[T]) asciiArt(w io.Writer, u, m, l string) {
if node == nil {
fmt.Fprintf(w, "%snil\n", m)
return
@@ -43,7 +44,7 @@ func (node *RBNode[V]) asciiArt(w io.Writer, u, m, l string) {
node.Left.asciiArt(w, l+" | ", l+" `--", l+" ")
}
-func checkRBTree[K constraints.Ordered, V any](t *testing.T, expectedSet Set[K], tree *RBTree[NativeOrdered[K], V]) {
+func checkRBTree[T constraints.Ordered](t *testing.T, expectedSet Set[T], tree *RBTree[NativeOrdered[T]]) {
// 1. Every node is either red or black
// 2. The root is black.
@@ -52,19 +53,19 @@ func checkRBTree[K constraints.Ordered, V any](t *testing.T, expectedSet Set[K],
// 3. Every nil is black.
// 4. If a node is red, then both its children are black.
- require.NoError(t, tree.Walk(func(node *RBNode[V]) error {
+ tree.Range(func(node *RBNode[NativeOrdered[T]]) bool {
if node.getColor() == Red {
require.Equal(t, Black, node.Left.getColor())
require.Equal(t, Black, node.Right.getColor())
}
- return nil
- }))
+ return true
+ })
// 5. For each node, all simple paths from the node to
// descendent leaves contain the same number of black
// nodes.
- var walkCnt func(node *RBNode[V], cnt int, leafFn func(int))
- walkCnt = func(node *RBNode[V], cnt int, leafFn func(int)) {
+ var walkCnt func(node *RBNode[NativeOrdered[T]], cnt int, leafFn func(int))
+ walkCnt = func(node *RBNode[NativeOrdered[T]], cnt int, leafFn func(int)) {
if node.getColor() == Black {
cnt++
}
@@ -75,7 +76,7 @@ func checkRBTree[K constraints.Ordered, V any](t *testing.T, expectedSet Set[K],
walkCnt(node.Left, cnt, leafFn)
walkCnt(node.Right, cnt, leafFn)
}
- require.NoError(t, tree.Walk(func(node *RBNode[V]) error {
+ tree.Range(func(node *RBNode[NativeOrdered[T]]) bool {
var cnts []int
walkCnt(node, 0, func(cnt int) {
cnts = append(cnts, cnt)
@@ -86,27 +87,24 @@ func checkRBTree[K constraints.Ordered, V any](t *testing.T, expectedSet Set[K],
break
}
}
- return nil
- }))
+ return true
+ })
// expected contents
- expectedOrder := make([]K, 0, len(expectedSet))
- for k := range expectedSet {
- expectedOrder = append(expectedOrder, k)
- node := tree.Lookup(NativeOrdered[K]{Val: k})
+ expectedOrder := make([]T, 0, len(expectedSet))
+ for v := range expectedSet {
+ expectedOrder = append(expectedOrder, v)
+ node := tree.Search(NativeOrdered[T]{Val: v}.Compare)
require.NotNil(t, tree, node)
- require.Equal(t, k, tree.KeyFn(node.Value).Val)
}
- sort.Slice(expectedOrder, func(i, j int) bool {
- return expectedOrder[i] < expectedOrder[j]
+ slices.Sort(expectedOrder)
+ actOrder := make([]T, 0, len(expectedSet))
+ tree.Range(func(node *RBNode[NativeOrdered[T]]) bool {
+ actOrder = append(actOrder, node.Value.Val)
+ return true
})
- actOrder := make([]K, 0, len(expectedSet))
- require.NoError(t, tree.Walk(func(node *RBNode[V]) error {
- actOrder = append(actOrder, tree.KeyFn(node.Value).Val)
- return nil
- }))
require.Equal(t, expectedOrder, actOrder)
- require.Equal(t, tree.Len(), len(expectedSet))
+ require.Equal(t, len(expectedSet), tree.Len())
}
func FuzzRBTree(f *testing.F) {
@@ -132,11 +130,7 @@ func FuzzRBTree(f *testing.F) {
})
f.Fuzz(func(t *testing.T, dat []uint8) {
- tree := &RBTree[NativeOrdered[uint8], uint8]{
- KeyFn: func(x uint8) NativeOrdered[uint8] {
- return NativeOrdered[uint8]{Val: x}
- },
- }
+ tree := new(RBTree[NativeOrdered[uint8]])
set := make(Set[uint8])
checkRBTree(t, set, tree)
t.Logf("\n%s\n", tree.ASCIIArt())
@@ -145,18 +139,18 @@ func FuzzRBTree(f *testing.F) {
val := (b & 0b0011_1111)
if ins {
t.Logf("Insert(%v)", val)
- tree.Insert(val)
+ tree.Insert(NativeOrdered[uint8]{Val: val})
set.Insert(val)
t.Logf("\n%s\n", tree.ASCIIArt())
- node := tree.Lookup(NativeOrdered[uint8]{Val: val})
+ node := tree.Search(NativeOrdered[uint8]{Val: val}.Compare)
require.NotNil(t, node)
- require.Equal(t, val, node.Value)
+ require.Equal(t, val, node.Value.Val)
} else {
t.Logf("Delete(%v)", val)
- tree.Delete(NativeOrdered[uint8]{Val: val})
+ tree.Delete(tree.Search(NativeOrdered[uint8]{Val: val}.Compare))
delete(set, val)
t.Logf("\n%s\n", tree.ASCIIArt())
- require.Nil(t, tree.Lookup(NativeOrdered[uint8]{Val: val}))
+ require.Nil(t, tree.Search(NativeOrdered[uint8]{Val: val}.Compare))
}
checkRBTree(t, set, tree)
}
diff --git a/lib/containers/set.go b/lib/containers/set.go
index b2af494..0d9202c 100644
--- a/lib/containers/set.go
+++ b/lib/containers/set.go
@@ -32,7 +32,7 @@ func (o Set[T]) EncodeJSON(w io.Writer) error {
var zero T
switch (any(zero)).(type) {
case _Ordered[T]:
- less = func(a, b T) bool { return cast[_Ordered[T]](a).Cmp(b) < 0 }
+ less = func(a, b T) bool { return cast[_Ordered[T]](a).Compare(b) < 0 }
// This is the constraints.Ordered list
case string:
less = func(a, b T) bool { return cast[string](a) < cast[string](b) }
diff --git a/lib/containers/sortedmap.go b/lib/containers/sortedmap.go
index 52308c9..d104274 100644
--- a/lib/containers/sortedmap.go
+++ b/lib/containers/sortedmap.go
@@ -1,40 +1,32 @@
-// Copyright (C) 2022 Luke Shumaker <lukeshu@lukeshu.com>
+// Copyright (C) 2022-2023 Luke Shumaker <lukeshu@lukeshu.com>
//
// SPDX-License-Identifier: GPL-2.0-or-later
package containers
-import (
- "errors"
-)
-
-type OrderedKV[K Ordered[K], V any] struct {
+type orderedKV[K Ordered[K], V any] struct {
K K
V V
}
-type SortedMap[K Ordered[K], V any] struct {
- inner RBTree[K, OrderedKV[K, V]]
-}
-
-func (m *SortedMap[K, V]) init() {
- if m.inner.KeyFn == nil {
- m.inner.KeyFn = m.keyFn
- }
+func (a orderedKV[K, V]) Compare(b orderedKV[K, V]) int {
+ return a.K.Compare(b.K)
}
-func (m *SortedMap[K, V]) keyFn(kv OrderedKV[K, V]) K {
- return kv.K
+type SortedMap[K Ordered[K], V any] struct {
+ inner RBTree[orderedKV[K, V]]
}
func (m *SortedMap[K, V]) Delete(key K) {
- m.init()
- m.inner.Delete(key)
+ m.inner.Delete(m.inner.Search(func(kv orderedKV[K, V]) int {
+ return key.Compare(kv.K)
+ }))
}
func (m *SortedMap[K, V]) Load(key K) (value V, ok bool) {
- m.init()
- node := m.inner.Lookup(key)
+ node := m.inner.Search(func(kv orderedKV[K, V]) int {
+ return key.Compare(kv.K)
+ })
if node == nil {
var zero V
return zero, false
@@ -42,41 +34,27 @@ func (m *SortedMap[K, V]) Load(key K) (value V, ok bool) {
return node.Value.V, true
}
-var errStop = errors.New("stop")
-
-func (m *SortedMap[K, V]) Range(f func(key K, value V) bool) {
- m.init()
- _ = m.inner.Walk(func(node *RBNode[OrderedKV[K, V]]) error {
- if f(node.Value.K, node.Value.V) {
- return nil
- } else {
- return errStop
- }
+func (m *SortedMap[K, V]) Store(key K, value V) {
+ m.inner.Insert(orderedKV[K, V]{
+ K: key,
+ V: value,
})
}
-func (m *SortedMap[K, V]) Subrange(rangeFn func(K, V) int, handleFn func(K, V) bool) {
- m.init()
- kvs := m.inner.SearchRange(func(kv OrderedKV[K, V]) int {
- return rangeFn(kv.K, kv.V)
+func (m *SortedMap[K, V]) Range(fn func(key K, value V) bool) {
+ m.inner.Range(func(node *RBNode[orderedKV[K, V]]) bool {
+ return fn(node.Value.K, node.Value.V)
})
- for _, kv := range kvs {
- if !handleFn(kv.K, kv.V) {
- break
- }
- }
}
-func (m *SortedMap[K, V]) Store(key K, value V) {
- m.init()
- m.inner.Insert(OrderedKV[K, V]{
- K: key,
- V: value,
- })
+func (m *SortedMap[K, V]) Subrange(rangeFn func(K, V) int, handleFn func(K, V) bool) {
+ m.inner.Subrange(
+ func(kv orderedKV[K, V]) int { return rangeFn(kv.K, kv.V) },
+ func(node *RBNode[orderedKV[K, V]]) bool { return handleFn(node.Value.K, node.Value.V) })
}
func (m *SortedMap[K, V]) Search(fn func(K, V) int) (K, V, bool) {
- node := m.inner.Search(func(kv OrderedKV[K, V]) int {
+ node := m.inner.Search(func(kv orderedKV[K, V]) int {
return fn(kv.K, kv.V)
})
if node == nil {
@@ -87,12 +65,6 @@ func (m *SortedMap[K, V]) Search(fn func(K, V) int) (K, V, bool) {
return node.Value.K, node.Value.V, true
}
-func (m *SortedMap[K, V]) SearchAll(fn func(K, V) int) []OrderedKV[K, V] {
- return m.inner.SearchRange(func(kv OrderedKV[K, V]) int {
- return fn(kv.K, kv.V)
- })
-}
-
func (m *SortedMap[K, V]) Len() int {
return m.inner.Len()
}