diff --git a/cmd/restic/cmd_debug.go b/cmd/restic/cmd_debug.go index 9ac7abad3..8f25933f9 100644 --- a/cmd/restic/cmd_debug.go +++ b/cmd/restic/cmd_debug.go @@ -15,8 +15,6 @@ import ( "github.com/restic/restic/internal/pack" "github.com/restic/restic/internal/repository" "github.com/restic/restic/internal/restic" - - "github.com/restic/restic/internal/worker" ) var cmdDebug = &cobra.Command{ @@ -52,26 +50,18 @@ func prettyPrintJSON(wr io.Writer, item interface{}) error { } func debugPrintSnapshots(repo *repository.Repository, wr io.Writer) error { - for id := range repo.List(context.TODO(), restic.SnapshotFile) { + return repo.List(context.TODO(), restic.SnapshotFile, func(id restic.ID, size int64) error { snapshot, err := restic.LoadSnapshot(context.TODO(), repo, id) if err != nil { - fmt.Fprintf(os.Stderr, "LoadSnapshot(%v): %v", id.Str(), err) - continue + return err } fmt.Fprintf(wr, "snapshot_id: %v\n", id) - err = prettyPrintJSON(wr, snapshot) - if err != nil { - return err - } - } - - return nil + return prettyPrintJSON(wr, snapshot) + }) } -const dumpPackWorkers = 10 - // Pack is the struct used in printPacks. type Pack struct { Name string `json:"name"` @@ -88,49 +78,21 @@ type Blob struct { } func printPacks(repo *repository.Repository, wr io.Writer) error { - f := func(ctx context.Context, job worker.Job) (interface{}, error) { - name := job.Data.(string) - h := restic.Handle{Type: restic.DataFile, Name: name} + return repo.List(context.TODO(), restic.DataFile, func(id restic.ID, size int64) error { + h := restic.Handle{Type: restic.DataFile, Name: id.String()} - blobInfo, err := repo.Backend().Stat(ctx, h) + blobs, err := pack.List(repo.Key(), restic.ReaderAt(repo.Backend(), h), size) if err != nil { - return nil, err + fmt.Fprintf(os.Stderr, "error for pack %v: %v\n", id.Str(), err) + return nil } - blobs, err := pack.List(repo.Key(), restic.ReaderAt(repo.Backend(), h), blobInfo.Size) - if err != nil { - return nil, err - } - - return blobs, nil - } - - jobCh := make(chan worker.Job) - resCh := make(chan worker.Job) - wp := worker.New(context.TODO(), dumpPackWorkers, f, jobCh, resCh) - - go func() { - for name := range repo.Backend().List(context.TODO(), restic.DataFile) { - jobCh <- worker.Job{Data: name} - } - close(jobCh) - }() - - for job := range resCh { - name := job.Data.(string) - - if job.Error != nil { - fmt.Fprintf(os.Stderr, "error for pack %v: %v\n", name, job.Error) - continue - } - - entries := job.Result.([]restic.Blob) p := Pack{ - Name: name, - Blobs: make([]Blob, len(entries)), + Name: id.String(), + Blobs: make([]Blob, len(blobs)), } - for i, blob := range entries { + for i, blob := range blobs { p.Blobs[i] = Blob{ Type: blob.Type, Length: blob.Length, @@ -139,16 +101,14 @@ func printPacks(repo *repository.Repository, wr io.Writer) error { } } - prettyPrintJSON(os.Stdout, p) - } - - wp.Wait() + return prettyPrintJSON(os.Stdout, p) + }) return nil } func dumpIndexes(repo restic.Repository) error { - for id := range repo.List(context.TODO(), restic.IndexFile) { + return repo.List(context.TODO(), restic.IndexFile, func(id restic.ID, size int64) error { fmt.Printf("index_id: %v\n", id) idx, err := repository.LoadIndex(context.TODO(), repo, id) @@ -156,13 +116,8 @@ func dumpIndexes(repo restic.Repository) error { return err } - err = idx.Dump(os.Stdout) - if err != nil { - return err - } - } - - return nil + return idx.Dump(os.Stdout) + }) } func runDebugDump(gopts GlobalOptions, args []string) error { diff --git a/cmd/restic/cmd_key.go b/cmd/restic/cmd_key.go index e89a69b63..7552c778d 100644 --- a/cmd/restic/cmd_key.go +++ b/cmd/restic/cmd_key.go @@ -32,11 +32,11 @@ func listKeys(ctx context.Context, s *repository.Repository) error { tab.Header = fmt.Sprintf(" %-10s %-10s %-10s %s", "ID", "User", "Host", "Created") tab.RowFormat = "%s%-10s %-10s %-10s %s" - for id := range s.List(ctx, restic.KeyFile) { + err := s.List(ctx, restic.KeyFile, func(id restic.ID, size int64) error { k, err := repository.LoadKey(ctx, s, id.String()) if err != nil { Warnf("LoadKey() failed: %v\n", err) - continue + return nil } var current string @@ -47,6 +47,10 @@ func listKeys(ctx context.Context, s *repository.Repository) error { } tab.Rows = append(tab.Rows, []interface{}{current, id.Str(), k.Username, k.Hostname, k.Created.Format(TimeFormat)}) + return nil + }) + if err != nil { + return err } return tab.Write(globalOptions.stdout) diff --git a/cmd/restic/cmd_list.go b/cmd/restic/cmd_list.go index 0a7e9ca01..431085ff5 100644 --- a/cmd/restic/cmd_list.go +++ b/cmd/restic/cmd_list.go @@ -73,9 +73,8 @@ func runList(opts GlobalOptions, args []string) error { return errors.Fatal("invalid type") } - for id := range repo.List(opts.ctx, t) { + return repo.List(opts.ctx, t, func(id restic.ID, size int64) error { Printf("%s\n", id) - } - - return nil + return nil + }) } diff --git a/cmd/restic/cmd_prune.go b/cmd/restic/cmd_prune.go index 1383d15a4..6baf7ead3 100644 --- a/cmd/restic/cmd_prune.go +++ b/cmd/restic/cmd_prune.go @@ -120,8 +120,12 @@ func pruneRepository(gopts GlobalOptions, repo restic.Repository) error { } Verbosef("counting files in repo\n") - for range repo.List(ctx, restic.DataFile) { + err = repo.List(ctx, restic.DataFile, func(restic.ID, int64) error { stats.packs++ + return nil + }) + if err != nil { + return err } Verbosef("building new index for repo\n") diff --git a/cmd/restic/cmd_rebuild_index.go b/cmd/restic/cmd_rebuild_index.go index 9480374ea..55bcfa047 100644 --- a/cmd/restic/cmd_rebuild_index.go +++ b/cmd/restic/cmd_rebuild_index.go @@ -48,8 +48,12 @@ func rebuildIndex(ctx context.Context, repo restic.Repository, ignorePacks resti Verbosef("counting files in repo\n") var packs uint64 - for range repo.List(ctx, restic.DataFile) { + err := repo.List(ctx, restic.DataFile, func(restic.ID, int64) error { packs++ + return nil + }) + if err != nil { + return err } bar := newProgressMax(!globalOptions.Quiet, packs-uint64(len(ignorePacks)), "packs") @@ -61,8 +65,12 @@ func rebuildIndex(ctx context.Context, repo restic.Repository, ignorePacks resti Verbosef("finding old index files\n") var supersedes restic.IDs - for id := range repo.List(ctx, restic.IndexFile) { + err = repo.List(ctx, restic.IndexFile, func(id restic.ID, size int64) error { supersedes = append(supersedes, id) + return nil + }) + if err != nil { + return err } id, err := idx.Save(ctx, repo, supersedes) diff --git a/cmd/restic/find.go b/cmd/restic/find.go index 8b227fa55..e48a6ab55 100644 --- a/cmd/restic/find.go +++ b/cmd/restic/find.go @@ -58,7 +58,13 @@ func FindFilteredSnapshots(ctx context.Context, repo *repository.Repository, hos return } - for _, sn := range restic.FindFilteredSnapshots(ctx, repo, host, tags, paths) { + snapshots, err := restic.FindFilteredSnapshots(ctx, repo, host, tags, paths) + if err != nil { + Warnf("could not load snapshots: %v\n", err) + return + } + + for _, sn := range snapshots { select { case <-ctx.Done(): return diff --git a/internal/archiver/archive_reader_test.go b/internal/archiver/archive_reader_test.go index 56e5fec5f..fafc0ed1a 100644 --- a/internal/archiver/archive_reader_test.go +++ b/internal/archiver/archive_reader_test.go @@ -135,8 +135,12 @@ func (e errReader) Read([]byte) (int, error) { func countSnapshots(t testing.TB, repo restic.Repository) int { snapshots := 0 - for range repo.List(context.TODO(), restic.SnapshotFile) { + err := repo.List(context.TODO(), restic.SnapshotFile, func(id restic.ID, size int64) error { snapshots++ + return nil + }) + if err != nil { + t.Fatal(err) } return snapshots } diff --git a/internal/archiver/archiver_duplication_test.go b/internal/archiver/archiver_duplication_test.go index 783dce11c..bdcecf0c6 100644 --- a/internal/archiver/archiver_duplication_test.go +++ b/internal/archiver/archiver_duplication_test.go @@ -60,10 +60,8 @@ func forgetfulBackend() restic.Backend { return nil } - be.ListFn = func(ctx context.Context, t restic.FileType) <-chan string { - ch := make(chan string) - close(ch) - return ch + be.ListFn = func(ctx context.Context, t restic.FileType, fn func(restic.FileInfo) error) error { + return nil } be.DeleteFn = func(ctx context.Context) error { diff --git a/internal/archiver/archiver_test.go b/internal/archiver/archiver_test.go index e578ab3de..6a16a36fc 100644 --- a/internal/archiver/archiver_test.go +++ b/internal/archiver/archiver_test.go @@ -131,9 +131,13 @@ func BenchmarkArchiveDirectory(b *testing.B) { } } -func countPacks(repo restic.Repository, t restic.FileType) (n uint) { - for range repo.Backend().List(context.TODO(), t) { +func countPacks(t testing.TB, repo restic.Repository, tpe restic.FileType) (n uint) { + err := repo.Backend().List(context.TODO(), tpe, func(restic.FileInfo) error { n++ + return nil + }) + if err != nil { + t.Fatal(err) } return n @@ -158,7 +162,7 @@ func archiveWithDedup(t testing.TB) { t.Logf("archived snapshot %v", sn.ID().Str()) // get archive stats - cnt.before.packs = countPacks(repo, restic.DataFile) + cnt.before.packs = countPacks(t, repo, restic.DataFile) cnt.before.dataBlobs = repo.Index().Count(restic.DataBlob) cnt.before.treeBlobs = repo.Index().Count(restic.TreeBlob) t.Logf("packs %v, data blobs %v, tree blobs %v", @@ -169,7 +173,7 @@ func archiveWithDedup(t testing.TB) { t.Logf("archived snapshot %v", sn2.ID().Str()) // get archive stats again - cnt.after.packs = countPacks(repo, restic.DataFile) + cnt.after.packs = countPacks(t, repo, restic.DataFile) cnt.after.dataBlobs = repo.Index().Count(restic.DataBlob) cnt.after.treeBlobs = repo.Index().Count(restic.TreeBlob) t.Logf("packs %v, data blobs %v, tree blobs %v", @@ -186,7 +190,7 @@ func archiveWithDedup(t testing.TB) { t.Logf("archived snapshot %v, parent %v", sn3.ID().Str(), sn2.ID().Str()) // get archive stats again - cnt.after2.packs = countPacks(repo, restic.DataFile) + cnt.after2.packs = countPacks(t, repo, restic.DataFile) cnt.after2.dataBlobs = repo.Index().Count(restic.DataBlob) cnt.after2.treeBlobs = repo.Index().Count(restic.TreeBlob) t.Logf("packs %v, data blobs %v, tree blobs %v", diff --git a/internal/backend/mem/mem_backend.go b/internal/backend/mem/mem_backend.go index ba0ede583..576ff8140 100644 --- a/internal/backend/mem/mem_backend.go +++ b/internal/backend/mem/mem_backend.go @@ -164,17 +164,22 @@ func (be *MemoryBackend) Remove(ctx context.Context, h restic.Handle) error { // List returns a channel which yields entries from the backend. func (be *MemoryBackend) List(ctx context.Context, t restic.FileType, fn func(restic.FileInfo) error) error { - be.m.Lock() - defer be.m.Unlock() + entries := make(map[string]int64) + be.m.Lock() for entry, buf := range be.data { if entry.Type != t { continue } + entries[entry.Name] = int64(len(buf)) + } + be.m.Unlock() + + for name, size := range entries { fi := restic.FileInfo{ - Name: entry.Name, - Size: int64(len(buf)), + Name: name, + Size: size, } if ctx.Err() != nil { diff --git a/internal/checker/checker.go b/internal/checker/checker.go index 6f134dbb7..aec691857 100644 --- a/internal/checker/checker.go +++ b/internal/checker/checker.go @@ -12,6 +12,7 @@ import ( "github.com/restic/restic/internal/fs" "github.com/restic/restic/internal/hashing" "github.com/restic/restic/internal/restic" + "golang.org/x/sync/errgroup" "github.com/restic/restic/internal/debug" "github.com/restic/restic/internal/pack" @@ -192,13 +193,14 @@ func (c *Checker) Packs(ctx context.Context, errChan chan<- error) { debug.Log("listing repository packs") repoPacks := restic.NewIDSet() - for id := range c.repo.List(ctx, restic.DataFile) { - select { - case <-ctx.Done(): - return - default: - } + + err := c.repo.List(ctx, restic.DataFile, func(id restic.ID, size int64) error { repoPacks.Insert(id) + return nil + }) + + if err != nil { + errChan <- err } // orphaned: present in the repo but not in c.packs @@ -719,42 +721,58 @@ func (c *Checker) ReadData(ctx context.Context, p *restic.Progress, errChan chan p.Start() defer p.Done() - worker := func(wg *sync.WaitGroup, in <-chan restic.ID) { - defer wg.Done() - for { - var id restic.ID - var ok bool + g, ctx := errgroup.WithContext(ctx) + ch := make(chan restic.ID) + // start producer for channel ch + g.Go(func() error { + defer close(ch) + return c.repo.List(ctx, restic.DataFile, func(id restic.ID, size int64) error { select { case <-ctx.Done(): - return - case id, ok = <-in: - if !ok { - return + case ch <- id: + } + return nil + }) + }) + + // run workers + for i := 0; i < defaultParallelism; i++ { + g.Go(func() error { + for { + var id restic.ID + var ok bool + + select { + case <-ctx.Done(): + return nil + case id, ok = <-ch: + if !ok { + return nil + } + } + + err := checkPack(ctx, c.repo, id) + p.Report(restic.Stat{Blobs: 1}) + if err == nil { + continue + } + + select { + case <-ctx.Done(): + return nil + case errChan <- err: } } + }) + } - err := checkPack(ctx, c.repo, id) - p.Report(restic.Stat{Blobs: 1}) - if err == nil { - continue - } - - select { - case <-ctx.Done(): - return - case errChan <- err: - } + err := g.Wait() + if err != nil { + select { + case <-ctx.Done(): + return + case errChan <- err: } } - - ch := c.repo.List(ctx, restic.DataFile) - - var wg sync.WaitGroup - for i := 0; i < defaultParallelism; i++ { - wg.Add(1) - go worker(&wg, ch) - } - - wg.Wait() } diff --git a/internal/fuse/file_test.go b/internal/fuse/file_test.go index 121c218f5..622b5dd80 100644 --- a/internal/fuse/file_test.go +++ b/internal/fuse/file_test.go @@ -35,11 +35,17 @@ func testRead(t testing.TB, f *file, offset, length int, data []byte) { } func firstSnapshotID(t testing.TB, repo restic.Repository) (first restic.ID) { - for id := range repo.List(context.TODO(), restic.SnapshotFile) { + err := repo.List(context.TODO(), restic.SnapshotFile, func(id restic.ID, size int64) error { if first.IsNull() { first = id } + return nil + }) + + if err != nil { + t.Fatal(err) } + return first } diff --git a/internal/fuse/snapshots_dir.go b/internal/fuse/snapshots_dir.go index 762056611..671a51ba2 100644 --- a/internal/fuse/snapshots_dir.go +++ b/internal/fuse/snapshots_dir.go @@ -227,18 +227,24 @@ func isElem(e string, list []string) bool { const minSnapshotsReloadTime = 60 * time.Second // update snapshots if repository has changed -func updateSnapshots(ctx context.Context, root *Root) { +func updateSnapshots(ctx context.Context, root *Root) error { if time.Since(root.lastCheck) < minSnapshotsReloadTime { - return + return nil + } + + snapshots, err := restic.FindFilteredSnapshots(ctx, root.repo, root.cfg.Host, root.cfg.Tags, root.cfg.Paths) + if err != nil { + return err } - snapshots := restic.FindFilteredSnapshots(ctx, root.repo, root.cfg.Host, root.cfg.Tags, root.cfg.Paths) if root.snCount != len(snapshots) { root.snCount = len(snapshots) root.repo.LoadIndex(ctx) root.snapshots = snapshots } root.lastCheck = time.Now() + + return nil } // read snapshot timestamps from the current repository-state. diff --git a/internal/index/index.go b/internal/index/index.go index 4c9ebeac3..732dc4b9c 100644 --- a/internal/index/index.go +++ b/internal/index/index.go @@ -115,13 +115,13 @@ func Load(ctx context.Context, repo restic.Repository, p *restic.Progress) (*Ind index := newIndex() - for id := range repo.List(ctx, restic.IndexFile) { + err := repo.List(ctx, restic.IndexFile, func(id restic.ID, size int64) error { p.Report(restic.Stat{Blobs: 1}) debug.Log("Load index %v", id.Str()) idx, err := loadIndexJSON(ctx, repo, id) if err != nil { - return nil, err + return err } res := make(map[restic.ID]Pack) @@ -144,12 +144,18 @@ func Load(ctx context.Context, repo restic.Repository, p *restic.Progress) (*Ind } if err = index.AddPack(jpack.ID, 0, entries); err != nil { - return nil, err + return err } } results[id] = res index.IndexIDs.Insert(id) + + return nil + }) + + if err != nil { + return nil, err } for superID, list := range supersedes { diff --git a/internal/index/index_test.go b/internal/index/index_test.go index 28829afe9..00e9a523e 100644 --- a/internal/index/index_test.go +++ b/internal/index/index_test.go @@ -28,7 +28,7 @@ func createFilledRepo(t testing.TB, snapshots int, dup float32) (restic.Reposito } func validateIndex(t testing.TB, repo restic.Repository, idx *Index) { - for id := range repo.List(context.TODO(), restic.DataFile) { + err := repo.List(context.TODO(), restic.DataFile, func(id restic.ID, size int64) error { p, ok := idx.Packs[id] if !ok { t.Errorf("pack %v missing from index", id.Str()) @@ -37,6 +37,11 @@ func validateIndex(t testing.TB, repo restic.Repository, idx *Index) { if !p.ID.Equal(id) { t.Errorf("pack %v has invalid ID: want %v, got %v", id.Str(), id, p.ID) } + return nil + }) + + if err != nil { + t.Fatal(err) } } @@ -308,7 +313,14 @@ func TestIndexAddRemovePack(t *testing.T) { t.Fatalf("Load() returned error %v", err) } - packID := <-repo.List(context.TODO(), restic.DataFile) + var packID restic.ID + err = repo.List(context.TODO(), restic.DataFile, func(id restic.ID, size int64) error { + packID = id + return nil + }) + if err != nil { + t.Fatal(err) + } t.Logf("selected pack %v", packID.Str()) diff --git a/internal/list/list.go b/internal/list/list.go index 04916b906..ffcc08729 100644 --- a/internal/list/list.go +++ b/internal/list/list.go @@ -11,7 +11,7 @@ const listPackWorkers = 10 // Lister combines lists packs in a repo and blobs in a pack. type Lister interface { - List(context.Context, restic.FileType) <-chan restic.ID + List(context.Context, restic.FileType, func(restic.ID, int64) error) error ListPack(context.Context, restic.ID) ([]restic.Blob, int64, error) } @@ -55,17 +55,19 @@ func AllPacks(ctx context.Context, repo Lister, ignorePacks restic.IDSet, ch cha go func() { defer close(jobCh) - for id := range repo.List(ctx, restic.DataFile) { + + _ = repo.List(ctx, restic.DataFile, func(id restic.ID, size int64) error { if ignorePacks.Has(id) { - continue + return nil } select { case jobCh <- worker.Job{Data: id}: case <-ctx.Done(): - return + return ctx.Err() } - } + return nil + }) }() wp.Wait() diff --git a/internal/migrations/s3_layout.go b/internal/migrations/s3_layout.go index 3d27f0d83..12ffef0ff 100644 --- a/internal/migrations/s3_layout.go +++ b/internal/migrations/s3_layout.go @@ -59,14 +59,14 @@ func (m *S3Layout) moveFiles(ctx context.Context, be *s3.Backend, l backend.Layo fmt.Fprintf(os.Stderr, "renaming file returned error: %v\n", err) } - for name := range be.List(ctx, t) { - h := restic.Handle{Type: t, Name: name} + return be.List(ctx, t, func(fi restic.FileInfo) error { + h := restic.Handle{Type: t, Name: fi.Name} debug.Log("move %v", h) - retry(maxErrors, printErr, func() error { + return retry(maxErrors, printErr, func() error { return be.Rename(h, l) }) - } + }) return nil } diff --git a/internal/repository/key.go b/internal/repository/key.go index fd3ef1a1c..63c35b8e8 100644 --- a/internal/repository/key.go +++ b/internal/repository/key.go @@ -113,42 +113,48 @@ func OpenKey(ctx context.Context, s *Repository, name string, password string) ( // given password. If none could be found, ErrNoKeyFound is returned. When // maxKeys is reached, ErrMaxKeysReached is returned. When setting maxKeys to // zero, all keys in the repo are checked. -func SearchKey(ctx context.Context, s *Repository, password string, maxKeys int) (*Key, error) { +func SearchKey(ctx context.Context, s *Repository, password string, maxKeys int) (k *Key, err error) { checked := 0 // try at most maxKeysForSearch keys in repo - for name := range s.Backend().List(ctx, restic.KeyFile) { + err = s.Backend().List(ctx, restic.KeyFile, func(fi restic.FileInfo) error { if maxKeys > 0 && checked > maxKeys { - return nil, ErrMaxKeysReached + return ErrMaxKeysReached } - _, err := restic.ParseID(name) + _, err := restic.ParseID(fi.Name) if err != nil { - debug.Log("rejecting key with invalid name: %v", name) - continue + debug.Log("rejecting key with invalid name: %v", fi.Name) + return nil } - debug.Log("trying key %q", name) - key, err := OpenKey(ctx, s, name, password) + debug.Log("trying key %q", fi.Name) + key, err := OpenKey(ctx, s, fi.Name, password) if err != nil { - debug.Log("key %v returned error %v", name, err) + debug.Log("key %v returned error %v", fi.Name, err) // ErrUnauthenticated means the password is wrong, try the next key if errors.Cause(err) == crypto.ErrUnauthenticated { - continue + return nil } - if err != nil { - debug.Log("unable to open key %v: %v\n", err) - continue - } + return err } - debug.Log("successfully opened key %v", name) - return key, nil + debug.Log("successfully opened key %v", fi.Name) + k = key + return nil + }) + + if err != nil { + return nil, err } - return nil, ErrNoKeyFound + if k == nil { + return nil, ErrNoKeyFound + } + + return k, nil } // LoadKey loads a key from the backend. diff --git a/internal/repository/parallel.go b/internal/repository/parallel.go index 5f87f94a5..154b58bfa 100644 --- a/internal/repository/parallel.go +++ b/internal/repository/parallel.go @@ -2,10 +2,10 @@ package repository import ( "context" - "sync" "github.com/restic/restic/internal/debug" "github.com/restic/restic/internal/restic" + "golang.org/x/sync/errgroup" ) // ParallelWorkFunc gets one file ID to work on. If an error is returned, @@ -17,47 +17,36 @@ type ParallelWorkFunc func(ctx context.Context, id string) error type ParallelIDWorkFunc func(ctx context.Context, id restic.ID) error // FilesInParallel runs n workers of f in parallel, on the IDs that -// repo.List(t) yield. If f returns an error, the process is aborted and the +// repo.List(t) yields. If f returns an error, the process is aborted and the // first error is returned. -func FilesInParallel(ctx context.Context, repo restic.Lister, t restic.FileType, n uint, f ParallelWorkFunc) error { - wg := &sync.WaitGroup{} - ch := repo.List(ctx, t) - errors := make(chan error, n) +func FilesInParallel(ctx context.Context, repo restic.Lister, t restic.FileType, n int, f ParallelWorkFunc) error { + g, ctx := errgroup.WithContext(ctx) - for i := 0; uint(i) < n; i++ { - wg.Add(1) - go func() { - defer wg.Done() + ch := make(chan string, n) + g.Go(func() error { + defer close(ch) + return repo.List(ctx, t, func(fi restic.FileInfo) error { + select { + case <-ctx.Done(): + case ch <- fi.Name: + } + return nil + }) + }) - for { - select { - case id, ok := <-ch: - if !ok { - return - } - - err := f(ctx, id) - if err != nil { - errors <- err - return - } - case <-ctx.Done(): - return + for i := 0; i < n; i++ { + g.Go(func() error { + for name := range ch { + err := f(ctx, name) + if err != nil { + return err } } - }() + return nil + }) } - wg.Wait() - - select { - case err := <-errors: - return err - default: - break - } - - return nil + return g.Wait() } // ParallelWorkFuncParseID converts a function that takes a restic.ID to a diff --git a/internal/repository/parallel_test.go b/internal/repository/parallel_test.go index 9fa3687bb..7b4c4a583 100644 --- a/internal/repository/parallel_test.go +++ b/internal/repository/parallel_test.go @@ -74,24 +74,25 @@ var lister = testIDs{ "34dd044c228727f2226a0c9c06a3e5ceb5e30e31cb7854f8fa1cde846b395a58", } -func (tests testIDs) List(ctx context.Context, t restic.FileType) <-chan string { - ch := make(chan string) +func (tests testIDs) List(ctx context.Context, t restic.FileType, fn func(restic.FileInfo) error) error { + for i := 0; i < 500; i++ { + for _, id := range tests { + if ctx.Err() != nil { + return ctx.Err() + } - go func() { - defer close(ch) + fi := restic.FileInfo{ + Name: id, + } - for i := 0; i < 500; i++ { - for _, id := range tests { - select { - case ch <- id: - case <-ctx.Done(): - return - } + err := fn(fi) + if err != nil { + return err } } - }() + } - return ch + return nil } func TestFilesInParallel(t *testing.T) { @@ -100,7 +101,7 @@ func TestFilesInParallel(t *testing.T) { return nil } - for n := uint(1); n < 5; n++ { + for n := 1; n < 5; n++ { err := repository.FilesInParallel(context.TODO(), lister, restic.DataFile, n*100, f) rtest.OK(t, err) } @@ -109,7 +110,6 @@ func TestFilesInParallel(t *testing.T) { var errTest = errors.New("test error") func TestFilesInParallelWithError(t *testing.T) { - f := func(ctx context.Context, id string) error { time.Sleep(1 * time.Millisecond) @@ -120,8 +120,10 @@ func TestFilesInParallelWithError(t *testing.T) { return nil } - for n := uint(1); n < 5; n++ { + for n := 1; n < 5; n++ { err := repository.FilesInParallel(context.TODO(), lister, restic.DataFile, n*100, f) - rtest.Equals(t, errTest, err) + if err != errTest { + t.Fatalf("wrong error returned, want %q, got %v", errTest, err) + } } } diff --git a/internal/repository/repack_test.go b/internal/repository/repack_test.go index 2d29a589a..72eb9d3dd 100644 --- a/internal/repository/repack_test.go +++ b/internal/repository/repack_test.go @@ -74,7 +74,7 @@ func selectBlobs(t *testing.T, repo restic.Repository, p float32) (list1, list2 blobs := restic.NewBlobSet() - for id := range repo.List(context.TODO(), restic.DataFile) { + err := repo.List(context.TODO(), restic.DataFile, func(id restic.ID, size int64) error { entries, _, err := repo.ListPack(context.TODO(), id) if err != nil { t.Fatalf("error listing pack %v: %v", id, err) @@ -84,7 +84,7 @@ func selectBlobs(t *testing.T, repo restic.Repository, p float32) (list1, list2 h := restic.BlobHandle{ID: entry.ID, Type: entry.Type} if blobs.Has(h) { t.Errorf("ignoring duplicate blob %v", h) - continue + return nil } blobs.Insert(h) @@ -93,8 +93,11 @@ func selectBlobs(t *testing.T, repo restic.Repository, p float32) (list1, list2 } else { list2.Insert(restic.BlobHandle{ID: entry.ID, Type: entry.Type}) } - } + return nil + }) + if err != nil { + t.Fatal(err) } return list1, list2 @@ -102,8 +105,13 @@ func selectBlobs(t *testing.T, repo restic.Repository, p float32) (list1, list2 func listPacks(t *testing.T, repo restic.Repository) restic.IDSet { list := restic.NewIDSet() - for id := range repo.List(context.TODO(), restic.DataFile) { + err := repo.List(context.TODO(), restic.DataFile, func(id restic.ID, size int64) error { list.Insert(id) + return nil + }) + + if err != nil { + t.Fatal(err) } return list @@ -153,15 +161,15 @@ func rebuildIndex(t *testing.T, repo restic.Repository) { t.Fatal(err) } - for id := range repo.List(context.TODO(), restic.IndexFile) { + err = repo.List(context.TODO(), restic.IndexFile, func(id restic.ID, size int64) error { h := restic.Handle{ Type: restic.IndexFile, Name: id.String(), } - err = repo.Backend().Remove(context.TODO(), h) - if err != nil { - t.Fatal(err) - } + return repo.Backend().Remove(context.TODO(), h) + }) + if err != nil { + t.Fatal(err) } _, err = idx.Save(context.TODO(), repo, nil) diff --git a/internal/repository/repository.go b/internal/repository/repository.go index 193ec1ca7..e772cd8ef 100644 --- a/internal/repository/repository.go +++ b/internal/repository/repository.go @@ -536,22 +536,15 @@ func (r *Repository) KeyName() string { return r.keyName } -// List returns a channel that yields all IDs of type t in the backend. -func (r *Repository) List(ctx context.Context, t restic.FileType) <-chan restic.ID { - out := make(chan restic.ID) - go func() { - defer close(out) - for strID := range r.be.List(ctx, t) { - if id, err := restic.ParseID(strID); err == nil { - select { - case out <- id: - case <-ctx.Done(): - return - } - } +// List runs fn for all files of type t in the repo. +func (r *Repository) List(ctx context.Context, t restic.FileType, fn func(restic.ID, int64) error) error { + return r.be.List(ctx, t, func(fi restic.FileInfo) error { + id, err := restic.ParseID(fi.Name) + if err != nil { + debug.Log("unable to parse %v as an ID", fi.Name) } - }() - return out + return fn(id, fi.Size) + }) } // ListPack returns the list of blobs saved in the pack id and the length of diff --git a/internal/repository/repository_test.go b/internal/repository/repository_test.go index a90f0959b..60c1190ce 100644 --- a/internal/repository/repository_test.go +++ b/internal/repository/repository_test.go @@ -369,7 +369,7 @@ func TestRepositoryIncrementalIndex(t *testing.T) { packEntries := make(map[restic.ID]map[restic.ID]struct{}) - for id := range repo.List(context.TODO(), restic.IndexFile) { + err := repo.List(context.TODO(), restic.IndexFile, func(id restic.ID, size int64) error { idx, err := repository.LoadIndex(context.TODO(), repo, id) rtest.OK(t, err) @@ -380,6 +380,10 @@ func TestRepositoryIncrementalIndex(t *testing.T) { packEntries[pb.PackID][id] = struct{}{} } + return nil + }) + if err != nil { + t.Fatal(err) } for packID, ids := range packEntries { diff --git a/internal/restic/lock.go b/internal/restic/lock.go index 177b0d707..882f970f3 100644 --- a/internal/restic/lock.go +++ b/internal/restic/lock.go @@ -157,15 +157,14 @@ func (l *Lock) checkForOtherLocks(ctx context.Context) error { } func eachLock(ctx context.Context, repo Repository, f func(ID, *Lock, error) error) error { - for id := range repo.List(ctx, LockFile) { + return repo.List(ctx, LockFile, func(id ID, size int64) error { lock, err := LoadLock(ctx, repo, id) - err = f(id, lock, err) if err != nil { return err } - } - return nil + return f(id, lock, err) + }) } // createLock acquires the lock by creating a file in the repository. diff --git a/internal/restic/lock_test.go b/internal/restic/lock_test.go index a3b4936c9..daadd479f 100644 --- a/internal/restic/lock_test.go +++ b/internal/restic/lock_test.go @@ -227,21 +227,29 @@ func TestLockRefresh(t *testing.T) { rtest.OK(t, err) var lockID *restic.ID - for id := range repo.List(context.TODO(), restic.LockFile) { + err = repo.List(context.TODO(), restic.LockFile, func(id restic.ID, size int64) error { if lockID != nil { t.Error("more than one lock found") } lockID = &id + return nil + }) + if err != nil { + t.Fatal(err) } rtest.OK(t, lock.Refresh(context.TODO())) var lockID2 *restic.ID - for id := range repo.List(context.TODO(), restic.LockFile) { + err = repo.List(context.TODO(), restic.LockFile, func(id restic.ID, size int64) error { if lockID2 != nil { t.Error("more than one lock found") } lockID2 = &id + return nil + }) + if err != nil { + t.Fatal(err) } rtest.Assert(t, !lockID.Equal(*lockID2), diff --git a/internal/restic/repository.go b/internal/restic/repository.go index defd6174f..f1da9770c 100644 --- a/internal/restic/repository.go +++ b/internal/restic/repository.go @@ -26,7 +26,7 @@ type Repository interface { LookupBlobSize(ID, BlobType) (uint, error) - List(context.Context, FileType) <-chan ID + List(context.Context, FileType, func(ID, int64) error) error ListPack(context.Context, ID) ([]Blob, int64, error) Flush(context.Context) error diff --git a/internal/restic/snapshot.go b/internal/restic/snapshot.go index 47b123240..4622bb530 100644 --- a/internal/restic/snapshot.go +++ b/internal/restic/snapshot.go @@ -64,15 +64,21 @@ func LoadSnapshot(ctx context.Context, repo Repository, id ID) (*Snapshot, error // LoadAllSnapshots returns a list of all snapshots in the repo. func LoadAllSnapshots(ctx context.Context, repo Repository) (snapshots []*Snapshot, err error) { - for id := range repo.List(ctx, SnapshotFile) { + err = repo.List(ctx, SnapshotFile, func(id ID, size int64) error { sn, err := LoadSnapshot(ctx, repo, id) if err != nil { - return nil, err + return err } snapshots = append(snapshots, sn) + return nil + }) + + if err != nil { + return nil, err } - return + + return snapshots, nil } func (sn Snapshot) String() string { diff --git a/internal/restic/snapshot_find.go b/internal/restic/snapshot_find.go index 4c239fb1e..b5d0a8276 100644 --- a/internal/restic/snapshot_find.go +++ b/internal/restic/snapshot_find.go @@ -20,26 +20,31 @@ func FindLatestSnapshot(ctx context.Context, repo Repository, targets []string, found bool ) - for snapshotID := range repo.List(ctx, SnapshotFile) { + err := repo.List(ctx, SnapshotFile, func(snapshotID ID, size int64) error { snapshot, err := LoadSnapshot(ctx, repo, snapshotID) if err != nil { - return ID{}, errors.Errorf("Error listing snapshot: %v", err) + return errors.Errorf("Error loading snapshot %v: %v", snapshotID.Str(), err) } if snapshot.Time.Before(latest) || (hostname != "" && hostname != snapshot.Hostname) { - continue + return nil } if !snapshot.HasTagList(tagLists) { - continue + return nil } if !snapshot.HasPaths(targets) { - continue + return nil } latest = snapshot.Time latestID = snapshotID found = true + return nil + }) + + if err != nil { + return ID{}, err } if !found { @@ -64,20 +69,27 @@ func FindSnapshot(repo Repository, s string) (ID, error) { // FindFilteredSnapshots yields Snapshots filtered from the list of all // snapshots. -func FindFilteredSnapshots(ctx context.Context, repo Repository, host string, tags []TagList, paths []string) Snapshots { +func FindFilteredSnapshots(ctx context.Context, repo Repository, host string, tags []TagList, paths []string) (Snapshots, error) { results := make(Snapshots, 0, 20) - for id := range repo.List(ctx, SnapshotFile) { + err := repo.List(ctx, SnapshotFile, func(id ID, size int64) error { sn, err := LoadSnapshot(ctx, repo, id) if err != nil { fmt.Fprintf(os.Stderr, "could not load snapshot %v: %v\n", id.Str(), err) - continue + return nil } + if (host != "" && host != sn.Hostname) || !sn.HasTagList(tags) || !sn.HasPaths(paths) { - continue + return nil } results = append(results, sn) + return nil + }) + + if err != nil { + return nil, err } - return results + + return results, nil }