diff --git a/internal/restorer/filerestorer.go b/internal/restorer/filerestorer.go index 92ba39e77..bdb3ff708 100644 --- a/internal/restorer/filerestorer.go +++ b/internal/restorer/filerestorer.go @@ -9,6 +9,8 @@ import ( "sort" "sync" + "golang.org/x/sync/errgroup" + "github.com/restic/restic/internal/crypto" "github.com/restic/restic/internal/debug" "github.com/restic/restic/internal/errors" @@ -23,20 +25,16 @@ import ( const ( workerCount = 8 - // fileInfo flags - fileProgress = 1 - fileError = 2 - largeFileBlobCount = 25 ) // information about regular file being restored type fileInfo struct { - lock sync.Mutex - flags int - size int64 - location string // file on local filesystem relative to restorer basedir - blobs interface{} // blobs of the file + lock sync.Mutex + inProgress bool + size int64 + location string // file on local filesystem relative to restorer basedir + blobs interface{} // blobs of the file } type fileBlobInfo struct { @@ -60,6 +58,7 @@ type fileRestorer struct { dst string files []*fileInfo + Error func(string, error) error } func newFileRestorer(dst string, @@ -73,6 +72,7 @@ func newFileRestorer(dst string, packLoader: packLoader, filesWriter: newFilesWriter(workerCount), dst: dst, + Error: restorerAbortOnAllErrors, } } @@ -142,47 +142,42 @@ func (r *fileRestorer) restoreFiles(ctx context.Context) error { } } - var wg sync.WaitGroup + wg, ctx := errgroup.WithContext(ctx) downloadCh := make(chan *packInfo) - worker := func() { - defer wg.Done() - for { - select { - case <-ctx.Done(): - return // context cancelled - case pack, ok := <-downloadCh: - if !ok { - return // channel closed - } - r.downloadPack(ctx, pack) + + worker := func() error { + for pack := range downloadCh { + if err := r.downloadPack(ctx, pack); err != nil { + return err } } + return nil } for i := 0; i < workerCount; i++ { - wg.Add(1) - go worker() + wg.Go(worker) } // the main restore loop - for _, id := range packOrder { - pack := packs[id] - select { - case <-ctx.Done(): - return ctx.Err() - case downloadCh <- pack: - debug.Log("Scheduled download pack %s", pack.id.Str()) + wg.Go(func() error { + for _, id := range packOrder { + pack := packs[id] + select { + case <-ctx.Done(): + return ctx.Err() + case downloadCh <- pack: + debug.Log("Scheduled download pack %s", pack.id.Str()) + } } - } + close(downloadCh) + return nil + }) - close(downloadCh) - wg.Wait() - - return nil + return wg.Wait() } const maxBufferSize = 4 * 1024 * 1024 -func (r *fileRestorer) downloadPack(ctx context.Context, pack *packInfo) { +func (r *fileRestorer) downloadPack(ctx context.Context, pack *packInfo) error { // calculate pack byte range and blob->[]files->[]offsets mappings start, end := int64(math.MaxInt64), int64(0) @@ -237,12 +232,11 @@ func (r *fileRestorer) downloadPack(ctx context.Context, pack *packInfo) { return blobs[sortedBlobs[i]].offset < blobs[sortedBlobs[j]].offset }) - markFileError := func(file *fileInfo, err error) { - file.lock.Lock() - defer file.lock.Unlock() - if file.flags&fileError == 0 { - file.flags |= fileError + sanitizeError := func(file *fileInfo, err error) error { + if err != nil { + err = r.Error(file.location, err) } + return err } h := restic.Handle{Type: restic.PackFile, Name: pack.id.String()} @@ -263,7 +257,9 @@ func (r *fileRestorer) downloadPack(ctx context.Context, pack *packInfo) { blobData, buf, err = r.loadBlob(bufRd, blobID, blob.length, buf) if err != nil { for file := range blob.files { - markFileError(file, err) + if errFile := sanitizeError(file, err); errFile != nil { + return errFile + } } continue } @@ -277,37 +273,36 @@ func (r *fileRestorer) downloadPack(ctx context.Context, pack *packInfo) { // - should allow concurrent writes to the file // so write the first blob while holding file lock // write other blobs after releasing the lock - file.lock.Lock() - create := file.flags&fileProgress == 0 createSize := int64(-1) - if create { - defer file.lock.Unlock() - file.flags |= fileProgress - createSize = file.size - } else { + file.lock.Lock() + if file.inProgress { file.lock.Unlock() + } else { + defer file.lock.Unlock() + file.inProgress = true + createSize = file.size } return r.filesWriter.writeToFile(r.targetPath(file.location), blobData, offset, createSize) } - err := writeToFile() + err := sanitizeError(file, writeToFile()) if err != nil { - markFileError(file, err) - break + return err } } } } - return nil }) if err != nil { for file := range pack.files { - markFileError(file, err) + if errFile := sanitizeError(file, err); errFile != nil { + return errFile + } } - return } + return nil } func (r *fileRestorer) loadBlob(rd io.Reader, blobID restic.ID, length int, buf []byte) ([]byte, []byte, error) { diff --git a/internal/restorer/filerestorer_test.go b/internal/restorer/filerestorer_test.go index 6ccbc2b37..ac8e371da 100644 --- a/internal/restorer/filerestorer_test.go +++ b/internal/restorer/filerestorer_test.go @@ -262,5 +262,5 @@ func TestErrorRestoreFiles(t *testing.T) { r.files = repo.files err := r.restoreFiles(context.TODO()) - rtest.Assert(t, err != nil, "restoreFiles should have reported an error!") + rtest.Equals(t, loadError, err) } diff --git a/internal/restorer/restorer.go b/internal/restorer/restorer.go index 292241498..75ed99f86 100644 --- a/internal/restorer/restorer.go +++ b/internal/restorer/restorer.go @@ -216,6 +216,7 @@ func (res *Restorer) RestoreTo(ctx context.Context, dst string) error { idx := restic.NewHardlinkIndex() filerestorer := newFileRestorer(dst, res.repo.Backend().Load, res.repo.Key(), res.repo.Index().Lookup) + filerestorer.Error = res.Error debug.Log("first pass for %q", dst)