diff options
author | Luke Shumaker <lukeshu@lukeshu.com> | 2022-06-30 01:25:00 -0600 |
---|---|---|
committer | Luke Shumaker <lukeshu@lukeshu.com> | 2022-06-30 01:25:00 -0600 |
commit | 0f196fa9c2908a7ea9b2114a9119707df8880328 (patch) | |
tree | 3f08c31d935a03c71429b6d73527f1da3149b53a /pkg | |
parent | 2e229edfc6c28b3947d4175a6126167203c0f644 (diff) |
fix1
Diffstat (limited to 'pkg')
-rw-r--r-- | pkg/rbtree/rbtree.go | 143 | ||||
-rw-r--r-- | pkg/rbtree/rbtree_test.go | 9 |
2 files changed, 102 insertions, 50 deletions
diff --git a/pkg/rbtree/rbtree.go b/pkg/rbtree/rbtree.go index d9249c8..34c73d4 100644 --- a/pkg/rbtree/rbtree.go +++ b/pkg/rbtree/rbtree.go @@ -317,100 +317,155 @@ func (t *Tree[K, V]) Insert(val V) { t.root.Color = Black } +func (t *Tree[K, V]) transplant(old, new *Node[V]) { + *t.parentChild(old) = new + if new != nil { + new.Parent = old.Parent + } +} + func (t *Tree[K, V]) Delete(key K) { nodeToDelete := t.Lookup(key) if nodeToDelete == nil { return } - needsFixup := nodeToDelete.getColor() == Black + // A pointer to node that now resides at the place in the tree + // where the deleted node was. + nodeToRebalance := t.parentChild(nodeToDelete) + nodeToRebalanceParent := nodeToDelete.Parent + needsRebalance := nodeToDelete.Color == Black - var nodeToFixup *Node[V] switch { case nodeToDelete.Left == nil: - nodeToFixup = nodeToDelete.Right - if nodeToDelete.Right != nil { - nodeToDelete.Right.Parent = nodeToDelete.Parent - } - *t.parentChild(nodeToDelete) = nodeToDelete.Right + t.transplant(nodeToDelete, nodeToDelete.Right) case nodeToDelete.Right == nil: - nodeToFixup = nodeToDelete.Left - if nodeToDelete.Left != nil { - nodeToDelete.Left.Parent = nodeToDelete.Parent - } - *t.parentChild(nodeToDelete) = nodeToDelete.Left + t.transplant(nodeToDelete, nodeToDelete.Left) default: // The node being deleted has a child on both sides, // so we've go to reshuffle the parents a bit to make // room for those children. + // next := nodeToDelete.next() - needsFixup = next.getColor() == Black + // If nodeToDelete.next().Parent == nodeToDelete, then + // this is pretty easy... if next.Parent != nodeToDelete { - if next.Right != nil { - next.Right.Parent = next.Parent + // but if it's not, then we have another step + // ('mid' might actually be a chain of nodes) + // to get it there + // + // p p + // | | + // +-----+ +-----+ + // | ntd | | ntd | + // +-----+ +-----+ + // / \ / \ + // a x a +-----+ + // / \ => | nxt | + // y z +-----+ + // / \ / \ + // +-----+ c nil x + // | nxt | / \ + // +-----+ y z + // / \ / \ + // nil b b c + // + // This looks an aweful lot like a + // t.rightRotate(next.Parent), but spans + // multiple nodes + x := nodeToDelete.Right + y := next.Parent + b := next.Right + + next.Parent = nodeToDelete + nodeToDelete.Right = next + + x.Parent = next + next.Right = x + + if b != nil { + b.Parent = y } - *t.parentChild(next) = next.Right - - next.Right = nodeToDelete.Right - next.Right.Parent = next + y.Left = b } - next.Parent = nodeToDelete.Parent + // ... OK, back to the easy case: + // + // p p + // | | + // +-----+ +-----+ + // | ntd | | nxt | + // +-----+ +-----+ + // / \ => / \ + // a +-----+ a b + // | nxt | + // +-----+ + // / \ + // nil b + // *t.parentChild(nodeToDelete) = next + next.Parent = nodeToDelete.Parent + next.Left = nodeToDelete.Left next.Left.Parent = next + + // idk + nodeToRebalance = &next.Right + nodeToRebalanceParent = next + needsRebalance = next.Color == Black next.Color = nodeToDelete.Color } - if needsFixup { - node := nodeToFixup - for node != nil && node != t.root && node.getColor() == Black { - if node == node.Parent.Left { - sibling := node.Parent.Right + if needsRebalance { + node := *nodeToRebalance + nodeParent := nodeToRebalanceParent // in case 'node' is nil, which it can be + for node != t.root && node.getColor() == Black { + if node == nodeParent.Left { + sibling := nodeParent.Right if sibling.getColor() == Red { sibling.Color = Black - node.Parent.Color = Red - t.leftRotate(node.Parent) - sibling = node.Parent.Right + nodeParent.Color = Red + t.leftRotate(nodeParent) + sibling = nodeParent.Right } if sibling.Left.getColor() == Black && sibling.Right.getColor() == Black { sibling.Color = Red - node = node.Parent + node, nodeParent = nodeParent, nodeParent.Parent } else { if sibling.Right.getColor() == Black { sibling.Left.Color = Black sibling.Color = Red t.rightRotate(sibling) - sibling = node.Parent.Right + sibling = nodeParent.Right } - sibling.Color = node.Parent.Color - node.Parent.Color = Black + sibling.Color = nodeParent.Color + nodeParent.Color = Black sibling.Right.Color = Black - t.leftRotate(node.Parent) - node = t.root + t.leftRotate(nodeParent) + node, nodeParent = t.root, nil } } else { - sibling := node.Parent.Left + sibling := nodeParent.Left if sibling.getColor() == Red { sibling.Color = Black - node.Parent.Color = Red - t.rightRotate(node.Parent) - sibling = node.Parent.Left + nodeParent.Color = Red + t.rightRotate(nodeParent) + sibling = nodeParent.Left } if sibling.Right.getColor() == Black && sibling.Left.getColor() == Black { sibling.Color = Red - node = node.Parent + node, nodeParent = nodeParent, nodeParent.Parent } else { if sibling.Left.getColor() == Black { sibling.Right.Color = Black sibling.Color = Red t.leftRotate(sibling) - sibling = node.Parent.Left + sibling = nodeParent.Left } - sibling.Color = node.Parent.Color - node.Parent.Color = Black + sibling.Color = nodeParent.Color + nodeParent.Color = Black sibling.Left.Color = Black - t.rightRotate(node.Parent) - node = t.root + t.rightRotate(nodeParent) + node, nodeParent = t.root, nil } } } diff --git a/pkg/rbtree/rbtree_test.go b/pkg/rbtree/rbtree_test.go index 7b3987f..4f4a12c 100644 --- a/pkg/rbtree/rbtree_test.go +++ b/pkg/rbtree/rbtree_test.go @@ -80,24 +80,21 @@ func FuzzTree(f *testing.F) { KeyFn: func(x uint8) uint8 { return x }, } checkTree(t, tree) + t.Logf("\n%s\n", tree.ASCIIArt()) for _, b := range dat { ins := (b & 0b0100_0000) != 0 val := (b & 0b0011_1111) if ins { t.Logf("Insert(%v)", val) tree.Insert(val) + t.Logf("\n%s\n", tree.ASCIIArt()) node := tree.Lookup(val) require.NotNil(t, node) assert.Equal(t, val, node.Value) } else { t.Logf("Delete(%v)", val) - if val == 25 { - t.Logf("before:\n\n%s\n", tree.ASCIIArt()) - } tree.Delete(val) - if val == 25 { - t.Logf("after:\n\n%s\n", tree.ASCIIArt()) - } + t.Logf("\n%s\n", tree.ASCIIArt()) assert.Nil(t, tree.Lookup(val)) } checkTree(t, tree) |