diff options
| -rw-r--r-- | lib/diskio/kmp.go | 53 | ||||
| -rw-r--r-- | lib/diskio/kmp_test.go | 54 | 
2 files changed, 97 insertions, 10 deletions
| diff --git a/lib/diskio/kmp.go b/lib/diskio/kmp.go index da19e81..15537de 100644 --- a/lib/diskio/kmp.go +++ b/lib/diskio/kmp.go @@ -9,12 +9,42 @@ import (  	"io"  ) -func mustGet[K ~int64, V any](seq Sequence[K, V], i K) V { -	val, err := seq.Get(i) -	if err != nil { -		panic(err) +var ErrWildcard = errors.New("wildcard") + +func kmpEq2[K ~int64, V comparable](aS Sequence[K, V], aI K, bS Sequence[K, V], bI K) bool { +	aV, aErr := aS.Get(aI) +	bV, bErr := bS.Get(bI) +	if aErr != nil { +		if errors.Is(aErr, ErrWildcard) { +			aV = bV +			aErr = nil +		} else { +			panic(aErr) +		} +	} +	if bErr != nil { +		if errors.Is(bErr, ErrWildcard) { +			bV = aV +			bErr = nil +		} else { +			panic(bErr) +		} +	} +	if aErr != nil || bErr != nil { +		return false  	} -	return val +	return aV == bV +} + +func kmpEq1[K ~int64, V comparable](aV V, bS Sequence[K, V], bI K) bool { +	bV, bErr := bS.Get(bI) +	if bErr != nil { +		if errors.Is(bErr, ErrWildcard) { +			return true +		} +		panic(bErr) +	} +	return aV == bV  }  // buildKMPTable takes the string 'substr', and returns a table such @@ -23,7 +53,7 @@ func mustGet[K ~int64, V any](seq Sequence[K, V], i K) V {  func buildKMPTable[K ~int64, V comparable](substr Sequence[K, V]) ([]K, error) {  	var substrLen K  	for { -		if _, err := substr.Get(substrLen); err != nil { +		if _, err := substr.Get(substrLen); err != nil && !errors.Is(err, ErrWildcard) {  			if errors.Is(err, io.EOF) {  				break  			} @@ -41,11 +71,11 @@ func buildKMPTable[K ~int64, V comparable](substr Sequence[K, V]) ([]K, error) {  		}  		val := table[j-1]  		// not a match; go back -		for val > 0 && mustGet(substr, j) != mustGet(substr, val) { +		for val > 0 && !kmpEq2(substr, j, substr, val) {  			val = table[val-1]  		}  		// is a match; go forward -		if mustGet(substr, val) == mustGet(substr, j) { +		if kmpEq2(substr, val, substr, j) {  			val++  		}  		table[j] = val @@ -64,6 +94,9 @@ func buildKMPTable[K ~int64, V comparable](substr Sequence[K, V]) ([]K, error) {  //  // Will panic if the length of 'substr' is 0.  // +// The 'substr' may include wildcard characters by returning +// ErrWildcard for a position. +//  // Uses the Knuth-Morris-Pratt algorithm.  func IndexAll[K ~int64, V comparable](str, substr Sequence[K, V]) ([]K, error) {  	table, err := buildKMPTable(substr) @@ -89,12 +122,12 @@ func IndexAll[K ~int64, V comparable](str, substr Sequence[K, V]) ([]K, error) {  		}  		// Consider 'chr' -		for curMatchLen > 0 && chr != mustGet(substr, curMatchLen) { // shorten the match +		for curMatchLen > 0 && !kmpEq1(chr, substr, curMatchLen) { // shorten the match  			overlap := table[curMatchLen-1]  			curMatchBeg += curMatchLen - overlap  			curMatchLen = overlap  		} -		if chr == mustGet(substr, curMatchLen) { // lengthen the match +		if kmpEq1(chr, substr, curMatchLen) { // lengthen the match  			if curMatchLen == 0 {  				curMatchBeg = pos  			} diff --git a/lib/diskio/kmp_test.go b/lib/diskio/kmp_test.go index 51c7b5e..59b6224 100644 --- a/lib/diskio/kmp_test.go +++ b/lib/diskio/kmp_test.go @@ -6,6 +6,7 @@ package diskio  import (  	"bytes" +	"io"  	"testing"  	"github.com/stretchr/testify/assert" @@ -65,3 +66,56 @@ func FuzzIndexAll(f *testing.F) {  		assert.Equal(t, exp, act)  	})  } + +type RESeq string + +func (re RESeq) Get(i int64) (byte, error) { +	if i < 0 || i >= int64(len(re)) { +		return 0, io.EOF +	} +	chr := re[int(i)] +	if chr == '.' { +		return 0, ErrWildcard +	} +	return chr, nil +} + +func TestKMPWildcard(t *testing.T) { +	type testcase struct { +		InStr      string +		InSubstr   string +		ExpMatches []int64 +	} +	testcases := map[string]testcase{ +		"trivial-bar": { +			InStr:      "foo_bar", +			InSubstr:   "foo.ba.", +			ExpMatches: []int64{0}, +		}, +		"trival-baz": { +			InStr:      "foo-baz", +			InSubstr:   "foo.ba.", +			ExpMatches: []int64{0}, +		}, +		"suffix": { +			InStr:      "foobarbaz", +			InSubstr:   "...baz", +			ExpMatches: []int64{3}, +		}, +		"overlap": { +			InStr:      "foobarbar", +			InSubstr:   "...bar", +			ExpMatches: []int64{0, 3}, +		}, +	} +	for tcName, tc := range testcases { +		tc := tc +		t.Run(tcName, func(t *testing.T) { +			matches, err := IndexAll[int64, byte]( +				StringSequence[int64](tc.InStr), +				RESeq(tc.InSubstr)) +			assert.NoError(t, err) +			assert.Equal(t, tc.ExpMatches, matches) +		}) +	} +} | 
