summaryrefslogtreecommitdiff
path: root/lib/btrfsprogs/btrfsinspect/rebuildmappings/kmp.go
diff options
context:
space:
mode:
Diffstat (limited to 'lib/btrfsprogs/btrfsinspect/rebuildmappings/kmp.go')
-rw-r--r--lib/btrfsprogs/btrfsinspect/rebuildmappings/kmp.go149
1 files changed, 149 insertions, 0 deletions
diff --git a/lib/btrfsprogs/btrfsinspect/rebuildmappings/kmp.go b/lib/btrfsprogs/btrfsinspect/rebuildmappings/kmp.go
new file mode 100644
index 0000000..eeaab0c
--- /dev/null
+++ b/lib/btrfsprogs/btrfsinspect/rebuildmappings/kmp.go
@@ -0,0 +1,149 @@
+// Copyright (C) 2022-2023 Luke Shumaker <lukeshu@lukeshu.com>
+//
+// SPDX-License-Identifier: GPL-2.0-or-later
+
+package rebuildmappings
+
+import (
+ "errors"
+ "io"
+
+ "git.lukeshu.com/btrfs-progs-ng/lib/diskio"
+)
+
+var ErrWildcard = errors.New("wildcard")
+
+func kmpEq2[K ~int64, V comparable](aS diskio.Sequence[K, V], aI K, bS diskio.Sequence[K, V], bI K) bool {
+ aV, aErr := aS.Get(aI)
+ bV, bErr := bS.Get(bI)
+ if aErr != nil {
+ //nolint:errorlint // The == is just a fast-path; we still fall back to errors.Is.
+ if aErr == ErrWildcard || errors.Is(aErr, ErrWildcard) {
+ aV = bV
+ aErr = nil
+ } else {
+ panic(aErr)
+ }
+ }
+ if bErr != nil {
+ //nolint:errorlint // The == is just a fast-path; we still fall back to errors.Is.
+ if bErr == ErrWildcard || errors.Is(bErr, ErrWildcard) {
+ bV = aV
+ bErr = nil
+ } else {
+ panic(bErr)
+ }
+ }
+ if aErr != nil || bErr != nil {
+ return false
+ }
+ return aV == bV
+}
+
+func kmpEq1[K ~int64, V comparable](aV V, bS diskio.Sequence[K, V], bI K) bool {
+ bV, bErr := bS.Get(bI)
+ if bErr != nil {
+ //nolint:errorlint // The == is just a fast-path; we still fall back to errors.Is.
+ if bErr == ErrWildcard || errors.Is(bErr, ErrWildcard) {
+ return true
+ }
+ panic(bErr)
+ }
+ return aV == bV
+}
+
+// buildKMPTable takes the string 'substr', and returns a table such
+// that 'table[matchLen-1]' is the largest value 'val' for which 'val < matchLen' and
+// 'substr[:val] == substr[matchLen-val:matchLen]'.
+func buildKMPTable[K ~int64, V comparable](substr diskio.Sequence[K, V]) ([]K, error) {
+ var substrLen K
+ for {
+ //nolint:errorlint // The == is just a fast-path; we still fall back to errors.Is.
+ if _, err := substr.Get(substrLen); err != nil && !(err == ErrWildcard || errors.Is(err, ErrWildcard)) {
+ if errors.Is(err, io.EOF) {
+ break
+ }
+ return nil, err
+ }
+ substrLen++
+ }
+
+ table := make([]K, substrLen)
+ for j := K(0); j < substrLen; j++ {
+ if j == 0 {
+ // First entry must always be 0 (in order to
+ // satisfy 'val < matchLen').
+ continue
+ }
+ val := table[j-1]
+ // not a match; go back
+ for val > 0 && !kmpEq2(substr, j, substr, val) {
+ val = table[val-1]
+ }
+ // is a match; go forward
+ if kmpEq2(substr, val, substr, j) {
+ val++
+ }
+ table[j] = val
+ }
+ return table, nil
+}
+
+// IndexAll returns the starting-position of all possibly-overlapping
+// occurrences of 'substr' in the 'str' sequence.
+//
+// Will hop around in 'substr', but will only get the natural sequence
+// [0...) in order from 'str'. When hopping around in 'substr' it
+// assumes that once it has gotten a given index without error, it can
+// continue to do so without error; errors appearing later will cause
+// panics.
+//
+// Will panic if the length of 'substr' is 0.
+//
+// The 'substr' may include wildcard characters by returning
+// ErrWildcard for a position.
+//
+// Uses the Knuth-Morris-Pratt algorithm.
+func IndexAll[K ~int64, V comparable](str, substr diskio.Sequence[K, V]) ([]K, error) {
+ table, err := buildKMPTable(substr)
+ if err != nil {
+ return nil, err
+ }
+ substrLen := K(len(table))
+ if substrLen == 0 {
+ panic(errors.New("rebuildmappings.IndexAll: empty substring"))
+ }
+
+ var matches []K
+ var curMatchBeg K
+ var curMatchLen K
+
+ for pos := K(0); ; pos++ {
+ chr, err := str.Get(pos)
+ if err != nil {
+ if errors.Is(err, io.EOF) {
+ err = nil
+ }
+ return matches, err
+ }
+
+ // Consider 'chr'
+ for curMatchLen > 0 && !kmpEq1(chr, substr, curMatchLen) { // shorten the match
+ overlap := table[curMatchLen-1]
+ curMatchBeg += curMatchLen - overlap
+ curMatchLen = overlap
+ }
+ if kmpEq1(chr, substr, curMatchLen) { // lengthen the match
+ if curMatchLen == 0 {
+ curMatchBeg = pos
+ }
+ curMatchLen++
+ if curMatchLen == substrLen {
+ matches = append(matches, curMatchBeg)
+ overlap := table[curMatchLen-1]
+ curMatchBeg += curMatchLen - overlap
+ curMatchLen = overlap
+ }
+ }
+ }
+}