From 7fba10e5be51a3fe565a6f69a946ece9f0e59a67 Mon Sep 17 00:00:00 2001 From: Luke Shumaker Date: Mon, 5 Sep 2022 12:43:46 -0600 Subject: Try to uniformly use containers.Set --- lib/containers/rbtree_test.go | 6 +++--- lib/containers/set.go | 19 ++++++++++++++++++- 2 files changed, 21 insertions(+), 4 deletions(-) (limited to 'lib/containers') diff --git a/lib/containers/rbtree_test.go b/lib/containers/rbtree_test.go index 65cc78d..9841d26 100644 --- a/lib/containers/rbtree_test.go +++ b/lib/containers/rbtree_test.go @@ -49,7 +49,7 @@ func (node *RBNode[V]) asciiArt(w io.Writer, u, m, l string) { node.Left.asciiArt(w, l+" | ", l+" `--", l+" ") } -func checkRBTree[K constraints.Ordered, V any](t *testing.T, expectedSet map[K]struct{}, tree *RBTree[NativeOrdered[K], V]) { +func checkRBTree[K constraints.Ordered, V any](t *testing.T, expectedSet Set[K], tree *RBTree[NativeOrdered[K], V]) { // 1. Every node is either red or black // 2. The root is black. @@ -142,7 +142,7 @@ func FuzzRBTree(f *testing.F) { return NativeOrdered[uint8]{Val: x} }, } - set := make(map[uint8]struct{}) + set := make(Set[uint8]) checkRBTree(t, set, tree) t.Logf("\n%s\n", tree.ASCIIArt()) for _, b := range dat { @@ -151,7 +151,7 @@ func FuzzRBTree(f *testing.F) { if ins { t.Logf("Insert(%v)", val) tree.Insert(val) - set[val] = struct{}{} + set.Insert(val) t.Logf("\n%s\n", tree.ASCIIArt()) node := tree.Lookup(NativeOrdered[uint8]{Val: val}) require.NotNil(t, node) diff --git a/lib/containers/set.go b/lib/containers/set.go index 42e5ad2..1c525ca 100644 --- a/lib/containers/set.go +++ b/lib/containers/set.go @@ -13,6 +13,10 @@ import ( "git.lukeshu.com/btrfs-progs-ng/lib/maps" ) +// Set[T] is an unordered set of T. +// +// Despite Set[T] being unordered, T is required to be an ordered type +// in order that a Set[T] have a deterministic JSON representation. type Set[T constraints.Ordered] map[T]struct{} var ( @@ -45,6 +49,14 @@ func (o *Set[T]) DecodeJSON(r io.RuneScanner) error { }) } +func NewSet[T constraints.Ordered](values ...T) Set[T] { + ret := make(Set[T], len(values)) + for _, value := range values { + ret.Insert(value) + } + return ret +} + func (o Set[T]) Insert(v T) { o[v] = struct{}{} } @@ -79,7 +91,12 @@ func (o Set[T]) TakeOne() T { return zero } -func (small Set[T]) HasIntersection(big Set[T]) bool { +func (o Set[T]) Has(v T) bool { + _, has := o[v] + return has +} + +func (small Set[T]) HasAny(big Set[T]) bool { if len(big) < len(small) { small, big = big, small } -- cgit v1.2.3-54-g00ecf