From 1328996a17776d74b0f1604a428826b6a761dbe4 Mon Sep 17 00:00:00 2001 From: Luke Shumaker Date: Sun, 5 Feb 2023 14:25:59 -0700 Subject: syncutil: MapOnce: Add a TryLoadOrDo method --- syncutil/maponce.go | 46 +++++++++++++++++++++++++++++++++++++++------- 1 file changed, 39 insertions(+), 7 deletions(-) diff --git a/syncutil/maponce.go b/syncutil/maponce.go index 150c6ea..875988f 100644 --- a/syncutil/maponce.go +++ b/syncutil/maponce.go @@ -5,6 +5,8 @@ package syncutil import ( + "context" + "git.lukeshu.com/go/containers/typedsync" ) @@ -49,8 +51,8 @@ type Map[K mapkey, V any] interface { // A MapOnceVal is a values that MapOnce stores in to its underlying // Map. type MapOnceVal[V any] struct { - V V - wg typedsync.WaitGroup + V V + c chan struct{} } // Delete removes the value for a key. If the value for that key is @@ -71,17 +73,47 @@ func (m *MapOnce[K, V, M]) Delete(key K) { func (m *MapOnce[K, V, M]) LoadOrDo(key K, do func(K) V) (actual V, loaded bool) { _value, _ := m.pool.Get() if _value == nil { - _value = new(MapOnceVal[V]) + _value = &MapOnceVal[V]{ + c: make(chan struct{}), + } } - _value.wg.Add(1) _actual, loaded := m.Inner.LoadOrStore(key, _value) if loaded { - *_value = MapOnceVal[V]{} m.pool.Put(_value) - _actual.wg.Wait() + <-_actual.c } else { _actual.V = do(key) - _actual.wg.Done() + close(_actual.c) } return _actual.V, loaded } + +// TryLoadOrDo is like LoadOrDo, but obeys context cancellation. If a +// call is cancelled, the call to "do" continues running in a separate +// goroutine, in case other LoadOrDo calls are waiting on it. If a +// call is cancelled, the error from ctx.Err() is returned, otherwise +// err is nil. +func (m *MapOnce[K, V, M]) TryLoadOrDo(ctx context.Context, key K, do func(K) V) (actual V, loaded bool, err error) { + _value, _ := m.pool.Get() + if _value == nil { + _value = &MapOnceVal[V]{ + c: make(chan struct{}), + } + } + _actual, loaded := m.Inner.LoadOrStore(key, _value) + if loaded { + m.pool.Put(_value) + } else { + go func() { + _actual.V = do(key) + close(_actual.c) + }() + } + select { + case <-ctx.Done(): + var zero V + return zero, false, ctx.Err() //nolint:wrapcheck // We're too low level for that to be useful. + case <-_actual.c: + return _actual.V, loaded, nil + } +} -- cgit v1.2.3