summaryrefslogtreecommitdiff
path: root/lib/containers
diff options
context:
space:
mode:
authorLuke Shumaker <lukeshu@lukeshu.com>2022-07-13 20:36:27 -0600
committerLuke Shumaker <lukeshu@lukeshu.com>2022-07-13 20:36:27 -0600
commit4e29bb393ec774f0a79c70d9d69c54fe4e8ecb72 (patch)
tree3382a06206d00b27a756a9376e11de1126febd1b /lib/containers
parent09cc146211148a3b2568261c41a804a802c31d4c (diff)
Move lib/rbtree to lib/containers
Diffstat (limited to 'lib/containers')
-rw-r--r--lib/containers/rbtree.go522
-rw-r--r--lib/containers/rbtree_test.go171
-rw-r--r--lib/containers/testdata/fuzz/FuzzTree/be408ce7760bc8ced841300ea7e6bac1a1e9505b1535810083d18db95d86f4892
-rw-r--r--lib/containers/testdata/fuzz/FuzzTree/f9e6421dacf921f7bb25d402bffbfdce114baad0b1c8b9a9189b5a97fda27e412
4 files changed, 697 insertions, 0 deletions
diff --git a/lib/containers/rbtree.go b/lib/containers/rbtree.go
new file mode 100644
index 0000000..6bffc02
--- /dev/null
+++ b/lib/containers/rbtree.go
@@ -0,0 +1,522 @@
+// Copyright (C) 2022 Luke Shumaker <lukeshu@lukeshu.com>
+//
+// SPDX-License-Identifier: GPL-2.0-or-later
+
+package containers
+
+import (
+ "fmt"
+ "reflect"
+
+ "git.lukeshu.com/btrfs-progs-ng/lib/util"
+)
+
+type Color bool
+
+const (
+ Black = Color(false)
+ Red = Color(true)
+)
+
+type RBNode[V any] struct {
+ Parent, Left, Right *RBNode[V]
+
+ Color Color
+
+ Value V
+}
+
+func (node *RBNode[V]) getColor() Color {
+ if node == nil {
+ return Black
+ }
+ return node.Color
+}
+
+type RBTree[K util.Ordered[K], V any] struct {
+ KeyFn func(V) K
+ root *RBNode[V]
+}
+
+func (t *RBTree[K, V]) Walk(fn func(*RBNode[V]) error) error {
+ return t.root.walk(fn)
+}
+
+func (node *RBNode[V]) walk(fn func(*RBNode[V]) error) error {
+ if node == nil {
+ return nil
+ }
+ if err := node.Left.walk(fn); err != nil {
+ return err
+ }
+ if err := fn(node); err != nil {
+ return err
+ }
+ if err := node.Right.walk(fn); err != nil {
+ return err
+ }
+ return nil
+}
+
+// Search the tree for a value that satisfied the given callbackk
+// function. A return value of 0 means to to return this value; <0
+// means to go left on the tree (the value is too high), >0 means to
+// go right on th etree (the value is too low).
+//
+// +-----+
+// | v=8 | == 0 : this is it
+// +-----+
+// / \
+// / \
+// <0 : go left >0 : go right
+// / \
+// +---+ +---+
+// | 7 | | 9 |
+// +---+ +---+
+//
+// Returns nil if no such value is found.
+//
+// Search is good for advanced lookup, like when a range of values is
+// acceptable. For simple exact-value lookup, use Lookup.
+func (t *RBTree[K, V]) Search(fn func(V) int) *RBNode[V] {
+ ret, _ := t.root.search(fn)
+ return ret
+}
+
+func (node *RBNode[V]) search(fn func(V) int) (exact, nearest *RBNode[V]) {
+ var prev *RBNode[V]
+ for {
+ if node == nil {
+ return nil, prev
+ }
+ direction := fn(node.Value)
+ prev = node
+ switch {
+ case direction < 0:
+ node = node.Left
+ case direction == 0:
+ return node, nil
+ case direction > 0:
+ node = node.Right
+ }
+ }
+}
+
+func (t *RBTree[K, V]) exactKey(key K) func(V) int {
+ return func(val V) int {
+ valKey := t.KeyFn(val)
+ return key.Cmp(valKey)
+ }
+}
+
+// Lookup looks up the value for an exact key. If no such value
+// exists, nil is returned.
+func (t *RBTree[K, V]) Lookup(key K) *RBNode[V] {
+ return t.Search(t.exactKey(key))
+}
+
+// Min returns the minimum value stored in the tree, or nil if the
+// tree is empty.
+func (t *RBTree[K, V]) Min() *RBNode[V] {
+ return t.root.min()
+}
+
+func (node *RBNode[V]) min() *RBNode[V] {
+ if node == nil {
+ return nil
+ }
+ for {
+ if node.Left == nil {
+ return node
+ }
+ node = node.Left
+ }
+}
+
+// Max returns the maximum value stored in the tree, or nil if the
+// tree is empty.
+func (t *RBTree[K, V]) Max() *RBNode[V] {
+ return t.root.max()
+}
+
+func (node *RBNode[V]) max() *RBNode[V] {
+ if node == nil {
+ return nil
+ }
+ for {
+ if node.Right == nil {
+ return node
+ }
+ node = node.Right
+ }
+}
+
+func (t *RBTree[K, V]) Next(cur *RBNode[V]) *RBNode[V] {
+ return cur.next()
+}
+
+func (cur *RBNode[V]) next() *RBNode[V] {
+ if cur.Right != nil {
+ return cur.Right.min()
+ }
+ child, parent := cur, cur.Parent
+ for parent != nil && child == parent.Right {
+ child, parent = parent, parent.Parent
+ }
+ return parent
+}
+
+func (t *RBTree[K, V]) Prev(cur *RBNode[V]) *RBNode[V] {
+ return cur.prev()
+}
+
+func (cur *RBNode[V]) prev() *RBNode[V] {
+ if cur.Left != nil {
+ return cur.Left.max()
+ }
+ child, parent := cur, cur.Parent
+ for parent != nil && child == parent.Left {
+ child, parent = parent, parent.Parent
+ }
+ return parent
+}
+
+// SearchRange is like Search, but returns all nodes that match the
+// function; assuming that they are contiguous.
+func (t *RBTree[K, V]) SearchRange(fn func(V) int) []V {
+ middle := t.Search(fn)
+ if middle == nil {
+ return nil
+ }
+ ret := []V{middle.Value}
+ for node := t.Prev(middle); node != nil && fn(node.Value) == 0; node = t.Prev(node) {
+ ret = append(ret, node.Value)
+ }
+ util.ReverseSlice(ret)
+ for node := t.Next(middle); node != nil && fn(node.Value) == 0; node = t.Next(node) {
+ ret = append(ret, node.Value)
+ }
+ return ret
+}
+
+func (t *RBTree[K, V]) Equal(u *RBTree[K, V]) bool {
+ if (t == nil) != (u == nil) {
+ return false
+ }
+ if t == nil {
+ return true
+ }
+
+ var tSlice []V
+ _ = t.Walk(func(node *RBNode[V]) error {
+ tSlice = append(tSlice, node.Value)
+ return nil
+ })
+
+ var uSlice []V
+ _ = u.Walk(func(node *RBNode[V]) error {
+ uSlice = append(uSlice, node.Value)
+ return nil
+ })
+
+ return reflect.DeepEqual(tSlice, uSlice)
+}
+
+func (t *RBTree[K, V]) parentChild(node *RBNode[V]) **RBNode[V] {
+ switch {
+ case node.Parent == nil:
+ return &t.root
+ case node.Parent.Left == node:
+ return &node.Parent.Left
+ case node.Parent.Right == node:
+ return &node.Parent.Right
+ default:
+ panic(fmt.Errorf("node %p is not a child of its parent %p", node, node.Parent))
+ }
+}
+
+func (t *RBTree[K, V]) leftRotate(x *RBNode[V]) {
+ // p p
+ // | |
+ // +---+ +---+
+ // | x | | y |
+ // +---+ +---+
+ // / \ => / \
+ // a +---+ +---+ c
+ // | y | | x |
+ // +---+ +---+
+ // / \ / \
+ // b c a b
+
+ // Define 'p', 'x', 'y', and 'b' per the above diagram.
+ p := x.Parent
+ pChild := t.parentChild(x)
+ y := x.Right
+ b := y.Left
+
+ // Move things around
+
+ y.Parent = p
+ *pChild = y
+
+ x.Parent = y
+ y.Left = x
+
+ if b != nil {
+ b.Parent = x
+ }
+ x.Right = b
+}
+
+func (t *RBTree[K, V]) rightRotate(y *RBNode[V]) {
+ // | |
+ // +---+ +---+
+ // | y | | x |
+ // +---+ +---+
+ // / \ => / \
+ // +---+ c a +---+
+ // | x | | y |
+ // +---+ +---+
+ // / \ / \
+ // a b b c
+
+ // Define 'p', 'x', 'y', and 'b' per the above diagram.
+ p := y.Parent
+ pChild := t.parentChild(y)
+ x := y.Left
+ b := x.Right
+
+ // Move things around
+
+ x.Parent = p
+ *pChild = x
+
+ y.Parent = x
+ x.Right = y
+
+ if b != nil {
+ b.Parent = y
+ }
+ y.Left = b
+}
+
+func (t *RBTree[K, V]) Insert(val V) {
+ // Naive-insert
+
+ key := t.KeyFn(val)
+ exact, parent := t.root.search(t.exactKey(key))
+ if exact != nil {
+ exact.Value = val
+ return
+ }
+
+ node := &RBNode[V]{
+ Color: Red,
+ Parent: parent,
+ Value: val,
+ }
+ if parent == nil {
+ t.root = node
+ } else if key.Cmp(t.KeyFn(parent.Value)) < 0 {
+ parent.Left = node
+ } else {
+ parent.Right = node
+ }
+
+ // Re-balance
+ //
+ // This is closely based on the algorithm presented in CLRS
+ // 3e.
+
+ for node.Parent.getColor() == Red {
+ if node.Parent == node.Parent.Parent.Left {
+ uncle := node.Parent.Parent.Right
+ if uncle.getColor() == Red {
+ node.Parent.Color = Black
+ uncle.Color = Black
+ node.Parent.Parent.Color = Red
+ node = node.Parent.Parent
+ } else {
+ if node == node.Parent.Right {
+ node = node.Parent
+ t.leftRotate(node)
+ }
+ node.Parent.Color = Black
+ node.Parent.Parent.Color = Red
+ t.rightRotate(node.Parent.Parent)
+ }
+ } else {
+ uncle := node.Parent.Parent.Left
+ if uncle.getColor() == Red {
+ node.Parent.Color = Black
+ uncle.Color = Black
+ node.Parent.Parent.Color = Red
+ node = node.Parent.Parent
+ } else {
+ if node == node.Parent.Left {
+ node = node.Parent
+ t.rightRotate(node)
+ }
+ node.Parent.Color = Black
+ node.Parent.Parent.Color = Red
+ t.leftRotate(node.Parent.Parent)
+ }
+ }
+ }
+ t.root.Color = Black
+}
+
+func (t *RBTree[K, V]) transplant(old, new *RBNode[V]) {
+ *t.parentChild(old) = new
+ if new != nil {
+ new.Parent = old.Parent
+ }
+}
+
+func (t *RBTree[K, V]) Delete(key K) {
+ nodeToDelete := t.Lookup(key)
+ if nodeToDelete == nil {
+ return
+ }
+
+ // This is closely based on the algorithm presented in CLRS
+ // 3e.
+
+ var nodeToRebalance *RBNode[V]
+ var nodeToRebalanceParent *RBNode[V] // in case 'nodeToRebalance' is nil, which it can be
+ needsRebalance := nodeToDelete.Color == Black
+
+ switch {
+ case nodeToDelete.Left == nil:
+ nodeToRebalance = nodeToDelete.Right
+ nodeToRebalanceParent = nodeToDelete.Parent
+ t.transplant(nodeToDelete, nodeToDelete.Right)
+ case nodeToDelete.Right == nil:
+ nodeToRebalance = nodeToDelete.Left
+ nodeToRebalanceParent = nodeToDelete.Parent
+ t.transplant(nodeToDelete, nodeToDelete.Left)
+ default:
+ // The node being deleted has a child on both sides,
+ // so we've go to reshuffle the parents a bit to make
+ // room for those children.
+ next := nodeToDelete.next()
+ if next.Parent == nodeToDelete {
+ // p p
+ // | |
+ // +-----+ +-----+
+ // | ntd | | nxt |
+ // +-----+ +-----+
+ // / \ => / \
+ // a +-----+ a b
+ // | nxt |
+ // +-----+
+ // / \
+ // nil b
+ nodeToRebalance = next.Right
+ nodeToRebalanceParent = next
+
+ *t.parentChild(nodeToDelete) = next
+ next.Parent = nodeToDelete.Parent
+
+ next.Left = nodeToDelete.Left
+ next.Left.Parent = next
+ } else {
+ // p p
+ // | |
+ // +-----+ +-----+
+ // | ntd | | nxt |
+ // +-----+ +-----+
+ // / \ / \
+ // a x a x
+ // / \ => / \
+ // y z y z
+ // / \ / \
+ // +-----+ c b c
+ // | nxt |
+ // +-----+
+ // / \
+ // nil b
+ y := next.Parent
+ b := next.Right
+ nodeToRebalance = b
+ nodeToRebalanceParent = y
+
+ *t.parentChild(nodeToDelete) = next
+ next.Parent = nodeToDelete.Parent
+
+ next.Left = nodeToDelete.Left
+ next.Left.Parent = next
+
+ next.Right = nodeToDelete.Right
+ next.Right.Parent = next
+
+ y.Left = b
+ if b != nil {
+ b.Parent = y
+ }
+ }
+
+ // idk
+ needsRebalance = next.Color == Black
+ next.Color = nodeToDelete.Color
+ }
+
+ if needsRebalance {
+ node := nodeToRebalance
+ nodeParent := nodeToRebalanceParent
+ for node != t.root && node.getColor() == Black {
+ if node == nodeParent.Left {
+ sibling := nodeParent.Right
+ if sibling.getColor() == Red {
+ sibling.Color = Black
+ nodeParent.Color = Red
+ t.leftRotate(nodeParent)
+ sibling = nodeParent.Right
+ }
+ if sibling.Left.getColor() == Black && sibling.Right.getColor() == Black {
+ sibling.Color = Red
+ node, nodeParent = nodeParent, nodeParent.Parent
+ } else {
+ if sibling.Right.getColor() == Black {
+ sibling.Left.Color = Black
+ sibling.Color = Red
+ t.rightRotate(sibling)
+ sibling = nodeParent.Right
+ }
+ sibling.Color = nodeParent.Color
+ nodeParent.Color = Black
+ sibling.Right.Color = Black
+ t.leftRotate(nodeParent)
+ node, nodeParent = t.root, nil
+ }
+ } else {
+ sibling := nodeParent.Left
+ if sibling.getColor() == Red {
+ sibling.Color = Black
+ nodeParent.Color = Red
+ t.rightRotate(nodeParent)
+ sibling = nodeParent.Left
+ }
+ if sibling.Right.getColor() == Black && sibling.Left.getColor() == Black {
+ sibling.Color = Red
+ node, nodeParent = nodeParent, nodeParent.Parent
+ } else {
+ if sibling.Left.getColor() == Black {
+ sibling.Right.Color = Black
+ sibling.Color = Red
+ t.leftRotate(sibling)
+ sibling = nodeParent.Left
+ }
+ sibling.Color = nodeParent.Color
+ nodeParent.Color = Black
+ sibling.Left.Color = Black
+ t.rightRotate(nodeParent)
+ node, nodeParent = t.root, nil
+ }
+ }
+ }
+ if node != nil {
+ node.Color = Black
+ }
+ }
+}
diff --git a/lib/containers/rbtree_test.go b/lib/containers/rbtree_test.go
new file mode 100644
index 0000000..3360bc0
--- /dev/null
+++ b/lib/containers/rbtree_test.go
@@ -0,0 +1,171 @@
+// Copyright (C) 2022 Luke Shumaker <lukeshu@lukeshu.com>
+//
+// SPDX-License-Identifier: GPL-2.0-or-later
+
+package containers
+
+import (
+ "fmt"
+ "io"
+ "sort"
+ "strings"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+ "golang.org/x/exp/constraints"
+
+ "git.lukeshu.com/btrfs-progs-ng/lib/util"
+)
+
+func (t *RBTree[K, V]) ASCIIArt() string {
+ var out strings.Builder
+ t.root.asciiArt(&out, "", "", "")
+ return out.String()
+}
+
+func (node *RBNode[V]) String() string {
+ switch {
+ case node == nil:
+ return "nil"
+ case node.Color == Red:
+ return fmt.Sprintf("R(%v)", node.Value)
+ default:
+ return fmt.Sprintf("B(%v)", node.Value)
+ }
+}
+
+func (node *RBNode[V]) asciiArt(w io.Writer, u, m, l string) {
+ if node == nil {
+ fmt.Fprintf(w, "%snil\n", m)
+ return
+ }
+
+ node.Right.asciiArt(w, u+" ", u+" ,--", u+" | ")
+
+ if node.Color == Red {
+ fmt.Fprintf(w, "%s%v\n", m, node)
+ } else {
+ fmt.Fprintf(w, "%s%v\n", m, node)
+ }
+
+ node.Left.asciiArt(w, l+" | ", l+" `--", l+" ")
+}
+
+func checkRBTree[K constraints.Ordered, V any](t *testing.T, expectedSet map[K]struct{}, tree *RBTree[util.NativeOrdered[K], V]) {
+ // 1. Every node is either red or black
+
+ // 2. The root is black.
+ require.Equal(t, Black, tree.root.getColor())
+
+ // 3. Every nil is black.
+
+ // 4. If a node is red, then both its children are black.
+ require.NoError(t, tree.Walk(func(node *RBNode[V]) error {
+ if node.getColor() == Red {
+ require.Equal(t, Black, node.Left.getColor())
+ require.Equal(t, Black, node.Right.getColor())
+ }
+ return nil
+ }))
+
+ // 5. For each node, all simple paths from the node to
+ // descendent leaves contain the same number of black
+ // nodes.
+ var walkCnt func(node *RBNode[V], cnt int, leafFn func(int))
+ walkCnt = func(node *RBNode[V], cnt int, leafFn func(int)) {
+ if node.getColor() == Black {
+ cnt++
+ }
+ if node == nil {
+ leafFn(cnt)
+ return
+ }
+ walkCnt(node.Left, cnt, leafFn)
+ walkCnt(node.Right, cnt, leafFn)
+ }
+ require.NoError(t, tree.Walk(func(node *RBNode[V]) error {
+ var cnts []int
+ walkCnt(node, 0, func(cnt int) {
+ cnts = append(cnts, cnt)
+ })
+ for i := range cnts {
+ if cnts[0] != cnts[i] {
+ require.Truef(t, false, "node %v: not all leafs have same black-count: %v", node.Value, cnts)
+ break
+ }
+ }
+ return nil
+ }))
+
+ // expected contents
+ expectedOrder := make([]K, 0, len(expectedSet))
+ for k := range expectedSet {
+ expectedOrder = append(expectedOrder, k)
+ node := tree.Lookup(util.NativeOrdered[K]{Val: k})
+ require.NotNil(t, tree, node)
+ require.Equal(t, k, tree.KeyFn(node.Value).Val)
+ }
+ sort.Slice(expectedOrder, func(i, j int) bool {
+ return expectedOrder[i] < expectedOrder[j]
+ })
+ actOrder := make([]K, 0, len(expectedSet))
+ require.NoError(t, tree.Walk(func(node *RBNode[V]) error {
+ actOrder = append(actOrder, tree.KeyFn(node.Value).Val)
+ return nil
+ }))
+ require.Equal(t, expectedOrder, actOrder)
+}
+
+func FuzzRBTree(f *testing.F) {
+ Ins := uint8(0b0100_0000)
+ Del := uint8(0)
+
+ f.Add([]uint8{})
+ f.Add([]uint8{Ins | 5, Del | 5})
+ f.Add([]uint8{Ins | 5, Del | 6})
+ f.Add([]uint8{Del | 6})
+
+ f.Add([]uint8{ // CLRS Figure 14.4
+ Ins | 1,
+ Ins | 2,
+ Ins | 5,
+ Ins | 7,
+ Ins | 8,
+ Ins | 11,
+ Ins | 14,
+ Ins | 15,
+
+ Ins | 4,
+ })
+
+ f.Fuzz(func(t *testing.T, dat []uint8) {
+ tree := &RBTree[util.NativeOrdered[uint8], uint8]{
+ KeyFn: func(x uint8) util.NativeOrdered[uint8] {
+ return util.NativeOrdered[uint8]{Val: x}
+ },
+ }
+ set := make(map[uint8]struct{})
+ checkRBTree(t, set, tree)
+ t.Logf("\n%s\n", tree.ASCIIArt())
+ for _, b := range dat {
+ ins := (b & 0b0100_0000) != 0
+ val := (b & 0b0011_1111)
+ if ins {
+ t.Logf("Insert(%v)", val)
+ tree.Insert(val)
+ set[val] = struct{}{}
+ t.Logf("\n%s\n", tree.ASCIIArt())
+ node := tree.Lookup(util.NativeOrdered[uint8]{Val: val})
+ require.NotNil(t, node)
+ require.Equal(t, val, node.Value)
+ } else {
+ t.Logf("Delete(%v)", val)
+ tree.Delete(util.NativeOrdered[uint8]{Val: val})
+ delete(set, val)
+ t.Logf("\n%s\n", tree.ASCIIArt())
+ require.Nil(t, tree.Lookup(util.NativeOrdered[uint8]{Val: val}))
+ }
+ checkRBTree(t, set, tree)
+ }
+ })
+}
diff --git a/lib/containers/testdata/fuzz/FuzzTree/be408ce7760bc8ced841300ea7e6bac1a1e9505b1535810083d18db95d86f489 b/lib/containers/testdata/fuzz/FuzzTree/be408ce7760bc8ced841300ea7e6bac1a1e9505b1535810083d18db95d86f489
new file mode 100644
index 0000000..40318c6
--- /dev/null
+++ b/lib/containers/testdata/fuzz/FuzzTree/be408ce7760bc8ced841300ea7e6bac1a1e9505b1535810083d18db95d86f489
@@ -0,0 +1,2 @@
+go test fuzz v1
+[]byte("aAm0BCrb0x00!0000000000000")
diff --git a/lib/containers/testdata/fuzz/FuzzTree/f9e6421dacf921f7bb25d402bffbfdce114baad0b1c8b9a9189b5a97fda27e41 b/lib/containers/testdata/fuzz/FuzzTree/f9e6421dacf921f7bb25d402bffbfdce114baad0b1c8b9a9189b5a97fda27e41
new file mode 100644
index 0000000..238e44f
--- /dev/null
+++ b/lib/containers/testdata/fuzz/FuzzTree/f9e6421dacf921f7bb25d402bffbfdce114baad0b1c8b9a9189b5a97fda27e41
@@ -0,0 +1,2 @@
+go test fuzz v1
+[]byte("YZAB\x990")