1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
|
package rbtree
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/exp/constraints"
)
func checkTree[K constraints.Ordered, V any](t *testing.T, tree *Tree[K, V]) {
// 1. Every node is either red or black
// 2. The root is black.
assert.Equal(t, Black, tree.root.getColor())
// 3. Every nil is black.
// 4. If a node is red, then both its children are black.
tree.Walk(func(node *Node[V]) {
if node.getColor() == Red {
assert.Equal(t, Black, node.Left.getColor())
assert.Equal(t, Black, node.Right.getColor())
}
})
// 5. For each node, all simple paths from the node to
// descendent leaves contain the same number of black
// nodes.
var walkCnt func(node *Node[V], cnt int, leafFn func(int))
walkCnt = func(node *Node[V], cnt int, leafFn func(int)) {
if node.getColor() == Black {
cnt++
}
if node == nil {
leafFn(cnt)
return
}
walkCnt(node.Left, cnt, leafFn)
walkCnt(node.Right, cnt, leafFn)
}
tree.Walk(func(node *Node[V]) {
var cnts []int
walkCnt(node, 0, func(cnt int) {
cnts = append(cnts, cnt)
})
for i := range cnts {
if cnts[0] != cnts[i] {
assert.Truef(t, false, "node %v: not all leafs have same black-count: %v", node.Value, cnts)
break
}
}
})
}
func FuzzTree(f *testing.F) {
Ins := uint8(0b0100_0000)
Del := uint8(0)
f.Add([]uint8{})
f.Add([]uint8{Ins | 5, Del | 5})
f.Add([]uint8{Ins | 5, Del | 6})
f.Add([]uint8{Del | 6})
f.Add([]uint8{ // CLRS Figure 14.4
Ins | 1,
Ins | 2,
Ins | 5,
Ins | 7,
Ins | 8,
Ins | 11,
Ins | 14,
Ins | 15,
Ins | 4,
})
f.Fuzz(func(t *testing.T, dat []uint8) {
tree := &Tree[uint8, uint8]{
KeyFn: func(x uint8) uint8 { return x },
}
checkTree(t, tree)
for _, b := range dat {
ins := (b & 0b0100_0000) != 0
val := (b & 0b0011_1111)
if ins {
t.Logf("Insert(%v)", val)
tree.Insert(val)
node := tree.Lookup(val)
require.NotNil(t, node)
assert.Equal(t, val, node.Value)
} else {
t.Logf("Delete(%v)", val)
if val == 25 {
t.Logf("before:\n\n%s\n", tree.ASCIIArt())
}
tree.Delete(val)
if val == 25 {
t.Logf("after:\n\n%s\n", tree.ASCIIArt())
}
assert.Nil(t, tree.Lookup(val))
}
checkTree(t, tree)
}
})
}
|