diff --git a/cmd/restic/cmd_restore.go b/cmd/restic/cmd_restore.go index 618dfadf8..e51eee6cb 100644 --- a/cmd/restic/cmd_restore.go +++ b/cmd/restic/cmd_restore.go @@ -140,13 +140,15 @@ func runRestore(opts RestoreOptions, gopts GlobalOptions, args []string) error { return nil } + excludePatterns := filter.ParsePatterns(opts.Exclude) + insensitiveExcludePatterns := filter.ParsePatterns(opts.InsensitiveExclude) selectExcludeFilter := func(item string, dstpath string, node *restic.Node) (selectedForRestore bool, childMayBeSelected bool) { - matched, _, err := filter.List(opts.Exclude, item) + matched, err := filter.List(excludePatterns, item) if err != nil { Warnf("error for exclude pattern: %v", err) } - matchedInsensitive, _, err := filter.List(opts.InsensitiveExclude, strings.ToLower(item)) + matchedInsensitive, err := filter.List(insensitiveExcludePatterns, strings.ToLower(item)) if err != nil { Warnf("error for iexclude pattern: %v", err) } @@ -161,13 +163,15 @@ func runRestore(opts RestoreOptions, gopts GlobalOptions, args []string) error { return selectedForRestore, childMayBeSelected } + includePatterns := filter.ParsePatterns(opts.Include) + insensitiveIncludePatterns := filter.ParsePatterns(opts.InsensitiveInclude) selectIncludeFilter := func(item string, dstpath string, node *restic.Node) (selectedForRestore bool, childMayBeSelected bool) { - matched, childMayMatch, err := filter.List(opts.Include, item) + matched, childMayMatch, err := filter.ListWithChild(includePatterns, item) if err != nil { Warnf("error for include pattern: %v", err) } - matchedInsensitive, childMayMatchInsensitive, err := filter.List(opts.InsensitiveInclude, strings.ToLower(item)) + matchedInsensitive, childMayMatchInsensitive, err := filter.ListWithChild(insensitiveIncludePatterns, strings.ToLower(item)) if err != nil { Warnf("error for iexclude pattern: %v", err) } diff --git a/cmd/restic/exclude.go b/cmd/restic/exclude.go index 8d5585cfc..db603c04e 100644 --- a/cmd/restic/exclude.go +++ b/cmd/restic/exclude.go @@ -74,8 +74,9 @@ type RejectFunc func(path string, fi os.FileInfo) bool // rejectByPattern returns a RejectByNameFunc which rejects files that match // one of the patterns. func rejectByPattern(patterns []string) RejectByNameFunc { + parsedPatterns := filter.ParsePatterns(patterns) return func(item string) bool { - matched, _, err := filter.List(patterns, item) + matched, err := filter.List(parsedPatterns, item) if err != nil { Warnf("error for exclude pattern: %v", err) } diff --git a/internal/filter/filter.go b/internal/filter/filter.go index 74deddb03..1f6f04133 100644 --- a/internal/filter/filter.go +++ b/internal/filter/filter.go @@ -11,6 +11,47 @@ import ( // second argument. var ErrBadString = errors.New("filter.Match: string is empty") +type patternPart struct { + pattern string // First is "/" for absolute pattern; "" for "**". + isSimple bool +} + +// Pattern represents a preparsed filter pattern +type Pattern []patternPart + +func prepareStr(str string) ([]string, error) { + if str == "" { + return nil, ErrBadString + } + return splitPath(str), nil +} + +func preparePattern(pattern string) Pattern { + parts := splitPath(filepath.Clean(pattern)) + patterns := make([]patternPart, len(parts)) + for i, part := range parts { + isSimple := !strings.ContainsAny(part, "\\[]*?") + // Replace "**" with the empty string to get faster comparisons + // (length-check only) in hasDoubleWildcard. + if part == "**" { + part = "" + } + patterns[i] = patternPart{part, isSimple} + } + + return patterns +} + +// Split p into path components. Assuming p has been Cleaned, no component +// will be empty. For absolute paths, the first component is "/". +func splitPath(p string) []string { + parts := strings.Split(filepath.ToSlash(p), "/") + if parts[0] == "" { + parts[0] = "/" + } + return parts +} + // Match returns true if str matches the pattern. When the pattern is // malformed, filepath.ErrBadPattern is returned. The empty pattern matches // everything, when str is the empty string ErrBadString is returned. @@ -26,21 +67,13 @@ func Match(pattern, str string) (matched bool, err error) { return true, nil } - pattern = filepath.Clean(pattern) + patterns := preparePattern(pattern) + strs, err := prepareStr(str) - if str == "" { - return false, ErrBadString + if err != nil { + return false, err } - // convert file path separator to '/' - if filepath.Separator != '/' { - pattern = strings.Replace(pattern, string(filepath.Separator), "/", -1) - str = strings.Replace(str, string(filepath.Separator), "/", -1) - } - - patterns := strings.Split(pattern, "/") - strs := strings.Split(str, "/") - return match(patterns, strs) } @@ -59,26 +92,18 @@ func ChildMatch(pattern, str string) (matched bool, err error) { return true, nil } - pattern = filepath.Clean(pattern) + patterns := preparePattern(pattern) + strs, err := prepareStr(str) - if str == "" { - return false, ErrBadString + if err != nil { + return false, err } - // convert file path separator to '/' - if filepath.Separator != '/' { - pattern = strings.Replace(pattern, string(filepath.Separator), "/", -1) - str = strings.Replace(str, string(filepath.Separator), "/", -1) - } - - patterns := strings.Split(pattern, "/") - strs := strings.Split(str, "/") - return childMatch(patterns, strs) } -func childMatch(patterns, strs []string) (matched bool, err error) { - if patterns[0] != "" { +func childMatch(patterns Pattern, strs []string) (matched bool, err error) { + if patterns[0].pattern != "/" { // relative pattern can always be nested down return true, nil } @@ -99,9 +124,9 @@ func childMatch(patterns, strs []string) (matched bool, err error) { return match(patterns[0:l], strs) } -func hasDoubleWildcard(list []string) (ok bool, pos int) { +func hasDoubleWildcard(list Pattern) (ok bool, pos int) { for i, item := range list { - if item == "**" { + if item.pattern == "" { return true, i } } @@ -109,14 +134,18 @@ func hasDoubleWildcard(list []string) (ok bool, pos int) { return false, 0 } -func match(patterns, strs []string) (matched bool, err error) { +func match(patterns Pattern, strs []string) (matched bool, err error) { if ok, pos := hasDoubleWildcard(patterns); ok { // gradually expand '**' into separate wildcards + newPat := make(Pattern, len(strs)) + // copy static prefix once + copy(newPat, patterns[:pos]) for i := 0; i <= len(strs)-len(patterns)+1; i++ { - newPat := make([]string, pos) - copy(newPat, patterns[:pos]) - for k := 0; k < i; k++ { - newPat = append(newPat, "*") + // limit to static prefix and already appended '*' + newPat := newPat[:pos+i] + // in the first iteration the wildcard expands to nothing + if i > 0 { + newPat[pos+i-1] = patternPart{"*", false} } newPat = append(newPat, patterns[pos+1:]...) @@ -138,13 +167,27 @@ func match(patterns, strs []string) (matched bool, err error) { } if len(patterns) <= len(strs) { + minOffset := 0 + maxOffset := len(strs) - len(patterns) + // special case absolute patterns + if patterns[0].pattern == "/" { + maxOffset = 0 + } else if strs[0] == "/" { + // skip absolute path marker if pattern is not rooted + minOffset = 1 + } outer: - for offset := len(strs) - len(patterns); offset >= 0; offset-- { + for offset := maxOffset; offset >= minOffset; offset-- { for i := len(patterns) - 1; i >= 0; i-- { - ok, err := filepath.Match(patterns[i], strs[offset+i]) - if err != nil { - return false, errors.Wrap(err, "Match") + var ok bool + if patterns[i].isSimple { + ok = patterns[i].pattern == strs[offset+i] + } else { + ok, err = filepath.Match(patterns[i].pattern, strs[offset+i]) + if err != nil { + return false, errors.Wrap(err, "Match") + } } if !ok { @@ -159,22 +202,55 @@ func match(patterns, strs []string) (matched bool, err error) { return false, nil } -// List returns true if str matches one of the patterns. Empty patterns are -// ignored. -func List(patterns []string, str string) (matched bool, childMayMatch bool, err error) { +// ParsePatterns prepares a list of patterns for use with List. +func ParsePatterns(patterns []string) []Pattern { + patpat := make([]Pattern, 0) for _, pat := range patterns { if pat == "" { continue } - m, err := Match(pat, str) + pats := preparePattern(pat) + patpat = append(patpat, pats) + } + return patpat +} + +// List returns true if str matches one of the patterns. Empty patterns are ignored. +func List(patterns []Pattern, str string) (matched bool, err error) { + matched, _, err = list(patterns, false, str) + return matched, err +} + +// ListWithChild returns true if str matches one of the patterns. Empty patterns are ignored. +func ListWithChild(patterns []Pattern, str string) (matched bool, childMayMatch bool, err error) { + return list(patterns, true, str) +} + +// List returns true if str matches one of the patterns. Empty patterns are ignored. +func list(patterns []Pattern, checkChildMatches bool, str string) (matched bool, childMayMatch bool, err error) { + if len(patterns) == 0 { + return false, false, nil + } + + strs, err := prepareStr(str) + if err != nil { + return false, false, err + } + for _, pat := range patterns { + m, err := match(pat, strs) if err != nil { return false, false, err } - c, err := ChildMatch(pat, str) - if err != nil { - return false, false, err + var c bool + if checkChildMatches { + c, err = childMatch(pat, strs) + if err != nil { + return false, false, err + } + } else { + c = true } matched = matched || m diff --git a/internal/filter/filter_test.go b/internal/filter/filter_test.go index 97df452fb..ee7f05a32 100644 --- a/internal/filter/filter_test.go +++ b/internal/filter/filter_test.go @@ -240,25 +240,28 @@ func ExampleMatch_wildcards() { } var filterListTests = []struct { - patterns []string - path string - match bool + patterns []string + path string + match bool + childMatch bool }{ - {[]string{"*.go"}, "/foo/bar/test.go", true}, - {[]string{"*.c"}, "/foo/bar/test.go", false}, - {[]string{"*.go", "*.c"}, "/foo/bar/test.go", true}, - {[]string{"*"}, "/foo/bar/test.go", true}, - {[]string{"x"}, "/foo/bar/test.go", false}, - {[]string{"?"}, "/foo/bar/test.go", false}, - {[]string{"?", "x"}, "/foo/bar/x", true}, - {[]string{"/*/*/bar/test.*"}, "/foo/bar/test.go", false}, - {[]string{"/*/*/bar/test.*", "*.go"}, "/foo/bar/test.go", true}, - {[]string{"", "*.c"}, "/foo/bar/test.go", false}, + {[]string{}, "/foo/bar/test.go", false, false}, + {[]string{"*.go"}, "/foo/bar/test.go", true, true}, + {[]string{"*.c"}, "/foo/bar/test.go", false, true}, + {[]string{"*.go", "*.c"}, "/foo/bar/test.go", true, true}, + {[]string{"*"}, "/foo/bar/test.go", true, true}, + {[]string{"x"}, "/foo/bar/test.go", false, true}, + {[]string{"?"}, "/foo/bar/test.go", false, true}, + {[]string{"?", "x"}, "/foo/bar/x", true, true}, + {[]string{"/*/*/bar/test.*"}, "/foo/bar/test.go", false, false}, + {[]string{"/*/*/bar/test.*", "*.go"}, "/foo/bar/test.go", true, true}, + {[]string{"", "*.c"}, "/foo/bar/test.go", false, true}, } func TestList(t *testing.T) { for i, test := range filterListTests { - match, _, err := filter.List(test.patterns, test.path) + patterns := filter.ParsePatterns(test.patterns) + match, err := filter.List(patterns, test.path) if err != nil { t.Errorf("test %d failed: expected no error for patterns %q, but error returned: %v", i, test.patterns, err) @@ -266,19 +269,64 @@ func TestList(t *testing.T) { } if match != test.match { - t.Errorf("test %d: filter.MatchList(%q, %q): expected %v, got %v", + t.Errorf("test %d: filter.List(%q, %q): expected %v, got %v", i, test.patterns, test.path, test.match, match) } + + match, childMatch, err := filter.ListWithChild(patterns, test.path) + if err != nil { + t.Errorf("test %d failed: expected no error for patterns %q, but error returned: %v", + i, test.patterns, err) + continue + } + + if match != test.match || childMatch != test.childMatch { + t.Errorf("test %d: filter.ListWithChild(%q, %q): expected %v, %v, got %v, %v", + i, test.patterns, test.path, test.match, test.childMatch, match, childMatch) + } } } func ExampleList() { - match, _, _ := filter.List([]string{"*.c", "*.go"}, "/home/user/file.go") + patterns := filter.ParsePatterns([]string{"*.c", "*.go"}) + match, _ := filter.List(patterns, "/home/user/file.go") fmt.Printf("match: %v\n", match) // Output: // match: true } +func TestInvalidStrs(t *testing.T) { + _, err := filter.Match("test", "") + if err == nil { + t.Error("Match accepted invalid path") + } + + _, err = filter.ChildMatch("test", "") + if err == nil { + t.Error("ChildMatch accepted invalid path") + } + + patterns := []string{"test"} + _, err = filter.List(filter.ParsePatterns(patterns), "") + if err == nil { + t.Error("List accepted invalid path") + } +} + +func TestInvalidPattern(t *testing.T) { + patterns := []string{"test/["} + _, err := filter.List(filter.ParsePatterns(patterns), "test/example") + if err == nil { + t.Error("List accepted invalid pattern") + } + + patterns = []string{"test/**/["} + _, err = filter.List(filter.ParsePatterns(patterns), "test/example") + if err == nil { + t.Error("List accepted invalid pattern") + } +} + func extractTestLines(t testing.TB) (lines []string) { f, err := os.Open("testdata/libreoffice.txt.bz2") if err != nil { @@ -360,30 +408,60 @@ func BenchmarkFilterLines(b *testing.B) { } func BenchmarkFilterPatterns(b *testing.B) { - patterns := []string{ - "sdk/*", - "*.html", - } lines := extractTestLines(b) - var c uint - - b.ResetTimer() - - for i := 0; i < b.N; i++ { - c = 0 - for _, line := range lines { - match, _, err := filter.List(patterns, line) - if err != nil { - b.Fatal(err) - } - - if match { - c++ - } + modlines := make([]string, 200) + for i, line := range lines { + if i >= len(modlines) { + break } + modlines[i] = line + "-does-not-match" + } + tests := []struct { + name string + patterns []filter.Pattern + matches uint + }{ + {"Relative", filter.ParsePatterns([]string{ + "does-not-match", + "sdk/*", + "*.html", + }), 22185}, + {"Absolute", filter.ParsePatterns([]string{ + "/etc", + "/home/*/test", + "/usr/share/doc/libreoffice/sdk/docs/java", + }), 150}, + {"Wildcard", filter.ParsePatterns([]string{ + "/etc/**/example", + "/home/**/test", + "/usr/**/java", + }), 150}, + {"ManyNoMatch", filter.ParsePatterns(modlines), 0}, + } - if c != 22185 { - b.Fatalf("wrong number of matches: expected 22185, got %d", c) - } + for _, test := range tests { + b.Run(test.name, func(b *testing.B) { + var c uint + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + c = 0 + for _, line := range lines { + match, err := filter.List(test.patterns, line) + if err != nil { + b.Fatal(err) + } + + if match { + c++ + } + } + + if c != test.matches { + b.Fatalf("wrong number of matches: expected %d, got %d", test.matches, c) + } + } + }) } }