summaryrefslogtreecommitdiff
path: root/lib/diskio/seq.go
diff options
context:
space:
mode:
authorLuke Shumaker <lukeshu@lukeshu.com>2022-07-16 14:53:11 -0600
committerLuke Shumaker <lukeshu@lukeshu.com>2022-07-16 16:11:23 -0600
commit5889f1fa2818f34025ca6e2feecb26928c6e6341 (patch)
treed1f11679288826bd1d02093f61c30fd918ee0bf3 /lib/diskio/seq.go
parent259aecbdc2e22836a6e75011549503795a22536f (diff)
Re-jigger the KMP search to be more independent of the underlying data types
Diffstat (limited to 'lib/diskio/seq.go')
-rw-r--r--lib/diskio/seq.go68
1 files changed, 68 insertions, 0 deletions
diff --git a/lib/diskio/seq.go b/lib/diskio/seq.go
new file mode 100644
index 0000000..3c5f4ae
--- /dev/null
+++ b/lib/diskio/seq.go
@@ -0,0 +1,68 @@
+// Copyright (C) 2022 Luke Shumaker <lukeshu@lukeshu.com>
+//
+// SPDX-License-Identifier: GPL-2.0-or-later
+
+package diskio
+
+import (
+ "fmt"
+ "io"
+)
+
+// interface /////////////////////////////////////////////////////////
+
+type Sequence[K ~int64, V any] interface {
+ // Get the value at 'pos' in the sequence. Positions start at
+ // 0 and increment naturally. Return an error that is io.EOF
+ // if 'pos' is past the end of the sequence'.
+ Get(pos K) (V, error)
+}
+
+// implementation: slice /////////////////////////////////////////////
+
+type SliceSequence[K ~int64, V any] []V
+
+var _ Sequence[assertAddr, byte] = SliceSequence[assertAddr, byte]([]byte(nil))
+
+func (s SliceSequence[K, V]) Get(i K) (V, error) {
+ if i >= K(len(s)) {
+ var v V
+ return v, io.EOF
+ }
+ return s[int(i)], nil
+}
+
+// implementation: string ////////////////////////////////////////////
+
+type StringSequence[K ~int64] string
+
+var _ Sequence[assertAddr, byte] = StringSequence[assertAddr]("")
+
+func (s StringSequence[K]) Get(i K) (byte, error) {
+ if i >= K(len(s)) {
+ return 0, io.EOF
+ }
+ return s[int(i)], nil
+}
+
+// implementation: io.ByteReader /////////////////////////////////////
+
+type ByteReaderSequence[K ~int64] struct {
+ R io.ByteReader
+ pos K
+}
+
+var _ Sequence[assertAddr, byte] = &ByteReaderSequence[assertAddr]{R: nil}
+
+func (s *ByteReaderSequence[K]) Get(i K) (byte, error) {
+ if i != s.pos {
+ return 0, fmt.Errorf("%T.Get(%v): can only call .Get(%v)",
+ s, i, s.pos)
+ }
+ chr, err := s.R.ReadByte()
+ if err != nil {
+ return chr, err
+ }
+ s.pos++
+ return chr, nil
+}