diff --git a/cmd/restic/cmd_backup.go b/cmd/restic/cmd_backup.go index 0b33f2263..67af4397a 100644 --- a/cmd/restic/cmd_backup.go +++ b/cmd/restic/cmd_backup.go @@ -307,8 +307,8 @@ func collectRejectByNameFuncs(opts BackupOptions, repo *repository.Repository, t return nil, err } - if valid, invalidPatterns := filter.ValidatePatterns(excludes); !valid { - return nil, errors.Fatalf("--exclude-file: invalid pattern(s) provided:\n%s", strings.Join(invalidPatterns, "\n")) + if err := filter.ValidatePatterns(excludes); err != nil { + return nil, errors.Fatalf("--exclude-file: %s", err) } opts.Excludes = append(opts.Excludes, excludes...) @@ -320,24 +320,24 @@ func collectRejectByNameFuncs(opts BackupOptions, repo *repository.Repository, t return nil, err } - if valid, invalidPatterns := filter.ValidatePatterns(excludes); !valid { - return nil, errors.Fatalf("--iexclude-file: invalid pattern(s) provided:\n%s", strings.Join(invalidPatterns, "\n")) + if err := filter.ValidatePatterns(excludes); err != nil { + return nil, errors.Fatalf("--iexclude-file: %s", err) } opts.InsensitiveExcludes = append(opts.InsensitiveExcludes, excludes...) } if len(opts.InsensitiveExcludes) > 0 { - if valid, invalidPatterns := filter.ValidatePatterns(opts.InsensitiveExcludes); !valid { - return nil, errors.Fatalf("--iexclude: invalid pattern(s) provided:\n%s", strings.Join(invalidPatterns, "\n")) + if err := filter.ValidatePatterns(opts.InsensitiveExcludes); err != nil { + return nil, errors.Fatalf("--iexclude: %s", err) } fs = append(fs, rejectByInsensitivePattern(opts.InsensitiveExcludes)) } if len(opts.Excludes) > 0 { - if valid, invalidPatterns := filter.ValidatePatterns(opts.Excludes); !valid { - return nil, errors.Fatalf("--exclude: invalid pattern(s) provided:\n%s", strings.Join(invalidPatterns, "\n")) + if err := filter.ValidatePatterns(opts.Excludes); err != nil { + return nil, errors.Fatalf("--exclude: %s", err) } fs = append(fs, rejectByPattern(opts.Excludes)) diff --git a/cmd/restic/cmd_restore.go b/cmd/restic/cmd_restore.go index addd36661..dc625188b 100644 --- a/cmd/restic/cmd_restore.go +++ b/cmd/restic/cmd_restore.go @@ -72,23 +72,23 @@ func runRestore(opts RestoreOptions, gopts GlobalOptions, args []string) error { // Validate provided patterns if len(opts.Exclude) > 0 { - if valid, invalidPatterns := filter.ValidatePatterns(opts.Exclude); !valid { - return errors.Fatalf("--exclude: invalid pattern(s) provided:\n%s", strings.Join(invalidPatterns, "\n")) + if err := filter.ValidatePatterns(opts.Exclude); err != nil { + return errors.Fatalf("--exclude: %s", err) } } if len(opts.InsensitiveExclude) > 0 { - if valid, invalidPatterns := filter.ValidatePatterns(opts.InsensitiveExclude); !valid { - return errors.Fatalf("--iexclude: invalid pattern(s) provided:\n%s", strings.Join(invalidPatterns, "\n")) + if err := filter.ValidatePatterns(opts.InsensitiveExclude); err != nil { + return errors.Fatalf("--iexclude: %s", err) } } if len(opts.Include) > 0 { - if valid, invalidPatterns := filter.ValidatePatterns(opts.Include); !valid { - return errors.Fatalf("--include: invalid pattern(s) provided:\n%s", strings.Join(invalidPatterns, "\n")) + if err := filter.ValidatePatterns(opts.Include); err != nil { + return errors.Fatalf("--include: %s", err) } } if len(opts.InsensitiveInclude) > 0 { - if valid, invalidPatterns := filter.ValidatePatterns(opts.InsensitiveInclude); !valid { - return errors.Fatalf("--iinclude: invalid pattern(s) provided:\n%s", strings.Join(invalidPatterns, "\n")) + if err := filter.ValidatePatterns(opts.InsensitiveInclude); err != nil { + return errors.Fatalf("--iinclude: %s", err) } } diff --git a/internal/filter/filter.go b/internal/filter/filter.go index cfacb8cc5..473f1f4cb 100644 --- a/internal/filter/filter.go +++ b/internal/filter/filter.go @@ -220,10 +220,18 @@ func match(pattern Pattern, strs []string) (matched bool, err error) { return false, nil } +type InvalidPatternError struct { + InvalidPatterns []string +} + +func (e *InvalidPatternError) Error() string { + return "invalid pattern(s) provided:\n" + strings.Join(e.InvalidPatterns, "\n") +} + // ValidatePatterns validates a slice of patterns. // Returns true if all patterns are valid - false otherwise, along with the invalid patterns. -func ValidatePatterns(patterns []string) (allValid bool, invalidPatterns []string) { - invalidPatterns = make([]string, 0) +func ValidatePatterns(patterns []string) error { + invalidPatterns := make([]string, 0) for _, Pattern := range ParsePatterns(patterns) { // Validate all pattern parts @@ -238,7 +246,10 @@ func ValidatePatterns(patterns []string) (allValid bool, invalidPatterns []strin } } - return len(invalidPatterns) == 0, invalidPatterns + if len(invalidPatterns) > 0 { + return &InvalidPatternError{InvalidPatterns: invalidPatterns} + } + return nil } // ParsePatterns prepares a list of patterns for use with List. diff --git a/internal/filter/filter_patterns_test.go b/internal/filter/filter_patterns_test.go index 215471500..64599d698 100644 --- a/internal/filter/filter_patterns_test.go +++ b/internal/filter/filter_patterns_test.go @@ -8,7 +8,6 @@ package filter_test import ( - "strings" "testing" "github.com/restic/restic/internal/filter" @@ -18,11 +17,15 @@ import ( func TestValidPatterns(t *testing.T) { // Test invalid patterns are detected and returned t.Run("detect-invalid-patterns", func(t *testing.T) { - allValid, invalidPatterns := filter.ValidatePatterns([]string{"*.foo", "*[._]log[.-][0-9]", "!*[._]log[.-][0-9]"}) + err := filter.ValidatePatterns([]string{"*.foo", "*[._]log[.-][0-9]", "!*[._]log[.-][0-9]"}) - rtest.Assert(t, allValid == false, "Expected invalid patterns to be detected") + rtest.Assert(t, err != nil, "Expected invalid patterns to be detected") - rtest.Equals(t, invalidPatterns, []string{"*[._]log[.-][0-9]", "!*[._]log[.-][0-9]"}) + if ip, ok := err.(*filter.InvalidPatternError); ok { + rtest.Equals(t, ip.InvalidPatterns, []string{"*[._]log[.-][0-9]", "!*[._]log[.-][0-9]"}) + } else { + t.Errorf("wrong error type %v", err) + } }) // Test all patterns defined in matchTests are valid @@ -33,10 +36,10 @@ func TestValidPatterns(t *testing.T) { } t.Run("validate-patterns", func(t *testing.T) { - allValid, invalidPatterns := filter.ValidatePatterns(patterns) + err := filter.ValidatePatterns(patterns) - if !allValid { - t.Errorf("Found invalid pattern(s):\n%s", strings.Join(invalidPatterns, "\n")) + if err != nil { + t.Error(err) } }) @@ -48,10 +51,10 @@ func TestValidPatterns(t *testing.T) { } t.Run("validate-child-patterns", func(t *testing.T) { - allValid, invalidPatterns := filter.ValidatePatterns(childPatterns) + err := filter.ValidatePatterns(childPatterns) - if !allValid { - t.Errorf("Found invalid child pattern(s):\n%s", strings.Join(invalidPatterns, "\n")) + if err != nil { + t.Error(err) } }) }