diff options
author | Luke Shumaker <lukeshu@lukeshu.com> | 2022-07-17 11:54:49 -0600 |
---|---|---|
committer | Luke Shumaker <lukeshu@lukeshu.com> | 2022-07-17 11:54:49 -0600 |
commit | 4952fc1880bf0f4286b17cbfbe0c49a132d09ebc (patch) | |
tree | 38052a3f6119d70216d796bcce65e8f1fb984ab2 | |
parent | 987a4b238e047238bd83384c87b8317afdd45ad8 (diff) |
implement wildcards in the KMP IndexAll
-rw-r--r-- | lib/diskio/kmp.go | 53 | ||||
-rw-r--r-- | lib/diskio/kmp_test.go | 54 |
2 files changed, 97 insertions, 10 deletions
diff --git a/lib/diskio/kmp.go b/lib/diskio/kmp.go index da19e81..15537de 100644 --- a/lib/diskio/kmp.go +++ b/lib/diskio/kmp.go @@ -9,12 +9,42 @@ import ( "io" ) -func mustGet[K ~int64, V any](seq Sequence[K, V], i K) V { - val, err := seq.Get(i) - if err != nil { - panic(err) +var ErrWildcard = errors.New("wildcard") + +func kmpEq2[K ~int64, V comparable](aS Sequence[K, V], aI K, bS Sequence[K, V], bI K) bool { + aV, aErr := aS.Get(aI) + bV, bErr := bS.Get(bI) + if aErr != nil { + if errors.Is(aErr, ErrWildcard) { + aV = bV + aErr = nil + } else { + panic(aErr) + } + } + if bErr != nil { + if errors.Is(bErr, ErrWildcard) { + bV = aV + bErr = nil + } else { + panic(bErr) + } + } + if aErr != nil || bErr != nil { + return false } - return val + return aV == bV +} + +func kmpEq1[K ~int64, V comparable](aV V, bS Sequence[K, V], bI K) bool { + bV, bErr := bS.Get(bI) + if bErr != nil { + if errors.Is(bErr, ErrWildcard) { + return true + } + panic(bErr) + } + return aV == bV } // buildKMPTable takes the string 'substr', and returns a table such @@ -23,7 +53,7 @@ func mustGet[K ~int64, V any](seq Sequence[K, V], i K) V { func buildKMPTable[K ~int64, V comparable](substr Sequence[K, V]) ([]K, error) { var substrLen K for { - if _, err := substr.Get(substrLen); err != nil { + if _, err := substr.Get(substrLen); err != nil && !errors.Is(err, ErrWildcard) { if errors.Is(err, io.EOF) { break } @@ -41,11 +71,11 @@ func buildKMPTable[K ~int64, V comparable](substr Sequence[K, V]) ([]K, error) { } val := table[j-1] // not a match; go back - for val > 0 && mustGet(substr, j) != mustGet(substr, val) { + for val > 0 && !kmpEq2(substr, j, substr, val) { val = table[val-1] } // is a match; go forward - if mustGet(substr, val) == mustGet(substr, j) { + if kmpEq2(substr, val, substr, j) { val++ } table[j] = val @@ -64,6 +94,9 @@ func buildKMPTable[K ~int64, V comparable](substr Sequence[K, V]) ([]K, error) { // // Will panic if the length of 'substr' is 0. // +// The 'substr' may include wildcard characters by returning +// ErrWildcard for a position. +// // Uses the Knuth-Morris-Pratt algorithm. func IndexAll[K ~int64, V comparable](str, substr Sequence[K, V]) ([]K, error) { table, err := buildKMPTable(substr) @@ -89,12 +122,12 @@ func IndexAll[K ~int64, V comparable](str, substr Sequence[K, V]) ([]K, error) { } // Consider 'chr' - for curMatchLen > 0 && chr != mustGet(substr, curMatchLen) { // shorten the match + for curMatchLen > 0 && !kmpEq1(chr, substr, curMatchLen) { // shorten the match overlap := table[curMatchLen-1] curMatchBeg += curMatchLen - overlap curMatchLen = overlap } - if chr == mustGet(substr, curMatchLen) { // lengthen the match + if kmpEq1(chr, substr, curMatchLen) { // lengthen the match if curMatchLen == 0 { curMatchBeg = pos } diff --git a/lib/diskio/kmp_test.go b/lib/diskio/kmp_test.go index 51c7b5e..59b6224 100644 --- a/lib/diskio/kmp_test.go +++ b/lib/diskio/kmp_test.go @@ -6,6 +6,7 @@ package diskio import ( "bytes" + "io" "testing" "github.com/stretchr/testify/assert" @@ -65,3 +66,56 @@ func FuzzIndexAll(f *testing.F) { assert.Equal(t, exp, act) }) } + +type RESeq string + +func (re RESeq) Get(i int64) (byte, error) { + if i < 0 || i >= int64(len(re)) { + return 0, io.EOF + } + chr := re[int(i)] + if chr == '.' { + return 0, ErrWildcard + } + return chr, nil +} + +func TestKMPWildcard(t *testing.T) { + type testcase struct { + InStr string + InSubstr string + ExpMatches []int64 + } + testcases := map[string]testcase{ + "trivial-bar": { + InStr: "foo_bar", + InSubstr: "foo.ba.", + ExpMatches: []int64{0}, + }, + "trival-baz": { + InStr: "foo-baz", + InSubstr: "foo.ba.", + ExpMatches: []int64{0}, + }, + "suffix": { + InStr: "foobarbaz", + InSubstr: "...baz", + ExpMatches: []int64{3}, + }, + "overlap": { + InStr: "foobarbar", + InSubstr: "...bar", + ExpMatches: []int64{0, 3}, + }, + } + for tcName, tc := range testcases { + tc := tc + t.Run(tcName, func(t *testing.T) { + matches, err := IndexAll[int64, byte]( + StringSequence[int64](tc.InStr), + RESeq(tc.InSubstr)) + assert.NoError(t, err) + assert.Equal(t, tc.ExpMatches, matches) + }) + } +} |