summaryrefslogtreecommitdiff
path: root/pkg/rbtree
diff options
context:
space:
mode:
authorLuke Shumaker <lukeshu@lukeshu.com>2022-06-30 01:25:00 -0600
committerLuke Shumaker <lukeshu@lukeshu.com>2022-06-30 01:25:00 -0600
commit0f196fa9c2908a7ea9b2114a9119707df8880328 (patch)
tree3f08c31d935a03c71429b6d73527f1da3149b53a /pkg/rbtree
parent2e229edfc6c28b3947d4175a6126167203c0f644 (diff)
fix1
Diffstat (limited to 'pkg/rbtree')
-rw-r--r--pkg/rbtree/rbtree.go143
-rw-r--r--pkg/rbtree/rbtree_test.go9
2 files changed, 102 insertions, 50 deletions
diff --git a/pkg/rbtree/rbtree.go b/pkg/rbtree/rbtree.go
index d9249c8..34c73d4 100644
--- a/pkg/rbtree/rbtree.go
+++ b/pkg/rbtree/rbtree.go
@@ -317,100 +317,155 @@ func (t *Tree[K, V]) Insert(val V) {
t.root.Color = Black
}
+func (t *Tree[K, V]) transplant(old, new *Node[V]) {
+ *t.parentChild(old) = new
+ if new != nil {
+ new.Parent = old.Parent
+ }
+}
+
func (t *Tree[K, V]) Delete(key K) {
nodeToDelete := t.Lookup(key)
if nodeToDelete == nil {
return
}
- needsFixup := nodeToDelete.getColor() == Black
+ // A pointer to node that now resides at the place in the tree
+ // where the deleted node was.
+ nodeToRebalance := t.parentChild(nodeToDelete)
+ nodeToRebalanceParent := nodeToDelete.Parent
+ needsRebalance := nodeToDelete.Color == Black
- var nodeToFixup *Node[V]
switch {
case nodeToDelete.Left == nil:
- nodeToFixup = nodeToDelete.Right
- if nodeToDelete.Right != nil {
- nodeToDelete.Right.Parent = nodeToDelete.Parent
- }
- *t.parentChild(nodeToDelete) = nodeToDelete.Right
+ t.transplant(nodeToDelete, nodeToDelete.Right)
case nodeToDelete.Right == nil:
- nodeToFixup = nodeToDelete.Left
- if nodeToDelete.Left != nil {
- nodeToDelete.Left.Parent = nodeToDelete.Parent
- }
- *t.parentChild(nodeToDelete) = nodeToDelete.Left
+ t.transplant(nodeToDelete, nodeToDelete.Left)
default:
// The node being deleted has a child on both sides,
// so we've go to reshuffle the parents a bit to make
// room for those children.
+ //
next := nodeToDelete.next()
- needsFixup = next.getColor() == Black
+ // If nodeToDelete.next().Parent == nodeToDelete, then
+ // this is pretty easy...
if next.Parent != nodeToDelete {
- if next.Right != nil {
- next.Right.Parent = next.Parent
+ // but if it's not, then we have another step
+ // ('mid' might actually be a chain of nodes)
+ // to get it there
+ //
+ // p p
+ // | |
+ // +-----+ +-----+
+ // | ntd | | ntd |
+ // +-----+ +-----+
+ // / \ / \
+ // a x a +-----+
+ // / \ => | nxt |
+ // y z +-----+
+ // / \ / \
+ // +-----+ c nil x
+ // | nxt | / \
+ // +-----+ y z
+ // / \ / \
+ // nil b b c
+ //
+ // This looks an aweful lot like a
+ // t.rightRotate(next.Parent), but spans
+ // multiple nodes
+ x := nodeToDelete.Right
+ y := next.Parent
+ b := next.Right
+
+ next.Parent = nodeToDelete
+ nodeToDelete.Right = next
+
+ x.Parent = next
+ next.Right = x
+
+ if b != nil {
+ b.Parent = y
}
- *t.parentChild(next) = next.Right
-
- next.Right = nodeToDelete.Right
- next.Right.Parent = next
+ y.Left = b
}
- next.Parent = nodeToDelete.Parent
+ // ... OK, back to the easy case:
+ //
+ // p p
+ // | |
+ // +-----+ +-----+
+ // | ntd | | nxt |
+ // +-----+ +-----+
+ // / \ => / \
+ // a +-----+ a b
+ // | nxt |
+ // +-----+
+ // / \
+ // nil b
+ //
*t.parentChild(nodeToDelete) = next
+ next.Parent = nodeToDelete.Parent
+
next.Left = nodeToDelete.Left
next.Left.Parent = next
+
+ // idk
+ nodeToRebalance = &next.Right
+ nodeToRebalanceParent = next
+ needsRebalance = next.Color == Black
next.Color = nodeToDelete.Color
}
- if needsFixup {
- node := nodeToFixup
- for node != nil && node != t.root && node.getColor() == Black {
- if node == node.Parent.Left {
- sibling := node.Parent.Right
+ if needsRebalance {
+ node := *nodeToRebalance
+ nodeParent := nodeToRebalanceParent // in case 'node' is nil, which it can be
+ for node != t.root && node.getColor() == Black {
+ if node == nodeParent.Left {
+ sibling := nodeParent.Right
if sibling.getColor() == Red {
sibling.Color = Black
- node.Parent.Color = Red
- t.leftRotate(node.Parent)
- sibling = node.Parent.Right
+ nodeParent.Color = Red
+ t.leftRotate(nodeParent)
+ sibling = nodeParent.Right
}
if sibling.Left.getColor() == Black && sibling.Right.getColor() == Black {
sibling.Color = Red
- node = node.Parent
+ node, nodeParent = nodeParent, nodeParent.Parent
} else {
if sibling.Right.getColor() == Black {
sibling.Left.Color = Black
sibling.Color = Red
t.rightRotate(sibling)
- sibling = node.Parent.Right
+ sibling = nodeParent.Right
}
- sibling.Color = node.Parent.Color
- node.Parent.Color = Black
+ sibling.Color = nodeParent.Color
+ nodeParent.Color = Black
sibling.Right.Color = Black
- t.leftRotate(node.Parent)
- node = t.root
+ t.leftRotate(nodeParent)
+ node, nodeParent = t.root, nil
}
} else {
- sibling := node.Parent.Left
+ sibling := nodeParent.Left
if sibling.getColor() == Red {
sibling.Color = Black
- node.Parent.Color = Red
- t.rightRotate(node.Parent)
- sibling = node.Parent.Left
+ nodeParent.Color = Red
+ t.rightRotate(nodeParent)
+ sibling = nodeParent.Left
}
if sibling.Right.getColor() == Black && sibling.Left.getColor() == Black {
sibling.Color = Red
- node = node.Parent
+ node, nodeParent = nodeParent, nodeParent.Parent
} else {
if sibling.Left.getColor() == Black {
sibling.Right.Color = Black
sibling.Color = Red
t.leftRotate(sibling)
- sibling = node.Parent.Left
+ sibling = nodeParent.Left
}
- sibling.Color = node.Parent.Color
- node.Parent.Color = Black
+ sibling.Color = nodeParent.Color
+ nodeParent.Color = Black
sibling.Left.Color = Black
- t.rightRotate(node.Parent)
- node = t.root
+ t.rightRotate(nodeParent)
+ node, nodeParent = t.root, nil
}
}
}
diff --git a/pkg/rbtree/rbtree_test.go b/pkg/rbtree/rbtree_test.go
index 7b3987f..4f4a12c 100644
--- a/pkg/rbtree/rbtree_test.go
+++ b/pkg/rbtree/rbtree_test.go
@@ -80,24 +80,21 @@ func FuzzTree(f *testing.F) {
KeyFn: func(x uint8) uint8 { return x },
}
checkTree(t, tree)
+ t.Logf("\n%s\n", tree.ASCIIArt())
for _, b := range dat {
ins := (b & 0b0100_0000) != 0
val := (b & 0b0011_1111)
if ins {
t.Logf("Insert(%v)", val)
tree.Insert(val)
+ t.Logf("\n%s\n", tree.ASCIIArt())
node := tree.Lookup(val)
require.NotNil(t, node)
assert.Equal(t, val, node.Value)
} else {
t.Logf("Delete(%v)", val)
- if val == 25 {
- t.Logf("before:\n\n%s\n", tree.ASCIIArt())
- }
tree.Delete(val)
- if val == 25 {
- t.Logf("after:\n\n%s\n", tree.ASCIIArt())
- }
+ t.Logf("\n%s\n", tree.ASCIIArt())
assert.Nil(t, tree.Lookup(val))
}
checkTree(t, tree)