summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLuke Shumaker <lukeshu@lukeshu.com>2022-07-13 23:59:04 -0600
committerLuke Shumaker <lukeshu@lukeshu.com>2022-07-14 00:50:21 -0600
commitb8c5940165399f9dc404c912aa455822347bb367 (patch)
tree1a4504f16c321ea2b460d9fe0b0b900c9db1a822
parente784de8a66c3645fdd3a54939b5b844f3bacd82d (diff)
diskio.FindAll: Have the return type be a type parameter
-rw-r--r--lib/diskio/kmp.go12
-rw-r--r--lib/diskio/kmp_test.go2
2 files changed, 7 insertions, 7 deletions
diff --git a/lib/diskio/kmp.go b/lib/diskio/kmp.go
index 4c0f531..69a3a51 100644
--- a/lib/diskio/kmp.go
+++ b/lib/diskio/kmp.go
@@ -40,17 +40,17 @@ func buildKMPTable(substr []byte) []int {
// Will panic if len(substr)==0.
//
// Uses the Knuth-Morris-Pratt algorithm.
-func FindAll(r io.ByteReader, substr []byte) ([]int64, error) {
+func FindAll[A ~int64](r io.ByteReader, substr []byte) ([]A, error) {
if len(substr) == 0 {
panic(errors.New("diskio.FindAll: empty substring"))
}
table := buildKMPTable(substr)
- var matches []int64
- var curMatchBeg int64
+ var matches []A
+ var curMatchBeg A
var curMatchLen int
- pos := int64(-1) // if 'r' were a slice; define 'pos' such that 'chr=r[pos]'
+ pos := A(-1) // if 'r' were a slice; define 'pos' such that 'chr=r[pos]'
for {
// I/O
var chr byte
@@ -66,7 +66,7 @@ func FindAll(r io.ByteReader, substr []byte) ([]int64, error) {
// Consider 'chr'
for curMatchLen > 0 && chr != substr[curMatchLen] { // shorten the match
overlap := table[curMatchLen-1]
- curMatchBeg += int64(curMatchLen - overlap)
+ curMatchBeg += A(curMatchLen - overlap)
curMatchLen = overlap
}
if chr == substr[curMatchLen] { // lengthen the match
@@ -77,7 +77,7 @@ func FindAll(r io.ByteReader, substr []byte) ([]int64, error) {
if curMatchLen == len(substr) {
matches = append(matches, curMatchBeg)
overlap := table[curMatchLen-1]
- curMatchBeg += int64(curMatchLen - overlap)
+ curMatchBeg += A(curMatchLen - overlap)
curMatchLen = overlap
}
}
diff --git a/lib/diskio/kmp_test.go b/lib/diskio/kmp_test.go
index 6c6ef78..836605a 100644
--- a/lib/diskio/kmp_test.go
+++ b/lib/diskio/kmp_test.go
@@ -55,7 +55,7 @@ func FuzzFindAll(f *testing.F) {
t.Logf("str =%q", str)
t.Logf("substr=%q", substr)
exp := NaiveFindAll(str, substr)
- act, err := FindAll(bytes.NewReader(str), substr)
+ act, err := FindAll[int64](bytes.NewReader(str), substr)
assert.NoError(t, err)
assert.Equal(t, exp, act)
})