diff --git a/internal/restorer/filerestorer.go b/internal/restorer/filerestorer.go index fd5b3c5db..56059cb16 100644 --- a/internal/restorer/filerestorer.go +++ b/internal/restorer/filerestorer.go @@ -53,6 +53,8 @@ type fileRestorer struct { sparse bool progress *restore.Progress + allowRecursiveDelete bool + dst string files []*fileInfo Error func(string, error) error @@ -63,21 +65,23 @@ func newFileRestorer(dst string, idx func(restic.BlobType, restic.ID) []restic.PackedBlob, connections uint, sparse bool, + allowRecursiveDelete bool, progress *restore.Progress) *fileRestorer { // as packs are streamed the concurrency is limited by IO workerCount := int(connections) return &fileRestorer{ - idx: idx, - blobsLoader: blobsLoader, - filesWriter: newFilesWriter(workerCount), - zeroChunk: repository.ZeroChunk(), - sparse: sparse, - progress: progress, - workerCount: workerCount, - dst: dst, - Error: restorerAbortOnAllErrors, + idx: idx, + blobsLoader: blobsLoader, + filesWriter: newFilesWriter(workerCount, allowRecursiveDelete), + zeroChunk: repository.ZeroChunk(), + sparse: sparse, + progress: progress, + allowRecursiveDelete: allowRecursiveDelete, + workerCount: workerCount, + dst: dst, + Error: restorerAbortOnAllErrors, } } @@ -207,7 +211,7 @@ func (r *fileRestorer) restoreFiles(ctx context.Context) error { } func (r *fileRestorer) restoreEmptyFileAt(location string) error { - f, err := createFile(r.targetPath(location), 0, false) + f, err := createFile(r.targetPath(location), 0, false, r.allowRecursiveDelete) if err != nil { return err } diff --git a/internal/restorer/filerestorer_test.go b/internal/restorer/filerestorer_test.go index d29c0dcea..f594760e4 100644 --- a/internal/restorer/filerestorer_test.go +++ b/internal/restorer/filerestorer_test.go @@ -144,7 +144,7 @@ func restoreAndVerify(t *testing.T, tempdir string, content []TestFile, files ma t.Helper() repo := newTestRepo(content) - r := newFileRestorer(tempdir, repo.loader, repo.Lookup, 2, sparse, nil) + r := newFileRestorer(tempdir, repo.loader, repo.Lookup, 2, sparse, false, nil) if files == nil { r.files = repo.files @@ -285,7 +285,7 @@ func TestErrorRestoreFiles(t *testing.T) { return loadError } - r := newFileRestorer(tempdir, repo.loader, repo.Lookup, 2, false, nil) + r := newFileRestorer(tempdir, repo.loader, repo.Lookup, 2, false, false, nil) r.files = repo.files err := r.restoreFiles(context.TODO()) @@ -326,7 +326,7 @@ func TestFatalDownloadError(t *testing.T) { }) } - r := newFileRestorer(tempdir, repo.loader, repo.Lookup, 2, false, nil) + r := newFileRestorer(tempdir, repo.loader, repo.Lookup, 2, false, false, nil) r.files = repo.files var errors []string diff --git a/internal/restorer/fileswriter.go b/internal/restorer/fileswriter.go index 034ed2725..962f66619 100644 --- a/internal/restorer/fileswriter.go +++ b/internal/restorer/fileswriter.go @@ -19,7 +19,8 @@ import ( // TODO I am not 100% convinced this is necessary, i.e. it may be okay // to use multiple os.File to write to the same target file type filesWriter struct { - buckets []filesWriterBucket + buckets []filesWriterBucket + allowRecursiveDelete bool } type filesWriterBucket struct { @@ -33,13 +34,14 @@ type partialFile struct { sparse bool } -func newFilesWriter(count int) *filesWriter { +func newFilesWriter(count int, allowRecursiveDelete bool) *filesWriter { buckets := make([]filesWriterBucket, count) for b := 0; b < count; b++ { buckets[b].files = make(map[string]*partialFile) } return &filesWriter{ - buckets: buckets, + buckets: buckets, + allowRecursiveDelete: allowRecursiveDelete, } } @@ -60,7 +62,7 @@ func openFile(path string) (*os.File, error) { return f, nil } -func createFile(path string, createSize int64, sparse bool) (*os.File, error) { +func createFile(path string, createSize int64, sparse bool, allowRecursiveDelete bool) (*os.File, error) { f, err := fs.OpenFile(path, fs.O_CREATE|fs.O_WRONLY|fs.O_NOFOLLOW, 0600) if err != nil && fs.IsAccessDenied(err) { // If file is readonly, clear the readonly flag by resetting the @@ -109,8 +111,14 @@ func createFile(path string, createSize int64, sparse bool) (*os.File, error) { } // not what we expected, try to get rid of it - if err := fs.Remove(path); err != nil { - return nil, err + if allowRecursiveDelete { + if err := fs.RemoveAll(path); err != nil { + return nil, err + } + } else { + if err := fs.Remove(path); err != nil { + return nil, err + } } // create a new file, pass O_EXCL to make sure there are no surprises f, err = fs.OpenFile(path, fs.O_CREATE|fs.O_WRONLY|fs.O_EXCL|fs.O_NOFOLLOW, 0600) @@ -169,7 +177,7 @@ func (w *filesWriter) writeToFile(path string, blob []byte, offset int64, create var f *os.File var err error if createSize >= 0 { - f, err = createFile(path, createSize, sparse) + f, err = createFile(path, createSize, sparse, w.allowRecursiveDelete) if err != nil { return nil, err } diff --git a/internal/restorer/fileswriter_test.go b/internal/restorer/fileswriter_test.go index 383a9e0d7..c69847927 100644 --- a/internal/restorer/fileswriter_test.go +++ b/internal/restorer/fileswriter_test.go @@ -13,7 +13,7 @@ import ( func TestFilesWriterBasic(t *testing.T) { dir := rtest.TempDir(t) - w := newFilesWriter(1) + w := newFilesWriter(1, false) f1 := dir + "/f1" f2 := dir + "/f2" @@ -39,6 +39,29 @@ func TestFilesWriterBasic(t *testing.T) { rtest.Equals(t, []byte{2, 2}, buf) } +func TestFilesWriterRecursiveOverwrite(t *testing.T) { + path := filepath.Join(t.TempDir(), "test") + + // create filled directory + rtest.OK(t, os.Mkdir(path, 0o700)) + rtest.OK(t, os.WriteFile(filepath.Join(path, "file"), []byte("data"), 0o400)) + + // must error if recursive delete is not allowed + w := newFilesWriter(1, false) + err := w.writeToFile(path, []byte{1}, 0, 2, false) + rtest.Assert(t, errors.Is(err, notEmptyDirError()), "unexepected error got %v", err) + rtest.Equals(t, 0, len(w.buckets[0].files)) + + // must replace directory + w = newFilesWriter(1, true) + rtest.OK(t, w.writeToFile(path, []byte{1, 1}, 0, 2, false)) + rtest.Equals(t, 0, len(w.buckets[0].files)) + + buf, err := os.ReadFile(path) + rtest.OK(t, err) + rtest.Equals(t, []byte{1, 1}, buf) +} + func TestCreateFile(t *testing.T) { basepath := filepath.Join(t.TempDir(), "test") @@ -110,7 +133,7 @@ func TestCreateFile(t *testing.T) { for j, test := range tests { path := basepath + fmt.Sprintf("%v%v", i, j) sc.create(t, path) - f, err := createFile(path, test.size, test.isSparse) + f, err := createFile(path, test.size, test.isSparse, false) if sc.err == nil { rtest.OK(t, err) fi, err := f.Stat() @@ -129,3 +152,19 @@ func TestCreateFile(t *testing.T) { }) } } + +func TestCreateFileRecursiveDelete(t *testing.T) { + path := filepath.Join(t.TempDir(), "test") + + // create filled directory + rtest.OK(t, os.Mkdir(path, 0o700)) + rtest.OK(t, os.WriteFile(filepath.Join(path, "file"), []byte("data"), 0o400)) + + // replace it + f, err := createFile(path, 42, false, true) + rtest.OK(t, err) + fi, err := f.Stat() + rtest.OK(t, err) + rtest.Assert(t, fi.Mode().IsRegular(), "wrong filetype %v", fi.Mode()) + rtest.OK(t, f.Close()) +} diff --git a/internal/restorer/restorer.go b/internal/restorer/restorer.go index 6e81812c2..9efaa64df 100644 --- a/internal/restorer/restorer.go +++ b/internal/restorer/restorer.go @@ -349,7 +349,7 @@ func (res *Restorer) RestoreTo(ctx context.Context, dst string) error { idx := NewHardlinkIndex[string]() filerestorer := newFileRestorer(dst, res.repo.LoadBlobsFromPack, res.repo.LookupBlob, - res.repo.Connections(), res.opts.Sparse, res.opts.Progress) + res.repo.Connections(), res.opts.Sparse, res.opts.Delete, res.opts.Progress) filerestorer.Error = res.Error debug.Log("first pass for %q", dst) diff --git a/internal/restorer/restorer_test.go b/internal/restorer/restorer_test.go index 8a8f81ce0..3d2323d0f 100644 --- a/internal/restorer/restorer_test.go +++ b/internal/restorer/restorer_test.go @@ -1214,6 +1214,27 @@ func TestRestoreDryRun(t *testing.T) { rtest.Assert(t, errors.Is(err, os.ErrNotExist), "expected no file to be created, got %v", err) } +func TestRestoreOverwriteDirectory(t *testing.T) { + saveSnapshotsAndOverwrite(t, + Snapshot{ + Nodes: map[string]Node{ + "dir": Dir{ + Mode: normalizeFileMode(0755 | os.ModeDir), + Nodes: map[string]Node{ + "anotherfile": File{Data: "content: file\n"}, + }, + }, + }, + }, + Snapshot{ + Nodes: map[string]Node{ + "dir": File{Data: "content: file\n"}, + }, + }, + Options{Delete: true}, + ) +} + func TestRestoreDelete(t *testing.T) { repo := repository.TestRepository(t) tempdir := rtest.TempDir(t)