diff --git a/internal/walker/testing.go b/internal/walker/testing.go new file mode 100644 index 000000000..c06778242 --- /dev/null +++ b/internal/walker/testing.go @@ -0,0 +1 @@ +package walker diff --git a/internal/walker/walker.go b/internal/walker/walker.go new file mode 100644 index 000000000..df2573c57 --- /dev/null +++ b/internal/walker/walker.go @@ -0,0 +1,134 @@ +package walker + +import ( + "context" + "path" + "sort" + + "github.com/pkg/errors" + + "github.com/restic/restic/internal/restic" +) + +// TreeLoader loads a tree from a repository. +type TreeLoader interface { + LoadTree(context.Context, restic.ID) (*restic.Tree, error) +} + +// SkipNode is returned by WalkFunc when a dir node should not be walked. +var SkipNode = errors.New("skip this node") + +// WalkFunc is the type of the function called for each node visited by Walk. +// Path is the slash-separated path from the root node. If there was a problem +// loading a node, err is set to a non-nil error. WalkFunc can chose to ignore +// it by returning nil. +// +// When the special value SkipNode is returned and node is a dir node, it is +// not walked. When the node is not a dir node, the remaining items in this +// tree are skipped. +// +// Setting ignore to true tells Walk that it should not visit the node again. +// For tree nodes, this means that the function is not called for the +// referenced tree. If the node is not a tree, and all nodes in the current +// tree have ignore set to true, the current tree will not be visited again. +// When err is not nil and different from SkipNode, the value returned for +// ignore is ignored. +type WalkFunc func(path string, node *restic.Node, nodeErr error) (ignore bool, err error) + +// Walk calls walkFn recursively for each node in root. If walkFn returns an +// error, it is passed up the call stack. The trees in ignoreTrees are not +// walked. If walkFn ignores trees, these are added to the set. +func Walk(ctx context.Context, repo TreeLoader, root restic.ID, ignoreTrees restic.IDSet, walkFn WalkFunc) error { + tree, err := repo.LoadTree(ctx, root) + _, err = walkFn("/", nil, err) + + if err != nil { + if err == SkipNode { + err = nil + } + return err + } + + _, err = walk(ctx, repo, "/", tree, ignoreTrees, walkFn) + return err +} + +// walk recursively traverses the tree, ignoring subtrees when the ID of the +// subtree is in ignoreTrees. If err is nil and ignore is true, the subtree ID +// will be added to ignoreTrees by walk. +func walk(ctx context.Context, repo TreeLoader, prefix string, tree *restic.Tree, ignoreTrees restic.IDSet, walkFn WalkFunc) (ignore bool, err error) { + var allNodesIgnored = true + + sort.Slice(tree.Nodes, func(i, j int) bool { + return tree.Nodes[i].Name < tree.Nodes[j].Name + }) + + for _, node := range tree.Nodes { + p := path.Join(prefix, node.Name) + + if node.Type == "" { + return false, errors.Errorf("node type is empty for node %q", node.Name) + } + + if node.Type != "dir" { + ignore, err := walkFn(p, node, nil) + if err != nil { + if err == SkipNode { + // skip the remaining entries in this tree + return allNodesIgnored, nil + } + + return false, err + } + + if ignore == false { + allNodesIgnored = false + } + + continue + } + + if node.Subtree == nil { + return false, errors.Errorf("subtree for node %v in tree %v is nil", node.Name, p) + } + + if ignoreTrees.Has(*node.Subtree) { + continue + } + + subtree, err := repo.LoadTree(ctx, *node.Subtree) + ignore, err := walkFn(p, node, err) + if err != nil { + if err == SkipNode { + if ignore { + ignoreTrees.Insert(*node.Subtree) + } + continue + } + return false, err + } + + if ignore { + ignoreTrees.Insert(*node.Subtree) + } + + if !ignore { + allNodesIgnored = false + } + + ignore, err = walk(ctx, repo, p, subtree, ignoreTrees, walkFn) + if err != nil { + return false, err + } + + if ignore { + ignoreTrees.Insert(*node.Subtree) + } + + if !ignore { + allNodesIgnored = false + } + } + + return allNodesIgnored, nil +} diff --git a/internal/walker/walker_test.go b/internal/walker/walker_test.go new file mode 100644 index 000000000..08b4fe405 --- /dev/null +++ b/internal/walker/walker_test.go @@ -0,0 +1,423 @@ +package walker + +import ( + "context" + "encoding/json" + "fmt" + "testing" + + "github.com/pkg/errors" + "github.com/restic/restic/internal/restic" +) + +// TestTree is used to construct a list of trees for testing the walker. +type TestTree map[string]interface{} + +// TestNode is used to test the walker. +type TestFile struct{} + +func BuildTreeMap(tree TestTree) (m TreeMap, root restic.ID) { + m = TreeMap{} + id := buildTreeMap(tree, m) + return m, id +} + +func buildTreeMap(tree TestTree, m TreeMap) restic.ID { + res := restic.NewTree() + + for name, item := range tree { + switch elem := item.(type) { + case TestFile: + res.Insert(&restic.Node{ + Name: name, + Type: "file", + }) + case TestTree: + id := buildTreeMap(elem, m) + res.Insert(&restic.Node{ + Name: name, + Subtree: &id, + Type: "dir", + }) + default: + panic(fmt.Sprintf("invalid type %T", elem)) + } + } + + buf, err := json.Marshal(res) + if err != nil { + panic(err) + } + + id := restic.Hash(buf) + + if _, ok := m[id]; !ok { + m[id] = res + } + + return id +} + +// TreeMap returns the trees from the map on LoadTree. +type TreeMap map[restic.ID]*restic.Tree + +func (t TreeMap) LoadTree(ctx context.Context, id restic.ID) (*restic.Tree, error) { + tree, ok := t[id] + if !ok { + return nil, errors.New("tree not found") + } + + return tree, nil +} + +// checkFunc returns a function suitable for walking the tree to check +// something, and a function which will check the final result. +type checkFunc func(t testing.TB) (walker WalkFunc, final func(testing.TB)) + +// checkItemOrder ensures that the order of the 'path' arguments is the one passed in as 'want'. +func checkItemOrder(want []string) checkFunc { + pos := 0 + return func(t testing.TB) (walker WalkFunc, final func(testing.TB)) { + walker = func(path string, node *restic.Node, err error) (bool, error) { + if err != nil { + t.Errorf("error walking %v: %v", path, err) + return false, err + } + + if pos >= len(want) { + t.Errorf("additional unexpected path found: %v", path) + return false, nil + } + + if path != want[pos] { + t.Errorf("wrong path found, want %q, got %q", want[pos], path) + } + pos++ + return false, nil + } + + final = func(t testing.TB) { + if pos != len(want) { + t.Errorf("not enough items returned, want %d, got %d", len(want), pos) + } + } + + return walker, final + } +} + +// checkSkipFor returns SkipNode if path is in skipFor, it checks that the +// paths the walk func is called for are exactly the ones in wantPaths. +func checkSkipFor(skipFor map[string]struct{}, wantPaths []string) checkFunc { + var pos int + + return func(t testing.TB) (walker WalkFunc, final func(testing.TB)) { + walker = func(path string, node *restic.Node, err error) (bool, error) { + if err != nil { + t.Errorf("error walking %v: %v", path, err) + return false, err + } + + if pos >= len(wantPaths) { + t.Errorf("additional unexpected path found: %v", path) + return false, nil + } + + if path != wantPaths[pos] { + t.Errorf("wrong path found, want %q, got %q", wantPaths[pos], path) + } + pos++ + + if _, ok := skipFor[path]; ok { + return false, SkipNode + } + + return false, nil + } + + final = func(t testing.TB) { + if pos != len(wantPaths) { + t.Errorf("wrong number of paths returned, want %d, got %d", len(wantPaths), pos) + } + } + + return walker, final + } +} + +// checkIgnore returns SkipNode if path is in skipFor and sets ignore according +// to ignoreFor. It checks that the paths the walk func is called for are exactly +// the ones in wantPaths. +func checkIgnore(skipFor map[string]struct{}, ignoreFor map[string]bool, wantPaths []string) checkFunc { + var pos int + + return func(t testing.TB) (walker WalkFunc, final func(testing.TB)) { + walker = func(path string, node *restic.Node, err error) (bool, error) { + if err != nil { + t.Errorf("error walking %v: %v", path, err) + return false, err + } + + if pos >= len(wantPaths) { + t.Errorf("additional unexpected path found: %v", path) + return ignoreFor[path], nil + } + + if path != wantPaths[pos] { + t.Errorf("wrong path found, want %q, got %q", wantPaths[pos], path) + } + pos++ + + if _, ok := skipFor[path]; ok { + return ignoreFor[path], SkipNode + } + + return ignoreFor[path], nil + } + + final = func(t testing.TB) { + if pos != len(wantPaths) { + t.Errorf("wrong number of paths returned, want %d, got %d", len(wantPaths), pos) + } + } + + return walker, final + } +} + +func TestWalker(t *testing.T) { + var tests = []struct { + tree TestTree + checks []checkFunc + }{ + { + tree: TestTree{ + "foo": TestFile{}, + "subdir": TestTree{ + "subfile": TestFile{}, + }, + }, + checks: []checkFunc{ + checkItemOrder([]string{ + "/", + "/foo", + "/subdir", + "/subdir/subfile", + }), + checkSkipFor( + map[string]struct{}{ + "/subdir": struct{}{}, + }, []string{ + "/", + "/foo", + "/subdir", + }, + ), + checkIgnore( + map[string]struct{}{}, map[string]bool{ + "/subdir": true, + }, []string{ + "/", + "/foo", + "/subdir", + "/subdir/subfile", + }, + ), + }, + }, + { + tree: TestTree{ + "foo": TestFile{}, + "subdir1": TestTree{ + "subfile1": TestFile{}, + }, + "subdir2": TestTree{ + "subfile2": TestFile{}, + "subsubdir2": TestTree{ + "subsubfile3": TestFile{}, + }, + }, + }, + checks: []checkFunc{ + checkItemOrder([]string{ + "/", + "/foo", + "/subdir1", + "/subdir1/subfile1", + "/subdir2", + "/subdir2/subfile2", + "/subdir2/subsubdir2", + "/subdir2/subsubdir2/subsubfile3", + }), + checkSkipFor( + map[string]struct{}{ + "/subdir1": struct{}{}, + }, []string{ + "/", + "/foo", + "/subdir1", + "/subdir2", + "/subdir2/subfile2", + "/subdir2/subsubdir2", + "/subdir2/subsubdir2/subsubfile3", + }, + ), + checkSkipFor( + map[string]struct{}{ + "/subdir1": struct{}{}, + "/subdir2/subsubdir2": struct{}{}, + }, []string{ + "/", + "/foo", + "/subdir1", + "/subdir2", + "/subdir2/subfile2", + "/subdir2/subsubdir2", + }, + ), + checkSkipFor( + map[string]struct{}{ + "/foo": struct{}{}, + }, []string{ + "/", + "/foo", + }, + ), + }, + }, + { + tree: TestTree{ + "foo": TestFile{}, + "subdir1": TestTree{ + "subfile1": TestFile{}, + "subfile2": TestFile{}, + "subfile3": TestFile{}, + }, + "subdir2": TestTree{ + "subfile1": TestFile{}, + "subfile2": TestFile{}, + "subfile3": TestFile{}, + }, + "subdir3": TestTree{ + "subfile1": TestFile{}, + "subfile2": TestFile{}, + "subfile3": TestFile{}, + }, + "zzz other": TestFile{}, + }, + checks: []checkFunc{ + checkItemOrder([]string{ + "/", + "/foo", + "/subdir1", + "/subdir1/subfile1", + "/subdir1/subfile2", + "/subdir1/subfile3", + "/subdir2", + "/subdir2/subfile1", + "/subdir2/subfile2", + "/subdir2/subfile3", + "/subdir3", + "/subdir3/subfile1", + "/subdir3/subfile2", + "/subdir3/subfile3", + "/zzz other", + }), + checkIgnore( + map[string]struct{}{ + "/subdir1": struct{}{}, + }, map[string]bool{ + "/subdir1": true, + }, []string{ + "/", + "/foo", + "/subdir1", + "/zzz other", + }, + ), + checkIgnore( + map[string]struct{}{}, map[string]bool{ + "/subdir1": true, + }, []string{ + "/", + "/foo", + "/subdir1", + "/subdir1/subfile1", + "/subdir1/subfile2", + "/subdir1/subfile3", + "/zzz other", + }, + ), + checkIgnore( + map[string]struct{}{ + "/subdir2": struct{}{}, + }, map[string]bool{ + "/subdir2": true, + }, []string{ + "/", + "/foo", + "/subdir1", + "/subdir1/subfile1", + "/subdir1/subfile2", + "/subdir1/subfile3", + "/subdir2", + "/zzz other", + }, + ), + checkIgnore( + map[string]struct{}{}, map[string]bool{ + "/subdir1/subfile1": true, + "/subdir1/subfile2": true, + "/subdir1/subfile3": true, + }, []string{ + "/", + "/foo", + "/subdir1", + "/subdir1/subfile1", + "/subdir1/subfile2", + "/subdir1/subfile3", + "/zzz other", + }, + ), + checkIgnore( + map[string]struct{}{}, map[string]bool{ + "/subdir2/subfile1": true, + "/subdir2/subfile2": true, + "/subdir2/subfile3": true, + }, []string{ + "/", + "/foo", + "/subdir1", + "/subdir1/subfile1", + "/subdir1/subfile2", + "/subdir1/subfile3", + "/subdir2", + "/subdir2/subfile1", + "/subdir2/subfile2", + "/subdir2/subfile3", + "/zzz other", + }, + ), + }, + }, + } + + for _, test := range tests { + t.Run("", func(t *testing.T) { + repo, root := BuildTreeMap(test.tree) + for _, check := range test.checks { + t.Run("", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.TODO()) + defer cancel() + + fn, last := check(t) + err := Walk(ctx, repo, root, restic.NewIDSet(), fn) + if err != nil { + t.Error(err) + } + last(t) + }) + } + }) + } +}