diff --git a/internal/restic/restorer_test.go b/internal/restic/restorer_test.go new file mode 100644 index 000000000..2f9d18998 --- /dev/null +++ b/internal/restic/restorer_test.go @@ -0,0 +1,312 @@ +package restic_test + +import ( + "bytes" + "context" + "io/ioutil" + "os" + "path/filepath" + "strings" + "testing" + "time" + + "github.com/restic/restic/internal/fs" + "github.com/restic/restic/internal/repository" + "github.com/restic/restic/internal/restic" + rtest "github.com/restic/restic/internal/test" +) + +type Node interface{} + +type Snapshot struct { + Nodes map[string]Node + treeID restic.ID +} + +type File struct { + Data string +} + +type Dir struct { + Nodes map[string]Node +} + +func saveFile(t testing.TB, repo restic.Repository, node File) restic.ID { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + id, err := repo.SaveBlob(ctx, restic.DataBlob, []byte(node.Data), restic.ID{}) + if err != nil { + t.Fatal(err) + } + + return id +} + +func saveDir(t testing.TB, repo restic.Repository, nodes map[string]Node) restic.ID { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + tree := &restic.Tree{} + for name, n := range nodes { + var id restic.ID + switch node := n.(type) { + case File: + id = saveFile(t, repo, node) + tree.Insert(&restic.Node{ + Type: "file", + Mode: 0644, + Name: name, + UID: uint32(os.Getuid()), + GID: uint32(os.Getgid()), + Content: []restic.ID{id}, + }) + case Dir: + id = saveDir(t, repo, node.Nodes) + tree.Insert(&restic.Node{ + Type: "dir", + Mode: 0755, + Name: name, + UID: uint32(os.Getuid()), + GID: uint32(os.Getgid()), + Subtree: &id, + }) + default: + t.Fatalf("unknown node type %T", node) + } + } + + id, err := repo.SaveTree(ctx, tree) + if err != nil { + t.Fatal(err) + } + + return id +} + +func saveSnapshot(t testing.TB, repo restic.Repository, snapshot Snapshot) (restic.Repository, restic.ID) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + treeID := saveDir(t, repo, snapshot.Nodes) + + err := repo.Flush() + if err != nil { + t.Fatal(err) + } + + err = repo.SaveIndex(ctx) + if err != nil { + t.Fatal(err) + } + + sn, err := restic.NewSnapshot([]string{"test"}, nil, "", time.Now()) + if err != nil { + t.Fatal(err) + } + + sn.Tree = &treeID + id, err := repo.SaveJSONUnpacked(ctx, restic.SnapshotFile, sn) + if err != nil { + t.Fatal(err) + } + + return repo, id +} + +// toSlash converts the OS specific path dir to a slash-separated path. +func toSlash(dir string) string { + data := strings.Split(dir, string(filepath.Separator)) + return strings.Join(data, "/") +} + +func TestRestorer(t *testing.T) { + var tests = []struct { + Snapshot + Files map[string]string + ErrorsMust map[string]string + ErrorsMay map[string]string + }{ + // valid test cases + { + Snapshot: Snapshot{ + Nodes: map[string]Node{ + "foo": File{"content: foo\n"}, + "dirtest": Dir{ + Nodes: map[string]Node{ + "file": File{"content: file\n"}, + }, + }, + }, + }, + Files: map[string]string{ + "foo": "content: foo\n", + "dirtest/file": "content: file\n", + }, + }, + { + Snapshot: Snapshot{ + Nodes: map[string]Node{ + "top": File{"toplevel file"}, + "dir": Dir{ + Nodes: map[string]Node{ + "file": File{"file in dir"}, + "subdir": Dir{ + Nodes: map[string]Node{ + "file": File{"file in subdir"}, + }, + }, + }, + }, + }, + }, + Files: map[string]string{ + "top": "toplevel file", + "dir/file": "file in dir", + "dir/subdir/file": "file in subdir", + }, + }, + + // test cases with invalid/constructed names + { + Snapshot: Snapshot{ + Nodes: map[string]Node{ + `..\test`: File{"foo\n"}, + `..\..\foo\..\bar\..\xx\test2`: File{"test2\n"}, + }, + }, + ErrorsMay: map[string]string{ + `/#..\test`: "node has invalid name", + `/#..\..\foo\..\bar\..\xx\test2`: "node has invalid name", + }, + }, + { + Snapshot: Snapshot{ + Nodes: map[string]Node{ + `../test`: File{"foo\n"}, + `../../foo/../bar/../xx/test2`: File{"test2\n"}, + }, + }, + ErrorsMay: map[string]string{ + `/#../test`: "node has invalid name", + `/#../../foo/../bar/../xx/test2`: "node has invalid name", + }, + }, + { + Snapshot: Snapshot{ + Nodes: map[string]Node{ + "top": File{"toplevel file"}, + "x": Dir{ + Nodes: map[string]Node{ + "file1": File{"file1"}, + "..": Dir{ + Nodes: map[string]Node{ + "file2": File{"file2"}, + "..": Dir{ + Nodes: map[string]Node{ + "file2": File{"file2"}, + }, + }, + }, + }, + }, + }, + }, + }, + Files: map[string]string{ + "top": "toplevel file", + }, + ErrorsMust: map[string]string{ + `/x#..`: "node has invalid name", + }, + }, + } + + for _, test := range tests { + t.Run("", func(t *testing.T) { + repo, cleanup := repository.TestRepository(t) + defer cleanup() + _, id := saveSnapshot(t, repo, test.Snapshot) + t.Logf("snapshot saved as %v", id.Str()) + + res, err := restic.NewRestorer(repo, id) + if err != nil { + t.Fatal(err) + } + + tempdir, cleanup := rtest.TempDir(t) + defer cleanup() + + res.SelectFilter = func(item, dstpath string, node *restic.Node) (selectedForRestore bool, childMayBeSelected bool) { + t.Logf("restore %v to %v", item, dstpath) + if !fs.HasPathPrefix(tempdir, dstpath) { + t.Errorf("would restore %v to %v, which is not within the target dir %v", + item, dstpath, tempdir) + return false, false + } + return true, true + } + + errors := make(map[string]string) + res.Error = func(dir string, node *restic.Node, err error) error { + t.Logf("restore returned error for %q in dir %v: %v", node.Name, dir, err) + dir = toSlash(dir) + errors[dir+"#"+node.Name] = err.Error() + return nil + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + err = res.RestoreTo(ctx, tempdir) + if err != nil { + t.Fatal(err) + } + + for filename, errorMessage := range test.ErrorsMust { + msg, ok := errors[filename] + if !ok { + t.Errorf("expected error for %v, found none", filename) + continue + } + + if msg != "" && msg != errorMessage { + t.Errorf("wrong error message for %v: got %q, want %q", + filename, msg, errorMessage) + } + + delete(errors, filename) + } + + for filename, errorMessage := range test.ErrorsMay { + msg, ok := errors[filename] + if !ok { + continue + } + + if msg != "" && msg != errorMessage { + t.Errorf("wrong error message for %v: got %q, want %q", + filename, msg, errorMessage) + } + + delete(errors, filename) + } + + for filename, err := range errors { + t.Errorf("unexpected error for %v found: %v", filename, err) + } + + for filename, content := range test.Files { + data, err := ioutil.ReadFile(filepath.Join(tempdir, filepath.FromSlash(filename))) + if err != nil { + t.Errorf("unable to read file %v: %v", filename, err) + continue + } + + if !bytes.Equal(data, []byte(content)) { + t.Errorf("file %v has wrong content: want %q, got %q", filename, content, data) + } + } + }) + } +}