diff --git a/cmd/restic/cmd_rewrite.go b/cmd/restic/cmd_rewrite.go index 4019d9264..a60fdc8fc 100644 --- a/cmd/restic/cmd_rewrite.go +++ b/cmd/restic/cmd_rewrite.go @@ -87,12 +87,19 @@ func rewriteSnapshot(ctx context.Context, repo *repository.Repository, sn *resti return true } + rewriter := walker.NewTreeRewriter(walker.RewriteOpts{ + RewriteNode: func(node *restic.Node, path string) *restic.Node { + if selectByName(path) { + return node + } + Verbosef(fmt.Sprintf("excluding %s\n", path)) + return nil + }, + }) + return filterAndReplaceSnapshot(ctx, repo, sn, func(ctx context.Context, sn *restic.Snapshot) (restic.ID, error) { - return walker.FilterTree(ctx, repo, "/", *sn.Tree, &walker.TreeFilterVisitor{ - SelectByName: selectByName, - PrintExclude: func(path string) { Verbosef(fmt.Sprintf("excluding %s\n", path)) }, - }) + return rewriter.RewriteTree(ctx, repo, "/", *sn.Tree) }, opts.DryRun, opts.Forget, "rewrite") } diff --git a/internal/walker/rewriter.go b/internal/walker/rewriter.go index bef3bd688..48f16a53a 100644 --- a/internal/walker/rewriter.go +++ b/internal/walker/rewriter.go @@ -9,13 +9,28 @@ import ( "github.com/restic/restic/internal/restic" ) -// SelectByNameFunc returns true for all items that should be included (files and -// dirs). If false is returned, files are ignored and dirs are not even walked. -type SelectByNameFunc func(item string) bool +type NodeRewriteFunc func(node *restic.Node, path string) *restic.Node -type TreeFilterVisitor struct { - SelectByName SelectByNameFunc - PrintExclude func(string) +type RewriteOpts struct { + // return nil to remove the node + RewriteNode NodeRewriteFunc +} + +type TreeRewriter struct { + opts RewriteOpts +} + +func NewTreeRewriter(opts RewriteOpts) *TreeRewriter { + rw := &TreeRewriter{ + opts: opts, + } + // setup default implementations + if rw.opts.RewriteNode == nil { + rw.opts.RewriteNode = func(node *restic.Node, path string) *restic.Node { + return node + } + } + return rw } type BlobLoadSaver interface { @@ -23,7 +38,7 @@ type BlobLoadSaver interface { restic.BlobLoader } -func FilterTree(ctx context.Context, repo BlobLoadSaver, nodepath string, nodeID restic.ID, visitor *TreeFilterVisitor) (newNodeID restic.ID, err error) { +func (t *TreeRewriter) RewriteTree(ctx context.Context, repo BlobLoadSaver, nodepath string, nodeID restic.ID) (newNodeID restic.ID, err error) { curTree, err := restic.LoadTree(ctx, repo, nodeID) if err != nil { return restic.ID{}, err @@ -45,10 +60,8 @@ func FilterTree(ctx context.Context, repo BlobLoadSaver, nodepath string, nodeID tb := restic.NewTreeJSONBuilder() for _, node := range curTree.Nodes { path := path.Join(nodepath, node.Name) - if !visitor.SelectByName(path) { - if visitor.PrintExclude != nil { - visitor.PrintExclude(path) - } + node = t.opts.RewriteNode(node, path) + if node == nil { continue } @@ -59,7 +72,7 @@ func FilterTree(ctx context.Context, repo BlobLoadSaver, nodepath string, nodeID } continue } - newID, err := FilterTree(ctx, repo, path, *node.Subtree, visitor) + newID, err := t.RewriteTree(ctx, repo, path, *node.Subtree) if err != nil { return restic.ID{}, err } diff --git a/internal/walker/rewriter_test.go b/internal/walker/rewriter_test.go index 3dcf0ac9e..8f99fe9bd 100644 --- a/internal/walker/rewriter_test.go +++ b/internal/walker/rewriter_test.go @@ -5,7 +5,6 @@ import ( "fmt" "testing" - "github.com/google/go-cmp/cmp" "github.com/pkg/errors" "github.com/restic/restic/internal/restic" ) @@ -38,26 +37,26 @@ func (t WritableTreeMap) Dump() { } } -type checkRewriteFunc func(t testing.TB) (visitor TreeFilterVisitor, final func(testing.TB)) +type checkRewriteFunc func(t testing.TB) (rewriter *TreeRewriter, final func(testing.TB)) // checkRewriteItemOrder ensures that the order of the 'path' arguments is the one passed in as 'want'. func checkRewriteItemOrder(want []string) checkRewriteFunc { pos := 0 - return func(t testing.TB) (visitor TreeFilterVisitor, final func(testing.TB)) { - vis := TreeFilterVisitor{ - SelectByName: func(path string) bool { + return func(t testing.TB) (rewriter *TreeRewriter, final func(testing.TB)) { + rewriter = NewTreeRewriter(RewriteOpts{ + RewriteNode: func(node *restic.Node, path string) *restic.Node { if pos >= len(want) { t.Errorf("additional unexpected path found: %v", path) - return false + return nil } if path != want[pos] { t.Errorf("wrong path found, want %q, got %q", want[pos], path) } pos++ - return true + return node }, - } + }) final = func(t testing.TB) { if pos != len(want) { @@ -65,21 +64,20 @@ func checkRewriteItemOrder(want []string) checkRewriteFunc { } } - return vis, final + return rewriter, final } } -// checkRewriteSkips excludes nodes if path is in skipFor, it checks that all excluded entries are printed. +// checkRewriteSkips excludes nodes if path is in skipFor, it checks that rewriting proceedes in the correct order. func checkRewriteSkips(skipFor map[string]struct{}, want []string) checkRewriteFunc { var pos int - printed := make(map[string]struct{}) - return func(t testing.TB) (visitor TreeFilterVisitor, final func(testing.TB)) { - vis := TreeFilterVisitor{ - SelectByName: func(path string) bool { + return func(t testing.TB) (rewriter *TreeRewriter, final func(testing.TB)) { + rewriter = NewTreeRewriter(RewriteOpts{ + RewriteNode: func(node *restic.Node, path string) *restic.Node { if pos >= len(want) { t.Errorf("additional unexpected path found: %v", path) - return false + return nil } if path != want[pos] { @@ -87,27 +85,39 @@ func checkRewriteSkips(skipFor map[string]struct{}, want []string) checkRewriteF } pos++ - _, ok := skipFor[path] - return !ok - }, - PrintExclude: func(s string) { - if _, ok := printed[s]; ok { - t.Errorf("path was already printed %v", s) + _, skip := skipFor[path] + if skip { + return nil } - printed[s] = struct{}{} + return node }, - } + }) final = func(t testing.TB) { - if !cmp.Equal(skipFor, printed) { - t.Errorf("unexpected paths skipped: %s", cmp.Diff(skipFor, printed)) - } if pos != len(want) { t.Errorf("not enough items returned, want %d, got %d", len(want), pos) } } - return vis, final + return rewriter, final + } +} + +// checkIncreaseNodeSize modifies each node by changing its size. +func checkIncreaseNodeSize(increase uint64) checkRewriteFunc { + return func(t testing.TB) (rewriter *TreeRewriter, final func(testing.TB)) { + rewriter = NewTreeRewriter(RewriteOpts{ + RewriteNode: func(node *restic.Node, path string) *restic.Node { + if node.Type == "file" { + node.Size += increase + } + return node + }, + }) + + final = func(t testing.TB) {} + + return rewriter, final } } @@ -172,6 +182,21 @@ func TestRewriter(t *testing.T) { }, ), }, + { // modify node + tree: TestTree{ + "foo": TestFile{Size: 21}, + "subdir": TestTree{ + "subfile": TestFile{Size: 21}, + }, + }, + newTree: TestTree{ + "foo": TestFile{Size: 42}, + "subdir": TestTree{ + "subfile": TestFile{Size: 42}, + }, + }, + check: checkIncreaseNodeSize(21), + }, } for _, test := range tests { @@ -186,8 +211,8 @@ func TestRewriter(t *testing.T) { ctx, cancel := context.WithCancel(context.TODO()) defer cancel() - vis, last := test.check(t) - newRoot, err := FilterTree(ctx, modrepo, "/", root, &vis) + rewriter, last := test.check(t) + newRoot, err := rewriter.RewriteTree(ctx, modrepo, "/", root) if err != nil { t.Error(err) } @@ -213,8 +238,15 @@ func TestRewriterFailOnUnknownFields(t *testing.T) { ctx, cancel := context.WithCancel(context.TODO()) defer cancel() - // use nil visitor to crash if the tree loading works unexpectedly - _, err := FilterTree(ctx, tm, "/", id, nil) + + rewriter := NewTreeRewriter(RewriteOpts{ + RewriteNode: func(node *restic.Node, path string) *restic.Node { + // tree loading must not succeed + t.Fail() + return node + }, + }) + _, err := rewriter.RewriteTree(ctx, tm, "/", id) if err == nil { t.Error("missing error on unknown field") diff --git a/internal/walker/walker_test.go b/internal/walker/walker_test.go index 6c4fd3436..8de1a9dc4 100644 --- a/internal/walker/walker_test.go +++ b/internal/walker/walker_test.go @@ -14,7 +14,9 @@ import ( type TestTree map[string]interface{} // TestNode is used to test the walker. -type TestFile struct{} +type TestFile struct { + Size uint64 +} func BuildTreeMap(tree TestTree) (m TreeMap, root restic.ID) { m = TreeMap{} @@ -37,6 +39,7 @@ func buildTreeMap(tree TestTree, m TreeMap) restic.ID { err := tb.AddNode(&restic.Node{ Name: name, Type: "file", + Size: elem.Size, }) if err != nil { panic(err)