diff --git a/internal/restorer/restorer_test.go b/internal/restorer/restorer_test.go index 64e10a6e8..635c30ee6 100644 --- a/internal/restorer/restorer_test.go +++ b/internal/restorer/restorer_test.go @@ -893,3 +893,92 @@ func TestRestorerSparseFiles(t *testing.T) { t.Logf("wrote %d zeros as %d blocks, %.1f%% sparse", len(zeros), blocks, 100*sparsity) } + +func TestRestorerOverwriteBehavior(t *testing.T) { + baseTime := time.Now() + baseSnapshot := Snapshot{ + Nodes: map[string]Node{ + "foo": File{Data: "content: foo\n", ModTime: baseTime}, + "dirtest": Dir{ + Nodes: map[string]Node{ + "file": File{Data: "content: file\n", ModTime: baseTime}, + }, + ModTime: baseTime, + }, + }, + } + overwriteSnapshot := Snapshot{ + Nodes: map[string]Node{ + "foo": File{Data: "content: new\n", ModTime: baseTime.Add(time.Second)}, + "dirtest": Dir{ + Nodes: map[string]Node{ + "file": File{Data: "content: file2\n", ModTime: baseTime.Add(-time.Second)}, + }, + }, + }, + } + + var tests = []struct { + Overwrite OverwriteBehavior + Files map[string]string + }{ + { + Overwrite: OverwriteAlways, + Files: map[string]string{ + "foo": "content: new\n", + "dirtest/file": "content: file2\n", + }, + }, + { + Overwrite: OverwriteIfNewer, + Files: map[string]string{ + "foo": "content: new\n", + "dirtest/file": "content: file\n", + }, + }, + { + Overwrite: OverwriteNever, + Files: map[string]string{ + "foo": "content: foo\n", + "dirtest/file": "content: file\n", + }, + }, + } + + for _, test := range tests { + t.Run("", func(t *testing.T) { + repo := repository.TestRepository(t) + tempdir := filepath.Join(rtest.TempDir(t), "target") + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + // base snapshot + sn, id := saveSnapshot(t, repo, baseSnapshot, noopGetGenericAttributes) + t.Logf("base snapshot saved as %v", id.Str()) + + res := NewRestorer(repo, sn, Options{}) + rtest.OK(t, res.RestoreTo(ctx, tempdir)) + + // overwrite snapshot + sn, id = saveSnapshot(t, repo, overwriteSnapshot, noopGetGenericAttributes) + t.Logf("overwrite snapshot saved as %v", id.Str()) + res = NewRestorer(repo, sn, Options{Overwrite: test.Overwrite}) + rtest.OK(t, res.RestoreTo(ctx, tempdir)) + + _, err := res.VerifyFiles(ctx, tempdir) + rtest.OK(t, err) + + for filename, content := range test.Files { + data, err := os.ReadFile(filepath.Join(tempdir, filepath.FromSlash(filename))) + if err != nil { + t.Errorf("unable to read file %v: %v", filename, err) + continue + } + + if !bytes.Equal(data, []byte(content)) { + t.Errorf("file %v has wrong content: want %q, got %q", filename, content, data) + } + } + }) + } +}