summaryrefslogtreecommitdiff
path: root/syncutil/maponce.go
diff options
context:
space:
mode:
authorLuke Shumaker <lukeshu@lukeshu.com>2023-02-05 14:25:59 -0700
committerLuke Shumaker <lukeshu@lukeshu.com>2023-02-05 14:25:59 -0700
commit1328996a17776d74b0f1604a428826b6a761dbe4 (patch)
tree9b1fa93c73cf8cbcd8312567d99ae1d4e54915f0 /syncutil/maponce.go
parent2c99c1e26340941ef6a266354bce058b66526cce (diff)
syncutil: MapOnce: Add a TryLoadOrDo methodv0.2.0
Diffstat (limited to 'syncutil/maponce.go')
-rw-r--r--syncutil/maponce.go46
1 files changed, 39 insertions, 7 deletions
diff --git a/syncutil/maponce.go b/syncutil/maponce.go
index 150c6ea..875988f 100644
--- a/syncutil/maponce.go
+++ b/syncutil/maponce.go
@@ -5,6 +5,8 @@
package syncutil
import (
+ "context"
+
"git.lukeshu.com/go/containers/typedsync"
)
@@ -49,8 +51,8 @@ type Map[K mapkey, V any] interface {
// A MapOnceVal is a values that MapOnce stores in to its underlying
// Map.
type MapOnceVal[V any] struct {
- V V
- wg typedsync.WaitGroup
+ V V
+ c chan struct{}
}
// Delete removes the value for a key. If the value for that key is
@@ -71,17 +73,47 @@ func (m *MapOnce[K, V, M]) Delete(key K) {
func (m *MapOnce[K, V, M]) LoadOrDo(key K, do func(K) V) (actual V, loaded bool) {
_value, _ := m.pool.Get()
if _value == nil {
- _value = new(MapOnceVal[V])
+ _value = &MapOnceVal[V]{
+ c: make(chan struct{}),
+ }
}
- _value.wg.Add(1)
_actual, loaded := m.Inner.LoadOrStore(key, _value)
if loaded {
- *_value = MapOnceVal[V]{}
m.pool.Put(_value)
- _actual.wg.Wait()
+ <-_actual.c
} else {
_actual.V = do(key)
- _actual.wg.Done()
+ close(_actual.c)
}
return _actual.V, loaded
}
+
+// TryLoadOrDo is like LoadOrDo, but obeys context cancellation. If a
+// call is cancelled, the call to "do" continues running in a separate
+// goroutine, in case other LoadOrDo calls are waiting on it. If a
+// call is cancelled, the error from ctx.Err() is returned, otherwise
+// err is nil.
+func (m *MapOnce[K, V, M]) TryLoadOrDo(ctx context.Context, key K, do func(K) V) (actual V, loaded bool, err error) {
+ _value, _ := m.pool.Get()
+ if _value == nil {
+ _value = &MapOnceVal[V]{
+ c: make(chan struct{}),
+ }
+ }
+ _actual, loaded := m.Inner.LoadOrStore(key, _value)
+ if loaded {
+ m.pool.Put(_value)
+ } else {
+ go func() {
+ _actual.V = do(key)
+ close(_actual.c)
+ }()
+ }
+ select {
+ case <-ctx.Done():
+ var zero V
+ return zero, false, ctx.Err() //nolint:wrapcheck // We're too low level for that to be useful.
+ case <-_actual.c:
+ return _actual.V, loaded, nil
+ }
+}