diff options
Diffstat (limited to 'lib/diskio/kmp.go')
-rw-r--r-- | lib/diskio/kmp.go | 53 |
1 files changed, 43 insertions, 10 deletions
diff --git a/lib/diskio/kmp.go b/lib/diskio/kmp.go index da19e81..15537de 100644 --- a/lib/diskio/kmp.go +++ b/lib/diskio/kmp.go @@ -9,12 +9,42 @@ import ( "io" ) -func mustGet[K ~int64, V any](seq Sequence[K, V], i K) V { - val, err := seq.Get(i) - if err != nil { - panic(err) +var ErrWildcard = errors.New("wildcard") + +func kmpEq2[K ~int64, V comparable](aS Sequence[K, V], aI K, bS Sequence[K, V], bI K) bool { + aV, aErr := aS.Get(aI) + bV, bErr := bS.Get(bI) + if aErr != nil { + if errors.Is(aErr, ErrWildcard) { + aV = bV + aErr = nil + } else { + panic(aErr) + } + } + if bErr != nil { + if errors.Is(bErr, ErrWildcard) { + bV = aV + bErr = nil + } else { + panic(bErr) + } + } + if aErr != nil || bErr != nil { + return false } - return val + return aV == bV +} + +func kmpEq1[K ~int64, V comparable](aV V, bS Sequence[K, V], bI K) bool { + bV, bErr := bS.Get(bI) + if bErr != nil { + if errors.Is(bErr, ErrWildcard) { + return true + } + panic(bErr) + } + return aV == bV } // buildKMPTable takes the string 'substr', and returns a table such @@ -23,7 +53,7 @@ func mustGet[K ~int64, V any](seq Sequence[K, V], i K) V { func buildKMPTable[K ~int64, V comparable](substr Sequence[K, V]) ([]K, error) { var substrLen K for { - if _, err := substr.Get(substrLen); err != nil { + if _, err := substr.Get(substrLen); err != nil && !errors.Is(err, ErrWildcard) { if errors.Is(err, io.EOF) { break } @@ -41,11 +71,11 @@ func buildKMPTable[K ~int64, V comparable](substr Sequence[K, V]) ([]K, error) { } val := table[j-1] // not a match; go back - for val > 0 && mustGet(substr, j) != mustGet(substr, val) { + for val > 0 && !kmpEq2(substr, j, substr, val) { val = table[val-1] } // is a match; go forward - if mustGet(substr, val) == mustGet(substr, j) { + if kmpEq2(substr, val, substr, j) { val++ } table[j] = val @@ -64,6 +94,9 @@ func buildKMPTable[K ~int64, V comparable](substr Sequence[K, V]) ([]K, error) { // // Will panic if the length of 'substr' is 0. // +// The 'substr' may include wildcard characters by returning +// ErrWildcard for a position. +// // Uses the Knuth-Morris-Pratt algorithm. func IndexAll[K ~int64, V comparable](str, substr Sequence[K, V]) ([]K, error) { table, err := buildKMPTable(substr) @@ -89,12 +122,12 @@ func IndexAll[K ~int64, V comparable](str, substr Sequence[K, V]) ([]K, error) { } // Consider 'chr' - for curMatchLen > 0 && chr != mustGet(substr, curMatchLen) { // shorten the match + for curMatchLen > 0 && !kmpEq1(chr, substr, curMatchLen) { // shorten the match overlap := table[curMatchLen-1] curMatchBeg += curMatchLen - overlap curMatchLen = overlap } - if chr == mustGet(substr, curMatchLen) { // lengthen the match + if kmpEq1(chr, substr, curMatchLen) { // lengthen the match if curMatchLen == 0 { curMatchBeg = pos } |