diff --git a/src/cmds/restic/cmd_backup.go b/src/cmds/restic/cmd_backup.go index 077bba85a..53204f2d2 100644 --- a/src/cmds/restic/cmd_backup.go +++ b/src/cmds/restic/cmd_backup.go @@ -2,6 +2,7 @@ package main import ( "bufio" + "context" "fmt" "io" "os" @@ -263,7 +264,7 @@ func readBackupFromStdin(opts BackupOptions, gopts GlobalOptions, args []string) return err } - err = repo.LoadIndex() + err = repo.LoadIndex(context.TODO()) if err != nil { return err } @@ -274,7 +275,7 @@ func readBackupFromStdin(opts BackupOptions, gopts GlobalOptions, args []string) Hostname: opts.Hostname, } - _, id, err := r.Archive(opts.StdinFilename, os.Stdin, newArchiveStdinProgress(gopts)) + _, id, err := r.Archive(context.TODO(), opts.StdinFilename, os.Stdin, newArchiveStdinProgress(gopts)) if err != nil { return err } @@ -372,7 +373,7 @@ func runBackup(opts BackupOptions, gopts GlobalOptions, args []string) error { return err } - err = repo.LoadIndex() + err = repo.LoadIndex(context.TODO()) if err != nil { return err } @@ -391,7 +392,7 @@ func runBackup(opts BackupOptions, gopts GlobalOptions, args []string) error { // Find last snapshot to set it as parent, if not already set if !opts.Force && parentSnapshotID == nil { - id, err := restic.FindLatestSnapshot(repo, target, opts.Tags, opts.Hostname) + id, err := restic.FindLatestSnapshot(context.TODO(), repo, target, opts.Tags, opts.Hostname) if err == nil { parentSnapshotID = &id } else if err != restic.ErrNoSnapshotFound { @@ -489,7 +490,7 @@ func runBackup(opts BackupOptions, gopts GlobalOptions, args []string) error { Warnf("%s\rwarning for %s: %v\n", ClearLine(), dir, err) } - _, id, err := arch.Snapshot(newArchiveProgress(gopts, stat), target, opts.Tags, opts.Hostname, parentSnapshotID) + _, id, err := arch.Snapshot(context.TODO(), newArchiveProgress(gopts, stat), target, opts.Tags, opts.Hostname, parentSnapshotID) if err != nil { return err } diff --git a/src/cmds/restic/cmd_cat.go b/src/cmds/restic/cmd_cat.go index ee5798d21..565cccbbe 100644 --- a/src/cmds/restic/cmd_cat.go +++ b/src/cmds/restic/cmd_cat.go @@ -1,6 +1,7 @@ package main import ( + "context" "encoding/json" "fmt" "os" @@ -73,7 +74,7 @@ func runCat(gopts GlobalOptions, args []string) error { fmt.Println(string(buf)) return nil case "index": - buf, err := repo.LoadAndDecrypt(restic.IndexFile, id) + buf, err := repo.LoadAndDecrypt(context.TODO(), restic.IndexFile, id) if err != nil { return err } @@ -83,7 +84,7 @@ func runCat(gopts GlobalOptions, args []string) error { case "snapshot": sn := &restic.Snapshot{} - err = repo.LoadJSONUnpacked(restic.SnapshotFile, id, sn) + err = repo.LoadJSONUnpacked(context.TODO(), restic.SnapshotFile, id, sn) if err != nil { return err } @@ -98,7 +99,7 @@ func runCat(gopts GlobalOptions, args []string) error { return nil case "key": h := restic.Handle{Type: restic.KeyFile, Name: id.String()} - buf, err := backend.LoadAll(repo.Backend(), h) + buf, err := backend.LoadAll(context.TODO(), repo.Backend(), h) if err != nil { return err } @@ -125,7 +126,7 @@ func runCat(gopts GlobalOptions, args []string) error { fmt.Println(string(buf)) return nil case "lock": - lock, err := restic.LoadLock(repo, id) + lock, err := restic.LoadLock(context.TODO(), repo, id) if err != nil { return err } @@ -141,7 +142,7 @@ func runCat(gopts GlobalOptions, args []string) error { } // load index, handle all the other types - err = repo.LoadIndex() + err = repo.LoadIndex(context.TODO()) if err != nil { return err } @@ -149,7 +150,7 @@ func runCat(gopts GlobalOptions, args []string) error { switch tpe { case "pack": h := restic.Handle{Type: restic.DataFile, Name: id.String()} - buf, err := backend.LoadAll(repo.Backend(), h) + buf, err := backend.LoadAll(context.TODO(), repo.Backend(), h) if err != nil { return err } @@ -171,7 +172,7 @@ func runCat(gopts GlobalOptions, args []string) error { blob := list[0] buf := make([]byte, blob.Length) - n, err := repo.LoadBlob(t, id, buf) + n, err := repo.LoadBlob(context.TODO(), t, id, buf) if err != nil { return err } diff --git a/src/cmds/restic/cmd_check.go b/src/cmds/restic/cmd_check.go index 2f0064f1a..985542d6e 100644 --- a/src/cmds/restic/cmd_check.go +++ b/src/cmds/restic/cmd_check.go @@ -1,6 +1,7 @@ package main import ( + "context" "fmt" "os" "time" @@ -92,7 +93,7 @@ func runCheck(opts CheckOptions, gopts GlobalOptions, args []string) error { chkr := checker.New(repo) Verbosef("Load indexes\n") - hints, errs := chkr.LoadIndex() + hints, errs := chkr.LoadIndex(context.TODO()) dupFound := false for _, hint := range hints { @@ -113,14 +114,11 @@ func runCheck(opts CheckOptions, gopts GlobalOptions, args []string) error { return errors.Fatal("LoadIndex returned errors") } - done := make(chan struct{}) - defer close(done) - errorsFound := false errChan := make(chan error) Verbosef("Check all packs\n") - go chkr.Packs(errChan, done) + go chkr.Packs(context.TODO(), errChan) for err := range errChan { errorsFound = true @@ -129,7 +127,7 @@ func runCheck(opts CheckOptions, gopts GlobalOptions, args []string) error { Verbosef("Check snapshots, trees and blobs\n") errChan = make(chan error) - go chkr.Structure(errChan, done) + go chkr.Structure(context.TODO(), errChan) for err := range errChan { errorsFound = true @@ -156,7 +154,7 @@ func runCheck(opts CheckOptions, gopts GlobalOptions, args []string) error { p := newReadProgress(gopts, restic.Stat{Blobs: chkr.CountPacks()}) errChan := make(chan error) - go chkr.ReadData(p, errChan, done) + go chkr.ReadData(context.TODO(), p, errChan) for err := range errChan { errorsFound = true diff --git a/src/cmds/restic/cmd_dump.go b/src/cmds/restic/cmd_dump.go index 350e4d7dd..86fd8e93e 100644 --- a/src/cmds/restic/cmd_dump.go +++ b/src/cmds/restic/cmd_dump.go @@ -1,8 +1,9 @@ -// +build debug +// xbuild debug package main import ( + "context" "encoding/json" "fmt" "io" @@ -44,11 +45,8 @@ func prettyPrintJSON(wr io.Writer, item interface{}) error { } func debugPrintSnapshots(repo *repository.Repository, wr io.Writer) error { - done := make(chan struct{}) - defer close(done) - - for id := range repo.List(restic.SnapshotFile, done) { - snapshot, err := restic.LoadSnapshot(repo, id) + for id := range repo.List(context.TODO(), restic.SnapshotFile) { + snapshot, err := restic.LoadSnapshot(context.TODO(), repo, id) if err != nil { fmt.Fprintf(os.Stderr, "LoadSnapshot(%v): %v", id.Str(), err) continue @@ -83,15 +81,12 @@ type Blob struct { } func printPacks(repo *repository.Repository, wr io.Writer) error { - done := make(chan struct{}) - defer close(done) - - f := func(job worker.Job, done <-chan struct{}) (interface{}, error) { + f := func(ctx context.Context, job worker.Job) (interface{}, error) { name := job.Data.(string) h := restic.Handle{Type: restic.DataFile, Name: name} - blobInfo, err := repo.Backend().Stat(h) + blobInfo, err := repo.Backend().Stat(ctx, h) if err != nil { return nil, err } @@ -106,10 +101,10 @@ func printPacks(repo *repository.Repository, wr io.Writer) error { jobCh := make(chan worker.Job) resCh := make(chan worker.Job) - wp := worker.New(dumpPackWorkers, f, jobCh, resCh) + wp := worker.New(context.TODO(), dumpPackWorkers, f, jobCh, resCh) go func() { - for name := range repo.Backend().List(restic.DataFile, done) { + for name := range repo.Backend().List(context.TODO(), restic.DataFile) { jobCh <- worker.Job{Data: name} } close(jobCh) @@ -146,13 +141,10 @@ func printPacks(repo *repository.Repository, wr io.Writer) error { } func dumpIndexes(repo restic.Repository) error { - done := make(chan struct{}) - defer close(done) - - for id := range repo.List(restic.IndexFile, done) { + for id := range repo.List(context.TODO(), restic.IndexFile) { fmt.Printf("index_id: %v\n", id) - idx, err := repository.LoadIndex(repo, id) + idx, err := repository.LoadIndex(context.TODO(), repo, id) if err != nil { return err } @@ -184,7 +176,7 @@ func runDump(gopts GlobalOptions, args []string) error { } } - err = repo.LoadIndex() + err = repo.LoadIndex(context.TODO()) if err != nil { return err } diff --git a/src/cmds/restic/cmd_find.go b/src/cmds/restic/cmd_find.go index 70d1cc127..36932491c 100644 --- a/src/cmds/restic/cmd_find.go +++ b/src/cmds/restic/cmd_find.go @@ -187,7 +187,7 @@ func (f *Finder) findInTree(treeID restic.ID, prefix string) error { debug.Log("%v checking tree %v\n", prefix, treeID.Str()) - tree, err := f.repo.LoadTree(treeID) + tree, err := f.repo.LoadTree(context.TODO(), treeID) if err != nil { return err } @@ -283,7 +283,7 @@ func runFind(opts FindOptions, gopts GlobalOptions, args []string) error { } } - if err = repo.LoadIndex(); err != nil { + if err = repo.LoadIndex(context.TODO()); err != nil { return err } diff --git a/src/cmds/restic/cmd_forget.go b/src/cmds/restic/cmd_forget.go index 4885cf617..f4c03b4f1 100644 --- a/src/cmds/restic/cmd_forget.go +++ b/src/cmds/restic/cmd_forget.go @@ -97,7 +97,7 @@ func runForget(opts ForgetOptions, gopts GlobalOptions, args []string) error { // When explicit snapshots args are given, remove them immediately. if !opts.DryRun { h := restic.Handle{Type: restic.SnapshotFile, Name: sn.ID().String()} - if err = repo.Backend().Remove(h); err != nil { + if err = repo.Backend().Remove(context.TODO(), h); err != nil { return err } Verbosef("removed snapshot %v\n", sn.ID().Str()) @@ -167,7 +167,7 @@ func runForget(opts ForgetOptions, gopts GlobalOptions, args []string) error { if !opts.DryRun { for _, sn := range remove { h := restic.Handle{Type: restic.SnapshotFile, Name: sn.ID().String()} - err = repo.Backend().Remove(h) + err = repo.Backend().Remove(context.TODO(), h) if err != nil { return err } diff --git a/src/cmds/restic/cmd_init.go b/src/cmds/restic/cmd_init.go index ffa9cf272..69cacc6ca 100644 --- a/src/cmds/restic/cmd_init.go +++ b/src/cmds/restic/cmd_init.go @@ -1,6 +1,7 @@ package main import ( + "context" "restic/errors" "restic/repository" @@ -43,7 +44,7 @@ func runInit(gopts GlobalOptions, args []string) error { s := repository.New(be) - err = s.Init(gopts.password) + err = s.Init(context.TODO(), gopts.password) if err != nil { return errors.Fatalf("create key in backend at %s failed: %v\n", gopts.Repo, err) } diff --git a/src/cmds/restic/cmd_key.go b/src/cmds/restic/cmd_key.go index 052dd5b8a..864c520e6 100644 --- a/src/cmds/restic/cmd_key.go +++ b/src/cmds/restic/cmd_key.go @@ -30,8 +30,8 @@ func listKeys(ctx context.Context, s *repository.Repository) error { tab.Header = fmt.Sprintf(" %-10s %-10s %-10s %s", "ID", "User", "Host", "Created") tab.RowFormat = "%s%-10s %-10s %-10s %s" - for id := range s.List(restic.KeyFile, ctx.Done()) { - k, err := repository.LoadKey(s, id.String()) + for id := range s.List(ctx, restic.KeyFile) { + k, err := repository.LoadKey(ctx, s, id.String()) if err != nil { Warnf("LoadKey() failed: %v\n", err) continue @@ -69,7 +69,7 @@ func addKey(gopts GlobalOptions, repo *repository.Repository) error { return err } - id, err := repository.AddKey(repo, pw, repo.Key()) + id, err := repository.AddKey(context.TODO(), repo, pw, repo.Key()) if err != nil { return errors.Fatalf("creating new key failed: %v\n", err) } @@ -85,7 +85,7 @@ func deleteKey(repo *repository.Repository, name string) error { } h := restic.Handle{Type: restic.KeyFile, Name: name} - err := repo.Backend().Remove(h) + err := repo.Backend().Remove(context.TODO(), h) if err != nil { return err } @@ -100,13 +100,13 @@ func changePassword(gopts GlobalOptions, repo *repository.Repository) error { return err } - id, err := repository.AddKey(repo, pw, repo.Key()) + id, err := repository.AddKey(context.TODO(), repo, pw, repo.Key()) if err != nil { return errors.Fatalf("creating new key failed: %v\n", err) } h := restic.Handle{Type: restic.KeyFile, Name: repo.KeyName()} - err = repo.Backend().Remove(h) + err = repo.Backend().Remove(context.TODO(), h) if err != nil { return err } diff --git a/src/cmds/restic/cmd_list.go b/src/cmds/restic/cmd_list.go index 105f70a5e..dd4ee8053 100644 --- a/src/cmds/restic/cmd_list.go +++ b/src/cmds/restic/cmd_list.go @@ -1,6 +1,7 @@ package main import ( + "context" "fmt" "restic" "restic/errors" @@ -55,7 +56,7 @@ func runList(opts GlobalOptions, args []string) error { case "locks": t = restic.LockFile case "blobs": - idx, err := index.Load(repo, nil) + idx, err := index.Load(context.TODO(), repo, nil) if err != nil { return err } @@ -71,7 +72,7 @@ func runList(opts GlobalOptions, args []string) error { return errors.Fatal("invalid type") } - for id := range repo.List(t, nil) { + for id := range repo.List(context.TODO(), t) { Printf("%s\n", id) } diff --git a/src/cmds/restic/cmd_ls.go b/src/cmds/restic/cmd_ls.go index 7d613b45c..be38cf914 100644 --- a/src/cmds/restic/cmd_ls.go +++ b/src/cmds/restic/cmd_ls.go @@ -46,7 +46,7 @@ func init() { } func printTree(repo *repository.Repository, id *restic.ID, prefix string) error { - tree, err := repo.LoadTree(*id) + tree, err := repo.LoadTree(context.TODO(), *id) if err != nil { return err } @@ -74,7 +74,7 @@ func runLs(opts LsOptions, gopts GlobalOptions, args []string) error { return err } - if err = repo.LoadIndex(); err != nil { + if err = repo.LoadIndex(context.TODO()); err != nil { return err } diff --git a/src/cmds/restic/cmd_mount.go b/src/cmds/restic/cmd_mount.go index 20ce5dec3..86e27b351 100644 --- a/src/cmds/restic/cmd_mount.go +++ b/src/cmds/restic/cmd_mount.go @@ -4,6 +4,7 @@ package main import ( + "context" "os" "github.com/spf13/cobra" @@ -64,7 +65,7 @@ func mount(opts MountOptions, gopts GlobalOptions, mountpoint string) error { return err } - err = repo.LoadIndex() + err = repo.LoadIndex(context.TODO()) if err != nil { return err } diff --git a/src/cmds/restic/cmd_prune.go b/src/cmds/restic/cmd_prune.go index 2fb5a4c12..b46419a22 100644 --- a/src/cmds/restic/cmd_prune.go +++ b/src/cmds/restic/cmd_prune.go @@ -1,7 +1,6 @@ package main import ( - "context" "fmt" "restic" "restic/debug" @@ -76,14 +75,13 @@ func runPrune(gopts GlobalOptions) error { } func pruneRepository(gopts GlobalOptions, repo restic.Repository) error { - err := repo.LoadIndex() + ctx := gopts.ctx + + err := repo.LoadIndex(ctx) if err != nil { return err } - ctx, cancel := context.WithCancel(gopts.ctx) - defer cancel() - var stats struct { blobs int packs int @@ -92,14 +90,14 @@ func pruneRepository(gopts GlobalOptions, repo restic.Repository) error { } Verbosef("counting files in repo\n") - for range repo.List(restic.DataFile, ctx.Done()) { + for range repo.List(ctx, restic.DataFile) { stats.packs++ } Verbosef("building new index for repo\n") bar := newProgressMax(!gopts.Quiet, uint64(stats.packs), "packs") - idx, err := index.New(repo, bar) + idx, err := index.New(ctx, repo, bar) if err != nil { return err } @@ -135,7 +133,7 @@ func pruneRepository(gopts GlobalOptions, repo restic.Repository) error { Verbosef("load all snapshots\n") // find referenced blobs - snapshots, err := restic.LoadAllSnapshots(repo) + snapshots, err := restic.LoadAllSnapshots(ctx, repo) if err != nil { return err } @@ -152,7 +150,7 @@ func pruneRepository(gopts GlobalOptions, repo restic.Repository) error { for _, sn := range snapshots { debug.Log("process snapshot %v", sn.ID().Str()) - err = restic.FindUsedBlobs(repo, *sn.Tree, usedBlobs, seenBlobs) + err = restic.FindUsedBlobs(ctx, repo, *sn.Tree, usedBlobs, seenBlobs) if err != nil { return err } @@ -217,7 +215,7 @@ func pruneRepository(gopts GlobalOptions, repo restic.Repository) error { if len(rewritePacks) != 0 { bar = newProgressMax(!gopts.Quiet, uint64(len(rewritePacks)), "packs rewritten") bar.Start() - err = repository.Repack(repo, rewritePacks, usedBlobs, bar) + err = repository.Repack(ctx, repo, rewritePacks, usedBlobs, bar) if err != nil { return err } @@ -229,7 +227,7 @@ func pruneRepository(gopts GlobalOptions, repo restic.Repository) error { bar.Start() for packID := range removePacks { h := restic.Handle{Type: restic.DataFile, Name: packID.String()} - err = repo.Backend().Remove(h) + err = repo.Backend().Remove(ctx, h) if err != nil { Warnf("unable to remove file %v from the repository\n", packID.Str()) } diff --git a/src/cmds/restic/cmd_rebuild_index.go b/src/cmds/restic/cmd_rebuild_index.go index 72fa5d574..6a60ea900 100644 --- a/src/cmds/restic/cmd_rebuild_index.go +++ b/src/cmds/restic/cmd_rebuild_index.go @@ -45,12 +45,12 @@ func rebuildIndex(ctx context.Context, repo restic.Repository) error { Verbosef("counting files in repo\n") var packs uint64 - for range repo.List(restic.DataFile, ctx.Done()) { + for range repo.List(ctx, restic.DataFile) { packs++ } bar := newProgressMax(!globalOptions.Quiet, packs, "packs") - idx, err := index.New(repo, bar) + idx, err := index.New(ctx, repo, bar) if err != nil { return err } @@ -58,11 +58,11 @@ func rebuildIndex(ctx context.Context, repo restic.Repository) error { Verbosef("finding old index files\n") var supersedes restic.IDs - for id := range repo.List(restic.IndexFile, ctx.Done()) { + for id := range repo.List(ctx, restic.IndexFile) { supersedes = append(supersedes, id) } - id, err := idx.Save(repo, supersedes) + id, err := idx.Save(ctx, repo, supersedes) if err != nil { return err } @@ -72,7 +72,7 @@ func rebuildIndex(ctx context.Context, repo restic.Repository) error { Verbosef("remove %d old index files\n", len(supersedes)) for _, id := range supersedes { - if err := repo.Backend().Remove(restic.Handle{ + if err := repo.Backend().Remove(ctx, restic.Handle{ Type: restic.IndexFile, Name: id.String(), }); err != nil { diff --git a/src/cmds/restic/cmd_restore.go b/src/cmds/restic/cmd_restore.go index b19b58075..9dec03851 100644 --- a/src/cmds/restic/cmd_restore.go +++ b/src/cmds/restic/cmd_restore.go @@ -50,6 +50,8 @@ func init() { } func runRestore(opts RestoreOptions, gopts GlobalOptions, args []string) error { + ctx := gopts.ctx + if len(args) != 1 { return errors.Fatal("no snapshot ID specified") } @@ -79,7 +81,7 @@ func runRestore(opts RestoreOptions, gopts GlobalOptions, args []string) error { } } - err = repo.LoadIndex() + err = repo.LoadIndex(ctx) if err != nil { return err } @@ -87,7 +89,7 @@ func runRestore(opts RestoreOptions, gopts GlobalOptions, args []string) error { var id restic.ID if snapshotIDString == "latest" { - id, err = restic.FindLatestSnapshot(repo, opts.Paths, opts.Tags, opts.Host) + id, err = restic.FindLatestSnapshot(ctx, repo, opts.Paths, opts.Tags, opts.Host) if err != nil { Exitf(1, "latest snapshot for criteria not found: %v Paths:%v Host:%v", err, opts.Paths, opts.Host) } @@ -136,7 +138,7 @@ func runRestore(opts RestoreOptions, gopts GlobalOptions, args []string) error { Verbosef("restoring %s to %s\n", res.Snapshot(), opts.Target) - err = res.RestoreTo(opts.Target) + err = res.RestoreTo(ctx, opts.Target) if totalErrors > 0 { Printf("There were %d errors\n", totalErrors) } diff --git a/src/cmds/restic/cmd_tag.go b/src/cmds/restic/cmd_tag.go index 17ed81919..32c2ba583 100644 --- a/src/cmds/restic/cmd_tag.go +++ b/src/cmds/restic/cmd_tag.go @@ -76,7 +76,7 @@ func changeTags(repo *repository.Repository, sn *restic.Snapshot, setTags, addTa } // Save the new snapshot. - id, err := repo.SaveJSONUnpacked(restic.SnapshotFile, sn) + id, err := repo.SaveJSONUnpacked(context.TODO(), restic.SnapshotFile, sn) if err != nil { return false, err } @@ -89,7 +89,7 @@ func changeTags(repo *repository.Repository, sn *restic.Snapshot, setTags, addTa // Remove the old snapshot. h := restic.Handle{Type: restic.SnapshotFile, Name: sn.ID().String()} - if err = repo.Backend().Remove(h); err != nil { + if err = repo.Backend().Remove(context.TODO(), h); err != nil { return false, err } diff --git a/src/cmds/restic/cmd_unlock.go b/src/cmds/restic/cmd_unlock.go index 6601909cb..33735d26a 100644 --- a/src/cmds/restic/cmd_unlock.go +++ b/src/cmds/restic/cmd_unlock.go @@ -1,6 +1,7 @@ package main import ( + "context" "restic" "github.com/spf13/cobra" @@ -41,7 +42,7 @@ func runUnlock(opts UnlockOptions, gopts GlobalOptions) error { fn = restic.RemoveAllLocks } - err = fn(repo) + err = fn(context.TODO(), repo) if err != nil { return err } diff --git a/src/cmds/restic/find.go b/src/cmds/restic/find.go index fa9d71694..9a0e9bb56 100644 --- a/src/cmds/restic/find.go +++ b/src/cmds/restic/find.go @@ -22,7 +22,7 @@ func FindFilteredSnapshots(ctx context.Context, repo *repository.Repository, hos // Process all snapshot IDs given as arguments. for _, s := range snapshotIDs { if s == "latest" { - id, err = restic.FindLatestSnapshot(repo, paths, tags, host) + id, err = restic.FindLatestSnapshot(ctx, repo, paths, tags, host) if err != nil { Warnf("Ignoring %q, no snapshot matched given filter (Paths:%v Tags:%v Host:%v)\n", s, paths, tags, host) usedFilter = true @@ -44,7 +44,7 @@ func FindFilteredSnapshots(ctx context.Context, repo *repository.Repository, hos } for _, id := range ids.Uniq() { - sn, err := restic.LoadSnapshot(repo, id) + sn, err := restic.LoadSnapshot(ctx, repo, id) if err != nil { Warnf("Ignoring %q, could not load snapshot: %v\n", id, err) continue @@ -58,8 +58,8 @@ func FindFilteredSnapshots(ctx context.Context, repo *repository.Repository, hos return } - for id := range repo.List(restic.SnapshotFile, ctx.Done()) { - sn, err := restic.LoadSnapshot(repo, id) + for id := range repo.List(ctx, restic.SnapshotFile) { + sn, err := restic.LoadSnapshot(ctx, repo, id) if err != nil { Warnf("Ignoring %q, could not load snapshot: %v\n", id, err) continue diff --git a/src/cmds/restic/global.go b/src/cmds/restic/global.go index ae3106285..fe266aa56 100644 --- a/src/cmds/restic/global.go +++ b/src/cmds/restic/global.go @@ -310,7 +310,7 @@ func OpenRepository(opts GlobalOptions) (*repository.Repository, error) { } } - err = s.SearchKey(opts.password, maxKeys) + err = s.SearchKey(context.TODO(), opts.password, maxKeys) if err != nil { return nil, errors.Fatalf("unable to open repo: %v", err) } @@ -440,7 +440,7 @@ func open(s string, opts options.Options) (restic.Backend, error) { } // check if config is there - fi, err := be.Stat(restic.Handle{Type: restic.ConfigFile}) + fi, err := be.Stat(context.TODO(), restic.Handle{Type: restic.ConfigFile}) if err != nil { return nil, errors.Fatalf("unable to open config file: %v\nIs there a repository at the following location?\n%v", err, s) } diff --git a/src/cmds/restic/lock.go b/src/cmds/restic/lock.go index 81bdafbc2..647233b9a 100644 --- a/src/cmds/restic/lock.go +++ b/src/cmds/restic/lock.go @@ -1,6 +1,7 @@ package main import ( + "context" "fmt" "os" "sync" @@ -32,7 +33,7 @@ func lockRepository(repo *repository.Repository, exclusive bool) (*restic.Lock, lockFn = restic.NewExclusiveLock } - lock, err := lockFn(repo) + lock, err := lockFn(context.TODO(), repo) if err != nil { return nil, err } @@ -75,7 +76,7 @@ func refreshLocks(wg *sync.WaitGroup, done <-chan struct{}) { debug.Log("refreshing locks") globalLocks.Lock() for _, lock := range globalLocks.locks { - err := lock.Refresh() + err := lock.Refresh(context.TODO()) if err != nil { fmt.Fprintf(os.Stderr, "unable to refresh lock: %v\n", err) } diff --git a/src/restic/archiver/archive_reader.go b/src/restic/archiver/archive_reader.go index 28f228828..d67130744 100644 --- a/src/restic/archiver/archive_reader.go +++ b/src/restic/archiver/archive_reader.go @@ -1,6 +1,7 @@ package archiver import ( + "context" "io" "restic" "restic/debug" @@ -20,7 +21,7 @@ type Reader struct { } // Archive reads data from the reader and saves it to the repo. -func (r *Reader) Archive(name string, rd io.Reader, p *restic.Progress) (*restic.Snapshot, restic.ID, error) { +func (r *Reader) Archive(ctx context.Context, name string, rd io.Reader, p *restic.Progress) (*restic.Snapshot, restic.ID, error) { if name == "" { return nil, restic.ID{}, errors.New("no filename given") } @@ -53,7 +54,7 @@ func (r *Reader) Archive(name string, rd io.Reader, p *restic.Progress) (*restic id := restic.Hash(chunk.Data) if !repo.Index().Has(id, restic.DataBlob) { - _, err := repo.SaveBlob(restic.DataBlob, chunk.Data, id) + _, err := repo.SaveBlob(ctx, restic.DataBlob, chunk.Data, id) if err != nil { return nil, restic.ID{}, err } @@ -87,14 +88,14 @@ func (r *Reader) Archive(name string, rd io.Reader, p *restic.Progress) (*restic }, } - treeID, err := repo.SaveTree(tree) + treeID, err := repo.SaveTree(ctx, tree) if err != nil { return nil, restic.ID{}, err } sn.Tree = &treeID debug.Log("tree saved as %v", treeID.Str()) - id, err := repo.SaveJSONUnpacked(restic.SnapshotFile, sn) + id, err := repo.SaveJSONUnpacked(ctx, restic.SnapshotFile, sn) if err != nil { return nil, restic.ID{}, err } @@ -106,7 +107,7 @@ func (r *Reader) Archive(name string, rd io.Reader, p *restic.Progress) (*restic return nil, restic.ID{}, err } - err = repo.SaveIndex() + err = repo.SaveIndex(ctx) if err != nil { return nil, restic.ID{}, err } diff --git a/src/restic/archiver/archive_reader_test.go b/src/restic/archiver/archive_reader_test.go index a8ab18668..03c644846 100644 --- a/src/restic/archiver/archive_reader_test.go +++ b/src/restic/archiver/archive_reader_test.go @@ -2,6 +2,7 @@ package archiver import ( "bytes" + "context" "errors" "io" "math/rand" @@ -12,7 +13,7 @@ import ( ) func loadBlob(t *testing.T, repo restic.Repository, id restic.ID, buf []byte) int { - n, err := repo.LoadBlob(restic.DataBlob, id, buf) + n, err := repo.LoadBlob(context.TODO(), restic.DataBlob, id, buf) if err != nil { t.Fatalf("LoadBlob(%v) returned error %v", id, err) } @@ -21,7 +22,7 @@ func loadBlob(t *testing.T, repo restic.Repository, id restic.ID, buf []byte) in } func checkSavedFile(t *testing.T, repo restic.Repository, treeID restic.ID, name string, rd io.Reader) { - tree, err := repo.LoadTree(treeID) + tree, err := repo.LoadTree(context.TODO(), treeID) if err != nil { t.Fatalf("LoadTree() returned error %v", err) } @@ -85,7 +86,7 @@ func TestArchiveReader(t *testing.T) { Tags: []string{"test"}, } - sn, id, err := r.Archive("fakefile", f, nil) + sn, id, err := r.Archive(context.TODO(), "fakefile", f, nil) if err != nil { t.Fatalf("ArchiveReader() returned error %v", err) } @@ -111,7 +112,7 @@ func TestArchiveReaderNull(t *testing.T) { Tags: []string{"test"}, } - sn, id, err := r.Archive("fakefile", bytes.NewReader(nil), nil) + sn, id, err := r.Archive(context.TODO(), "fakefile", bytes.NewReader(nil), nil) if err != nil { t.Fatalf("ArchiveReader() returned error %v", err) } @@ -132,11 +133,8 @@ func (e errReader) Read([]byte) (int, error) { } func countSnapshots(t testing.TB, repo restic.Repository) int { - done := make(chan struct{}) - defer close(done) - snapshots := 0 - for range repo.List(restic.SnapshotFile, done) { + for range repo.List(context.TODO(), restic.SnapshotFile) { snapshots++ } return snapshots @@ -152,7 +150,7 @@ func TestArchiveReaderError(t *testing.T) { Tags: []string{"test"}, } - sn, id, err := r.Archive("fakefile", errReader("error returned by reading stdin"), nil) + sn, id, err := r.Archive(context.TODO(), "fakefile", errReader("error returned by reading stdin"), nil) if err == nil { t.Errorf("expected error not returned") } @@ -195,7 +193,7 @@ func BenchmarkArchiveReader(t *testing.B) { t.ResetTimer() for i := 0; i < t.N; i++ { - _, _, err := r.Archive("fakefile", bytes.NewReader(buf), nil) + _, _, err := r.Archive(context.TODO(), "fakefile", bytes.NewReader(buf), nil) if err != nil { t.Fatal(err) } diff --git a/src/restic/archiver/archiver.go b/src/restic/archiver/archiver.go index e90cfd8f8..fd471a273 100644 --- a/src/restic/archiver/archiver.go +++ b/src/restic/archiver/archiver.go @@ -1,6 +1,7 @@ package archiver import ( + "context" "encoding/json" "fmt" "io" @@ -92,7 +93,7 @@ func (arch *Archiver) isKnownBlob(id restic.ID, t restic.BlobType) bool { } // Save stores a blob read from rd in the repository. -func (arch *Archiver) Save(t restic.BlobType, data []byte, id restic.ID) error { +func (arch *Archiver) Save(ctx context.Context, t restic.BlobType, data []byte, id restic.ID) error { debug.Log("Save(%v, %v)\n", t, id.Str()) if arch.isKnownBlob(id, restic.DataBlob) { @@ -100,7 +101,7 @@ func (arch *Archiver) Save(t restic.BlobType, data []byte, id restic.ID) error { return nil } - _, err := arch.repo.SaveBlob(t, data, id) + _, err := arch.repo.SaveBlob(ctx, t, data, id) if err != nil { debug.Log("Save(%v, %v): error %v\n", t, id.Str(), err) return err @@ -111,7 +112,7 @@ func (arch *Archiver) Save(t restic.BlobType, data []byte, id restic.ID) error { } // SaveTreeJSON stores a tree in the repository. -func (arch *Archiver) SaveTreeJSON(tree *restic.Tree) (restic.ID, error) { +func (arch *Archiver) SaveTreeJSON(ctx context.Context, tree *restic.Tree) (restic.ID, error) { data, err := json.Marshal(tree) if err != nil { return restic.ID{}, errors.Wrap(err, "Marshal") @@ -124,7 +125,7 @@ func (arch *Archiver) SaveTreeJSON(tree *restic.Tree) (restic.ID, error) { return id, nil } - return arch.repo.SaveBlob(restic.TreeBlob, data, id) + return arch.repo.SaveBlob(ctx, restic.TreeBlob, data, id) } func (arch *Archiver) reloadFileIfChanged(node *restic.Node, file fs.File) (*restic.Node, error) { @@ -153,11 +154,11 @@ type saveResult struct { bytes uint64 } -func (arch *Archiver) saveChunk(chunk chunker.Chunk, p *restic.Progress, token struct{}, file fs.File, resultChannel chan<- saveResult) { +func (arch *Archiver) saveChunk(ctx context.Context, chunk chunker.Chunk, p *restic.Progress, token struct{}, file fs.File, resultChannel chan<- saveResult) { defer freeBuf(chunk.Data) id := restic.Hash(chunk.Data) - err := arch.Save(restic.DataBlob, chunk.Data, id) + err := arch.Save(ctx, restic.DataBlob, chunk.Data, id) // TODO handle error if err != nil { panic(err) @@ -206,7 +207,7 @@ func updateNodeContent(node *restic.Node, results []saveResult) error { // SaveFile stores the content of the file on the backend as a Blob by calling // Save for each chunk. -func (arch *Archiver) SaveFile(p *restic.Progress, node *restic.Node) (*restic.Node, error) { +func (arch *Archiver) SaveFile(ctx context.Context, p *restic.Progress, node *restic.Node) (*restic.Node, error) { file, err := fs.Open(node.Path) defer file.Close() if err != nil { @@ -234,7 +235,7 @@ func (arch *Archiver) SaveFile(p *restic.Progress, node *restic.Node) (*restic.N } resCh := make(chan saveResult, 1) - go arch.saveChunk(chunk, p, <-arch.blobToken, file, resCh) + go arch.saveChunk(ctx, chunk, p, <-arch.blobToken, file, resCh) resultChannels = append(resultChannels, resCh) } @@ -247,7 +248,7 @@ func (arch *Archiver) SaveFile(p *restic.Progress, node *restic.Node) (*restic.N return node, err } -func (arch *Archiver) fileWorker(wg *sync.WaitGroup, p *restic.Progress, done <-chan struct{}, entCh <-chan pipe.Entry) { +func (arch *Archiver) fileWorker(ctx context.Context, wg *sync.WaitGroup, p *restic.Progress, entCh <-chan pipe.Entry) { defer func() { debug.Log("done") wg.Done() @@ -305,7 +306,7 @@ func (arch *Archiver) fileWorker(wg *sync.WaitGroup, p *restic.Progress, done <- // otherwise read file normally if node.Type == "file" && len(node.Content) == 0 { debug.Log(" read and save %v", e.Path()) - node, err = arch.SaveFile(p, node) + node, err = arch.SaveFile(ctx, p, node) if err != nil { fmt.Fprintf(os.Stderr, "error for %v: %v\n", node.Path, err) arch.Warn(e.Path(), nil, err) @@ -322,14 +323,14 @@ func (arch *Archiver) fileWorker(wg *sync.WaitGroup, p *restic.Progress, done <- debug.Log(" processed %v, %d blobs", e.Path(), len(node.Content)) e.Result() <- node p.Report(restic.Stat{Files: 1}) - case <-done: + case <-ctx.Done(): // pipeline was cancelled return } } } -func (arch *Archiver) dirWorker(wg *sync.WaitGroup, p *restic.Progress, done <-chan struct{}, dirCh <-chan pipe.Dir) { +func (arch *Archiver) dirWorker(ctx context.Context, wg *sync.WaitGroup, p *restic.Progress, dirCh <-chan pipe.Dir) { debug.Log("start") defer func() { debug.Log("done") @@ -398,7 +399,7 @@ func (arch *Archiver) dirWorker(wg *sync.WaitGroup, p *restic.Progress, done <-c node.Error = err.Error() } - id, err := arch.SaveTreeJSON(tree) + id, err := arch.SaveTreeJSON(ctx, tree) if err != nil { panic(err) } @@ -415,7 +416,7 @@ func (arch *Archiver) dirWorker(wg *sync.WaitGroup, p *restic.Progress, done <-c if dir.Path() != "" { p.Report(restic.Stat{Dirs: 1}) } - case <-done: + case <-ctx.Done(): // pipeline was cancelled return } @@ -427,7 +428,7 @@ type archivePipe struct { New <-chan pipe.Job } -func copyJobs(done <-chan struct{}, in <-chan pipe.Job, out chan<- pipe.Job) { +func copyJobs(ctx context.Context, in <-chan pipe.Job, out chan<- pipe.Job) { var ( // disable sending on the outCh until we received a job outCh chan<- pipe.Job @@ -439,7 +440,7 @@ func copyJobs(done <-chan struct{}, in <-chan pipe.Job, out chan<- pipe.Job) { for { select { - case <-done: + case <-ctx.Done(): return case job, ok = <-inCh: if !ok { @@ -462,7 +463,7 @@ type archiveJob struct { new pipe.Job } -func (a *archivePipe) compare(done <-chan struct{}, out chan<- pipe.Job) { +func (a *archivePipe) compare(ctx context.Context, out chan<- pipe.Job) { defer func() { close(out) debug.Log("done") @@ -488,7 +489,7 @@ func (a *archivePipe) compare(done <-chan struct{}, out chan<- pipe.Job) { out <- archiveJob{new: newJob}.Copy() } - copyJobs(done, a.New, out) + copyJobs(ctx, a.New, out) return } @@ -585,7 +586,7 @@ func (j archiveJob) Copy() pipe.Job { const saveIndexTime = 30 * time.Second // saveIndexes regularly queries the master index for full indexes and saves them. -func (arch *Archiver) saveIndexes(wg *sync.WaitGroup, done <-chan struct{}) { +func (arch *Archiver) saveIndexes(ctx context.Context, wg *sync.WaitGroup) { defer wg.Done() ticker := time.NewTicker(saveIndexTime) @@ -593,11 +594,11 @@ func (arch *Archiver) saveIndexes(wg *sync.WaitGroup, done <-chan struct{}) { for { select { - case <-done: + case <-ctx.Done(): return case <-ticker.C: debug.Log("saving full indexes") - err := arch.repo.SaveFullIndex() + err := arch.repo.SaveFullIndex(ctx) if err != nil { debug.Log("save indexes returned an error: %v", err) fmt.Fprintf(os.Stderr, "error saving preliminary index: %v\n", err) @@ -634,7 +635,7 @@ func (p baseNameSlice) Swap(i, j int) { p[i], p[j] = p[j], p[i] } // Snapshot creates a snapshot of the given paths. If parentrestic.ID is set, this is // used to compare the files to the ones archived at the time this snapshot was // taken. -func (arch *Archiver) Snapshot(p *restic.Progress, paths, tags []string, hostname string, parentID *restic.ID) (*restic.Snapshot, restic.ID, error) { +func (arch *Archiver) Snapshot(ctx context.Context, p *restic.Progress, paths, tags []string, hostname string, parentID *restic.ID) (*restic.Snapshot, restic.ID, error) { paths = unique(paths) sort.Sort(baseNameSlice(paths)) @@ -643,7 +644,6 @@ func (arch *Archiver) Snapshot(p *restic.Progress, paths, tags []string, hostnam debug.RunHook("Archiver.Snapshot", nil) // signal the whole pipeline to stop - done := make(chan struct{}) var err error p.Start() @@ -663,14 +663,14 @@ func (arch *Archiver) Snapshot(p *restic.Progress, paths, tags []string, hostnam sn.Parent = parentID // load parent snapshot - parent, err := restic.LoadSnapshot(arch.repo, *parentID) + parent, err := restic.LoadSnapshot(ctx, arch.repo, *parentID) if err != nil { return nil, restic.ID{}, err } // start walker on old tree ch := make(chan walk.TreeJob) - go walk.Tree(arch.repo, *parent.Tree, done, ch) + go walk.Tree(ctx, arch.repo, *parent.Tree, ch) jobs.Old = ch } else { // use closed channel @@ -683,13 +683,13 @@ func (arch *Archiver) Snapshot(p *restic.Progress, paths, tags []string, hostnam pipeCh := make(chan pipe.Job) resCh := make(chan pipe.Result, 1) go func() { - pipe.Walk(paths, arch.SelectFilter, done, pipeCh, resCh) + pipe.Walk(ctx, paths, arch.SelectFilter, pipeCh, resCh) debug.Log("pipe.Walk done") }() jobs.New = pipeCh ch := make(chan pipe.Job) - go jobs.compare(done, ch) + go jobs.compare(ctx, ch) var wg sync.WaitGroup entCh := make(chan pipe.Entry) @@ -708,22 +708,22 @@ func (arch *Archiver) Snapshot(p *restic.Progress, paths, tags []string, hostnam // run workers for i := 0; i < maxConcurrency; i++ { wg.Add(2) - go arch.fileWorker(&wg, p, done, entCh) - go arch.dirWorker(&wg, p, done, dirCh) + go arch.fileWorker(ctx, &wg, p, entCh) + go arch.dirWorker(ctx, &wg, p, dirCh) } // run index saver var wgIndexSaver sync.WaitGroup - stopIndexSaver := make(chan struct{}) + indexCtx, indexCancel := context.WithCancel(ctx) wgIndexSaver.Add(1) - go arch.saveIndexes(&wgIndexSaver, stopIndexSaver) + go arch.saveIndexes(indexCtx, &wgIndexSaver) // wait for all workers to terminate debug.Log("wait for workers") wg.Wait() // stop index saver - close(stopIndexSaver) + indexCancel() wgIndexSaver.Wait() debug.Log("workers terminated") @@ -740,7 +740,7 @@ func (arch *Archiver) Snapshot(p *restic.Progress, paths, tags []string, hostnam sn.Tree = root.Subtree // load top-level tree again to see if it is empty - toptree, err := arch.repo.LoadTree(*root.Subtree) + toptree, err := arch.repo.LoadTree(ctx, *root.Subtree) if err != nil { return nil, restic.ID{}, err } @@ -750,7 +750,7 @@ func (arch *Archiver) Snapshot(p *restic.Progress, paths, tags []string, hostnam } // save index - err = arch.repo.SaveIndex() + err = arch.repo.SaveIndex(ctx) if err != nil { debug.Log("error saving index: %v", err) return nil, restic.ID{}, err @@ -759,7 +759,7 @@ func (arch *Archiver) Snapshot(p *restic.Progress, paths, tags []string, hostnam debug.Log("saved indexes") // save snapshot - id, err := arch.repo.SaveJSONUnpacked(restic.SnapshotFile, sn) + id, err := arch.repo.SaveJSONUnpacked(ctx, restic.SnapshotFile, sn) if err != nil { return nil, restic.ID{}, err } diff --git a/src/restic/archiver/archiver_duplication_test.go b/src/restic/archiver/archiver_duplication_test.go index c7de1cc49..9c01d5035 100644 --- a/src/restic/archiver/archiver_duplication_test.go +++ b/src/restic/archiver/archiver_duplication_test.go @@ -1,6 +1,7 @@ package archiver_test import ( + "context" "crypto/rand" "io" mrand "math/rand" @@ -39,33 +40,33 @@ func randomID() restic.ID { func forgetfulBackend() restic.Backend { be := &mock.Backend{} - be.TestFn = func(h restic.Handle) (bool, error) { + be.TestFn = func(ctx context.Context, h restic.Handle) (bool, error) { return false, nil } - be.LoadFn = func(h restic.Handle, length int, offset int64) (io.ReadCloser, error) { + be.LoadFn = func(ctx context.Context, h restic.Handle, length int, offset int64) (io.ReadCloser, error) { return nil, errors.New("not found") } - be.SaveFn = func(h restic.Handle, rd io.Reader) error { + be.SaveFn = func(ctx context.Context, h restic.Handle, rd io.Reader) error { return nil } - be.StatFn = func(h restic.Handle) (restic.FileInfo, error) { + be.StatFn = func(ctx context.Context, h restic.Handle) (restic.FileInfo, error) { return restic.FileInfo{}, errors.New("not found") } - be.RemoveFn = func(h restic.Handle) error { + be.RemoveFn = func(ctx context.Context, h restic.Handle) error { return nil } - be.ListFn = func(t restic.FileType, done <-chan struct{}) <-chan string { + be.ListFn = func(ctx context.Context, t restic.FileType) <-chan string { ch := make(chan string) close(ch) return ch } - be.DeleteFn = func() error { + be.DeleteFn = func(ctx context.Context) error { return nil } @@ -80,7 +81,7 @@ func testArchiverDuplication(t *testing.T) { repo := repository.New(forgetfulBackend()) - err = repo.Init("foo") + err = repo.Init(context.TODO(), "foo") if err != nil { t.Fatal(err) } @@ -108,7 +109,7 @@ func testArchiverDuplication(t *testing.T) { buf := make([]byte, 50) - err := arch.Save(restic.DataBlob, buf, id) + err := arch.Save(context.TODO(), restic.DataBlob, buf, id) if err != nil { t.Fatal(err) } @@ -127,7 +128,7 @@ func testArchiverDuplication(t *testing.T) { case <-done: return case <-ticker.C: - err := repo.SaveFullIndex() + err := repo.SaveFullIndex(context.TODO()) if err != nil { t.Fatal(err) } diff --git a/src/restic/archiver/archiver_int_test.go b/src/restic/archiver/archiver_int_test.go index c4014f5b0..eb135436e 100644 --- a/src/restic/archiver/archiver_int_test.go +++ b/src/restic/archiver/archiver_int_test.go @@ -1,6 +1,7 @@ package archiver import ( + "context" "os" "testing" @@ -83,10 +84,10 @@ func (j testPipeJob) Error() error { return j.err } func (j testPipeJob) Info() os.FileInfo { return j.fi } func (j testPipeJob) Result() chan<- pipe.Result { return j.res } -func testTreeWalker(done <-chan struct{}, out chan<- walk.TreeJob) { +func testTreeWalker(ctx context.Context, out chan<- walk.TreeJob) { for _, e := range treeJobs { select { - case <-done: + case <-ctx.Done(): return case out <- walk.TreeJob{Path: e}: } @@ -95,10 +96,10 @@ func testTreeWalker(done <-chan struct{}, out chan<- walk.TreeJob) { close(out) } -func testPipeWalker(done <-chan struct{}, out chan<- pipe.Job) { +func testPipeWalker(ctx context.Context, out chan<- pipe.Job) { for _, e := range pipeJobs { select { - case <-done: + case <-ctx.Done(): return case out <- testPipeJob{path: e}: } @@ -108,19 +109,19 @@ func testPipeWalker(done <-chan struct{}, out chan<- pipe.Job) { } func TestArchivePipe(t *testing.T) { - done := make(chan struct{}) + ctx := context.TODO() treeCh := make(chan walk.TreeJob) pipeCh := make(chan pipe.Job) - go testTreeWalker(done, treeCh) - go testPipeWalker(done, pipeCh) + go testTreeWalker(ctx, treeCh) + go testPipeWalker(ctx, pipeCh) p := archivePipe{Old: treeCh, New: pipeCh} ch := make(chan pipe.Job) - go p.compare(done, ch) + go p.compare(ctx, ch) i := 0 for job := range ch { diff --git a/src/restic/archiver/archiver_test.go b/src/restic/archiver/archiver_test.go index 3536ff727..f4928089b 100644 --- a/src/restic/archiver/archiver_test.go +++ b/src/restic/archiver/archiver_test.go @@ -2,6 +2,7 @@ package archiver_test import ( "bytes" + "context" "io" "testing" "time" @@ -104,7 +105,7 @@ func archiveDirectory(b testing.TB) { arch := archiver.New(repo) - _, id, err := arch.Snapshot(nil, []string{BenchArchiveDirectory}, nil, "localhost", nil) + _, id, err := arch.Snapshot(context.TODO(), nil, []string{BenchArchiveDirectory}, nil, "localhost", nil) OK(b, err) b.Logf("snapshot archived as %v", id) @@ -129,7 +130,7 @@ func BenchmarkArchiveDirectory(b *testing.B) { } func countPacks(repo restic.Repository, t restic.FileType) (n uint) { - for range repo.Backend().List(t, nil) { + for range repo.Backend().List(context.TODO(), t) { n++ } @@ -234,7 +235,7 @@ func testParallelSaveWithDuplication(t *testing.T, seed int) { id := restic.Hash(c.Data) time.Sleep(time.Duration(id[0])) - err := arch.Save(restic.DataBlob, c.Data, id) + err := arch.Save(context.TODO(), restic.DataBlob, c.Data, id) <-barrier errChan <- err }(c, errChan) @@ -246,7 +247,7 @@ func testParallelSaveWithDuplication(t *testing.T, seed int) { } OK(t, repo.Flush()) - OK(t, repo.SaveIndex()) + OK(t, repo.SaveIndex(context.TODO())) chkr := createAndInitChecker(t, repo) assertNoUnreferencedPacks(t, chkr) @@ -271,7 +272,7 @@ func getRandomData(seed int, size int) []chunker.Chunk { func createAndInitChecker(t *testing.T, repo restic.Repository) *checker.Checker { chkr := checker.New(repo) - hints, errs := chkr.LoadIndex() + hints, errs := chkr.LoadIndex(context.TODO()) if len(errs) > 0 { t.Fatalf("expected no errors, got %v: %v", len(errs), errs) } @@ -284,11 +285,8 @@ func createAndInitChecker(t *testing.T, repo restic.Repository) *checker.Checker } func assertNoUnreferencedPacks(t *testing.T, chkr *checker.Checker) { - done := make(chan struct{}) - defer close(done) - errChan := make(chan error) - go chkr.Packs(errChan, done) + go chkr.Packs(context.TODO(), errChan) for err := range errChan { OK(t, err) @@ -301,7 +299,7 @@ func TestArchiveEmptySnapshot(t *testing.T) { arch := archiver.New(repo) - sn, id, err := arch.Snapshot(nil, []string{"file-does-not-exist-123123213123", "file2-does-not-exist-too-123123123"}, nil, "localhost", nil) + sn, id, err := arch.Snapshot(context.TODO(), nil, []string{"file-does-not-exist-123123213123", "file2-does-not-exist-too-123123123"}, nil, "localhost", nil) if err == nil { t.Errorf("expected error for empty snapshot, got nil") } diff --git a/src/restic/archiver/testing.go b/src/restic/archiver/testing.go index aad8ea1af..40af1ec57 100644 --- a/src/restic/archiver/testing.go +++ b/src/restic/archiver/testing.go @@ -1,6 +1,7 @@ package archiver import ( + "context" "restic" "testing" ) @@ -8,7 +9,7 @@ import ( // TestSnapshot creates a new snapshot of path. func TestSnapshot(t testing.TB, repo restic.Repository, path string, parent *restic.ID) *restic.Snapshot { arch := New(repo) - sn, _, err := arch.Snapshot(nil, []string{path}, []string{"test"}, "localhost", parent) + sn, _, err := arch.Snapshot(context.TODO(), nil, []string{path}, []string{"test"}, "localhost", parent) if err != nil { t.Fatal(err) } diff --git a/src/restic/backend.go b/src/restic/backend.go index 4f776b167..0020a76a9 100644 --- a/src/restic/backend.go +++ b/src/restic/backend.go @@ -1,6 +1,9 @@ package restic -import "io" +import ( + "context" + "io" +) // Backend is used to store and access data. type Backend interface { @@ -9,30 +12,30 @@ type Backend interface { Location() string // Test a boolean value whether a File with the name and type exists. - Test(h Handle) (bool, error) + Test(ctx context.Context, h Handle) (bool, error) // Remove removes a File with type t and name. - Remove(h Handle) error + Remove(ctx context.Context, h Handle) error // Close the backend Close() error // Save stores the data in the backend under the given handle. - Save(h Handle, rd io.Reader) error + Save(ctx context.Context, h Handle, rd io.Reader) error // Load returns a reader that yields the contents of the file at h at the // given offset. If length is larger than zero, only a portion of the file // is returned. rd must be closed after use. If an error is returned, the // ReadCloser must be nil. - Load(h Handle, length int, offset int64) (io.ReadCloser, error) + Load(ctx context.Context, h Handle, length int, offset int64) (io.ReadCloser, error) // Stat returns information about the File identified by h. - Stat(h Handle) (FileInfo, error) + Stat(ctx context.Context, h Handle) (FileInfo, error) // List returns a channel that yields all names of files of type t in an - // arbitrary order. A goroutine is started for this. If the channel done is - // closed, sending stops. - List(t FileType, done <-chan struct{}) <-chan string + // arbitrary order. A goroutine is started for this, which is stopped when + // ctx is cancelled. + List(ctx context.Context, t FileType) <-chan string } // FileInfo is returned by Stat() and contains information about a file in the diff --git a/src/restic/backend/b2/b2.go b/src/restic/backend/b2/b2.go index c209c13ab..9b80f2133 100644 --- a/src/restic/backend/b2/b2.go +++ b/src/restic/backend/b2/b2.go @@ -23,6 +23,9 @@ type b2Backend struct { sem *backend.Semaphore } +// ensure statically that *b2Backend implements restic.Backend. +var _ restic.Backend = &b2Backend{} + func newClient(ctx context.Context, cfg Config) (*b2.Client, error) { opts := []b2.ClientOption{b2.Transport(backend.Transport())} @@ -96,7 +99,7 @@ func Create(cfg Config) (restic.Backend, error) { sem: backend.NewSemaphore(cfg.Connections), } - present, err := be.Test(restic.Handle{Type: restic.ConfigFile}) + present, err := be.Test(context.TODO(), restic.Handle{Type: restic.ConfigFile}) if err != nil { return nil, err } @@ -140,7 +143,7 @@ func (wr *wrapReader) Close() error { // Load returns the data stored in the backend for h at the given offset // and saves it in p. Load has the same semantics as io.ReaderAt. -func (be *b2Backend) Load(h restic.Handle, length int, offset int64) (io.ReadCloser, error) { +func (be *b2Backend) Load(ctx context.Context, h restic.Handle, length int, offset int64) (io.ReadCloser, error) { debug.Log("Load %v, length %v, offset %v from %v", h, length, offset, be.Filename(h)) if err := h.Valid(); err != nil { return nil, err @@ -154,7 +157,7 @@ func (be *b2Backend) Load(h restic.Handle, length int, offset int64) (io.ReadClo return nil, errors.Errorf("invalid length %d", length) } - ctx, cancel := context.WithCancel(context.TODO()) + ctx, cancel := context.WithCancel(ctx) be.sem.GetToken() @@ -191,8 +194,8 @@ func (be *b2Backend) Load(h restic.Handle, length int, offset int64) (io.ReadClo } // Save stores data in the backend at the handle. -func (be *b2Backend) Save(h restic.Handle, rd io.Reader) (err error) { - ctx, cancel := context.WithCancel(context.TODO()) +func (be *b2Backend) Save(ctx context.Context, h restic.Handle, rd io.Reader) (err error) { + ctx, cancel := context.WithCancel(ctx) defer cancel() if err := h.Valid(); err != nil { @@ -225,12 +228,9 @@ func (be *b2Backend) Save(h restic.Handle, rd io.Reader) (err error) { } // Stat returns information about a blob. -func (be *b2Backend) Stat(h restic.Handle) (bi restic.FileInfo, err error) { +func (be *b2Backend) Stat(ctx context.Context, h restic.Handle) (bi restic.FileInfo, err error) { debug.Log("Stat %v", h) - ctx, cancel := context.WithCancel(context.TODO()) - defer cancel() - be.sem.GetToken() defer be.sem.ReleaseToken() @@ -245,12 +245,9 @@ func (be *b2Backend) Stat(h restic.Handle) (bi restic.FileInfo, err error) { } // Test returns true if a blob of the given type and name exists in the backend. -func (be *b2Backend) Test(h restic.Handle) (bool, error) { +func (be *b2Backend) Test(ctx context.Context, h restic.Handle) (bool, error) { debug.Log("Test %v", h) - ctx, cancel := context.WithCancel(context.TODO()) - defer cancel() - be.sem.GetToken() defer be.sem.ReleaseToken() @@ -265,12 +262,9 @@ func (be *b2Backend) Test(h restic.Handle) (bool, error) { } // Remove removes the blob with the given name and type. -func (be *b2Backend) Remove(h restic.Handle) error { +func (be *b2Backend) Remove(ctx context.Context, h restic.Handle) error { debug.Log("Remove %v", h) - ctx, cancel := context.WithCancel(context.TODO()) - defer cancel() - be.sem.GetToken() defer be.sem.ReleaseToken() @@ -281,11 +275,11 @@ func (be *b2Backend) Remove(h restic.Handle) error { // List returns a channel that yields all names of blobs of type t. A // goroutine is started for this. If the channel done is closed, sending // stops. -func (be *b2Backend) List(t restic.FileType, done <-chan struct{}) <-chan string { +func (be *b2Backend) List(ctx context.Context, t restic.FileType) <-chan string { debug.Log("List %v", t) ch := make(chan string) - ctx, cancel := context.WithCancel(context.TODO()) + ctx, cancel := context.WithCancel(ctx) be.sem.GetToken() @@ -315,7 +309,7 @@ func (be *b2Backend) List(t restic.FileType, done <-chan struct{}) <-chan string select { case ch <- m: - case <-done: + case <-ctx.Done(): return } } @@ -330,13 +324,10 @@ func (be *b2Backend) List(t restic.FileType, done <-chan struct{}) <-chan string } // Remove keys for a specified backend type. -func (be *b2Backend) removeKeys(t restic.FileType) error { +func (be *b2Backend) removeKeys(ctx context.Context, t restic.FileType) error { debug.Log("removeKeys %v", t) - - done := make(chan struct{}) - defer close(done) - for key := range be.List(t, done) { - err := be.Remove(restic.Handle{Type: t, Name: key}) + for key := range be.List(ctx, t) { + err := be.Remove(ctx, restic.Handle{Type: t, Name: key}) if err != nil { return err } @@ -345,7 +336,7 @@ func (be *b2Backend) removeKeys(t restic.FileType) error { } // Delete removes all restic keys in the bucket. It will not remove the bucket itself. -func (be *b2Backend) Delete() error { +func (be *b2Backend) Delete(ctx context.Context) error { alltypes := []restic.FileType{ restic.DataFile, restic.KeyFile, @@ -354,12 +345,12 @@ func (be *b2Backend) Delete() error { restic.IndexFile} for _, t := range alltypes { - err := be.removeKeys(t) + err := be.removeKeys(ctx, t) if err != nil { return nil } } - err := be.Remove(restic.Handle{Type: restic.ConfigFile}) + err := be.Remove(ctx, restic.Handle{Type: restic.ConfigFile}) if err != nil && b2.IsNotExist(errors.Cause(err)) { err = nil } diff --git a/src/restic/backend/b2/b2_test.go b/src/restic/backend/b2/b2_test.go index 64c00c9ff..6cf5c1bc6 100644 --- a/src/restic/backend/b2/b2_test.go +++ b/src/restic/backend/b2/b2_test.go @@ -1,6 +1,7 @@ package b2_test import ( + "context" "fmt" "os" "testing" @@ -52,7 +53,7 @@ func newB2TestSuite(t testing.TB) *test.Suite { return err } - if err := be.(restic.Deleter).Delete(); err != nil { + if err := be.(restic.Deleter).Delete(context.TODO()); err != nil { return err } diff --git a/src/restic/backend/local/layout_test.go b/src/restic/backend/local/layout_test.go index 16b6b16e3..3f009b49d 100644 --- a/src/restic/backend/local/layout_test.go +++ b/src/restic/backend/local/layout_test.go @@ -1,6 +1,7 @@ package local import ( + "context" "path/filepath" "restic" . "restic/test" @@ -47,7 +48,7 @@ func TestLayout(t *testing.T) { } datafiles := make(map[string]bool) - for id := range be.List(restic.DataFile, nil) { + for id := range be.List(context.TODO(), restic.DataFile) { datafiles[id] = false } diff --git a/src/restic/backend/local/local.go b/src/restic/backend/local/local.go index 3b97f761e..1a3c0158c 100644 --- a/src/restic/backend/local/local.go +++ b/src/restic/backend/local/local.go @@ -1,6 +1,7 @@ package local import ( + "context" "io" "os" "path/filepath" @@ -75,7 +76,7 @@ func (b *Local) Location() string { } // Save stores data in the backend at the handle. -func (b *Local) Save(h restic.Handle, rd io.Reader) (err error) { +func (b *Local) Save(ctx context.Context, h restic.Handle, rd io.Reader) (err error) { debug.Log("Save %v", h) if err := h.Valid(); err != nil { return err @@ -100,7 +101,7 @@ func (b *Local) Save(h restic.Handle, rd io.Reader) (err error) { return errors.Wrap(err, "MkdirAll") } - return b.Save(h, rd) + return b.Save(ctx, h, rd) } if err != nil { @@ -110,12 +111,12 @@ func (b *Local) Save(h restic.Handle, rd io.Reader) (err error) { // save data, then sync _, err = io.Copy(f, rd) if err != nil { - f.Close() + _ = f.Close() return errors.Wrap(err, "Write") } if err = f.Sync(); err != nil { - f.Close() + _ = f.Close() return errors.Wrap(err, "Sync") } @@ -136,7 +137,7 @@ func (b *Local) Save(h restic.Handle, rd io.Reader) (err error) { // Load returns a reader that yields the contents of the file at h at the // given offset. If length is nonzero, only a portion of the file is // returned. rd must be closed after use. -func (b *Local) Load(h restic.Handle, length int, offset int64) (io.ReadCloser, error) { +func (b *Local) Load(ctx context.Context, h restic.Handle, length int, offset int64) (io.ReadCloser, error) { debug.Log("Load %v, length %v, offset %v", h, length, offset) if err := h.Valid(); err != nil { return nil, err @@ -154,7 +155,7 @@ func (b *Local) Load(h restic.Handle, length int, offset int64) (io.ReadCloser, if offset > 0 { _, err = f.Seek(offset, 0) if err != nil { - f.Close() + _ = f.Close() return nil, err } } @@ -167,7 +168,7 @@ func (b *Local) Load(h restic.Handle, length int, offset int64) (io.ReadCloser, } // Stat returns information about a blob. -func (b *Local) Stat(h restic.Handle) (restic.FileInfo, error) { +func (b *Local) Stat(ctx context.Context, h restic.Handle) (restic.FileInfo, error) { debug.Log("Stat %v", h) if err := h.Valid(); err != nil { return restic.FileInfo{}, err @@ -182,7 +183,7 @@ func (b *Local) Stat(h restic.Handle) (restic.FileInfo, error) { } // Test returns true if a blob of the given type and name exists in the backend. -func (b *Local) Test(h restic.Handle) (bool, error) { +func (b *Local) Test(ctx context.Context, h restic.Handle) (bool, error) { debug.Log("Test %v", h) _, err := fs.Stat(b.Filename(h)) if err != nil { @@ -196,7 +197,7 @@ func (b *Local) Test(h restic.Handle) (bool, error) { } // Remove removes the blob with the given name and type. -func (b *Local) Remove(h restic.Handle) error { +func (b *Local) Remove(ctx context.Context, h restic.Handle) error { debug.Log("Remove %v", h) fn := b.Filename(h) @@ -214,9 +215,8 @@ func isFile(fi os.FileInfo) bool { } // List returns a channel that yields all names of blobs of type t. A -// goroutine is started for this. If the channel done is closed, sending -// stops. -func (b *Local) List(t restic.FileType, done <-chan struct{}) <-chan string { +// goroutine is started for this. +func (b *Local) List(ctx context.Context, t restic.FileType) <-chan string { debug.Log("List %v", t) ch := make(chan string) @@ -235,7 +235,7 @@ func (b *Local) List(t restic.FileType, done <-chan struct{}) <-chan string { select { case ch <- filepath.Base(path): - case <-done: + case <-ctx.Done(): return err } diff --git a/src/restic/backend/mem/mem_backend.go b/src/restic/backend/mem/mem_backend.go index 3e96f6a36..bbb4dbd1a 100644 --- a/src/restic/backend/mem/mem_backend.go +++ b/src/restic/backend/mem/mem_backend.go @@ -2,6 +2,7 @@ package mem import ( "bytes" + "context" "io" "io/ioutil" "restic" @@ -37,7 +38,7 @@ func New() *MemoryBackend { } // Test returns whether a file exists. -func (be *MemoryBackend) Test(h restic.Handle) (bool, error) { +func (be *MemoryBackend) Test(ctx context.Context, h restic.Handle) (bool, error) { be.m.Lock() defer be.m.Unlock() @@ -51,7 +52,7 @@ func (be *MemoryBackend) Test(h restic.Handle) (bool, error) { } // Save adds new Data to the backend. -func (be *MemoryBackend) Save(h restic.Handle, rd io.Reader) error { +func (be *MemoryBackend) Save(ctx context.Context, h restic.Handle, rd io.Reader) error { if err := h.Valid(); err != nil { return err } @@ -81,7 +82,7 @@ func (be *MemoryBackend) Save(h restic.Handle, rd io.Reader) error { // Load returns a reader that yields the contents of the file at h at the // given offset. If length is nonzero, only a portion of the file is // returned. rd must be closed after use. -func (be *MemoryBackend) Load(h restic.Handle, length int, offset int64) (io.ReadCloser, error) { +func (be *MemoryBackend) Load(ctx context.Context, h restic.Handle, length int, offset int64) (io.ReadCloser, error) { if err := h.Valid(); err != nil { return nil, err } @@ -117,7 +118,7 @@ func (be *MemoryBackend) Load(h restic.Handle, length int, offset int64) (io.Rea } // Stat returns information about a file in the backend. -func (be *MemoryBackend) Stat(h restic.Handle) (restic.FileInfo, error) { +func (be *MemoryBackend) Stat(ctx context.Context, h restic.Handle) (restic.FileInfo, error) { be.m.Lock() defer be.m.Unlock() @@ -140,7 +141,7 @@ func (be *MemoryBackend) Stat(h restic.Handle) (restic.FileInfo, error) { } // Remove deletes a file from the backend. -func (be *MemoryBackend) Remove(h restic.Handle) error { +func (be *MemoryBackend) Remove(ctx context.Context, h restic.Handle) error { be.m.Lock() defer be.m.Unlock() @@ -156,7 +157,7 @@ func (be *MemoryBackend) Remove(h restic.Handle) error { } // List returns a channel which yields entries from the backend. -func (be *MemoryBackend) List(t restic.FileType, done <-chan struct{}) <-chan string { +func (be *MemoryBackend) List(ctx context.Context, t restic.FileType) <-chan string { be.m.Lock() defer be.m.Unlock() @@ -177,7 +178,7 @@ func (be *MemoryBackend) List(t restic.FileType, done <-chan struct{}) <-chan st for _, id := range ids { select { case ch <- id: - case <-done: + case <-ctx.Done(): return } } @@ -192,7 +193,7 @@ func (be *MemoryBackend) Location() string { } // Delete removes all data in the backend. -func (be *MemoryBackend) Delete() error { +func (be *MemoryBackend) Delete(ctx context.Context) error { be.m.Lock() defer be.m.Unlock() diff --git a/src/restic/backend/mem/mem_backend_test.go b/src/restic/backend/mem/mem_backend_test.go index 06da32661..422920da1 100644 --- a/src/restic/backend/mem/mem_backend_test.go +++ b/src/restic/backend/mem/mem_backend_test.go @@ -1,6 +1,7 @@ package mem_test import ( + "context" "restic" "testing" @@ -25,7 +26,7 @@ func newTestSuite() *test.Suite { Create: func(cfg interface{}) (restic.Backend, error) { c := cfg.(*memConfig) if c.be != nil { - ok, err := c.be.Test(restic.Handle{Type: restic.ConfigFile}) + ok, err := c.be.Test(context.TODO(), restic.Handle{Type: restic.ConfigFile}) if err != nil { return nil, err } diff --git a/src/restic/backend/rest/rest.go b/src/restic/backend/rest/rest.go index cb3fdadc8..f9ce1f681 100644 --- a/src/restic/backend/rest/rest.go +++ b/src/restic/backend/rest/rest.go @@ -1,6 +1,7 @@ package rest import ( + "context" "encoding/json" "fmt" "io" @@ -11,6 +12,8 @@ import ( "restic" "strings" + "golang.org/x/net/context/ctxhttp" + "restic/debug" "restic/errors" @@ -25,7 +28,7 @@ var _ restic.Backend = &restBackend{} type restBackend struct { url *url.URL connChan chan struct{} - client http.Client + client *http.Client backend.Layout } @@ -36,7 +39,7 @@ func Open(cfg Config) (restic.Backend, error) { connChan <- struct{}{} } - client := http.Client{Transport: backend.Transport()} + client := &http.Client{Transport: backend.Transport()} // use url without trailing slash for layout url := cfg.URL.String() @@ -61,7 +64,7 @@ func Create(cfg Config) (restic.Backend, error) { return nil, err } - _, err = be.Stat(restic.Handle{Type: restic.ConfigFile}) + _, err = be.Stat(context.TODO(), restic.Handle{Type: restic.ConfigFile}) if err == nil { return nil, errors.Fatal("config file already exists") } @@ -99,22 +102,25 @@ func (b *restBackend) Location() string { } // Save stores data in the backend at the handle. -func (b *restBackend) Save(h restic.Handle, rd io.Reader) (err error) { +func (b *restBackend) Save(ctx context.Context, h restic.Handle, rd io.Reader) (err error) { if err := h.Valid(); err != nil { return err } + ctx, cancel := context.WithCancel(ctx) + defer cancel() + // make sure that client.Post() cannot close the reader by wrapping it in // backend.Closer, which has a noop method. rd = backend.Closer{Reader: rd} <-b.connChan - resp, err := b.client.Post(b.Filename(h), "binary/octet-stream", rd) + resp, err := ctxhttp.Post(ctx, b.client, b.Filename(h), "binary/octet-stream", rd) b.connChan <- struct{}{} if resp != nil { defer func() { - io.Copy(ioutil.Discard, resp.Body) + _, _ = io.Copy(ioutil.Discard, resp.Body) e := resp.Body.Close() if err == nil { @@ -137,7 +143,7 @@ func (b *restBackend) Save(h restic.Handle, rd io.Reader) (err error) { // Load returns a reader that yields the contents of the file at h at the // given offset. If length is nonzero, only a portion of the file is // returned. rd must be closed after use. -func (b *restBackend) Load(h restic.Handle, length int, offset int64) (io.ReadCloser, error) { +func (b *restBackend) Load(ctx context.Context, h restic.Handle, length int, offset int64) (io.ReadCloser, error) { debug.Log("Load %v, length %v, offset %v", h, length, offset) if err := h.Valid(); err != nil { return nil, err @@ -164,20 +170,19 @@ func (b *restBackend) Load(h restic.Handle, length int, offset int64) (io.ReadCl debug.Log("Load(%v) send range %v", h, byteRange) <-b.connChan - resp, err := b.client.Do(req) + resp, err := ctxhttp.Do(ctx, b.client, req) b.connChan <- struct{}{} if err != nil { if resp != nil { - io.Copy(ioutil.Discard, resp.Body) - resp.Body.Close() + _, _ = io.Copy(ioutil.Discard, resp.Body) + _ = resp.Body.Close() } return nil, errors.Wrap(err, "client.Do") } if resp.StatusCode != 200 && resp.StatusCode != 206 { - io.Copy(ioutil.Discard, resp.Body) - resp.Body.Close() + _ = resp.Body.Close() return nil, errors.Errorf("unexpected HTTP response (%v): %v", resp.StatusCode, resp.Status) } @@ -185,19 +190,19 @@ func (b *restBackend) Load(h restic.Handle, length int, offset int64) (io.ReadCl } // Stat returns information about a blob. -func (b *restBackend) Stat(h restic.Handle) (restic.FileInfo, error) { +func (b *restBackend) Stat(ctx context.Context, h restic.Handle) (restic.FileInfo, error) { if err := h.Valid(); err != nil { return restic.FileInfo{}, err } <-b.connChan - resp, err := b.client.Head(b.Filename(h)) + resp, err := ctxhttp.Head(ctx, b.client, b.Filename(h)) b.connChan <- struct{}{} if err != nil { return restic.FileInfo{}, errors.Wrap(err, "client.Head") } - io.Copy(ioutil.Discard, resp.Body) + _, _ = io.Copy(ioutil.Discard, resp.Body) if err = resp.Body.Close(); err != nil { return restic.FileInfo{}, errors.Wrap(err, "Close") } @@ -218,8 +223,8 @@ func (b *restBackend) Stat(h restic.Handle) (restic.FileInfo, error) { } // Test returns true if a blob of the given type and name exists in the backend. -func (b *restBackend) Test(h restic.Handle) (bool, error) { - _, err := b.Stat(h) +func (b *restBackend) Test(ctx context.Context, h restic.Handle) (bool, error) { + _, err := b.Stat(ctx, h) if err != nil { return false, nil } @@ -228,7 +233,7 @@ func (b *restBackend) Test(h restic.Handle) (bool, error) { } // Remove removes the blob with the given name and type. -func (b *restBackend) Remove(h restic.Handle) error { +func (b *restBackend) Remove(ctx context.Context, h restic.Handle) error { if err := h.Valid(); err != nil { return err } @@ -238,7 +243,7 @@ func (b *restBackend) Remove(h restic.Handle) error { return errors.Wrap(err, "http.NewRequest") } <-b.connChan - resp, err := b.client.Do(req) + resp, err := ctxhttp.Do(ctx, b.client, req) b.connChan <- struct{}{} if err != nil { @@ -249,14 +254,18 @@ func (b *restBackend) Remove(h restic.Handle) error { return errors.Errorf("blob not removed, server response: %v (%v)", resp.Status, resp.StatusCode) } - io.Copy(ioutil.Discard, resp.Body) - return resp.Body.Close() + _, err = io.Copy(ioutil.Discard, resp.Body) + if err != nil { + return errors.Wrap(err, "Copy") + } + + return errors.Wrap(resp.Body.Close(), "Close") } // List returns a channel that yields all names of blobs of type t. A // goroutine is started for this. If the channel done is closed, sending // stops. -func (b *restBackend) List(t restic.FileType, done <-chan struct{}) <-chan string { +func (b *restBackend) List(ctx context.Context, t restic.FileType) <-chan string { ch := make(chan string) url := b.Dirname(restic.Handle{Type: t}) @@ -265,12 +274,12 @@ func (b *restBackend) List(t restic.FileType, done <-chan struct{}) <-chan strin } <-b.connChan - resp, err := b.client.Get(url) + resp, err := ctxhttp.Get(ctx, b.client, url) b.connChan <- struct{}{} if resp != nil { defer func() { - io.Copy(ioutil.Discard, resp.Body) + _, _ = io.Copy(ioutil.Discard, resp.Body) e := resp.Body.Close() if err == nil { @@ -296,7 +305,7 @@ func (b *restBackend) List(t restic.FileType, done <-chan struct{}) <-chan strin for _, m := range list { select { case ch <- m: - case <-done: + case <-ctx.Done(): return } } diff --git a/src/restic/backend/s3/s3.go b/src/restic/backend/s3/s3.go index 960258392..b7c15f2fe 100644 --- a/src/restic/backend/s3/s3.go +++ b/src/restic/backend/s3/s3.go @@ -1,6 +1,7 @@ package s3 import ( + "context" "fmt" "io" "os" @@ -31,6 +32,9 @@ type s3 struct { backend.Layout } +// make sure that *s3 implements backend.Backend +var _ restic.Backend = &s3{} + const defaultLayout = "s3legacy" // Open opens the S3 backend at bucket and region. The bucket is created if it @@ -202,7 +206,7 @@ func (wr preventCloser) Close() error { } // Save stores data in the backend at the handle. -func (be *s3) Save(h restic.Handle, rd io.Reader) (err error) { +func (be *s3) Save(ctx context.Context, h restic.Handle, rd io.Reader) (err error) { if err := h.Valid(); err != nil { return err } @@ -259,7 +263,7 @@ func (wr wrapReader) Close() error { // Load returns a reader that yields the contents of the file at h at the // given offset. If length is nonzero, only a portion of the file is // returned. rd must be closed after use. -func (be *s3) Load(h restic.Handle, length int, offset int64) (io.ReadCloser, error) { +func (be *s3) Load(ctx context.Context, h restic.Handle, length int, offset int64) (io.ReadCloser, error) { debug.Log("Load %v, length %v, offset %v from %v", h, length, offset, be.Filename(h)) if err := h.Valid(); err != nil { return nil, err @@ -307,7 +311,7 @@ func (be *s3) Load(h restic.Handle, length int, offset int64) (io.ReadCloser, er } // Stat returns information about a blob. -func (be *s3) Stat(h restic.Handle) (bi restic.FileInfo, err error) { +func (be *s3) Stat(ctx context.Context, h restic.Handle) (bi restic.FileInfo, err error) { debug.Log("%v", h) objName := be.Filename(h) @@ -337,7 +341,7 @@ func (be *s3) Stat(h restic.Handle) (bi restic.FileInfo, err error) { } // Test returns true if a blob of the given type and name exists in the backend. -func (be *s3) Test(h restic.Handle) (bool, error) { +func (be *s3) Test(ctx context.Context, h restic.Handle) (bool, error) { found := false objName := be.Filename(h) _, err := be.client.StatObject(be.bucketname, objName) @@ -350,7 +354,7 @@ func (be *s3) Test(h restic.Handle) (bool, error) { } // Remove removes the blob with the given name and type. -func (be *s3) Remove(h restic.Handle) error { +func (be *s3) Remove(ctx context.Context, h restic.Handle) error { objName := be.Filename(h) err := be.client.RemoveObject(be.bucketname, objName) debug.Log("Remove(%v) at %v -> err %v", h, objName, err) @@ -360,7 +364,7 @@ func (be *s3) Remove(h restic.Handle) error { // List returns a channel that yields all names of blobs of type t. A // goroutine is started for this. If the channel done is closed, sending // stops. -func (be *s3) List(t restic.FileType, done <-chan struct{}) <-chan string { +func (be *s3) List(ctx context.Context, t restic.FileType) <-chan string { debug.Log("listing %v", t) ch := make(chan string) @@ -371,7 +375,7 @@ func (be *s3) List(t restic.FileType, done <-chan struct{}) <-chan string { prefix += "/" } - listresp := be.client.ListObjects(be.bucketname, prefix, true, done) + listresp := be.client.ListObjects(be.bucketname, prefix, true, ctx.Done()) go func() { defer close(ch) @@ -383,7 +387,7 @@ func (be *s3) List(t restic.FileType, done <-chan struct{}) <-chan string { select { case ch <- path.Base(m): - case <-done: + case <-ctx.Done(): return } } @@ -393,11 +397,9 @@ func (be *s3) List(t restic.FileType, done <-chan struct{}) <-chan string { } // Remove keys for a specified backend type. -func (be *s3) removeKeys(t restic.FileType) error { - done := make(chan struct{}) - defer close(done) - for key := range be.List(restic.DataFile, done) { - err := be.Remove(restic.Handle{Type: restic.DataFile, Name: key}) +func (be *s3) removeKeys(ctx context.Context, t restic.FileType) error { + for key := range be.List(ctx, restic.DataFile) { + err := be.Remove(ctx, restic.Handle{Type: restic.DataFile, Name: key}) if err != nil { return err } @@ -407,7 +409,7 @@ func (be *s3) removeKeys(t restic.FileType) error { } // Delete removes all restic keys in the bucket. It will not remove the bucket itself. -func (be *s3) Delete() error { +func (be *s3) Delete(ctx context.Context) error { alltypes := []restic.FileType{ restic.DataFile, restic.KeyFile, @@ -416,13 +418,13 @@ func (be *s3) Delete() error { restic.IndexFile} for _, t := range alltypes { - err := be.removeKeys(t) + err := be.removeKeys(ctx, t) if err != nil { return nil } } - return be.Remove(restic.Handle{Type: restic.ConfigFile}) + return be.Remove(ctx, restic.Handle{Type: restic.ConfigFile}) } // Close does nothing diff --git a/src/restic/backend/s3/s3_test.go b/src/restic/backend/s3/s3_test.go index 787166994..d3f870c0a 100644 --- a/src/restic/backend/s3/s3_test.go +++ b/src/restic/backend/s3/s3_test.go @@ -134,7 +134,7 @@ func newMinioTestSuite(ctx context.Context, t testing.TB) *test.Suite { return nil, err } - exists, err := be.Test(restic.Handle{Type: restic.ConfigFile}) + exists, err := be.Test(context.TODO(), restic.Handle{Type: restic.ConfigFile}) if err != nil { return nil, err } @@ -228,7 +228,7 @@ func newS3TestSuite(t testing.TB) *test.Suite { return nil, err } - exists, err := be.Test(restic.Handle{Type: restic.ConfigFile}) + exists, err := be.Test(context.TODO(), restic.Handle{Type: restic.ConfigFile}) if err != nil { return nil, err } @@ -255,7 +255,7 @@ func newS3TestSuite(t testing.TB) *test.Suite { return err } - if err := be.(restic.Deleter).Delete(); err != nil { + if err := be.(restic.Deleter).Delete(context.TODO()); err != nil { return err } diff --git a/src/restic/backend/sftp/layout_test.go b/src/restic/backend/sftp/layout_test.go index 166fa97e3..aa030ee05 100644 --- a/src/restic/backend/sftp/layout_test.go +++ b/src/restic/backend/sftp/layout_test.go @@ -1,6 +1,7 @@ package sftp_test import ( + "context" "fmt" "path/filepath" "restic" @@ -54,7 +55,7 @@ func TestLayout(t *testing.T) { } datafiles := make(map[string]bool) - for id := range be.List(restic.DataFile, nil) { + for id := range be.List(context.TODO(), restic.DataFile) { datafiles[id] = false } diff --git a/src/restic/backend/sftp/sftp.go b/src/restic/backend/sftp/sftp.go index 8070d01fe..f871da324 100644 --- a/src/restic/backend/sftp/sftp.go +++ b/src/restic/backend/sftp/sftp.go @@ -2,6 +2,7 @@ package sftp import ( "bufio" + "context" "fmt" "io" "os" @@ -262,7 +263,7 @@ func Join(parts ...string) string { } // Save stores data in the backend at the handle. -func (r *SFTP) Save(h restic.Handle, rd io.Reader) (err error) { +func (r *SFTP) Save(ctx context.Context, h restic.Handle, rd io.Reader) (err error) { debug.Log("Save %v", h) if err := r.clientError(); err != nil { return err @@ -283,7 +284,7 @@ func (r *SFTP) Save(h restic.Handle, rd io.Reader) (err error) { return errors.Wrap(err, "MkdirAll") } - return r.Save(h, rd) + return r.Save(ctx, h, rd) } if err != nil { @@ -315,7 +316,7 @@ func (r *SFTP) Save(h restic.Handle, rd io.Reader) (err error) { // Load returns a reader that yields the contents of the file at h at the // given offset. If length is nonzero, only a portion of the file is // returned. rd must be closed after use. -func (r *SFTP) Load(h restic.Handle, length int, offset int64) (io.ReadCloser, error) { +func (r *SFTP) Load(ctx context.Context, h restic.Handle, length int, offset int64) (io.ReadCloser, error) { debug.Log("Load %v, length %v, offset %v", h, length, offset) if err := h.Valid(); err != nil { return nil, err @@ -346,7 +347,7 @@ func (r *SFTP) Load(h restic.Handle, length int, offset int64) (io.ReadCloser, e } // Stat returns information about a blob. -func (r *SFTP) Stat(h restic.Handle) (restic.FileInfo, error) { +func (r *SFTP) Stat(ctx context.Context, h restic.Handle) (restic.FileInfo, error) { debug.Log("Stat(%v)", h) if err := r.clientError(); err != nil { return restic.FileInfo{}, err @@ -365,7 +366,7 @@ func (r *SFTP) Stat(h restic.Handle) (restic.FileInfo, error) { } // Test returns true if a blob of the given type and name exists in the backend. -func (r *SFTP) Test(h restic.Handle) (bool, error) { +func (r *SFTP) Test(ctx context.Context, h restic.Handle) (bool, error) { debug.Log("Test(%v)", h) if err := r.clientError(); err != nil { return false, err @@ -384,7 +385,7 @@ func (r *SFTP) Test(h restic.Handle) (bool, error) { } // Remove removes the content stored at name. -func (r *SFTP) Remove(h restic.Handle) error { +func (r *SFTP) Remove(ctx context.Context, h restic.Handle) error { debug.Log("Remove(%v)", h) if err := r.clientError(); err != nil { return err @@ -396,7 +397,7 @@ func (r *SFTP) Remove(h restic.Handle) error { // List returns a channel that yields all names of blobs of type t. A // goroutine is started for this. If the channel done is closed, sending // stops. -func (r *SFTP) List(t restic.FileType, done <-chan struct{}) <-chan string { +func (r *SFTP) List(ctx context.Context, t restic.FileType) <-chan string { debug.Log("List %v", t) ch := make(chan string) @@ -416,7 +417,7 @@ func (r *SFTP) List(t restic.FileType, done <-chan struct{}) <-chan string { select { case ch <- path.Base(walker.Path()): - case <-done: + case <-ctx.Done(): return } } diff --git a/src/restic/backend/swift/swift.go b/src/restic/backend/swift/swift.go index 733dc3221..b18b61947 100644 --- a/src/restic/backend/swift/swift.go +++ b/src/restic/backend/swift/swift.go @@ -1,6 +1,7 @@ package swift import ( + "context" "fmt" "io" "net/http" @@ -27,6 +28,9 @@ type beSwift struct { backend.Layout } +// ensure statically that *beSwift implements restic.Backend. +var _ restic.Backend = &beSwift{} + // Open opens the swift backend at a container in region. The container is // created if it does not exist yet. func Open(cfg Config) (restic.Backend, error) { @@ -120,7 +124,7 @@ func (be *beSwift) Location() string { // Load returns a reader that yields the contents of the file at h at the // given offset. If length is nonzero, only a portion of the file is // returned. rd must be closed after use. -func (be *beSwift) Load(h restic.Handle, length int, offset int64) (io.ReadCloser, error) { +func (be *beSwift) Load(ctx context.Context, h restic.Handle, length int, offset int64) (io.ReadCloser, error) { debug.Log("Load %v, length %v, offset %v", h, length, offset) if err := h.Valid(); err != nil { return nil, err @@ -164,7 +168,7 @@ func (be *beSwift) Load(h restic.Handle, length int, offset int64) (io.ReadClose } // Save stores data in the backend at the handle. -func (be *beSwift) Save(h restic.Handle, rd io.Reader) (err error) { +func (be *beSwift) Save(ctx context.Context, h restic.Handle, rd io.Reader) (err error) { if err = h.Valid(); err != nil { return err } @@ -201,7 +205,7 @@ func (be *beSwift) Save(h restic.Handle, rd io.Reader) (err error) { } // Stat returns information about a blob. -func (be *beSwift) Stat(h restic.Handle) (bi restic.FileInfo, err error) { +func (be *beSwift) Stat(ctx context.Context, h restic.Handle) (bi restic.FileInfo, err error) { debug.Log("%v", h) objName := be.Filename(h) @@ -216,7 +220,7 @@ func (be *beSwift) Stat(h restic.Handle) (bi restic.FileInfo, err error) { } // Test returns true if a blob of the given type and name exists in the backend. -func (be *beSwift) Test(h restic.Handle) (bool, error) { +func (be *beSwift) Test(ctx context.Context, h restic.Handle) (bool, error) { objName := be.Filename(h) switch _, _, err := be.conn.Object(be.container, objName); err { case nil: @@ -231,7 +235,7 @@ func (be *beSwift) Test(h restic.Handle) (bool, error) { } // Remove removes the blob with the given name and type. -func (be *beSwift) Remove(h restic.Handle) error { +func (be *beSwift) Remove(ctx context.Context, h restic.Handle) error { objName := be.Filename(h) err := be.conn.ObjectDelete(be.container, objName) debug.Log("Remove(%v) -> err %v", h, err) @@ -241,7 +245,7 @@ func (be *beSwift) Remove(h restic.Handle) error { // List returns a channel that yields all names of blobs of type t. A // goroutine is started for this. If the channel done is closed, sending // stops. -func (be *beSwift) List(t restic.FileType, done <-chan struct{}) <-chan string { +func (be *beSwift) List(ctx context.Context, t restic.FileType) <-chan string { debug.Log("listing %v", t) ch := make(chan string) @@ -264,7 +268,7 @@ func (be *beSwift) List(t restic.FileType, done <-chan struct{}) <-chan string { select { case ch <- m: - case <-done: + case <-ctx.Done(): return nil, io.EOF } } @@ -280,11 +284,9 @@ func (be *beSwift) List(t restic.FileType, done <-chan struct{}) <-chan string { } // Remove keys for a specified backend type. -func (be *beSwift) removeKeys(t restic.FileType) error { - done := make(chan struct{}) - defer close(done) - for key := range be.List(t, done) { - err := be.Remove(restic.Handle{Type: t, Name: key}) +func (be *beSwift) removeKeys(ctx context.Context, t restic.FileType) error { + for key := range be.List(ctx, t) { + err := be.Remove(ctx, restic.Handle{Type: t, Name: key}) if err != nil { return err } @@ -304,7 +306,7 @@ func (be *beSwift) IsNotExist(err error) bool { // Delete removes all restic objects in the container. // It will not remove the container itself. -func (be *beSwift) Delete() error { +func (be *beSwift) Delete(ctx context.Context) error { alltypes := []restic.FileType{ restic.DataFile, restic.KeyFile, @@ -313,13 +315,13 @@ func (be *beSwift) Delete() error { restic.IndexFile} for _, t := range alltypes { - err := be.removeKeys(t) + err := be.removeKeys(ctx, t) if err != nil { return nil } } - err := be.Remove(restic.Handle{Type: restic.ConfigFile}) + err := be.Remove(ctx, restic.Handle{Type: restic.ConfigFile}) if err != nil && !be.IsNotExist(err) { return err } diff --git a/src/restic/backend/swift/swift_test.go b/src/restic/backend/swift/swift_test.go index b53b7bb64..843efdcd4 100644 --- a/src/restic/backend/swift/swift_test.go +++ b/src/restic/backend/swift/swift_test.go @@ -1,6 +1,7 @@ package swift_test import ( + "context" "fmt" "os" "restic" @@ -44,7 +45,7 @@ func newSwiftTestSuite(t testing.TB) *test.Suite { return nil, err } - exists, err := be.Test(restic.Handle{Type: restic.ConfigFile}) + exists, err := be.Test(context.TODO(), restic.Handle{Type: restic.ConfigFile}) if err != nil { return nil, err } @@ -71,7 +72,7 @@ func newSwiftTestSuite(t testing.TB) *test.Suite { return err } - if err := be.(restic.Deleter).Delete(); err != nil { + if err := be.(restic.Deleter).Delete(context.TODO()); err != nil { return err } diff --git a/src/restic/backend/test/benchmarks.go b/src/restic/backend/test/benchmarks.go index 2b2b0666d..fb7106561 100644 --- a/src/restic/backend/test/benchmarks.go +++ b/src/restic/backend/test/benchmarks.go @@ -2,6 +2,7 @@ package test import ( "bytes" + "context" "io" "restic" "restic/test" @@ -12,14 +13,14 @@ func saveRandomFile(t testing.TB, be restic.Backend, length int) ([]byte, restic data := test.Random(23, length) id := restic.Hash(data) handle := restic.Handle{Type: restic.DataFile, Name: id.String()} - if err := be.Save(handle, bytes.NewReader(data)); err != nil { + if err := be.Save(context.TODO(), handle, bytes.NewReader(data)); err != nil { t.Fatalf("Save() error: %+v", err) } return data, handle } func remove(t testing.TB, be restic.Backend, h restic.Handle) { - if err := be.Remove(h); err != nil { + if err := be.Remove(context.TODO(), h); err != nil { t.Fatalf("Remove() returned error: %v", err) } } @@ -40,7 +41,7 @@ func (s *Suite) BenchmarkLoadFile(t *testing.B) { t.ResetTimer() for i := 0; i < t.N; i++ { - rd, err := be.Load(handle, 0, 0) + rd, err := be.Load(context.TODO(), handle, 0, 0) if err != nil { t.Fatal(err) } @@ -82,7 +83,7 @@ func (s *Suite) BenchmarkLoadPartialFile(t *testing.B) { t.ResetTimer() for i := 0; i < t.N; i++ { - rd, err := be.Load(handle, testLength, 0) + rd, err := be.Load(context.TODO(), handle, testLength, 0) if err != nil { t.Fatal(err) } @@ -126,7 +127,7 @@ func (s *Suite) BenchmarkLoadPartialFileOffset(t *testing.B) { t.ResetTimer() for i := 0; i < t.N; i++ { - rd, err := be.Load(handle, testLength, int64(testOffset)) + rd, err := be.Load(context.TODO(), handle, testLength, int64(testOffset)) if err != nil { t.Fatal(err) } @@ -171,11 +172,11 @@ func (s *Suite) BenchmarkSave(t *testing.B) { t.Fatal(err) } - if err := be.Save(handle, rd); err != nil { + if err := be.Save(context.TODO(), handle, rd); err != nil { t.Fatal(err) } - if err := be.Remove(handle); err != nil { + if err := be.Remove(context.TODO(), handle); err != nil { t.Fatal(err) } } diff --git a/src/restic/backend/test/tests.go b/src/restic/backend/test/tests.go index 61136cea7..b6da7182d 100644 --- a/src/restic/backend/test/tests.go +++ b/src/restic/backend/test/tests.go @@ -2,6 +2,7 @@ package test import ( "bytes" + "context" "fmt" "io" "io/ioutil" @@ -34,7 +35,7 @@ func (s *Suite) TestCreateWithConfig(t *testing.T) { // remove a config if present cfgHandle := restic.Handle{Type: restic.ConfigFile} - cfgPresent, err := b.Test(cfgHandle) + cfgPresent, err := b.Test(context.TODO(), cfgHandle) if err != nil { t.Fatalf("unable to test for config: %+v", err) } @@ -53,7 +54,7 @@ func (s *Suite) TestCreateWithConfig(t *testing.T) { } // remove config - err = b.Remove(restic.Handle{Type: restic.ConfigFile, Name: ""}) + err = b.Remove(context.TODO(), restic.Handle{Type: restic.ConfigFile, Name: ""}) if err != nil { t.Fatalf("unexpected error removing config: %+v", err) } @@ -78,12 +79,12 @@ func (s *Suite) TestConfig(t *testing.T) { var testString = "Config" // create config and read it back - _, err := backend.LoadAll(b, restic.Handle{Type: restic.ConfigFile}) + _, err := backend.LoadAll(context.TODO(), b, restic.Handle{Type: restic.ConfigFile}) if err == nil { t.Fatalf("did not get expected error for non-existing config") } - err = b.Save(restic.Handle{Type: restic.ConfigFile}, strings.NewReader(testString)) + err = b.Save(context.TODO(), restic.Handle{Type: restic.ConfigFile}, strings.NewReader(testString)) if err != nil { t.Fatalf("Save() error: %+v", err) } @@ -92,7 +93,7 @@ func (s *Suite) TestConfig(t *testing.T) { // same config for _, name := range []string{"", "foo", "bar", "0000000000000000000000000000000000000000000000000000000000000000"} { h := restic.Handle{Type: restic.ConfigFile, Name: name} - buf, err := backend.LoadAll(b, h) + buf, err := backend.LoadAll(context.TODO(), b, h) if err != nil { t.Fatalf("unable to read config with name %q: %+v", name, err) } @@ -113,12 +114,12 @@ func (s *Suite) TestLoad(t *testing.T) { b := s.open(t) defer s.close(t, b) - rd, err := b.Load(restic.Handle{}, 0, 0) + rd, err := b.Load(context.TODO(), restic.Handle{}, 0, 0) if err == nil { t.Fatalf("Load() did not return an error for invalid handle") } if rd != nil { - rd.Close() + _ = rd.Close() } err = testLoad(b, restic.Handle{Type: restic.DataFile, Name: "foobar"}, 0, 0) @@ -132,14 +133,14 @@ func (s *Suite) TestLoad(t *testing.T) { id := restic.Hash(data) handle := restic.Handle{Type: restic.DataFile, Name: id.String()} - err = b.Save(handle, bytes.NewReader(data)) + err = b.Save(context.TODO(), handle, bytes.NewReader(data)) if err != nil { t.Fatalf("Save() error: %+v", err) } t.Logf("saved %d bytes as %v", length, handle) - rd, err = b.Load(handle, 100, -1) + rd, err = b.Load(context.TODO(), handle, 100, -1) if err == nil { t.Fatalf("Load() returned no error for negative offset!") } @@ -174,7 +175,7 @@ func (s *Suite) TestLoad(t *testing.T) { d = d[:l] } - rd, err := b.Load(handle, getlen, int64(o)) + rd, err := b.Load(context.TODO(), handle, getlen, int64(o)) if err != nil { t.Logf("Load, l %v, o %v, len(d) %v, getlen %v", l, o, len(d), getlen) t.Errorf("Load(%d, %d) returned unexpected error: %+v", l, o, err) @@ -235,7 +236,7 @@ func (s *Suite) TestLoad(t *testing.T) { } } - test.OK(t, b.Remove(handle)) + test.OK(t, b.Remove(context.TODO(), handle)) } type errorCloser struct { @@ -276,10 +277,10 @@ func (s *Suite) TestSave(t *testing.T) { Type: restic.DataFile, Name: fmt.Sprintf("%s-%d", id, i), } - err := b.Save(h, bytes.NewReader(data)) + err := b.Save(context.TODO(), h, bytes.NewReader(data)) test.OK(t, err) - buf, err := backend.LoadAll(b, h) + buf, err := backend.LoadAll(context.TODO(), b, h) test.OK(t, err) if len(buf) != len(data) { t.Fatalf("number of bytes does not match, want %v, got %v", len(data), len(buf)) @@ -289,14 +290,14 @@ func (s *Suite) TestSave(t *testing.T) { t.Fatalf("data not equal") } - fi, err := b.Stat(h) + fi, err := b.Stat(context.TODO(), h) test.OK(t, err) if fi.Size != int64(len(data)) { t.Fatalf("Stat() returned different size, want %q, got %d", len(data), fi.Size) } - err = b.Remove(h) + err = b.Remove(context.TODO(), h) if err != nil { t.Fatalf("error removing item: %+v", err) } @@ -324,12 +325,12 @@ func (s *Suite) TestSave(t *testing.T) { // wrap the tempfile in an errorCloser, so we can detect if the backend // closes the reader - err = b.Save(h, errorCloser{t: t, size: int64(length), Reader: tmpfile}) + err = b.Save(context.TODO(), h, errorCloser{t: t, size: int64(length), Reader: tmpfile}) if err != nil { t.Fatal(err) } - err = b.Remove(h) + err = b.Remove(context.TODO(), h) if err != nil { t.Fatalf("error removing item: %+v", err) } @@ -339,7 +340,7 @@ func (s *Suite) TestSave(t *testing.T) { t.Fatal(err) } - err = b.Save(h, tmpfile) + err = b.Save(context.TODO(), h, tmpfile) if err != nil { t.Fatal(err) } @@ -348,7 +349,7 @@ func (s *Suite) TestSave(t *testing.T) { t.Fatal(err) } - err = b.Remove(h) + err = b.Remove(context.TODO(), h) if err != nil { t.Fatalf("error removing item: %+v", err) } @@ -377,13 +378,13 @@ func (s *Suite) TestSaveFilenames(t *testing.T) { for i, test := range filenameTests { h := restic.Handle{Name: test.name, Type: restic.DataFile} - err := b.Save(h, strings.NewReader(test.data)) + err := b.Save(context.TODO(), h, strings.NewReader(test.data)) if err != nil { t.Errorf("test %d failed: Save() returned %+v", i, err) continue } - buf, err := backend.LoadAll(b, h) + buf, err := backend.LoadAll(context.TODO(), b, h) if err != nil { t.Errorf("test %d failed: Load() returned %+v", i, err) continue @@ -393,7 +394,7 @@ func (s *Suite) TestSaveFilenames(t *testing.T) { t.Errorf("test %d: returned wrong bytes", i) } - err = b.Remove(h) + err = b.Remove(context.TODO(), h) if err != nil { t.Errorf("test %d failed: Remove() returned %+v", i, err) continue @@ -414,14 +415,14 @@ var testStrings = []struct { func store(t testing.TB, b restic.Backend, tpe restic.FileType, data []byte) restic.Handle { id := restic.Hash(data) h := restic.Handle{Name: id.String(), Type: tpe} - err := b.Save(h, bytes.NewReader(data)) + err := b.Save(context.TODO(), h, bytes.NewReader(data)) test.OK(t, err) return h } // testLoad loads a blob (but discards its contents). func testLoad(b restic.Backend, h restic.Handle, length int, offset int64) error { - rd, err := b.Load(h, 0, 0) + rd, err := b.Load(context.TODO(), h, 0, 0) if err != nil { return err } @@ -437,14 +438,14 @@ func testLoad(b restic.Backend, h restic.Handle, length int, offset int64) error func delayedRemove(b restic.Backend, h restic.Handle) error { // Some backend (swift, I'm looking at you) may implement delayed // removal of data. Let's wait a bit if this happens. - err := b.Remove(h) + err := b.Remove(context.TODO(), h) if err != nil { return err } - found, err := b.Test(h) + found, err := b.Test(context.TODO(), h) for i := 0; found && i < 20; i++ { - found, err = b.Test(h) + found, err = b.Test(context.TODO(), h) if found { time.Sleep(100 * time.Millisecond) } @@ -468,12 +469,12 @@ func (s *Suite) TestBackend(t *testing.T) { // test if blob is already in repository h := restic.Handle{Type: tpe, Name: id.String()} - ret, err := b.Test(h) + ret, err := b.Test(context.TODO(), h) test.OK(t, err) test.Assert(t, !ret, "blob was found to exist before creating") // try to stat a not existing blob - _, err = b.Stat(h) + _, err = b.Stat(context.TODO(), h) test.Assert(t, err != nil, "blob data could be extracted before creation") // try to read not existing blob @@ -481,7 +482,7 @@ func (s *Suite) TestBackend(t *testing.T) { test.Assert(t, err != nil, "blob could be read before creation") // try to get string out, should fail - ret, err = b.Test(h) + ret, err = b.Test(context.TODO(), h) test.OK(t, err) test.Assert(t, !ret, "id %q was found (but should not have)", ts.id) } @@ -492,7 +493,7 @@ func (s *Suite) TestBackend(t *testing.T) { // test Load() h := restic.Handle{Type: tpe, Name: ts.id} - buf, err := backend.LoadAll(b, h) + buf, err := backend.LoadAll(context.TODO(), b, h) test.OK(t, err) test.Equals(t, ts.data, string(buf)) @@ -502,7 +503,7 @@ func (s *Suite) TestBackend(t *testing.T) { length := end - start buf2 := make([]byte, length) - rd, err := b.Load(h, len(buf2), int64(start)) + rd, err := b.Load(context.TODO(), h, len(buf2), int64(start)) test.OK(t, err) n, err := io.ReadFull(rd, buf2) test.OK(t, err) @@ -522,7 +523,7 @@ func (s *Suite) TestBackend(t *testing.T) { // create blob h := restic.Handle{Type: tpe, Name: ts.id} - err := b.Save(h, strings.NewReader(ts.data)) + err := b.Save(context.TODO(), h, strings.NewReader(ts.data)) test.Assert(t, err != nil, "expected error for %v, got %v", h, err) // remove and recreate @@ -530,12 +531,12 @@ func (s *Suite) TestBackend(t *testing.T) { test.OK(t, err) // test that the blob is gone - ok, err := b.Test(h) + ok, err := b.Test(context.TODO(), h) test.OK(t, err) test.Assert(t, !ok, "removed blob still present") // create blob - err = b.Save(h, strings.NewReader(ts.data)) + err = b.Save(context.TODO(), h, strings.NewReader(ts.data)) test.OK(t, err) // list items @@ -549,7 +550,7 @@ func (s *Suite) TestBackend(t *testing.T) { list := restic.IDs{} - for s := range b.List(tpe, nil) { + for s := range b.List(context.TODO(), tpe) { list = append(list, restic.TestParseID(s)) } @@ -572,13 +573,13 @@ func (s *Suite) TestBackend(t *testing.T) { h := restic.Handle{Type: tpe, Name: id.String()} - found, err := b.Test(h) + found, err := b.Test(context.TODO(), h) test.OK(t, err) test.Assert(t, found, fmt.Sprintf("id %q not found", id)) test.OK(t, delayedRemove(b, h)) - found, err = b.Test(h) + found, err = b.Test(context.TODO(), h) test.OK(t, err) test.Assert(t, !found, fmt.Sprintf("id %q not found after removal", id)) } @@ -600,7 +601,7 @@ func (s *Suite) TestDelete(t *testing.T) { return } - err := be.Delete() + err := be.Delete(context.TODO()) if err != nil { t.Fatalf("error deleting backend: %+v", err) } diff --git a/src/restic/backend/test/tests_test.go b/src/restic/backend/test/tests_test.go index 010fcfe6e..d662e5a32 100644 --- a/src/restic/backend/test/tests_test.go +++ b/src/restic/backend/test/tests_test.go @@ -1,6 +1,7 @@ package test_test import ( + "context" "restic" "restic/errors" "testing" @@ -26,7 +27,7 @@ func newTestSuite(t testing.TB) *test.Suite { Create: func(cfg interface{}) (restic.Backend, error) { c := cfg.(*memConfig) if c.be != nil { - ok, err := c.be.Test(restic.Handle{Type: restic.ConfigFile}) + ok, err := c.be.Test(context.TODO(), restic.Handle{Type: restic.ConfigFile}) if err != nil { return nil, err } diff --git a/src/restic/backend/utils.go b/src/restic/backend/utils.go index 3f3a85749..a07c7e86e 100644 --- a/src/restic/backend/utils.go +++ b/src/restic/backend/utils.go @@ -1,14 +1,15 @@ package backend import ( + "context" "io" "io/ioutil" "restic" ) // LoadAll reads all data stored in the backend for the handle. -func LoadAll(be restic.Backend, h restic.Handle) (buf []byte, err error) { - rd, err := be.Load(h, 0, 0) +func LoadAll(ctx context.Context, be restic.Backend, h restic.Handle) (buf []byte, err error) { + rd, err := be.Load(ctx, h, 0, 0) if err != nil { return nil, err } diff --git a/src/restic/backend/utils_test.go b/src/restic/backend/utils_test.go index 51481ed0b..15829f46e 100644 --- a/src/restic/backend/utils_test.go +++ b/src/restic/backend/utils_test.go @@ -2,6 +2,7 @@ package backend_test import ( "bytes" + "context" "math/rand" "restic" "testing" @@ -21,10 +22,10 @@ func TestLoadAll(t *testing.T) { data := Random(23+i, rand.Intn(MiB)+500*KiB) id := restic.Hash(data) - err := b.Save(restic.Handle{Name: id.String(), Type: restic.DataFile}, bytes.NewReader(data)) + err := b.Save(context.TODO(), restic.Handle{Name: id.String(), Type: restic.DataFile}, bytes.NewReader(data)) OK(t, err) - buf, err := backend.LoadAll(b, restic.Handle{Type: restic.DataFile, Name: id.String()}) + buf, err := backend.LoadAll(context.TODO(), b, restic.Handle{Type: restic.DataFile, Name: id.String()}) OK(t, err) if len(buf) != len(data) { @@ -46,10 +47,10 @@ func TestLoadSmallBuffer(t *testing.T) { data := Random(23+i, rand.Intn(MiB)+500*KiB) id := restic.Hash(data) - err := b.Save(restic.Handle{Name: id.String(), Type: restic.DataFile}, bytes.NewReader(data)) + err := b.Save(context.TODO(), restic.Handle{Name: id.String(), Type: restic.DataFile}, bytes.NewReader(data)) OK(t, err) - buf, err := backend.LoadAll(b, restic.Handle{Type: restic.DataFile, Name: id.String()}) + buf, err := backend.LoadAll(context.TODO(), b, restic.Handle{Type: restic.DataFile, Name: id.String()}) OK(t, err) if len(buf) != len(data) { @@ -71,10 +72,10 @@ func TestLoadLargeBuffer(t *testing.T) { data := Random(23+i, rand.Intn(MiB)+500*KiB) id := restic.Hash(data) - err := b.Save(restic.Handle{Name: id.String(), Type: restic.DataFile}, bytes.NewReader(data)) + err := b.Save(context.TODO(), restic.Handle{Name: id.String(), Type: restic.DataFile}, bytes.NewReader(data)) OK(t, err) - buf, err := backend.LoadAll(b, restic.Handle{Type: restic.DataFile, Name: id.String()}) + buf, err := backend.LoadAll(context.TODO(), b, restic.Handle{Type: restic.DataFile, Name: id.String()}) OK(t, err) if len(buf) != len(data) { diff --git a/src/restic/backend_find.go b/src/restic/backend_find.go index 193fd165b..e445972c9 100644 --- a/src/restic/backend_find.go +++ b/src/restic/backend_find.go @@ -1,6 +1,9 @@ package restic -import "restic/errors" +import ( + "context" + "restic/errors" +) // ErrNoIDPrefixFound is returned by Find() when no ID for the given prefix // could be found. @@ -14,13 +17,10 @@ var ErrMultipleIDMatches = errors.New("multiple IDs with prefix found") // start with prefix. If none is found, nil and ErrNoIDPrefixFound is returned. // If more than one is found, nil and ErrMultipleIDMatches is returned. func Find(be Lister, t FileType, prefix string) (string, error) { - done := make(chan struct{}) - defer close(done) - match := "" // TODO: optimize by sorting list etc. - for name := range be.List(t, done) { + for name := range be.List(context.TODO(), t) { if prefix == name[:len(prefix)] { if match == "" { match = name @@ -42,12 +42,9 @@ const minPrefixLength = 8 // PrefixLength returns the number of bytes required so that all prefixes of // all names of type t are unique. func PrefixLength(be Lister, t FileType) (int, error) { - done := make(chan struct{}) - defer close(done) - // load all IDs of the given type list := make([]string, 0, 100) - for name := range be.List(t, done) { + for name := range be.List(context.TODO(), t) { list = append(list, name) } diff --git a/src/restic/backend_find_test.go b/src/restic/backend_find_test.go index cc86cd810..032c8a9d9 100644 --- a/src/restic/backend_find_test.go +++ b/src/restic/backend_find_test.go @@ -1,15 +1,16 @@ package restic import ( + "context" "testing" ) type mockBackend struct { - list func(FileType, <-chan struct{}) <-chan string + list func(context.Context, FileType) <-chan string } -func (m mockBackend) List(t FileType, done <-chan struct{}) <-chan string { - return m.list(t, done) +func (m mockBackend) List(ctx context.Context, t FileType) <-chan string { + return m.list(ctx, t) } var samples = IDs{ @@ -27,14 +28,14 @@ func TestPrefixLength(t *testing.T) { list := samples m := mockBackend{} - m.list = func(t FileType, done <-chan struct{}) <-chan string { + m.list = func(ctx context.Context, t FileType) <-chan string { ch := make(chan string) go func() { defer close(ch) for _, id := range list { select { case ch <- id.String(): - case <-done: + case <-ctx.Done(): return } } diff --git a/src/restic/checker/checker.go b/src/restic/checker/checker.go index 80dee3005..99d9f18c1 100644 --- a/src/restic/checker/checker.go +++ b/src/restic/checker/checker.go @@ -1,6 +1,7 @@ package checker import ( + "context" "crypto/sha256" "fmt" "io" @@ -76,7 +77,7 @@ func (err ErrOldIndexFormat) Error() string { } // LoadIndex loads all index files. -func (c *Checker) LoadIndex() (hints []error, errs []error) { +func (c *Checker) LoadIndex(ctx context.Context) (hints []error, errs []error) { debug.Log("Start") type indexRes struct { Index *repository.Index @@ -86,21 +87,21 @@ func (c *Checker) LoadIndex() (hints []error, errs []error) { indexCh := make(chan indexRes) - worker := func(id restic.ID, done <-chan struct{}) error { + worker := func(ctx context.Context, id restic.ID) error { debug.Log("worker got index %v", id) - idx, err := repository.LoadIndexWithDecoder(c.repo, id, repository.DecodeIndex) + idx, err := repository.LoadIndexWithDecoder(ctx, c.repo, id, repository.DecodeIndex) if errors.Cause(err) == repository.ErrOldIndexFormat { debug.Log("index %v has old format", id.Str()) hints = append(hints, ErrOldIndexFormat{id}) - idx, err = repository.LoadIndexWithDecoder(c.repo, id, repository.DecodeOldIndex) + idx, err = repository.LoadIndexWithDecoder(ctx, c.repo, id, repository.DecodeOldIndex) } err = errors.Wrapf(err, "error loading index %v", id.Str()) select { case indexCh <- indexRes{Index: idx, ID: id.String(), err: err}: - case <-done: + case <-ctx.Done(): } return nil @@ -109,7 +110,7 @@ func (c *Checker) LoadIndex() (hints []error, errs []error) { go func() { defer close(indexCh) debug.Log("start loading indexes in parallel") - err := repository.FilesInParallel(c.repo.Backend(), restic.IndexFile, defaultParallelism, + err := repository.FilesInParallel(ctx, c.repo.Backend(), restic.IndexFile, defaultParallelism, repository.ParallelWorkFuncParseID(worker)) debug.Log("loading indexes finished, error: %v", err) if err != nil { @@ -183,7 +184,7 @@ func (e PackError) Error() string { return "pack " + e.ID.String() + ": " + e.Err.Error() } -func packIDTester(repo restic.Repository, inChan <-chan restic.ID, errChan chan<- error, wg *sync.WaitGroup, done <-chan struct{}) { +func packIDTester(ctx context.Context, repo restic.Repository, inChan <-chan restic.ID, errChan chan<- error, wg *sync.WaitGroup) { debug.Log("worker start") defer debug.Log("worker done") @@ -191,7 +192,7 @@ func packIDTester(repo restic.Repository, inChan <-chan restic.ID, errChan chan< for id := range inChan { h := restic.Handle{Type: restic.DataFile, Name: id.String()} - ok, err := repo.Backend().Test(h) + ok, err := repo.Backend().Test(ctx, h) if err != nil { err = PackError{ID: id, Err: err} } else { @@ -203,7 +204,7 @@ func packIDTester(repo restic.Repository, inChan <-chan restic.ID, errChan chan< if err != nil { debug.Log("error checking for pack %s: %v", id.Str(), err) select { - case <-done: + case <-ctx.Done(): return case errChan <- err: } @@ -218,7 +219,7 @@ func packIDTester(repo restic.Repository, inChan <-chan restic.ID, errChan chan< // Packs checks that all packs referenced in the index are still available and // there are no packs that aren't in an index. errChan is closed after all // packs have been checked. -func (c *Checker) Packs(errChan chan<- error, done <-chan struct{}) { +func (c *Checker) Packs(ctx context.Context, errChan chan<- error) { defer close(errChan) debug.Log("checking for %d packs", len(c.packs)) @@ -229,7 +230,7 @@ func (c *Checker) Packs(errChan chan<- error, done <-chan struct{}) { IDChan := make(chan restic.ID) for i := 0; i < defaultParallelism; i++ { workerWG.Add(1) - go packIDTester(c.repo, IDChan, errChan, &workerWG, done) + go packIDTester(ctx, c.repo, IDChan, errChan, &workerWG) } for id := range c.packs { @@ -242,12 +243,12 @@ func (c *Checker) Packs(errChan chan<- error, done <-chan struct{}) { workerWG.Wait() debug.Log("workers terminated") - for id := range c.repo.List(restic.DataFile, done) { + for id := range c.repo.List(ctx, restic.DataFile) { debug.Log("check data blob %v", id.Str()) if !seenPacks.Has(id) { c.orphanedPacks = append(c.orphanedPacks, id) select { - case <-done: + case <-ctx.Done(): return case errChan <- PackError{ID: id, Orphaned: true, Err: errors.New("not referenced in any index")}: } @@ -277,8 +278,8 @@ func (e Error) Error() string { return e.Err.Error() } -func loadTreeFromSnapshot(repo restic.Repository, id restic.ID) (restic.ID, error) { - sn, err := restic.LoadSnapshot(repo, id) +func loadTreeFromSnapshot(ctx context.Context, repo restic.Repository, id restic.ID) (restic.ID, error) { + sn, err := restic.LoadSnapshot(ctx, repo, id) if err != nil { debug.Log("error loading snapshot %v: %v", id.Str(), err) return restic.ID{}, err @@ -293,7 +294,7 @@ func loadTreeFromSnapshot(repo restic.Repository, id restic.ID) (restic.ID, erro } // loadSnapshotTreeIDs loads all snapshots from backend and returns the tree IDs. -func loadSnapshotTreeIDs(repo restic.Repository) (restic.IDs, []error) { +func loadSnapshotTreeIDs(ctx context.Context, repo restic.Repository) (restic.IDs, []error) { var trees struct { IDs restic.IDs sync.Mutex @@ -304,7 +305,7 @@ func loadSnapshotTreeIDs(repo restic.Repository) (restic.IDs, []error) { sync.Mutex } - snapshotWorker := func(strID string, done <-chan struct{}) error { + snapshotWorker := func(ctx context.Context, strID string) error { id, err := restic.ParseID(strID) if err != nil { return err @@ -312,7 +313,7 @@ func loadSnapshotTreeIDs(repo restic.Repository) (restic.IDs, []error) { debug.Log("load snapshot %v", id.Str()) - treeID, err := loadTreeFromSnapshot(repo, id) + treeID, err := loadTreeFromSnapshot(ctx, repo, id) if err != nil { errs.Lock() errs.errs = append(errs.errs, err) @@ -328,7 +329,7 @@ func loadSnapshotTreeIDs(repo restic.Repository) (restic.IDs, []error) { return nil } - err := repository.FilesInParallel(repo.Backend(), restic.SnapshotFile, defaultParallelism, snapshotWorker) + err := repository.FilesInParallel(ctx, repo.Backend(), restic.SnapshotFile, defaultParallelism, snapshotWorker) if err != nil { errs.errs = append(errs.errs, err) } @@ -353,9 +354,9 @@ type treeJob struct { } // loadTreeWorker loads trees from repo and sends them to out. -func loadTreeWorker(repo restic.Repository, +func loadTreeWorker(ctx context.Context, repo restic.Repository, in <-chan restic.ID, out chan<- treeJob, - done <-chan struct{}, wg *sync.WaitGroup) { + wg *sync.WaitGroup) { defer func() { debug.Log("exiting") @@ -371,7 +372,7 @@ func loadTreeWorker(repo restic.Repository, outCh = nil for { select { - case <-done: + case <-ctx.Done(): return case treeID, ok := <-inCh: @@ -380,7 +381,7 @@ func loadTreeWorker(repo restic.Repository, } debug.Log("load tree %v", treeID.Str()) - tree, err := repo.LoadTree(treeID) + tree, err := repo.LoadTree(ctx, treeID) debug.Log("load tree %v (%v) returned err: %v", tree, treeID.Str(), err) job = treeJob{ID: treeID, error: err, Tree: tree} outCh = out @@ -395,7 +396,7 @@ func loadTreeWorker(repo restic.Repository, } // checkTreeWorker checks the trees received and sends out errors to errChan. -func (c *Checker) checkTreeWorker(in <-chan treeJob, out chan<- error, done <-chan struct{}, wg *sync.WaitGroup) { +func (c *Checker) checkTreeWorker(ctx context.Context, in <-chan treeJob, out chan<- error, wg *sync.WaitGroup) { defer func() { debug.Log("exiting") wg.Done() @@ -410,7 +411,7 @@ func (c *Checker) checkTreeWorker(in <-chan treeJob, out chan<- error, done <-ch outCh = nil for { select { - case <-done: + case <-ctx.Done(): debug.Log("done channel closed, exiting") return @@ -458,7 +459,7 @@ func (c *Checker) checkTreeWorker(in <-chan treeJob, out chan<- error, done <-ch } } -func filterTrees(backlog restic.IDs, loaderChan chan<- restic.ID, in <-chan treeJob, out chan<- treeJob, done <-chan struct{}) { +func filterTrees(ctx context.Context, backlog restic.IDs, loaderChan chan<- restic.ID, in <-chan treeJob, out chan<- treeJob) { defer func() { debug.Log("closing output channels") close(loaderChan) @@ -489,7 +490,7 @@ func filterTrees(backlog restic.IDs, loaderChan chan<- restic.ID, in <-chan tree } select { - case <-done: + case <-ctx.Done(): return case loadCh <- nextTreeID: @@ -549,15 +550,15 @@ func filterTrees(backlog restic.IDs, loaderChan chan<- restic.ID, in <-chan tree // Structure checks that for all snapshots all referenced data blobs and // subtrees are available in the index. errChan is closed after all trees have // been traversed. -func (c *Checker) Structure(errChan chan<- error, done <-chan struct{}) { +func (c *Checker) Structure(ctx context.Context, errChan chan<- error) { defer close(errChan) - trees, errs := loadSnapshotTreeIDs(c.repo) + trees, errs := loadSnapshotTreeIDs(ctx, c.repo) debug.Log("need to check %d trees from snapshots, %d errs returned", len(trees), len(errs)) for _, err := range errs { select { - case <-done: + case <-ctx.Done(): return case errChan <- err: } @@ -570,11 +571,11 @@ func (c *Checker) Structure(errChan chan<- error, done <-chan struct{}) { var wg sync.WaitGroup for i := 0; i < defaultParallelism; i++ { wg.Add(2) - go loadTreeWorker(c.repo, treeIDChan, treeJobChan1, done, &wg) - go c.checkTreeWorker(treeJobChan2, errChan, done, &wg) + go loadTreeWorker(ctx, c.repo, treeIDChan, treeJobChan1, &wg) + go c.checkTreeWorker(ctx, treeJobChan2, errChan, &wg) } - filterTrees(trees, treeIDChan, treeJobChan1, treeJobChan2, done) + filterTrees(ctx, trees, treeIDChan, treeJobChan1, treeJobChan2) wg.Wait() } @@ -659,11 +660,11 @@ func (c *Checker) CountPacks() uint64 { } // checkPack reads a pack and checks the integrity of all blobs. -func checkPack(r restic.Repository, id restic.ID) error { +func checkPack(ctx context.Context, r restic.Repository, id restic.ID) error { debug.Log("checking pack %v", id.Str()) h := restic.Handle{Type: restic.DataFile, Name: id.String()} - rd, err := r.Backend().Load(h, 0, 0) + rd, err := r.Backend().Load(ctx, h, 0, 0) if err != nil { return err } @@ -748,7 +749,7 @@ func checkPack(r restic.Repository, id restic.ID) error { } // ReadData loads all data from the repository and checks the integrity. -func (c *Checker) ReadData(p *restic.Progress, errChan chan<- error, done <-chan struct{}) { +func (c *Checker) ReadData(ctx context.Context, p *restic.Progress, errChan chan<- error) { defer close(errChan) p.Start() @@ -761,7 +762,7 @@ func (c *Checker) ReadData(p *restic.Progress, errChan chan<- error, done <-chan var ok bool select { - case <-done: + case <-ctx.Done(): return case id, ok = <-in: if !ok { @@ -769,21 +770,21 @@ func (c *Checker) ReadData(p *restic.Progress, errChan chan<- error, done <-chan } } - err := checkPack(c.repo, id) + err := checkPack(ctx, c.repo, id) p.Report(restic.Stat{Blobs: 1}) if err == nil { continue } select { - case <-done: + case <-ctx.Done(): return case errChan <- err: } } } - ch := c.repo.List(restic.DataFile, done) + ch := c.repo.List(ctx, restic.DataFile) var wg sync.WaitGroup for i := 0; i < defaultParallelism; i++ { diff --git a/src/restic/checker/checker_test.go b/src/restic/checker/checker_test.go index 65e764137..ee345a97b 100644 --- a/src/restic/checker/checker_test.go +++ b/src/restic/checker/checker_test.go @@ -1,6 +1,7 @@ package checker_test import ( + "context" "io" "math/rand" "path/filepath" @@ -16,13 +17,13 @@ import ( var checkerTestData = filepath.Join("testdata", "checker-test-repo.tar.gz") -func collectErrors(f func(chan<- error, <-chan struct{})) (errs []error) { - done := make(chan struct{}) - defer close(done) +func collectErrors(ctx context.Context, f func(context.Context, chan<- error)) (errs []error) { + ctx, cancel := context.WithCancel(ctx) + defer cancel() errChan := make(chan error) - go f(errChan, done) + go f(ctx, errChan) for err := range errChan { errs = append(errs, err) @@ -32,17 +33,18 @@ func collectErrors(f func(chan<- error, <-chan struct{})) (errs []error) { } func checkPacks(chkr *checker.Checker) []error { - return collectErrors(chkr.Packs) + return collectErrors(context.TODO(), chkr.Packs) } func checkStruct(chkr *checker.Checker) []error { - return collectErrors(chkr.Structure) + return collectErrors(context.TODO(), chkr.Structure) } func checkData(chkr *checker.Checker) []error { return collectErrors( - func(errCh chan<- error, done <-chan struct{}) { - chkr.ReadData(nil, errCh, done) + context.TODO(), + func(ctx context.Context, errCh chan<- error) { + chkr.ReadData(ctx, nil, errCh) }, ) } @@ -54,7 +56,7 @@ func TestCheckRepo(t *testing.T) { repo := repository.TestOpenLocal(t, repodir) chkr := checker.New(repo) - hints, errs := chkr.LoadIndex() + hints, errs := chkr.LoadIndex(context.TODO()) if len(errs) > 0 { t.Fatalf("expected no errors, got %v: %v", len(errs), errs) } @@ -77,10 +79,10 @@ func TestMissingPack(t *testing.T) { Type: restic.DataFile, Name: "657f7fb64f6a854fff6fe9279998ee09034901eded4e6db9bcee0e59745bbce6", } - test.OK(t, repo.Backend().Remove(packHandle)) + test.OK(t, repo.Backend().Remove(context.TODO(), packHandle)) chkr := checker.New(repo) - hints, errs := chkr.LoadIndex() + hints, errs := chkr.LoadIndex(context.TODO()) if len(errs) > 0 { t.Fatalf("expected no errors, got %v: %v", len(errs), errs) } @@ -113,10 +115,10 @@ func TestUnreferencedPack(t *testing.T) { Type: restic.IndexFile, Name: "3f1abfcb79c6f7d0a3be517d2c83c8562fba64ef2c8e9a3544b4edaf8b5e3b44", } - test.OK(t, repo.Backend().Remove(indexHandle)) + test.OK(t, repo.Backend().Remove(context.TODO(), indexHandle)) chkr := checker.New(repo) - hints, errs := chkr.LoadIndex() + hints, errs := chkr.LoadIndex(context.TODO()) if len(errs) > 0 { t.Fatalf("expected no errors, got %v: %v", len(errs), errs) } @@ -147,7 +149,7 @@ func TestUnreferencedBlobs(t *testing.T) { Type: restic.SnapshotFile, Name: "51d249d28815200d59e4be7b3f21a157b864dc343353df9d8e498220c2499b02", } - test.OK(t, repo.Backend().Remove(snapshotHandle)) + test.OK(t, repo.Backend().Remove(context.TODO(), snapshotHandle)) unusedBlobsBySnapshot := restic.IDs{ restic.TestParseID("58c748bbe2929fdf30c73262bd8313fe828f8925b05d1d4a87fe109082acb849"), @@ -161,7 +163,7 @@ func TestUnreferencedBlobs(t *testing.T) { sort.Sort(unusedBlobsBySnapshot) chkr := checker.New(repo) - hints, errs := chkr.LoadIndex() + hints, errs := chkr.LoadIndex(context.TODO()) if len(errs) > 0 { t.Fatalf("expected no errors, got %v: %v", len(errs), errs) } @@ -192,7 +194,7 @@ func TestModifiedIndex(t *testing.T) { Type: restic.IndexFile, Name: "90f838b4ac28735fda8644fe6a08dbc742e57aaf81b30977b4fefa357010eafd", } - f, err := repo.Backend().Load(h, 0, 0) + f, err := repo.Backend().Load(context.TODO(), h, 0, 0) test.OK(t, err) // save the index again with a modified name so that the hash doesn't match @@ -201,13 +203,13 @@ func TestModifiedIndex(t *testing.T) { Type: restic.IndexFile, Name: "80f838b4ac28735fda8644fe6a08dbc742e57aaf81b30977b4fefa357010eafd", } - err = repo.Backend().Save(h2, f) + err = repo.Backend().Save(context.TODO(), h2, f) test.OK(t, err) test.OK(t, f.Close()) chkr := checker.New(repo) - hints, errs := chkr.LoadIndex() + hints, errs := chkr.LoadIndex(context.TODO()) if len(errs) == 0 { t.Fatalf("expected errors not found") } @@ -230,7 +232,7 @@ func TestDuplicatePacksInIndex(t *testing.T) { repo := repository.TestOpenLocal(t, repodir) chkr := checker.New(repo) - hints, errs := chkr.LoadIndex() + hints, errs := chkr.LoadIndex(context.TODO()) if len(hints) == 0 { t.Fatalf("did not get expected checker hints for duplicate packs in indexes") } @@ -259,8 +261,8 @@ type errorBackend struct { ProduceErrors bool } -func (b errorBackend) Load(h restic.Handle, length int, offset int64) (io.ReadCloser, error) { - rd, err := b.Backend.Load(h, length, offset) +func (b errorBackend) Load(ctx context.Context, h restic.Handle, length int, offset int64) (io.ReadCloser, error) { + rd, err := b.Backend.Load(ctx, h, length, offset) if err != nil { return rd, err } @@ -303,17 +305,17 @@ func TestCheckerModifiedData(t *testing.T) { defer cleanup() arch := archiver.New(repo) - _, id, err := arch.Snapshot(nil, []string{"."}, nil, "localhost", nil) + _, id, err := arch.Snapshot(context.TODO(), nil, []string{"."}, nil, "localhost", nil) test.OK(t, err) t.Logf("archived as %v", id.Str()) beError := &errorBackend{Backend: repo.Backend()} checkRepo := repository.New(beError) - test.OK(t, checkRepo.SearchKey(test.TestPassword, 5)) + test.OK(t, checkRepo.SearchKey(context.TODO(), test.TestPassword, 5)) chkr := checker.New(checkRepo) - hints, errs := chkr.LoadIndex() + hints, errs := chkr.LoadIndex(context.TODO()) if len(errs) > 0 { t.Fatalf("expected no errors, got %v: %v", len(errs), errs) } @@ -349,7 +351,7 @@ func BenchmarkChecker(t *testing.B) { repo := repository.TestOpenLocal(t, repodir) chkr := checker.New(repo) - hints, errs := chkr.LoadIndex() + hints, errs := chkr.LoadIndex(context.TODO()) if len(errs) > 0 { t.Fatalf("expected no errors, got %v: %v", len(errs), errs) } diff --git a/src/restic/checker/testing.go b/src/restic/checker/testing.go index 7b642dea1..26a213f39 100644 --- a/src/restic/checker/testing.go +++ b/src/restic/checker/testing.go @@ -1,6 +1,7 @@ package checker import ( + "context" "restic" "testing" ) @@ -9,7 +10,7 @@ import ( func TestCheckRepo(t testing.TB, repo restic.Repository) { chkr := New(repo) - hints, errs := chkr.LoadIndex() + hints, errs := chkr.LoadIndex(context.TODO()) if len(errs) != 0 { t.Fatalf("errors loading index: %v", errs) } @@ -18,12 +19,9 @@ func TestCheckRepo(t testing.TB, repo restic.Repository) { t.Fatalf("errors loading index: %v", hints) } - done := make(chan struct{}) - defer close(done) - // packs errChan := make(chan error) - go chkr.Packs(errChan, done) + go chkr.Packs(context.TODO(), errChan) for err := range errChan { t.Error(err) @@ -31,7 +29,7 @@ func TestCheckRepo(t testing.TB, repo restic.Repository) { // structure errChan = make(chan error) - go chkr.Structure(errChan, done) + go chkr.Structure(context.TODO(), errChan) for err := range errChan { t.Error(err) @@ -45,7 +43,7 @@ func TestCheckRepo(t testing.TB, repo restic.Repository) { // read data errChan = make(chan error) - go chkr.ReadData(nil, errChan, done) + go chkr.ReadData(context.TODO(), nil, errChan) for err := range errChan { t.Error(err) diff --git a/src/restic/config.go b/src/restic/config.go index 582fd2c7c..dbd6fd9cf 100644 --- a/src/restic/config.go +++ b/src/restic/config.go @@ -1,6 +1,7 @@ package restic import ( + "context" "testing" "restic/errors" @@ -23,7 +24,7 @@ const RepoVersion = 1 // JSONUnpackedLoader loads unpacked JSON. type JSONUnpackedLoader interface { - LoadJSONUnpacked(FileType, ID, interface{}) error + LoadJSONUnpacked(context.Context, FileType, ID, interface{}) error } // CreateConfig creates a config file with a randomly selected polynomial and @@ -57,12 +58,12 @@ func TestCreateConfig(t testing.TB, pol chunker.Pol) (cfg Config) { } // LoadConfig returns loads, checks and returns the config for a repository. -func LoadConfig(r JSONUnpackedLoader) (Config, error) { +func LoadConfig(ctx context.Context, r JSONUnpackedLoader) (Config, error) { var ( cfg Config ) - err := r.LoadJSONUnpacked(ConfigFile, ID{}, &cfg) + err := r.LoadJSONUnpacked(ctx, ConfigFile, ID{}, &cfg) if err != nil { return Config{}, err } diff --git a/src/restic/config_test.go b/src/restic/config_test.go index c5d2166e3..c287ae196 100644 --- a/src/restic/config_test.go +++ b/src/restic/config_test.go @@ -1,6 +1,7 @@ package restic_test import ( + "context" "restic" "testing" @@ -13,10 +14,10 @@ func (s saver) SaveJSONUnpacked(t restic.FileType, arg interface{}) (restic.ID, return s(t, arg) } -type loader func(restic.FileType, restic.ID, interface{}) error +type loader func(context.Context, restic.FileType, restic.ID, interface{}) error -func (l loader) LoadJSONUnpacked(t restic.FileType, id restic.ID, arg interface{}) error { - return l(t, id, arg) +func (l loader) LoadJSONUnpacked(ctx context.Context, t restic.FileType, id restic.ID, arg interface{}) error { + return l(ctx, t, id, arg) } func TestConfig(t *testing.T) { @@ -36,7 +37,7 @@ func TestConfig(t *testing.T) { _, err = saver(save).SaveJSONUnpacked(restic.ConfigFile, cfg1) - load := func(tpe restic.FileType, id restic.ID, arg interface{}) error { + load := func(ctx context.Context, tpe restic.FileType, id restic.ID, arg interface{}) error { Assert(t, tpe == restic.ConfigFile, "wrong backend type: got %v, wanted %v", tpe, restic.ConfigFile) @@ -46,7 +47,7 @@ func TestConfig(t *testing.T) { return nil } - cfg2, err := restic.LoadConfig(loader(load)) + cfg2, err := restic.LoadConfig(context.TODO(), loader(load)) OK(t, err) Assert(t, cfg1 == cfg2, diff --git a/src/restic/find.go b/src/restic/find.go index dcc9d0251..4b118abb0 100644 --- a/src/restic/find.go +++ b/src/restic/find.go @@ -1,12 +1,14 @@ package restic +import "context" + // FindUsedBlobs traverses the tree ID and adds all seen blobs (trees and data // blobs) to the set blobs. The tree blobs in the `seen` BlobSet will not be visited // again. -func FindUsedBlobs(repo Repository, treeID ID, blobs BlobSet, seen BlobSet) error { +func FindUsedBlobs(ctx context.Context, repo Repository, treeID ID, blobs BlobSet, seen BlobSet) error { blobs.Insert(BlobHandle{ID: treeID, Type: TreeBlob}) - tree, err := repo.LoadTree(treeID) + tree, err := repo.LoadTree(ctx, treeID) if err != nil { return err } @@ -26,7 +28,7 @@ func FindUsedBlobs(repo Repository, treeID ID, blobs BlobSet, seen BlobSet) erro seen.Insert(h) - err := FindUsedBlobs(repo, subtreeID, blobs, seen) + err := FindUsedBlobs(ctx, repo, subtreeID, blobs, seen) if err != nil { return err } diff --git a/src/restic/find_test.go b/src/restic/find_test.go index f4d7266ee..272472ffa 100644 --- a/src/restic/find_test.go +++ b/src/restic/find_test.go @@ -2,6 +2,7 @@ package restic_test import ( "bufio" + "context" "encoding/json" "flag" "fmt" @@ -92,7 +93,7 @@ func TestFindUsedBlobs(t *testing.T) { for i, sn := range snapshots { usedBlobs := restic.NewBlobSet() - err := restic.FindUsedBlobs(repo, *sn.Tree, usedBlobs, restic.NewBlobSet()) + err := restic.FindUsedBlobs(context.TODO(), repo, *sn.Tree, usedBlobs, restic.NewBlobSet()) if err != nil { t.Errorf("FindUsedBlobs returned error: %v", err) continue @@ -128,7 +129,7 @@ func BenchmarkFindUsedBlobs(b *testing.B) { for i := 0; i < b.N; i++ { seen := restic.NewBlobSet() blobs := restic.NewBlobSet() - err := restic.FindUsedBlobs(repo, *sn.Tree, blobs, seen) + err := restic.FindUsedBlobs(context.TODO(), repo, *sn.Tree, blobs, seen) if err != nil { b.Error(err) } diff --git a/src/restic/fuse/dir.go b/src/restic/fuse/dir.go index 7968a2ee5..49210bb90 100644 --- a/src/restic/fuse/dir.go +++ b/src/restic/fuse/dir.go @@ -26,9 +26,9 @@ type dir struct { ownerIsRoot bool } -func newDir(repo restic.Repository, node *restic.Node, ownerIsRoot bool) (*dir, error) { +func newDir(ctx context.Context, repo restic.Repository, node *restic.Node, ownerIsRoot bool) (*dir, error) { debug.Log("new dir for %v (%v)", node.Name, node.Subtree.Str()) - tree, err := repo.LoadTree(*node.Subtree) + tree, err := repo.LoadTree(ctx, *node.Subtree) if err != nil { debug.Log(" error loading tree %v: %v", node.Subtree.Str(), err) return nil, err @@ -49,7 +49,7 @@ func newDir(repo restic.Repository, node *restic.Node, ownerIsRoot bool) (*dir, // replaceSpecialNodes replaces nodes with name "." and "/" by their contents. // Otherwise, the node is returned. -func replaceSpecialNodes(repo restic.Repository, node *restic.Node) ([]*restic.Node, error) { +func replaceSpecialNodes(ctx context.Context, repo restic.Repository, node *restic.Node) ([]*restic.Node, error) { if node.Type != "dir" || node.Subtree == nil { return []*restic.Node{node}, nil } @@ -58,7 +58,7 @@ func replaceSpecialNodes(repo restic.Repository, node *restic.Node) ([]*restic.N return []*restic.Node{node}, nil } - tree, err := repo.LoadTree(*node.Subtree) + tree, err := repo.LoadTree(ctx, *node.Subtree) if err != nil { return nil, err } @@ -66,16 +66,16 @@ func replaceSpecialNodes(repo restic.Repository, node *restic.Node) ([]*restic.N return tree.Nodes, nil } -func newDirFromSnapshot(repo restic.Repository, snapshot SnapshotWithId, ownerIsRoot bool) (*dir, error) { +func newDirFromSnapshot(ctx context.Context, repo restic.Repository, snapshot SnapshotWithId, ownerIsRoot bool) (*dir, error) { debug.Log("new dir for snapshot %v (%v)", snapshot.ID.Str(), snapshot.Tree.Str()) - tree, err := repo.LoadTree(*snapshot.Tree) + tree, err := repo.LoadTree(ctx, *snapshot.Tree) if err != nil { debug.Log(" loadTree(%v) failed: %v", snapshot.ID.Str(), err) return nil, err } items := make(map[string]*restic.Node) for _, n := range tree.Nodes { - nodes, err := replaceSpecialNodes(repo, n) + nodes, err := replaceSpecialNodes(ctx, repo, n) if err != nil { debug.Log(" replaceSpecialNodes(%v) failed: %v", n, err) return nil, err @@ -167,7 +167,7 @@ func (d *dir) Lookup(ctx context.Context, name string) (fs.Node, error) { } switch node.Type { case "dir": - return newDir(d.repo, node, d.ownerIsRoot) + return newDir(ctx, d.repo, node, d.ownerIsRoot) case "file": return newFile(d.repo, node, d.ownerIsRoot) case "symlink": diff --git a/src/restic/fuse/file.go b/src/restic/fuse/file.go index 1e943212a..ae3ba5a7d 100644 --- a/src/restic/fuse/file.go +++ b/src/restic/fuse/file.go @@ -9,6 +9,8 @@ import ( "restic" "restic/debug" + scontext "context" + "bazil.org/fuse" "bazil.org/fuse/fs" "golang.org/x/net/context" @@ -25,7 +27,7 @@ var _ = fs.HandleReleaser(&file{}) // for fuse operations. type BlobLoader interface { LookupBlobSize(restic.ID, restic.BlobType) (uint, error) - LoadBlob(restic.BlobType, restic.ID, []byte) (int, error) + LoadBlob(scontext.Context, restic.BlobType, restic.ID, []byte) (int, error) } type file struct { @@ -88,7 +90,7 @@ func (f *file) Attr(ctx context.Context, a *fuse.Attr) error { } -func (f *file) getBlobAt(i int) (blob []byte, err error) { +func (f *file) getBlobAt(ctx context.Context, i int) (blob []byte, err error) { debug.Log("getBlobAt(%v, %v)", f.node.Name, i) if f.blobs[i] != nil { return f.blobs[i], nil @@ -100,7 +102,7 @@ func (f *file) getBlobAt(i int) (blob []byte, err error) { } buf := restic.NewBlobBuffer(f.sizes[i]) - n, err := f.repo.LoadBlob(restic.DataBlob, f.node.Content[i], buf) + n, err := f.repo.LoadBlob(ctx, restic.DataBlob, f.node.Content[i], buf) if err != nil { debug.Log("LoadBlob(%v, %v) failed: %v", f.node.Name, f.node.Content[i], err) return nil, err @@ -137,7 +139,7 @@ func (f *file) Read(ctx context.Context, req *fuse.ReadRequest, resp *fuse.ReadR readBytes := 0 remainingBytes := req.Size for i := startContent; remainingBytes > 0 && i < len(f.sizes); i++ { - blob, err := f.getBlobAt(i) + blob, err := f.getBlobAt(ctx, i) if err != nil { return err } diff --git a/src/restic/fuse/file_test.go b/src/restic/fuse/file_test.go index 9b2e0982d..dcb959fec 100644 --- a/src/restic/fuse/file_test.go +++ b/src/restic/fuse/file_test.go @@ -34,9 +34,7 @@ func testRead(t testing.TB, f *file, offset, length int, data []byte) { } func firstSnapshotID(t testing.TB, repo restic.Repository) (first restic.ID) { - done := make(chan struct{}) - defer close(done) - for id := range repo.List(restic.SnapshotFile, done) { + for id := range repo.List(context.TODO(), restic.SnapshotFile) { if first.IsNull() { first = id } @@ -46,13 +44,13 @@ func firstSnapshotID(t testing.TB, repo restic.Repository) (first restic.ID) { func loadFirstSnapshot(t testing.TB, repo restic.Repository) *restic.Snapshot { id := firstSnapshotID(t, repo) - sn, err := restic.LoadSnapshot(repo, id) + sn, err := restic.LoadSnapshot(context.TODO(), repo, id) OK(t, err) return sn } func loadTree(t testing.TB, repo restic.Repository, id restic.ID) *restic.Tree { - tree, err := repo.LoadTree(id) + tree, err := repo.LoadTree(context.TODO(), id) OK(t, err) return tree } @@ -87,7 +85,7 @@ func TestFuseFile(t *testing.T) { filesize += uint64(size) buf := restic.NewBlobBuffer(int(size)) - n, err := repo.LoadBlob(restic.DataBlob, id, buf) + n, err := repo.LoadBlob(context.TODO(), restic.DataBlob, id, buf) OK(t, err) if uint(n) != size { diff --git a/src/restic/fuse/snapshot.go b/src/restic/fuse/snapshot.go index 2a654397b..4057301f8 100644 --- a/src/restic/fuse/snapshot.go +++ b/src/restic/fuse/snapshot.go @@ -73,14 +73,14 @@ func (sn *SnapshotsDir) updateCache(ctx context.Context) error { sn.Lock() defer sn.Unlock() - for id := range sn.repo.List(restic.SnapshotFile, ctx.Done()) { + for id := range sn.repo.List(ctx, restic.SnapshotFile) { if sn.processed.Has(id) { debug.Log("skipping snapshot %v, already in list", id.Str()) continue } debug.Log("found snapshot id %v", id.Str()) - snapshot, err := restic.LoadSnapshot(sn.repo, id) + snapshot, err := restic.LoadSnapshot(ctx, sn.repo, id) if err != nil { return err } @@ -158,5 +158,5 @@ func (sn *SnapshotsDir) Lookup(ctx context.Context, name string) (fs.Node, error } } - return newDirFromSnapshot(sn.repo, snapshot, sn.ownerIsRoot) + return newDirFromSnapshot(ctx, sn.repo, snapshot, sn.ownerIsRoot) } diff --git a/src/restic/index/index.go b/src/restic/index/index.go index f1c41b79f..ab1ebafa4 100644 --- a/src/restic/index/index.go +++ b/src/restic/index/index.go @@ -2,6 +2,7 @@ package index import ( + "context" "fmt" "os" "restic" @@ -33,15 +34,12 @@ func newIndex() *Index { } // New creates a new index for repo from scratch. -func New(repo restic.Repository, p *restic.Progress) (*Index, error) { - done := make(chan struct{}) - defer close(done) - +func New(ctx context.Context, repo restic.Repository, p *restic.Progress) (*Index, error) { p.Start() defer p.Done() ch := make(chan worker.Job) - go list.AllPacks(repo, ch, done) + go list.AllPacks(ctx, repo, ch) idx := newIndex() @@ -84,11 +82,11 @@ type indexJSON struct { Packs []*packJSON `json:"packs"` } -func loadIndexJSON(repo restic.Repository, id restic.ID) (*indexJSON, error) { +func loadIndexJSON(ctx context.Context, repo restic.Repository, id restic.ID) (*indexJSON, error) { debug.Log("process index %v\n", id.Str()) var idx indexJSON - err := repo.LoadJSONUnpacked(restic.IndexFile, id, &idx) + err := repo.LoadJSONUnpacked(ctx, restic.IndexFile, id, &idx) if err != nil { return nil, err } @@ -97,25 +95,22 @@ func loadIndexJSON(repo restic.Repository, id restic.ID) (*indexJSON, error) { } // Load creates an index by loading all index files from the repo. -func Load(repo restic.Repository, p *restic.Progress) (*Index, error) { +func Load(ctx context.Context, repo restic.Repository, p *restic.Progress) (*Index, error) { debug.Log("loading indexes") p.Start() defer p.Done() - done := make(chan struct{}) - defer close(done) - supersedes := make(map[restic.ID]restic.IDSet) results := make(map[restic.ID]map[restic.ID]Pack) index := newIndex() - for id := range repo.List(restic.IndexFile, done) { + for id := range repo.List(ctx, restic.IndexFile) { p.Report(restic.Stat{Blobs: 1}) debug.Log("Load index %v", id.Str()) - idx, err := loadIndexJSON(repo, id) + idx, err := loadIndexJSON(ctx, repo, id) if err != nil { return nil, err } @@ -250,17 +245,17 @@ func (idx *Index) FindBlob(h restic.BlobHandle) (result []Location, err error) { } // Save writes the complete index to the repo. -func (idx *Index) Save(repo restic.Repository, supersedes restic.IDs) (restic.ID, error) { +func (idx *Index) Save(ctx context.Context, repo restic.Repository, supersedes restic.IDs) (restic.ID, error) { packs := make(map[restic.ID][]restic.Blob, len(idx.Packs)) for id, p := range idx.Packs { packs[id] = p.Entries } - return Save(repo, packs, supersedes) + return Save(ctx, repo, packs, supersedes) } // Save writes a new index containing the given packs. -func Save(repo restic.Repository, packs map[restic.ID][]restic.Blob, supersedes restic.IDs) (restic.ID, error) { +func Save(ctx context.Context, repo restic.Repository, packs map[restic.ID][]restic.Blob, supersedes restic.IDs) (restic.ID, error) { idx := &indexJSON{ Supersedes: supersedes, Packs: make([]*packJSON, 0, len(packs)), @@ -285,5 +280,5 @@ func Save(repo restic.Repository, packs map[restic.ID][]restic.Blob, supersedes idx.Packs = append(idx.Packs, p) } - return repo.SaveJSONUnpacked(restic.IndexFile, idx) + return repo.SaveJSONUnpacked(ctx, restic.IndexFile, idx) } diff --git a/src/restic/index/index_test.go b/src/restic/index/index_test.go index 1984c2cb6..11d0cc08a 100644 --- a/src/restic/index/index_test.go +++ b/src/restic/index/index_test.go @@ -1,6 +1,7 @@ package index import ( + "context" "math/rand" "restic" "restic/checker" @@ -26,7 +27,7 @@ func createFilledRepo(t testing.TB, snapshots int, dup float32) (restic.Reposito } func validateIndex(t testing.TB, repo restic.Repository, idx *Index) { - for id := range repo.List(restic.DataFile, nil) { + for id := range repo.List(context.TODO(), restic.DataFile) { p, ok := idx.Packs[id] if !ok { t.Errorf("pack %v missing from index", id.Str()) @@ -42,7 +43,7 @@ func TestIndexNew(t *testing.T) { repo, cleanup := createFilledRepo(t, 3, 0) defer cleanup() - idx, err := New(repo, nil) + idx, err := New(context.TODO(), repo, nil) if err != nil { t.Fatalf("New() returned error %v", err) } @@ -58,7 +59,7 @@ func TestIndexLoad(t *testing.T) { repo, cleanup := createFilledRepo(t, 3, 0) defer cleanup() - loadIdx, err := Load(repo, nil) + loadIdx, err := Load(context.TODO(), repo, nil) if err != nil { t.Fatalf("Load() returned error %v", err) } @@ -69,7 +70,7 @@ func TestIndexLoad(t *testing.T) { validateIndex(t, repo, loadIdx) - newIdx, err := New(repo, nil) + newIdx, err := New(context.TODO(), repo, nil) if err != nil { t.Fatalf("New() returned error %v", err) } @@ -133,7 +134,7 @@ func BenchmarkIndexNew(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { - idx, err := New(repo, nil) + idx, err := New(context.TODO(), repo, nil) if err != nil { b.Fatalf("New() returned error %v", err) @@ -150,7 +151,7 @@ func BenchmarkIndexSave(b *testing.B) { repo, cleanup := repository.TestRepository(b) defer cleanup() - idx, err := New(repo, nil) + idx, err := New(context.TODO(), repo, nil) test.OK(b, err) for i := 0; i < 8000; i++ { @@ -170,7 +171,7 @@ func BenchmarkIndexSave(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { - id, err := idx.Save(repo, nil) + id, err := idx.Save(context.TODO(), repo, nil) if err != nil { b.Fatalf("New() returned error %v", err) } @@ -183,7 +184,7 @@ func TestIndexDuplicateBlobs(t *testing.T) { repo, cleanup := createFilledRepo(t, 3, 0.01) defer cleanup() - idx, err := New(repo, nil) + idx, err := New(context.TODO(), repo, nil) if err != nil { t.Fatal(err) } @@ -202,7 +203,7 @@ func TestIndexDuplicateBlobs(t *testing.T) { } func loadIndex(t testing.TB, repo restic.Repository) *Index { - idx, err := Load(repo, nil) + idx, err := Load(context.TODO(), repo, nil) if err != nil { t.Fatalf("Load() returned error %v", err) } @@ -225,7 +226,7 @@ func TestSave(t *testing.T) { t.Logf("save %d/%d packs in a new index\n", len(packs), len(idx.Packs)) - id, err := Save(repo, packs, idx.IndexIDs.List()) + id, err := Save(context.TODO(), repo, packs, idx.IndexIDs.List()) if err != nil { t.Fatalf("unable to save new index: %v", err) } @@ -235,7 +236,7 @@ func TestSave(t *testing.T) { for id := range idx.IndexIDs { t.Logf("remove index %v", id.Str()) h := restic.Handle{Type: restic.IndexFile, Name: id.String()} - err = repo.Backend().Remove(h) + err = repo.Backend().Remove(context.TODO(), h) if err != nil { t.Errorf("error removing index %v: %v", id, err) } @@ -267,7 +268,7 @@ func TestIndexSave(t *testing.T) { idx := loadIndex(t, repo) - id, err := idx.Save(repo, idx.IndexIDs.List()) + id, err := idx.Save(context.TODO(), repo, idx.IndexIDs.List()) if err != nil { t.Fatalf("unable to save new index: %v", err) } @@ -277,7 +278,7 @@ func TestIndexSave(t *testing.T) { for id := range idx.IndexIDs { t.Logf("remove index %v", id.Str()) h := restic.Handle{Type: restic.IndexFile, Name: id.String()} - err = repo.Backend().Remove(h) + err = repo.Backend().Remove(context.TODO(), h) if err != nil { t.Errorf("error removing index %v: %v", id, err) } @@ -287,7 +288,7 @@ func TestIndexSave(t *testing.T) { t.Logf("load new index with %d packs", len(idx2.Packs)) checker := checker.New(repo) - hints, errs := checker.LoadIndex() + hints, errs := checker.LoadIndex(context.TODO()) for _, h := range hints { t.Logf("hint: %v\n", h) } @@ -301,15 +302,12 @@ func TestIndexAddRemovePack(t *testing.T) { repo, cleanup := createFilledRepo(t, 3, 0) defer cleanup() - idx, err := Load(repo, nil) + idx, err := Load(context.TODO(), repo, nil) if err != nil { t.Fatalf("Load() returned error %v", err) } - done := make(chan struct{}) - defer close(done) - - packID := <-repo.List(restic.DataFile, done) + packID := <-repo.List(context.TODO(), restic.DataFile) t.Logf("selected pack %v", packID.Str()) @@ -367,7 +365,7 @@ func TestIndexLoadDocReference(t *testing.T) { repo, cleanup := repository.TestRepository(t) defer cleanup() - id, err := repo.SaveUnpacked(restic.IndexFile, docExample) + id, err := repo.SaveUnpacked(context.TODO(), restic.IndexFile, docExample) if err != nil { t.Fatalf("SaveUnpacked() returned error %v", err) } diff --git a/src/restic/list/list.go b/src/restic/list/list.go index 70b3c7d9c..292ba8475 100644 --- a/src/restic/list/list.go +++ b/src/restic/list/list.go @@ -1,6 +1,7 @@ package list import ( + "context" "restic" "restic/worker" ) @@ -9,8 +10,8 @@ const listPackWorkers = 10 // Lister combines lists packs in a repo and blobs in a pack. type Lister interface { - List(restic.FileType, <-chan struct{}) <-chan restic.ID - ListPack(restic.ID) ([]restic.Blob, int64, error) + List(context.Context, restic.FileType) <-chan restic.ID + ListPack(context.Context, restic.ID) ([]restic.Blob, int64, error) } // Result is returned in the channel from LoadBlobsFromAllPacks. @@ -36,10 +37,10 @@ func (l Result) Entries() []restic.Blob { } // AllPacks sends the contents of all packs to ch. -func AllPacks(repo Lister, ch chan<- worker.Job, done <-chan struct{}) { - f := func(job worker.Job, done <-chan struct{}) (interface{}, error) { +func AllPacks(ctx context.Context, repo Lister, ch chan<- worker.Job) { + f := func(ctx context.Context, job worker.Job) (interface{}, error) { packID := job.Data.(restic.ID) - entries, size, err := repo.ListPack(packID) + entries, size, err := repo.ListPack(ctx, packID) return Result{ packID: packID, @@ -49,14 +50,14 @@ func AllPacks(repo Lister, ch chan<- worker.Job, done <-chan struct{}) { } jobCh := make(chan worker.Job) - wp := worker.New(listPackWorkers, f, jobCh, ch) + wp := worker.New(ctx, listPackWorkers, f, jobCh, ch) go func() { defer close(jobCh) - for id := range repo.List(restic.DataFile, done) { + for id := range repo.List(ctx, restic.DataFile) { select { case jobCh <- worker.Job{Data: id}: - case <-done: + case <-ctx.Done(): return } } diff --git a/src/restic/lock.go b/src/restic/lock.go index 97f2d652e..a036ec4b8 100644 --- a/src/restic/lock.go +++ b/src/restic/lock.go @@ -1,6 +1,7 @@ package restic import ( + "context" "fmt" "os" "os/signal" @@ -58,15 +59,15 @@ func IsAlreadyLocked(err error) bool { // NewLock returns a new, non-exclusive lock for the repository. If an // exclusive lock is already held by another process, ErrAlreadyLocked is // returned. -func NewLock(repo Repository) (*Lock, error) { - return newLock(repo, false) +func NewLock(ctx context.Context, repo Repository) (*Lock, error) { + return newLock(ctx, repo, false) } // NewExclusiveLock returns a new, exclusive lock for the repository. If // another lock (normal and exclusive) is already held by another process, // ErrAlreadyLocked is returned. -func NewExclusiveLock(repo Repository) (*Lock, error) { - return newLock(repo, true) +func NewExclusiveLock(ctx context.Context, repo Repository) (*Lock, error) { + return newLock(ctx, repo, true) } var waitBeforeLockCheck = 200 * time.Millisecond @@ -77,7 +78,7 @@ func TestSetLockTimeout(t testing.TB, d time.Duration) { waitBeforeLockCheck = d } -func newLock(repo Repository, excl bool) (*Lock, error) { +func newLock(ctx context.Context, repo Repository, excl bool) (*Lock, error) { lock := &Lock{ Time: time.Now(), PID: os.Getpid(), @@ -94,11 +95,11 @@ func newLock(repo Repository, excl bool) (*Lock, error) { return nil, err } - if err = lock.checkForOtherLocks(); err != nil { + if err = lock.checkForOtherLocks(ctx); err != nil { return nil, err } - lockID, err := lock.createLock() + lockID, err := lock.createLock(ctx) if err != nil { return nil, err } @@ -107,7 +108,7 @@ func newLock(repo Repository, excl bool) (*Lock, error) { time.Sleep(waitBeforeLockCheck) - if err = lock.checkForOtherLocks(); err != nil { + if err = lock.checkForOtherLocks(ctx); err != nil { lock.Unlock() return nil, err } @@ -132,8 +133,8 @@ func (l *Lock) fillUserInfo() error { // if there are any other locks, regardless if exclusive or not. If a // non-exclusive lock is to be created, an error is only returned when an // exclusive lock is found. -func (l *Lock) checkForOtherLocks() error { - return eachLock(l.repo, func(id ID, lock *Lock, err error) error { +func (l *Lock) checkForOtherLocks(ctx context.Context) error { + return eachLock(ctx, l.repo, func(id ID, lock *Lock, err error) error { if l.lockID != nil && id.Equal(*l.lockID) { return nil } @@ -155,12 +156,9 @@ func (l *Lock) checkForOtherLocks() error { }) } -func eachLock(repo Repository, f func(ID, *Lock, error) error) error { - done := make(chan struct{}) - defer close(done) - - for id := range repo.List(LockFile, done) { - lock, err := LoadLock(repo, id) +func eachLock(ctx context.Context, repo Repository, f func(ID, *Lock, error) error) error { + for id := range repo.List(ctx, LockFile) { + lock, err := LoadLock(ctx, repo, id) err = f(id, lock, err) if err != nil { return err @@ -171,8 +169,8 @@ func eachLock(repo Repository, f func(ID, *Lock, error) error) error { } // createLock acquires the lock by creating a file in the repository. -func (l *Lock) createLock() (ID, error) { - id, err := l.repo.SaveJSONUnpacked(LockFile, l) +func (l *Lock) createLock(ctx context.Context) (ID, error) { + id, err := l.repo.SaveJSONUnpacked(ctx, LockFile, l) if err != nil { return ID{}, err } @@ -186,7 +184,7 @@ func (l *Lock) Unlock() error { return nil } - return l.repo.Backend().Remove(Handle{Type: LockFile, Name: l.lockID.String()}) + return l.repo.Backend().Remove(context.TODO(), Handle{Type: LockFile, Name: l.lockID.String()}) } var staleTimeout = 30 * time.Minute @@ -227,14 +225,14 @@ func (l *Lock) Stale() bool { // Refresh refreshes the lock by creating a new file in the backend with a new // timestamp. Afterwards the old lock is removed. -func (l *Lock) Refresh() error { +func (l *Lock) Refresh(ctx context.Context) error { debug.Log("refreshing lock %v", l.lockID.Str()) - id, err := l.createLock() + id, err := l.createLock(ctx) if err != nil { return err } - err = l.repo.Backend().Remove(Handle{Type: LockFile, Name: l.lockID.String()}) + err = l.repo.Backend().Remove(context.TODO(), Handle{Type: LockFile, Name: l.lockID.String()}) if err != nil { return err } @@ -270,9 +268,9 @@ func init() { } // LoadLock loads and unserializes a lock from a repository. -func LoadLock(repo Repository, id ID) (*Lock, error) { +func LoadLock(ctx context.Context, repo Repository, id ID) (*Lock, error) { lock := &Lock{} - if err := repo.LoadJSONUnpacked(LockFile, id, lock); err != nil { + if err := repo.LoadJSONUnpacked(ctx, LockFile, id, lock); err != nil { return nil, err } lock.lockID = &id @@ -281,15 +279,15 @@ func LoadLock(repo Repository, id ID) (*Lock, error) { } // RemoveStaleLocks deletes all locks detected as stale from the repository. -func RemoveStaleLocks(repo Repository) error { - return eachLock(repo, func(id ID, lock *Lock, err error) error { +func RemoveStaleLocks(ctx context.Context, repo Repository) error { + return eachLock(ctx, repo, func(id ID, lock *Lock, err error) error { // ignore locks that cannot be loaded if err != nil { return nil } if lock.Stale() { - return repo.Backend().Remove(Handle{Type: LockFile, Name: id.String()}) + return repo.Backend().Remove(context.TODO(), Handle{Type: LockFile, Name: id.String()}) } return nil @@ -297,8 +295,8 @@ func RemoveStaleLocks(repo Repository) error { } // RemoveAllLocks removes all locks forcefully. -func RemoveAllLocks(repo Repository) error { - return eachLock(repo, func(id ID, lock *Lock, err error) error { - return repo.Backend().Remove(Handle{Type: LockFile, Name: id.String()}) +func RemoveAllLocks(ctx context.Context, repo Repository) error { + return eachLock(ctx, repo, func(id ID, lock *Lock, err error) error { + return repo.Backend().Remove(context.TODO(), Handle{Type: LockFile, Name: id.String()}) }) } diff --git a/src/restic/lock_test.go b/src/restic/lock_test.go index b8288d6cd..d5fd179a1 100644 --- a/src/restic/lock_test.go +++ b/src/restic/lock_test.go @@ -1,6 +1,7 @@ package restic_test import ( + "context" "os" "testing" "time" @@ -14,7 +15,7 @@ func TestLock(t *testing.T) { repo, cleanup := repository.TestRepository(t) defer cleanup() - lock, err := restic.NewLock(repo) + lock, err := restic.NewLock(context.TODO(), repo) OK(t, err) OK(t, lock.Unlock()) @@ -24,7 +25,7 @@ func TestDoubleUnlock(t *testing.T) { repo, cleanup := repository.TestRepository(t) defer cleanup() - lock, err := restic.NewLock(repo) + lock, err := restic.NewLock(context.TODO(), repo) OK(t, err) OK(t, lock.Unlock()) @@ -38,10 +39,10 @@ func TestMultipleLock(t *testing.T) { repo, cleanup := repository.TestRepository(t) defer cleanup() - lock1, err := restic.NewLock(repo) + lock1, err := restic.NewLock(context.TODO(), repo) OK(t, err) - lock2, err := restic.NewLock(repo) + lock2, err := restic.NewLock(context.TODO(), repo) OK(t, err) OK(t, lock1.Unlock()) @@ -52,7 +53,7 @@ func TestLockExclusive(t *testing.T) { repo, cleanup := repository.TestRepository(t) defer cleanup() - elock, err := restic.NewExclusiveLock(repo) + elock, err := restic.NewExclusiveLock(context.TODO(), repo) OK(t, err) OK(t, elock.Unlock()) } @@ -61,10 +62,10 @@ func TestLockOnExclusiveLockedRepo(t *testing.T) { repo, cleanup := repository.TestRepository(t) defer cleanup() - elock, err := restic.NewExclusiveLock(repo) + elock, err := restic.NewExclusiveLock(context.TODO(), repo) OK(t, err) - lock, err := restic.NewLock(repo) + lock, err := restic.NewLock(context.TODO(), repo) Assert(t, err != nil, "create normal lock with exclusively locked repo didn't return an error") Assert(t, restic.IsAlreadyLocked(err), @@ -78,10 +79,10 @@ func TestExclusiveLockOnLockedRepo(t *testing.T) { repo, cleanup := repository.TestRepository(t) defer cleanup() - elock, err := restic.NewLock(repo) + elock, err := restic.NewLock(context.TODO(), repo) OK(t, err) - lock, err := restic.NewExclusiveLock(repo) + lock, err := restic.NewExclusiveLock(context.TODO(), repo) Assert(t, err != nil, "create normal lock with exclusively locked repo didn't return an error") Assert(t, restic.IsAlreadyLocked(err), @@ -98,12 +99,12 @@ func createFakeLock(repo restic.Repository, t time.Time, pid int) (restic.ID, er } newLock := &restic.Lock{Time: t, PID: pid, Hostname: hostname} - return repo.SaveJSONUnpacked(restic.LockFile, &newLock) + return repo.SaveJSONUnpacked(context.TODO(), restic.LockFile, &newLock) } func removeLock(repo restic.Repository, id restic.ID) error { h := restic.Handle{Type: restic.LockFile, Name: id.String()} - return repo.Backend().Remove(h) + return repo.Backend().Remove(context.TODO(), h) } var staleLockTests = []struct { @@ -164,7 +165,7 @@ func TestLockStale(t *testing.T) { func lockExists(repo restic.Repository, t testing.TB, id restic.ID) bool { h := restic.Handle{Type: restic.LockFile, Name: id.String()} - exists, err := repo.Backend().Test(h) + exists, err := repo.Backend().Test(context.TODO(), h) OK(t, err) return exists @@ -183,7 +184,7 @@ func TestLockWithStaleLock(t *testing.T) { id3, err := createFakeLock(repo, time.Now().Add(-time.Minute), os.Getpid()+500000) OK(t, err) - OK(t, restic.RemoveStaleLocks(repo)) + OK(t, restic.RemoveStaleLocks(context.TODO(), repo)) Assert(t, lockExists(repo, t, id1) == false, "stale lock still exists after RemoveStaleLocks was called") @@ -208,7 +209,7 @@ func TestRemoveAllLocks(t *testing.T) { id3, err := createFakeLock(repo, time.Now().Add(-time.Minute), os.Getpid()+500000) OK(t, err) - OK(t, restic.RemoveAllLocks(repo)) + OK(t, restic.RemoveAllLocks(context.TODO(), repo)) Assert(t, lockExists(repo, t, id1) == false, "lock still exists after RemoveAllLocks was called") @@ -222,21 +223,21 @@ func TestLockRefresh(t *testing.T) { repo, cleanup := repository.TestRepository(t) defer cleanup() - lock, err := restic.NewLock(repo) + lock, err := restic.NewLock(context.TODO(), repo) OK(t, err) var lockID *restic.ID - for id := range repo.List(restic.LockFile, nil) { + for id := range repo.List(context.TODO(), restic.LockFile) { if lockID != nil { t.Error("more than one lock found") } lockID = &id } - OK(t, lock.Refresh()) + OK(t, lock.Refresh(context.TODO())) var lockID2 *restic.ID - for id := range repo.List(restic.LockFile, nil) { + for id := range repo.List(context.TODO(), restic.LockFile) { if lockID2 != nil { t.Error("more than one lock found") } diff --git a/src/restic/mock/backend.go b/src/restic/mock/backend.go index 704e87150..10effe045 100644 --- a/src/restic/mock/backend.go +++ b/src/restic/mock/backend.go @@ -1,6 +1,7 @@ package mock import ( + "context" "io" "restic" @@ -10,13 +11,13 @@ import ( // Backend implements a mock backend. type Backend struct { CloseFn func() error - SaveFn func(h restic.Handle, rd io.Reader) error - LoadFn func(h restic.Handle, length int, offset int64) (io.ReadCloser, error) - StatFn func(h restic.Handle) (restic.FileInfo, error) - ListFn func(restic.FileType, <-chan struct{}) <-chan string - RemoveFn func(h restic.Handle) error - TestFn func(h restic.Handle) (bool, error) - DeleteFn func() error + SaveFn func(ctx context.Context, h restic.Handle, rd io.Reader) error + LoadFn func(ctx context.Context, h restic.Handle, length int, offset int64) (io.ReadCloser, error) + StatFn func(ctx context.Context, h restic.Handle) (restic.FileInfo, error) + ListFn func(ctx context.Context, t restic.FileType) <-chan string + RemoveFn func(ctx context.Context, h restic.Handle) error + TestFn func(ctx context.Context, h restic.Handle) (bool, error) + DeleteFn func(ctx context.Context) error LocationFn func() string } @@ -39,68 +40,68 @@ func (m *Backend) Location() string { } // Save data in the backend. -func (m *Backend) Save(h restic.Handle, rd io.Reader) error { +func (m *Backend) Save(ctx context.Context, h restic.Handle, rd io.Reader) error { if m.SaveFn == nil { return errors.New("not implemented") } - return m.SaveFn(h, rd) + return m.SaveFn(ctx, h, rd) } // Load loads data from the backend. -func (m *Backend) Load(h restic.Handle, length int, offset int64) (io.ReadCloser, error) { +func (m *Backend) Load(ctx context.Context, h restic.Handle, length int, offset int64) (io.ReadCloser, error) { if m.LoadFn == nil { return nil, errors.New("not implemented") } - return m.LoadFn(h, length, offset) + return m.LoadFn(ctx, h, length, offset) } // Stat an object in the backend. -func (m *Backend) Stat(h restic.Handle) (restic.FileInfo, error) { +func (m *Backend) Stat(ctx context.Context, h restic.Handle) (restic.FileInfo, error) { if m.StatFn == nil { return restic.FileInfo{}, errors.New("not implemented") } - return m.StatFn(h) + return m.StatFn(ctx, h) } // List items of type t. -func (m *Backend) List(t restic.FileType, done <-chan struct{}) <-chan string { +func (m *Backend) List(ctx context.Context, t restic.FileType) <-chan string { if m.ListFn == nil { ch := make(chan string) close(ch) return ch } - return m.ListFn(t, done) + return m.ListFn(ctx, t) } // Remove data from the backend. -func (m *Backend) Remove(h restic.Handle) error { +func (m *Backend) Remove(ctx context.Context, h restic.Handle) error { if m.RemoveFn == nil { return errors.New("not implemented") } - return m.RemoveFn(h) + return m.RemoveFn(ctx, h) } // Test for the existence of a specific item. -func (m *Backend) Test(h restic.Handle) (bool, error) { +func (m *Backend) Test(ctx context.Context, h restic.Handle) (bool, error) { if m.TestFn == nil { return false, errors.New("not implemented") } - return m.TestFn(h) + return m.TestFn(ctx, h) } // Delete all data. -func (m *Backend) Delete() error { +func (m *Backend) Delete(ctx context.Context) error { if m.DeleteFn == nil { return errors.New("not implemented") } - return m.DeleteFn() + return m.DeleteFn(ctx) } // Make sure that Backend implements the backend interface. diff --git a/src/restic/node.go b/src/restic/node.go index 25e51aa48..982b64472 100644 --- a/src/restic/node.go +++ b/src/restic/node.go @@ -1,6 +1,7 @@ package restic import ( + "context" "encoding/json" "fmt" "os" @@ -116,7 +117,7 @@ func (node Node) GetExtendedAttribute(a string) []byte { } // CreateAt creates the node at the given path and restores all the meta data. -func (node *Node) CreateAt(path string, repo Repository, idx *HardlinkIndex) error { +func (node *Node) CreateAt(ctx context.Context, path string, repo Repository, idx *HardlinkIndex) error { debug.Log("create node %v at %v", node.Name, path) switch node.Type { @@ -125,7 +126,7 @@ func (node *Node) CreateAt(path string, repo Repository, idx *HardlinkIndex) err return err } case "file": - if err := node.createFileAt(path, repo, idx); err != nil { + if err := node.createFileAt(ctx, path, repo, idx); err != nil { return err } case "symlink": @@ -228,7 +229,7 @@ func (node Node) createDirAt(path string) error { return nil } -func (node Node) createFileAt(path string, repo Repository, idx *HardlinkIndex) error { +func (node Node) createFileAt(ctx context.Context, path string, repo Repository, idx *HardlinkIndex) error { if node.Links > 1 && idx.Has(node.Inode, node.DeviceID) { if err := fs.Remove(path); !os.IsNotExist(err) { return errors.Wrap(err, "RemoveCreateHardlink") @@ -259,7 +260,7 @@ func (node Node) createFileAt(path string, repo Repository, idx *HardlinkIndex) buf = NewBlobBuffer(int(size)) } - n, err := repo.LoadBlob(DataBlob, id, buf) + n, err := repo.LoadBlob(ctx, DataBlob, id, buf) if err != nil { return err } diff --git a/src/restic/node_test.go b/src/restic/node_test.go index f0c16b368..f357adab2 100644 --- a/src/restic/node_test.go +++ b/src/restic/node_test.go @@ -1,6 +1,7 @@ package restic_test import ( + "context" "io/ioutil" "os" "path/filepath" @@ -180,7 +181,7 @@ func TestNodeRestoreAt(t *testing.T) { for _, test := range nodeTests { nodePath := filepath.Join(tempdir, test.Name) - OK(t, test.CreateAt(nodePath, nil, idx)) + OK(t, test.CreateAt(context.TODO(), nodePath, nil, idx)) if test.Type == "symlink" && runtime.GOOS == "windows" { continue diff --git a/src/restic/pack/pack_test.go b/src/restic/pack/pack_test.go index 39cdbba66..c16996158 100644 --- a/src/restic/pack/pack_test.go +++ b/src/restic/pack/pack_test.go @@ -2,6 +2,7 @@ package pack_test import ( "bytes" + "context" "crypto/rand" "crypto/sha256" "encoding/binary" @@ -126,7 +127,7 @@ func TestUnpackReadSeeker(t *testing.T) { id := restic.Hash(packData) handle := restic.Handle{Type: restic.DataFile, Name: id.String()} - OK(t, b.Save(handle, bytes.NewReader(packData))) + OK(t, b.Save(context.TODO(), handle, bytes.NewReader(packData))) verifyBlobs(t, bufs, k, restic.ReaderAt(b, handle), packSize) } @@ -139,6 +140,6 @@ func TestShortPack(t *testing.T) { id := restic.Hash(packData) handle := restic.Handle{Type: restic.DataFile, Name: id.String()} - OK(t, b.Save(handle, bytes.NewReader(packData))) + OK(t, b.Save(context.TODO(), handle, bytes.NewReader(packData))) verifyBlobs(t, bufs, k, restic.ReaderAt(b, handle), packSize) } diff --git a/src/restic/pipe/pipe.go b/src/restic/pipe/pipe.go index 132938b6b..682bb5a2e 100644 --- a/src/restic/pipe/pipe.go +++ b/src/restic/pipe/pipe.go @@ -1,6 +1,7 @@ package pipe import ( + "context" "fmt" "os" "path/filepath" @@ -78,7 +79,7 @@ func readDirNames(dirname string) ([]string, error) { // dirs). If false is returned, files are ignored and dirs are not even walked. type SelectFunc func(item string, fi os.FileInfo) bool -func walk(basedir, dir string, selectFunc SelectFunc, done <-chan struct{}, jobs chan<- Job, res chan<- Result) (excluded bool) { +func walk(ctx context.Context, basedir, dir string, selectFunc SelectFunc, jobs chan<- Job, res chan<- Result) (excluded bool) { debug.Log("start on %q, basedir %q", dir, basedir) relpath, err := filepath.Rel(basedir, dir) @@ -92,7 +93,7 @@ func walk(basedir, dir string, selectFunc SelectFunc, done <-chan struct{}, jobs debug.Log("error for %v: %v, res %p", dir, err, res) select { case jobs <- Dir{basedir: basedir, path: relpath, info: info, error: err, result: res}: - case <-done: + case <-ctx.Done(): } return } @@ -107,7 +108,7 @@ func walk(basedir, dir string, selectFunc SelectFunc, done <-chan struct{}, jobs debug.Log("sending file job for %v, res %p", dir, res) select { case jobs <- Entry{info: info, basedir: basedir, path: relpath, result: res}: - case <-done: + case <-ctx.Done(): } return } @@ -117,7 +118,7 @@ func walk(basedir, dir string, selectFunc SelectFunc, done <-chan struct{}, jobs if err != nil { debug.Log("Readdirnames(%v) returned error: %v, res %p", dir, err, res) select { - case <-done: + case <-ctx.Done(): case jobs <- Dir{basedir: basedir, path: relpath, info: info, error: err, result: res}: } return @@ -146,7 +147,7 @@ func walk(basedir, dir string, selectFunc SelectFunc, done <-chan struct{}, jobs debug.Log("sending file job for %v, err %v, res %p", subpath, err, res) select { case jobs <- Entry{info: fi, error: statErr, basedir: basedir, path: filepath.Join(relpath, name), result: ch}: - case <-done: + case <-ctx.Done(): return } continue @@ -156,13 +157,13 @@ func walk(basedir, dir string, selectFunc SelectFunc, done <-chan struct{}, jobs // between walk and open debug.RunHook("pipe.walk2", filepath.Join(relpath, name)) - walk(basedir, subpath, selectFunc, done, jobs, ch) + walk(ctx, basedir, subpath, selectFunc, jobs, ch) } debug.Log("sending dirjob for %q, basedir %q, res %p", dir, basedir, res) select { case jobs <- Dir{basedir: basedir, path: relpath, info: info, Entries: entries, result: res}: - case <-done: + case <-ctx.Done(): } return @@ -191,7 +192,7 @@ func cleanupPath(path string) ([]string, error) { // Walk sends a Job for each file and directory it finds below the paths. When // the channel done is closed, processing stops. -func Walk(walkPaths []string, selectFunc SelectFunc, done chan struct{}, jobs chan<- Job, res chan<- Result) { +func Walk(ctx context.Context, walkPaths []string, selectFunc SelectFunc, jobs chan<- Job, res chan<- Result) { var paths []string for _, p := range walkPaths { @@ -215,7 +216,7 @@ func Walk(walkPaths []string, selectFunc SelectFunc, done chan struct{}, jobs ch for _, path := range paths { debug.Log("start walker for %v", path) ch := make(chan Result, 1) - excluded := walk(filepath.Dir(path), path, selectFunc, done, jobs, ch) + excluded := walk(ctx, filepath.Dir(path), path, selectFunc, jobs, ch) if excluded { debug.Log("walker for %v done, it was excluded by the filter", path) @@ -228,7 +229,7 @@ func Walk(walkPaths []string, selectFunc SelectFunc, done chan struct{}, jobs ch debug.Log("sending root node, res %p", res) select { - case <-done: + case <-ctx.Done(): return case jobs <- Dir{Entries: entries, result: res}: } diff --git a/src/restic/pipe/pipe_test.go b/src/restic/pipe/pipe_test.go index dd2c5e02d..197a1c428 100644 --- a/src/restic/pipe/pipe_test.go +++ b/src/restic/pipe/pipe_test.go @@ -1,6 +1,7 @@ package pipe_test import ( + "context" "io/ioutil" "os" "path/filepath" @@ -127,7 +128,7 @@ func TestPipelineWalkerWithSplit(t *testing.T) { }() resCh := make(chan pipe.Result, 1) - pipe.Walk([]string{TestWalkerPath}, acceptAll, done, jobs, resCh) + pipe.Walk(context.TODO(), []string{TestWalkerPath}, acceptAll, jobs, resCh) // wait for all workers to terminate wg.Wait() @@ -146,6 +147,9 @@ func TestPipelineWalker(t *testing.T) { t.Skipf("walkerpath not set, skipping TestPipelineWalker") } + ctx, cancel := context.WithCancel(context.TODO()) + defer cancel() + var err error if !filepath.IsAbs(TestWalkerPath) { TestWalkerPath, err = filepath.Abs(TestWalkerPath) @@ -164,7 +168,7 @@ func TestPipelineWalker(t *testing.T) { after := stats{} m := sync.Mutex{} - worker := func(wg *sync.WaitGroup, done <-chan struct{}, jobs <-chan pipe.Job) { + worker := func(ctx context.Context, wg *sync.WaitGroup, jobs <-chan pipe.Job) { defer wg.Done() for { select { @@ -195,7 +199,7 @@ func TestPipelineWalker(t *testing.T) { j.Result() <- true } - case <-done: + case <-ctx.Done(): // pipeline was cancelled return } @@ -203,16 +207,15 @@ func TestPipelineWalker(t *testing.T) { } var wg sync.WaitGroup - done := make(chan struct{}) jobs := make(chan pipe.Job) for i := 0; i < maxWorkers; i++ { wg.Add(1) - go worker(&wg, done, jobs) + go worker(ctx, &wg, jobs) } resCh := make(chan pipe.Result, 1) - pipe.Walk([]string{TestWalkerPath}, acceptAll, done, jobs, resCh) + pipe.Walk(ctx, []string{TestWalkerPath}, acceptAll, jobs, resCh) // wait for all workers to terminate wg.Wait() @@ -286,11 +289,12 @@ func TestPipeWalkerError(t *testing.T) { OK(t, os.RemoveAll(testdir)) }) - done := make(chan struct{}) + ctx, cancel := context.WithCancel(context.TODO()) + ch := make(chan pipe.Job) resCh := make(chan pipe.Result, 1) - go pipe.Walk([]string{dir}, acceptAll, done, ch, resCh) + go pipe.Walk(ctx, []string{dir}, acceptAll, ch, resCh) i := 0 for job := range ch { @@ -321,7 +325,7 @@ func TestPipeWalkerError(t *testing.T) { t.Errorf("expected %d jobs, got %d", len(testjobs), i) } - close(done) + cancel() Assert(t, ranHook, "hook did not run") OK(t, os.RemoveAll(dir)) @@ -335,7 +339,7 @@ func BenchmarkPipelineWalker(b *testing.B) { var max time.Duration m := sync.Mutex{} - fileWorker := func(wg *sync.WaitGroup, done <-chan struct{}, ch <-chan pipe.Entry) { + fileWorker := func(ctx context.Context, wg *sync.WaitGroup, ch <-chan pipe.Entry) { defer wg.Done() for { select { @@ -349,14 +353,14 @@ func BenchmarkPipelineWalker(b *testing.B) { //time.Sleep(10 * time.Millisecond) e.Result() <- true - case <-done: + case <-ctx.Done(): // pipeline was cancelled return } } } - dirWorker := func(wg *sync.WaitGroup, done <-chan struct{}, ch <-chan pipe.Dir) { + dirWorker := func(ctx context.Context, wg *sync.WaitGroup, ch <-chan pipe.Dir) { defer wg.Done() for { select { @@ -381,16 +385,18 @@ func BenchmarkPipelineWalker(b *testing.B) { m.Unlock() dir.Result() <- true - case <-done: + case <-ctx.Done(): // pipeline was cancelled return } } } + ctx, cancel := context.WithCancel(context.TODO()) + defer cancel() + for i := 0; i < b.N; i++ { max = 0 - done := make(chan struct{}) entCh := make(chan pipe.Entry, 200) dirCh := make(chan pipe.Dir, 200) @@ -398,8 +404,8 @@ func BenchmarkPipelineWalker(b *testing.B) { b.Logf("starting %d workers", maxWorkers) for i := 0; i < maxWorkers; i++ { wg.Add(2) - go dirWorker(&wg, done, dirCh) - go fileWorker(&wg, done, entCh) + go dirWorker(ctx, &wg, dirCh) + go fileWorker(ctx, &wg, entCh) } jobs := make(chan pipe.Job, 200) @@ -412,7 +418,7 @@ func BenchmarkPipelineWalker(b *testing.B) { }() resCh := make(chan pipe.Result, 1) - pipe.Walk([]string{TestWalkerPath}, acceptAll, done, jobs, resCh) + pipe.Walk(ctx, []string{TestWalkerPath}, acceptAll, jobs, resCh) // wait for all workers to terminate wg.Wait() @@ -429,6 +435,9 @@ func TestPipelineWalkerMultiple(t *testing.T) { t.Skipf("walkerpath not set, skipping TestPipelineWalker") } + ctx, cancel := context.WithCancel(context.TODO()) + defer cancel() + paths, err := filepath.Glob(filepath.Join(TestWalkerPath, "*")) OK(t, err) @@ -441,7 +450,7 @@ func TestPipelineWalkerMultiple(t *testing.T) { after := stats{} m := sync.Mutex{} - worker := func(wg *sync.WaitGroup, done <-chan struct{}, jobs <-chan pipe.Job) { + worker := func(ctx context.Context, wg *sync.WaitGroup, jobs <-chan pipe.Job) { defer wg.Done() for { select { @@ -472,7 +481,7 @@ func TestPipelineWalkerMultiple(t *testing.T) { j.Result() <- true } - case <-done: + case <-ctx.Done(): // pipeline was cancelled return } @@ -480,16 +489,15 @@ func TestPipelineWalkerMultiple(t *testing.T) { } var wg sync.WaitGroup - done := make(chan struct{}) jobs := make(chan pipe.Job) for i := 0; i < maxWorkers; i++ { wg.Add(1) - go worker(&wg, done, jobs) + go worker(ctx, &wg, jobs) } resCh := make(chan pipe.Result, 1) - pipe.Walk(paths, acceptAll, done, jobs, resCh) + pipe.Walk(ctx, paths, acceptAll, jobs, resCh) // wait for all workers to terminate wg.Wait() @@ -547,9 +555,6 @@ func testPipeWalkerRootWithPath(path string, t *testing.T) { t.Logf("paths in %v (pattern %q) expanded to %v items", path, pattern, len(rootPaths)) - done := make(chan struct{}) - defer close(done) - jobCh := make(chan pipe.Job) var jobs []pipe.Job @@ -571,7 +576,7 @@ func testPipeWalkerRootWithPath(path string, t *testing.T) { } resCh := make(chan pipe.Result, 1) - pipe.Walk([]string{path}, filter, done, jobCh, resCh) + pipe.Walk(context.TODO(), []string{path}, filter, jobCh, resCh) wg.Wait() diff --git a/src/restic/readerat.go b/src/restic/readerat.go index a57974473..69169e3d4 100644 --- a/src/restic/readerat.go +++ b/src/restic/readerat.go @@ -1,6 +1,7 @@ package restic import ( + "context" "io" "restic/debug" ) @@ -11,7 +12,7 @@ type backendReaderAt struct { } func (brd backendReaderAt) ReadAt(p []byte, offset int64) (n int, err error) { - return ReadAt(brd.be, brd.h, offset, p) + return ReadAt(context.TODO(), brd.be, brd.h, offset, p) } // ReaderAt returns an io.ReaderAt for a file in the backend. @@ -20,9 +21,9 @@ func ReaderAt(be Backend, h Handle) io.ReaderAt { } // ReadAt reads from the backend handle h at the given position. -func ReadAt(be Backend, h Handle, offset int64, p []byte) (n int, err error) { +func ReadAt(ctx context.Context, be Backend, h Handle, offset int64, p []byte) (n int, err error) { debug.Log("ReadAt(%v) at %v, len %v", h, offset, len(p)) - rd, err := be.Load(h, len(p), offset) + rd, err := be.Load(ctx, h, len(p), offset) if err != nil { return 0, err } diff --git a/src/restic/repository.go b/src/restic/repository.go index 959c0bd3c..81f217a37 100644 --- a/src/restic/repository.go +++ b/src/restic/repository.go @@ -1,6 +1,9 @@ package restic -import "restic/crypto" +import ( + "context" + "restic/crypto" +) // Repository stores data in a backend. It provides high-level functions and // transparently encrypts/decrypts data. @@ -14,40 +17,40 @@ type Repository interface { SetIndex(Index) Index() Index - SaveFullIndex() error - SaveIndex() error - LoadIndex() error + SaveFullIndex(context.Context) error + SaveIndex(context.Context) error + LoadIndex(context.Context) error Config() Config LookupBlobSize(ID, BlobType) (uint, error) - List(FileType, <-chan struct{}) <-chan ID - ListPack(ID) ([]Blob, int64, error) + List(context.Context, FileType) <-chan ID + ListPack(context.Context, ID) ([]Blob, int64, error) Flush() error - SaveUnpacked(FileType, []byte) (ID, error) - SaveJSONUnpacked(FileType, interface{}) (ID, error) + SaveUnpacked(context.Context, FileType, []byte) (ID, error) + SaveJSONUnpacked(context.Context, FileType, interface{}) (ID, error) - LoadJSONUnpacked(FileType, ID, interface{}) error - LoadAndDecrypt(FileType, ID) ([]byte, error) + LoadJSONUnpacked(context.Context, FileType, ID, interface{}) error + LoadAndDecrypt(context.Context, FileType, ID) ([]byte, error) - LoadBlob(BlobType, ID, []byte) (int, error) - SaveBlob(BlobType, []byte, ID) (ID, error) + LoadBlob(context.Context, BlobType, ID, []byte) (int, error) + SaveBlob(context.Context, BlobType, []byte, ID) (ID, error) - LoadTree(ID) (*Tree, error) - SaveTree(t *Tree) (ID, error) + LoadTree(context.Context, ID) (*Tree, error) + SaveTree(context.Context, *Tree) (ID, error) } // Deleter removes all data stored in a backend/repo. type Deleter interface { - Delete() error + Delete(context.Context) error } // Lister allows listing files in a backend. type Lister interface { - List(FileType, <-chan struct{}) <-chan string + List(context.Context, FileType) <-chan string } // Index keeps track of the blobs are stored within files. diff --git a/src/restic/repository/index.go b/src/restic/repository/index.go index 0db683215..6e5dac25b 100644 --- a/src/restic/repository/index.go +++ b/src/restic/repository/index.go @@ -1,6 +1,7 @@ package repository import ( + "context" "encoding/json" "io" "restic" @@ -519,10 +520,10 @@ func DecodeOldIndex(buf []byte) (idx *Index, err error) { } // LoadIndexWithDecoder loads the index and decodes it with fn. -func LoadIndexWithDecoder(repo restic.Repository, id restic.ID, fn func([]byte) (*Index, error)) (idx *Index, err error) { +func LoadIndexWithDecoder(ctx context.Context, repo restic.Repository, id restic.ID, fn func([]byte) (*Index, error)) (idx *Index, err error) { debug.Log("Loading index %v", id.Str()) - buf, err := repo.LoadAndDecrypt(restic.IndexFile, id) + buf, err := repo.LoadAndDecrypt(ctx, restic.IndexFile, id) if err != nil { return nil, err } diff --git a/src/restic/repository/key.go b/src/restic/repository/key.go index 7ce9757aa..68714db6c 100644 --- a/src/restic/repository/key.go +++ b/src/restic/repository/key.go @@ -2,6 +2,7 @@ package repository import ( "bytes" + "context" "encoding/json" "fmt" "os" @@ -58,12 +59,12 @@ var ( // createMasterKey creates a new master key in the given backend and encrypts // it with the password. func createMasterKey(s *Repository, password string) (*Key, error) { - return AddKey(s, password, nil) + return AddKey(context.TODO(), s, password, nil) } // OpenKey tries do decrypt the key specified by name with the given password. -func OpenKey(s *Repository, name string, password string) (*Key, error) { - k, err := LoadKey(s, name) +func OpenKey(ctx context.Context, s *Repository, name string, password string) (*Key, error) { + k, err := LoadKey(ctx, s, name) if err != nil { debug.Log("LoadKey(%v) returned error %v", name[:12], err) return nil, err @@ -113,19 +114,17 @@ func OpenKey(s *Repository, name string, password string) (*Key, error) { // given password. If none could be found, ErrNoKeyFound is returned. When // maxKeys is reached, ErrMaxKeysReached is returned. When setting maxKeys to // zero, all keys in the repo are checked. -func SearchKey(s *Repository, password string, maxKeys int) (*Key, error) { +func SearchKey(ctx context.Context, s *Repository, password string, maxKeys int) (*Key, error) { checked := 0 // try at most maxKeysForSearch keys in repo - done := make(chan struct{}) - defer close(done) - for name := range s.Backend().List(restic.KeyFile, done) { + for name := range s.Backend().List(ctx, restic.KeyFile) { if maxKeys > 0 && checked > maxKeys { return nil, ErrMaxKeysReached } debug.Log("trying key %v", name[:12]) - key, err := OpenKey(s, name, password) + key, err := OpenKey(ctx, s, name, password) if err != nil { debug.Log("key %v returned error %v", name[:12], err) @@ -145,9 +144,9 @@ func SearchKey(s *Repository, password string, maxKeys int) (*Key, error) { } // LoadKey loads a key from the backend. -func LoadKey(s *Repository, name string) (k *Key, err error) { +func LoadKey(ctx context.Context, s *Repository, name string) (k *Key, err error) { h := restic.Handle{Type: restic.KeyFile, Name: name} - data, err := backend.LoadAll(s.be, h) + data, err := backend.LoadAll(ctx, s.be, h) if err != nil { return nil, err } @@ -162,7 +161,7 @@ func LoadKey(s *Repository, name string) (k *Key, err error) { } // AddKey adds a new key to an already existing repository. -func AddKey(s *Repository, password string, template *crypto.Key) (*Key, error) { +func AddKey(ctx context.Context, s *Repository, password string, template *crypto.Key) (*Key, error) { // make sure we have valid KDF parameters if KDFParams == nil { p, err := crypto.Calibrate(KDFTimeout, KDFMemory) @@ -233,7 +232,7 @@ func AddKey(s *Repository, password string, template *crypto.Key) (*Key, error) Name: restic.Hash(buf).String(), } - err = s.be.Save(h, bytes.NewReader(buf)) + err = s.be.Save(ctx, h, bytes.NewReader(buf)) if err != nil { return nil, err } diff --git a/src/restic/repository/packer_manager.go b/src/restic/repository/packer_manager.go index 07f051ba5..697cbfdcc 100644 --- a/src/restic/repository/packer_manager.go +++ b/src/restic/repository/packer_manager.go @@ -1,6 +1,7 @@ package repository import ( + "context" "crypto/sha256" "io" "os" @@ -18,7 +19,7 @@ import ( // Saver implements saving data in a backend. type Saver interface { - Save(restic.Handle, io.Reader) error + Save(context.Context, restic.Handle, io.Reader) error } // Packer holds a pack.Packer together with a hash writer. @@ -118,7 +119,7 @@ func (r *Repository) savePacker(p *Packer) error { id := restic.IDFromHash(p.hw.Sum(nil)) h := restic.Handle{Type: restic.DataFile, Name: id.String()} - err = r.be.Save(h, p.tmpfile) + err = r.be.Save(context.TODO(), h, p.tmpfile) if err != nil { debug.Log("Save(%v) error: %v", h, err) return err diff --git a/src/restic/repository/packer_manager_test.go b/src/restic/repository/packer_manager_test.go index 2ca44de0b..3b49655bd 100644 --- a/src/restic/repository/packer_manager_test.go +++ b/src/restic/repository/packer_manager_test.go @@ -1,6 +1,7 @@ package repository import ( + "context" "io" "math/rand" "os" @@ -52,7 +53,7 @@ func saveFile(t testing.TB, be Saver, f *os.File, id restic.ID) { h := restic.Handle{Type: restic.DataFile, Name: id.String()} t.Logf("save file %v", h) - if err := be.Save(h, f); err != nil { + if err := be.Save(context.TODO(), h, f); err != nil { t.Fatal(err) } @@ -145,7 +146,7 @@ func BenchmarkPackerManager(t *testing.B) { rnd := newRandReader(rand.NewSource(23)) be := &mock.Backend{ - SaveFn: func(restic.Handle, io.Reader) error { return nil }, + SaveFn: func(context.Context, restic.Handle, io.Reader) error { return nil }, } blobBuf := make([]byte, maxBlobSize) diff --git a/src/restic/repository/parallel.go b/src/restic/repository/parallel.go index 6d2154bed..7797dacae 100644 --- a/src/restic/repository/parallel.go +++ b/src/restic/repository/parallel.go @@ -1,6 +1,7 @@ package repository import ( + "context" "restic" "sync" @@ -18,24 +19,19 @@ func closeIfOpen(ch chan struct{}) { } // ParallelWorkFunc gets one file ID to work on. If an error is returned, -// processing stops. If done is closed, the function should return. -type ParallelWorkFunc func(id string, done <-chan struct{}) error +// processing stops. When the contect is cancelled the function should return. +type ParallelWorkFunc func(ctx context.Context, id string) error // ParallelIDWorkFunc gets one restic.ID to work on. If an error is returned, -// processing stops. If done is closed, the function should return. -type ParallelIDWorkFunc func(id restic.ID, done <-chan struct{}) error +// processing stops. When the context is cancelled the function should return. +type ParallelIDWorkFunc func(ctx context.Context, id restic.ID) error // FilesInParallel runs n workers of f in parallel, on the IDs that // repo.List(t) yield. If f returns an error, the process is aborted and the // first error is returned. -func FilesInParallel(repo restic.Lister, t restic.FileType, n uint, f ParallelWorkFunc) error { - done := make(chan struct{}) - defer closeIfOpen(done) - +func FilesInParallel(ctx context.Context, repo restic.Lister, t restic.FileType, n uint, f ParallelWorkFunc) error { wg := &sync.WaitGroup{} - - ch := repo.List(t, done) - + ch := repo.List(ctx, t) errors := make(chan error, n) for i := 0; uint(i) < n; i++ { @@ -50,13 +46,12 @@ func FilesInParallel(repo restic.Lister, t restic.FileType, n uint, f ParallelWo return } - err := f(id, done) + err := f(ctx, id) if err != nil { - closeIfOpen(done) errors <- err return } - case <-done: + case <-ctx.Done(): return } } @@ -79,13 +74,13 @@ func FilesInParallel(repo restic.Lister, t restic.FileType, n uint, f ParallelWo // function that takes a string. Filenames that do not parse as a restic.ID // are ignored. func ParallelWorkFuncParseID(f ParallelIDWorkFunc) ParallelWorkFunc { - return func(s string, done <-chan struct{}) error { + return func(ctx context.Context, s string) error { id, err := restic.ParseID(s) if err != nil { debug.Log("invalid ID %q: %v", id, err) return err } - return f(id, done) + return f(ctx, id) } } diff --git a/src/restic/repository/parallel_test.go b/src/restic/repository/parallel_test.go index cfa384a01..aa15e79e7 100644 --- a/src/restic/repository/parallel_test.go +++ b/src/restic/repository/parallel_test.go @@ -1,6 +1,7 @@ package repository_test import ( + "context" "math/rand" "restic" "testing" @@ -73,7 +74,7 @@ var lister = testIDs{ "34dd044c228727f2226a0c9c06a3e5ceb5e30e31cb7854f8fa1cde846b395a58", } -func (tests testIDs) List(t restic.FileType, done <-chan struct{}) <-chan string { +func (tests testIDs) List(ctx context.Context, t restic.FileType) <-chan string { ch := make(chan string) go func() { @@ -83,7 +84,7 @@ func (tests testIDs) List(t restic.FileType, done <-chan struct{}) <-chan string for _, id := range tests { select { case ch <- id: - case <-done: + case <-ctx.Done(): return } } @@ -94,13 +95,13 @@ func (tests testIDs) List(t restic.FileType, done <-chan struct{}) <-chan string } func TestFilesInParallel(t *testing.T) { - f := func(id string, done <-chan struct{}) error { + f := func(ctx context.Context, id string) error { time.Sleep(1 * time.Millisecond) return nil } for n := uint(1); n < 5; n++ { - err := repository.FilesInParallel(lister, restic.DataFile, n*100, f) + err := repository.FilesInParallel(context.TODO(), lister, restic.DataFile, n*100, f) OK(t, err) } } @@ -109,7 +110,7 @@ var errTest = errors.New("test error") func TestFilesInParallelWithError(t *testing.T) { - f := func(id string, done <-chan struct{}) error { + f := func(ctx context.Context, id string) error { time.Sleep(1 * time.Millisecond) if rand.Float32() < 0.01 { @@ -120,7 +121,7 @@ func TestFilesInParallelWithError(t *testing.T) { } for n := uint(1); n < 5; n++ { - err := repository.FilesInParallel(lister, restic.DataFile, n*100, f) + err := repository.FilesInParallel(context.TODO(), lister, restic.DataFile, n*100, f) Equals(t, errTest, err) } } diff --git a/src/restic/repository/repack.go b/src/restic/repository/repack.go index b049a4255..36a000783 100644 --- a/src/restic/repository/repack.go +++ b/src/restic/repository/repack.go @@ -1,6 +1,7 @@ package repository import ( + "context" "crypto/sha256" "io" "restic" @@ -17,7 +18,7 @@ import ( // these packs. Each pack is loaded and the blobs listed in keepBlobs is saved // into a new pack. Afterwards, the packs are removed. This operation requires // an exclusive lock on the repo. -func Repack(repo restic.Repository, packs restic.IDSet, keepBlobs restic.BlobSet, p *restic.Progress) (err error) { +func Repack(ctx context.Context, repo restic.Repository, packs restic.IDSet, keepBlobs restic.BlobSet, p *restic.Progress) (err error) { debug.Log("repacking %d packs while keeping %d blobs", len(packs), len(keepBlobs)) for packID := range packs { @@ -29,7 +30,7 @@ func Repack(repo restic.Repository, packs restic.IDSet, keepBlobs restic.BlobSet return errors.Wrap(err, "TempFile") } - beRd, err := repo.Backend().Load(h, 0, 0) + beRd, err := repo.Backend().Load(ctx, h, 0, 0) if err != nil { return err } @@ -100,7 +101,7 @@ func Repack(repo restic.Repository, packs restic.IDSet, keepBlobs restic.BlobSet h, tempfile.Name(), id) } - _, err = repo.SaveBlob(entry.Type, buf, entry.ID) + _, err = repo.SaveBlob(ctx, entry.Type, buf, entry.ID) if err != nil { return err } @@ -128,7 +129,7 @@ func Repack(repo restic.Repository, packs restic.IDSet, keepBlobs restic.BlobSet for packID := range packs { h := restic.Handle{Type: restic.DataFile, Name: packID.String()} - err := repo.Backend().Remove(h) + err := repo.Backend().Remove(ctx, h) if err != nil { debug.Log("error removing pack %v: %v", packID.Str(), err) return err diff --git a/src/restic/repository/repack_test.go b/src/restic/repository/repack_test.go index 622b3ba52..d339cf2b8 100644 --- a/src/restic/repository/repack_test.go +++ b/src/restic/repository/repack_test.go @@ -1,6 +1,7 @@ package repository_test import ( + "context" "io" "math/rand" "restic" @@ -47,7 +48,7 @@ func createRandomBlobs(t testing.TB, repo restic.Repository, blobs int, pData fl continue } - _, err := repo.SaveBlob(tpe, buf, id) + _, err := repo.SaveBlob(context.TODO(), tpe, buf, id) if err != nil { t.Fatalf("SaveFrom() error %v", err) } @@ -67,16 +68,13 @@ func createRandomBlobs(t testing.TB, repo restic.Repository, blobs int, pData fl // selectBlobs splits the list of all blobs randomly into two lists. A blob // will be contained in the firstone ith probability p. func selectBlobs(t *testing.T, repo restic.Repository, p float32) (list1, list2 restic.BlobSet) { - done := make(chan struct{}) - defer close(done) - list1 = restic.NewBlobSet() list2 = restic.NewBlobSet() blobs := restic.NewBlobSet() - for id := range repo.List(restic.DataFile, done) { - entries, _, err := repo.ListPack(id) + for id := range repo.List(context.TODO(), restic.DataFile) { + entries, _, err := repo.ListPack(context.TODO(), id) if err != nil { t.Fatalf("error listing pack %v: %v", id, err) } @@ -102,11 +100,8 @@ func selectBlobs(t *testing.T, repo restic.Repository, p float32) (list1, list2 } func listPacks(t *testing.T, repo restic.Repository) restic.IDSet { - done := make(chan struct{}) - defer close(done) - list := restic.NewIDSet() - for id := range repo.List(restic.DataFile, done) { + for id := range repo.List(context.TODO(), restic.DataFile) { list.Insert(id) } @@ -132,35 +127,36 @@ func findPacksForBlobs(t *testing.T, repo restic.Repository, blobs restic.BlobSe } func repack(t *testing.T, repo restic.Repository, packs restic.IDSet, blobs restic.BlobSet) { - err := repository.Repack(repo, packs, blobs, nil) + err := repository.Repack(context.TODO(), repo, packs, blobs, nil) if err != nil { t.Fatal(err) } } func saveIndex(t *testing.T, repo restic.Repository) { - if err := repo.SaveIndex(); err != nil { + if err := repo.SaveIndex(context.TODO()); err != nil { t.Fatalf("repo.SaveIndex() %v", err) } } func rebuildIndex(t *testing.T, repo restic.Repository) { - idx, err := index.New(repo, nil) + idx, err := index.New(context.TODO(), repo, nil) if err != nil { t.Fatal(err) } - for id := range repo.List(restic.IndexFile, nil) { - err = repo.Backend().Remove(restic.Handle{ + for id := range repo.List(context.TODO(), restic.IndexFile) { + h := restic.Handle{ Type: restic.IndexFile, Name: id.String(), - }) + } + err = repo.Backend().Remove(context.TODO(), h) if err != nil { t.Fatal(err) } } - _, err = idx.Save(repo, nil) + _, err = idx.Save(context.TODO(), repo, nil) if err != nil { t.Fatal(err) } @@ -168,7 +164,7 @@ func rebuildIndex(t *testing.T, repo restic.Repository) { func reloadIndex(t *testing.T, repo restic.Repository) { repo.SetIndex(repository.NewMasterIndex()) - if err := repo.LoadIndex(); err != nil { + if err := repo.LoadIndex(context.TODO()); err != nil { t.Fatalf("error loading new index: %v", err) } } diff --git a/src/restic/repository/repository.go b/src/restic/repository/repository.go index 4b93db721..1c158be88 100644 --- a/src/restic/repository/repository.go +++ b/src/restic/repository/repository.go @@ -2,6 +2,7 @@ package repository import ( "bytes" + "context" "encoding/json" "fmt" "os" @@ -50,11 +51,11 @@ func (r *Repository) PrefixLength(t restic.FileType) (int, error) { // LoadAndDecrypt loads and decrypts data identified by t and id from the // backend. -func (r *Repository) LoadAndDecrypt(t restic.FileType, id restic.ID) ([]byte, error) { +func (r *Repository) LoadAndDecrypt(ctx context.Context, t restic.FileType, id restic.ID) ([]byte, error) { debug.Log("load %v with id %v", t, id.Str()) h := restic.Handle{Type: t, Name: id.String()} - buf, err := backend.LoadAll(r.be, h) + buf, err := backend.LoadAll(ctx, r.be, h) if err != nil { debug.Log("error loading %v: %v", h, err) return nil, err @@ -76,7 +77,7 @@ func (r *Repository) LoadAndDecrypt(t restic.FileType, id restic.ID) ([]byte, er // loadBlob tries to load and decrypt content identified by t and id from a // pack from the backend, the result is stored in plaintextBuf, which must be // large enough to hold the complete blob. -func (r *Repository) loadBlob(id restic.ID, t restic.BlobType, plaintextBuf []byte) (int, error) { +func (r *Repository) loadBlob(ctx context.Context, id restic.ID, t restic.BlobType, plaintextBuf []byte) (int, error) { debug.Log("load %v with id %v (buf len %v, cap %d)", t, id.Str(), len(plaintextBuf), cap(plaintextBuf)) // lookup packs @@ -103,7 +104,7 @@ func (r *Repository) loadBlob(id restic.ID, t restic.BlobType, plaintextBuf []by plaintextBuf = plaintextBuf[:blob.Length] - n, err := restic.ReadAt(r.be, h, int64(blob.Offset), plaintextBuf) + n, err := restic.ReadAt(ctx, r.be, h, int64(blob.Offset), plaintextBuf) if err != nil { debug.Log("error loading blob %v: %v", blob, err) lastError = err @@ -143,8 +144,8 @@ func (r *Repository) loadBlob(id restic.ID, t restic.BlobType, plaintextBuf []by // LoadJSONUnpacked decrypts the data and afterwards calls json.Unmarshal on // the item. -func (r *Repository) LoadJSONUnpacked(t restic.FileType, id restic.ID, item interface{}) (err error) { - buf, err := r.LoadAndDecrypt(t, id) +func (r *Repository) LoadJSONUnpacked(ctx context.Context, t restic.FileType, id restic.ID, item interface{}) (err error) { + buf, err := r.LoadAndDecrypt(ctx, t, id) if err != nil { return err } @@ -159,7 +160,7 @@ func (r *Repository) LookupBlobSize(id restic.ID, tpe restic.BlobType) (uint, er // SaveAndEncrypt encrypts data and stores it to the backend as type t. If data // is small enough, it will be packed together with other small blobs. -func (r *Repository) SaveAndEncrypt(t restic.BlobType, data []byte, id *restic.ID) (restic.ID, error) { +func (r *Repository) SaveAndEncrypt(ctx context.Context, t restic.BlobType, data []byte, id *restic.ID) (restic.ID, error) { if id == nil { // compute plaintext hash hashedID := restic.Hash(data) @@ -204,19 +205,19 @@ func (r *Repository) SaveAndEncrypt(t restic.BlobType, data []byte, id *restic.I // SaveJSONUnpacked serialises item as JSON and encrypts and saves it in the // backend as type t, without a pack. It returns the storage hash. -func (r *Repository) SaveJSONUnpacked(t restic.FileType, item interface{}) (restic.ID, error) { +func (r *Repository) SaveJSONUnpacked(ctx context.Context, t restic.FileType, item interface{}) (restic.ID, error) { debug.Log("save new blob %v", t) plaintext, err := json.Marshal(item) if err != nil { return restic.ID{}, errors.Wrap(err, "json.Marshal") } - return r.SaveUnpacked(t, plaintext) + return r.SaveUnpacked(ctx, t, plaintext) } // SaveUnpacked encrypts data and stores it in the backend. Returned is the // storage hash. -func (r *Repository) SaveUnpacked(t restic.FileType, p []byte) (id restic.ID, err error) { +func (r *Repository) SaveUnpacked(ctx context.Context, t restic.FileType, p []byte) (id restic.ID, err error) { ciphertext := restic.NewBlobBuffer(len(p)) ciphertext, err = r.Encrypt(ciphertext, p) if err != nil { @@ -226,7 +227,7 @@ func (r *Repository) SaveUnpacked(t restic.FileType, p []byte) (id restic.ID, er id = restic.Hash(ciphertext) h := restic.Handle{Type: t, Name: id.String()} - err = r.be.Save(h, bytes.NewReader(ciphertext)) + err = r.be.Save(ctx, h, bytes.NewReader(ciphertext)) if err != nil { debug.Log("error saving blob %v: %v", h, err) return restic.ID{}, err @@ -269,7 +270,7 @@ func (r *Repository) SetIndex(i restic.Index) { } // SaveIndex saves an index in the repository. -func SaveIndex(repo restic.Repository, index *Index) (restic.ID, error) { +func SaveIndex(ctx context.Context, repo restic.Repository, index *Index) (restic.ID, error) { buf := bytes.NewBuffer(nil) err := index.Finalize(buf) @@ -277,15 +278,15 @@ func SaveIndex(repo restic.Repository, index *Index) (restic.ID, error) { return restic.ID{}, err } - return repo.SaveUnpacked(restic.IndexFile, buf.Bytes()) + return repo.SaveUnpacked(ctx, restic.IndexFile, buf.Bytes()) } // saveIndex saves all indexes in the backend. -func (r *Repository) saveIndex(indexes ...*Index) error { +func (r *Repository) saveIndex(ctx context.Context, indexes ...*Index) error { for i, idx := range indexes { debug.Log("Saving index %d", i) - sid, err := SaveIndex(r, idx) + sid, err := SaveIndex(ctx, r, idx) if err != nil { return err } @@ -297,34 +298,34 @@ func (r *Repository) saveIndex(indexes ...*Index) error { } // SaveIndex saves all new indexes in the backend. -func (r *Repository) SaveIndex() error { - return r.saveIndex(r.idx.NotFinalIndexes()...) +func (r *Repository) SaveIndex(ctx context.Context) error { + return r.saveIndex(ctx, r.idx.NotFinalIndexes()...) } // SaveFullIndex saves all full indexes in the backend. -func (r *Repository) SaveFullIndex() error { - return r.saveIndex(r.idx.FullIndexes()...) +func (r *Repository) SaveFullIndex(ctx context.Context) error { + return r.saveIndex(ctx, r.idx.FullIndexes()...) } const loadIndexParallelism = 20 // LoadIndex loads all index files from the backend in parallel and stores them // in the master index. The first error that occurred is returned. -func (r *Repository) LoadIndex() error { +func (r *Repository) LoadIndex(ctx context.Context) error { debug.Log("Loading index") errCh := make(chan error, 1) indexes := make(chan *Index) - worker := func(id restic.ID, done <-chan struct{}) error { - idx, err := LoadIndex(r, id) + worker := func(ctx context.Context, id restic.ID) error { + idx, err := LoadIndex(ctx, r, id) if err != nil { return err } select { case indexes <- idx: - case <-done: + case <-ctx.Done(): } return nil @@ -332,7 +333,7 @@ func (r *Repository) LoadIndex() error { go func() { defer close(indexes) - errCh <- FilesInParallel(r.be, restic.IndexFile, loadIndexParallelism, + errCh <- FilesInParallel(ctx, r.be, restic.IndexFile, loadIndexParallelism, ParallelWorkFuncParseID(worker)) }() @@ -348,15 +349,15 @@ func (r *Repository) LoadIndex() error { } // LoadIndex loads the index id from backend and returns it. -func LoadIndex(repo restic.Repository, id restic.ID) (*Index, error) { - idx, err := LoadIndexWithDecoder(repo, id, DecodeIndex) +func LoadIndex(ctx context.Context, repo restic.Repository, id restic.ID) (*Index, error) { + idx, err := LoadIndexWithDecoder(ctx, repo, id, DecodeIndex) if err == nil { return idx, nil } if errors.Cause(err) == ErrOldIndexFormat { fmt.Fprintf(os.Stderr, "index %v has old format\n", id.Str()) - return LoadIndexWithDecoder(repo, id, DecodeOldIndex) + return LoadIndexWithDecoder(ctx, repo, id, DecodeOldIndex) } return nil, err @@ -364,8 +365,8 @@ func LoadIndex(repo restic.Repository, id restic.ID) (*Index, error) { // SearchKey finds a key with the supplied password, afterwards the config is // read and parsed. It tries at most maxKeys key files in the repo. -func (r *Repository) SearchKey(password string, maxKeys int) error { - key, err := SearchKey(r, password, maxKeys) +func (r *Repository) SearchKey(ctx context.Context, password string, maxKeys int) error { + key, err := SearchKey(ctx, r, password, maxKeys) if err != nil { return err } @@ -373,14 +374,14 @@ func (r *Repository) SearchKey(password string, maxKeys int) error { r.key = key.master r.packerManager.key = key.master r.keyName = key.Name() - r.cfg, err = restic.LoadConfig(r) + r.cfg, err = restic.LoadConfig(ctx, r) return err } // Init creates a new master key with the supplied password, initializes and // saves the repository config. -func (r *Repository) Init(password string) error { - has, err := r.be.Test(restic.Handle{Type: restic.ConfigFile}) +func (r *Repository) Init(ctx context.Context, password string) error { + has, err := r.be.Test(ctx, restic.Handle{Type: restic.ConfigFile}) if err != nil { return err } @@ -393,12 +394,12 @@ func (r *Repository) Init(password string) error { return err } - return r.init(password, cfg) + return r.init(ctx, password, cfg) } // init creates a new master key with the supplied password and uses it to save // the config into the repo. -func (r *Repository) init(password string, cfg restic.Config) error { +func (r *Repository) init(ctx context.Context, password string, cfg restic.Config) error { key, err := createMasterKey(r, password) if err != nil { return err @@ -408,7 +409,7 @@ func (r *Repository) init(password string, cfg restic.Config) error { r.packerManager.key = key.master r.keyName = key.Name() r.cfg = cfg - _, err = r.SaveJSONUnpacked(restic.ConfigFile, cfg) + _, err = r.SaveJSONUnpacked(ctx, restic.ConfigFile, cfg) return err } @@ -443,15 +444,15 @@ func (r *Repository) KeyName() string { } // List returns a channel that yields all IDs of type t in the backend. -func (r *Repository) List(t restic.FileType, done <-chan struct{}) <-chan restic.ID { +func (r *Repository) List(ctx context.Context, t restic.FileType) <-chan restic.ID { out := make(chan restic.ID) go func() { defer close(out) - for strID := range r.be.List(t, done) { + for strID := range r.be.List(ctx, t) { if id, err := restic.ParseID(strID); err == nil { select { case out <- id: - case <-done: + case <-ctx.Done(): return } } @@ -462,10 +463,10 @@ func (r *Repository) List(t restic.FileType, done <-chan struct{}) <-chan restic // ListPack returns the list of blobs saved in the pack id and the length of // the file as stored in the backend. -func (r *Repository) ListPack(id restic.ID) ([]restic.Blob, int64, error) { +func (r *Repository) ListPack(ctx context.Context, id restic.ID) ([]restic.Blob, int64, error) { h := restic.Handle{Type: restic.DataFile, Name: id.String()} - blobInfo, err := r.Backend().Stat(h) + blobInfo, err := r.Backend().Stat(ctx, h) if err != nil { return nil, 0, err } @@ -480,9 +481,9 @@ func (r *Repository) ListPack(id restic.ID) ([]restic.Blob, int64, error) { // Delete calls backend.Delete() if implemented, and returns an error // otherwise. -func (r *Repository) Delete() error { +func (r *Repository) Delete(ctx context.Context) error { if b, ok := r.be.(restic.Deleter); ok { - return b.Delete() + return b.Delete(ctx) } return errors.New("Delete() called for backend that does not implement this method") @@ -496,7 +497,7 @@ func (r *Repository) Close() error { // LoadBlob loads a blob of type t from the repository to the buffer. buf must // be large enough to hold the encrypted blob, since it is used as scratch // space. -func (r *Repository) LoadBlob(t restic.BlobType, id restic.ID, buf []byte) (int, error) { +func (r *Repository) LoadBlob(ctx context.Context, t restic.BlobType, id restic.ID, buf []byte) (int, error) { debug.Log("load blob %v into buf (len %v, cap %v)", id.Str(), len(buf), cap(buf)) size, err := r.idx.LookupSize(id, t) if err != nil { @@ -507,7 +508,7 @@ func (r *Repository) LoadBlob(t restic.BlobType, id restic.ID, buf []byte) (int, return 0, errors.Errorf("buffer is too small for data blob (%d < %d)", cap(buf), restic.CiphertextLength(int(size))) } - n, err := r.loadBlob(id, t, buf) + n, err := r.loadBlob(ctx, id, t, buf) if err != nil { return 0, err } @@ -520,16 +521,16 @@ func (r *Repository) LoadBlob(t restic.BlobType, id restic.ID, buf []byte) (int, // SaveBlob saves a blob of type t into the repository. If id is the null id, it // will be computed and returned. -func (r *Repository) SaveBlob(t restic.BlobType, buf []byte, id restic.ID) (restic.ID, error) { +func (r *Repository) SaveBlob(ctx context.Context, t restic.BlobType, buf []byte, id restic.ID) (restic.ID, error) { var i *restic.ID if !id.IsNull() { i = &id } - return r.SaveAndEncrypt(t, buf, i) + return r.SaveAndEncrypt(ctx, t, buf, i) } // LoadTree loads a tree from the repository. -func (r *Repository) LoadTree(id restic.ID) (*restic.Tree, error) { +func (r *Repository) LoadTree(ctx context.Context, id restic.ID) (*restic.Tree, error) { debug.Log("load tree %v", id.Str()) size, err := r.idx.LookupSize(id, restic.TreeBlob) @@ -540,7 +541,7 @@ func (r *Repository) LoadTree(id restic.ID) (*restic.Tree, error) { debug.Log("size is %d, create buffer", size) buf := restic.NewBlobBuffer(int(size)) - n, err := r.loadBlob(id, restic.TreeBlob, buf) + n, err := r.loadBlob(ctx, id, restic.TreeBlob, buf) if err != nil { return nil, err } @@ -558,7 +559,7 @@ func (r *Repository) LoadTree(id restic.ID) (*restic.Tree, error) { // SaveTree stores a tree into the repository and returns the ID. The ID is // checked against the index. The tree is only stored when the index does not // contain the ID. -func (r *Repository) SaveTree(t *restic.Tree) (restic.ID, error) { +func (r *Repository) SaveTree(ctx context.Context, t *restic.Tree) (restic.ID, error) { buf, err := json.Marshal(t) if err != nil { return restic.ID{}, errors.Wrap(err, "MarshalJSON") @@ -573,6 +574,6 @@ func (r *Repository) SaveTree(t *restic.Tree) (restic.ID, error) { return id, nil } - _, err = r.SaveBlob(restic.TreeBlob, buf, id) + _, err = r.SaveBlob(ctx, restic.TreeBlob, buf, id) return id, err } diff --git a/src/restic/repository/repository_test.go b/src/restic/repository/repository_test.go index 6ee99f265..cc9e0ab64 100644 --- a/src/restic/repository/repository_test.go +++ b/src/restic/repository/repository_test.go @@ -2,6 +2,7 @@ package repository_test import ( "bytes" + "context" "crypto/sha256" "io" "math/rand" @@ -31,7 +32,7 @@ func TestSave(t *testing.T) { id := restic.Hash(data) // save - sid, err := repo.SaveBlob(restic.DataBlob, data, restic.ID{}) + sid, err := repo.SaveBlob(context.TODO(), restic.DataBlob, data, restic.ID{}) OK(t, err) Equals(t, id, sid) @@ -41,7 +42,7 @@ func TestSave(t *testing.T) { // read back buf := restic.NewBlobBuffer(size) - n, err := repo.LoadBlob(restic.DataBlob, id, buf) + n, err := repo.LoadBlob(context.TODO(), restic.DataBlob, id, buf) OK(t, err) Equals(t, len(buf), n) @@ -67,7 +68,7 @@ func TestSaveFrom(t *testing.T) { id := restic.Hash(data) // save - id2, err := repo.SaveBlob(restic.DataBlob, data, id) + id2, err := repo.SaveBlob(context.TODO(), restic.DataBlob, data, id) OK(t, err) Equals(t, id, id2) @@ -75,7 +76,7 @@ func TestSaveFrom(t *testing.T) { // read back buf := restic.NewBlobBuffer(size) - n, err := repo.LoadBlob(restic.DataBlob, id, buf) + n, err := repo.LoadBlob(context.TODO(), restic.DataBlob, id, buf) OK(t, err) Equals(t, len(buf), n) @@ -106,7 +107,7 @@ func BenchmarkSaveAndEncrypt(t *testing.B) { for i := 0; i < t.N; i++ { // save - _, err = repo.SaveBlob(restic.DataBlob, data, id) + _, err = repo.SaveBlob(context.TODO(), restic.DataBlob, data, id) OK(t, err) } } @@ -123,7 +124,7 @@ func TestLoadTree(t *testing.T) { sn := archiver.TestSnapshot(t, repo, BenchArchiveDirectory, nil) OK(t, repo.Flush()) - _, err := repo.LoadTree(*sn.Tree) + _, err := repo.LoadTree(context.TODO(), *sn.Tree) OK(t, err) } @@ -142,7 +143,7 @@ func BenchmarkLoadTree(t *testing.B) { t.ResetTimer() for i := 0; i < t.N; i++ { - _, err := repo.LoadTree(*sn.Tree) + _, err := repo.LoadTree(context.TODO(), *sn.Tree) OK(t, err) } } @@ -156,14 +157,14 @@ func TestLoadBlob(t *testing.T) { _, err := io.ReadFull(rnd, buf) OK(t, err) - id, err := repo.SaveBlob(restic.DataBlob, buf, restic.ID{}) + id, err := repo.SaveBlob(context.TODO(), restic.DataBlob, buf, restic.ID{}) OK(t, err) OK(t, repo.Flush()) // first, test with buffers that are too small for _, testlength := range []int{length - 20, length, restic.CiphertextLength(length) - 1} { buf = make([]byte, 0, testlength) - n, err := repo.LoadBlob(restic.DataBlob, id, buf) + n, err := repo.LoadBlob(context.TODO(), restic.DataBlob, id, buf) if err == nil { t.Errorf("LoadBlob() did not return an error for a buffer that is too small to hold the blob") continue @@ -179,7 +180,7 @@ func TestLoadBlob(t *testing.T) { base := restic.CiphertextLength(length) for _, testlength := range []int{base, base + 7, base + 15, base + 1000} { buf = make([]byte, 0, testlength) - n, err := repo.LoadBlob(restic.DataBlob, id, buf) + n, err := repo.LoadBlob(context.TODO(), restic.DataBlob, id, buf) if err != nil { t.Errorf("LoadBlob() returned an error for buffer size %v: %v", testlength, err) continue @@ -201,7 +202,7 @@ func BenchmarkLoadBlob(b *testing.B) { _, err := io.ReadFull(rnd, buf) OK(b, err) - id, err := repo.SaveBlob(restic.DataBlob, buf, restic.ID{}) + id, err := repo.SaveBlob(context.TODO(), restic.DataBlob, buf, restic.ID{}) OK(b, err) OK(b, repo.Flush()) @@ -209,7 +210,7 @@ func BenchmarkLoadBlob(b *testing.B) { b.SetBytes(int64(length)) for i := 0; i < b.N; i++ { - n, err := repo.LoadBlob(restic.DataBlob, id, buf) + n, err := repo.LoadBlob(context.TODO(), restic.DataBlob, id, buf) OK(b, err) if n != length { b.Errorf("wanted %d bytes, got %d", length, n) @@ -233,7 +234,7 @@ func BenchmarkLoadAndDecrypt(b *testing.B) { dataID := restic.Hash(buf) - storageID, err := repo.SaveUnpacked(restic.DataFile, buf) + storageID, err := repo.SaveUnpacked(context.TODO(), restic.DataFile, buf) OK(b, err) // OK(b, repo.Flush()) @@ -241,7 +242,7 @@ func BenchmarkLoadAndDecrypt(b *testing.B) { b.SetBytes(int64(length)) for i := 0; i < b.N; i++ { - data, err := repo.LoadAndDecrypt(restic.DataFile, storageID) + data, err := repo.LoadAndDecrypt(context.TODO(), restic.DataFile, storageID) OK(b, err) if len(data) != length { b.Errorf("wanted %d bytes, got %d", length, len(data)) @@ -267,13 +268,13 @@ func TestLoadJSONUnpacked(t *testing.T) { sn.Hostname = "foobar" sn.Username = "test!" - id, err := repo.SaveJSONUnpacked(restic.SnapshotFile, &sn) + id, err := repo.SaveJSONUnpacked(context.TODO(), restic.SnapshotFile, &sn) OK(t, err) var sn2 restic.Snapshot // restore - err = repo.LoadJSONUnpacked(restic.SnapshotFile, id, &sn2) + err = repo.LoadJSONUnpacked(context.TODO(), restic.SnapshotFile, id, &sn2) OK(t, err) Equals(t, sn.Hostname, sn2.Hostname) @@ -287,7 +288,7 @@ func TestRepositoryLoadIndex(t *testing.T) { defer cleanup() repo := repository.TestOpenLocal(t, repodir) - OK(t, repo.LoadIndex()) + OK(t, repo.LoadIndex(context.TODO())) } func BenchmarkLoadIndex(b *testing.B) { @@ -310,18 +311,18 @@ func BenchmarkLoadIndex(b *testing.B) { }) } - id, err := repository.SaveIndex(repo, idx) + id, err := repository.SaveIndex(context.TODO(), repo, idx) OK(b, err) b.Logf("index saved as %v (%v entries)", id.Str(), idx.Count(restic.DataBlob)) - fi, err := repo.Backend().Stat(restic.Handle{Type: restic.IndexFile, Name: id.String()}) + fi, err := repo.Backend().Stat(context.TODO(), restic.Handle{Type: restic.IndexFile, Name: id.String()}) OK(b, err) b.Logf("filesize is %v", fi.Size) b.ResetTimer() for i := 0; i < b.N; i++ { - _, err := repository.LoadIndex(repo, id) + _, err := repository.LoadIndex(context.TODO(), repo, id) OK(b, err) } } @@ -335,7 +336,7 @@ func saveRandomDataBlobs(t testing.TB, repo restic.Repository, num int, sizeMax _, err := io.ReadFull(rnd, buf) OK(t, err) - _, err = repo.SaveBlob(restic.DataBlob, buf, restic.ID{}) + _, err = repo.SaveBlob(context.TODO(), restic.DataBlob, buf, restic.ID{}) OK(t, err) } } @@ -354,7 +355,7 @@ func TestRepositoryIncrementalIndex(t *testing.T) { OK(t, repo.Flush()) } - OK(t, repo.SaveFullIndex()) + OK(t, repo.SaveFullIndex(context.TODO())) } // add another 5 packs @@ -364,12 +365,12 @@ func TestRepositoryIncrementalIndex(t *testing.T) { } // save final index - OK(t, repo.SaveIndex()) + OK(t, repo.SaveIndex(context.TODO())) packEntries := make(map[restic.ID]map[restic.ID]struct{}) - for id := range repo.List(restic.IndexFile, nil) { - idx, err := repository.LoadIndex(repo, id) + for id := range repo.List(context.TODO(), restic.IndexFile) { + idx, err := repository.LoadIndex(context.TODO(), repo, id) OK(t, err) for pb := range idx.Each(nil) { diff --git a/src/restic/repository/testing.go b/src/restic/repository/testing.go index a24912257..ab78bdad3 100644 --- a/src/restic/repository/testing.go +++ b/src/restic/repository/testing.go @@ -1,6 +1,7 @@ package repository import ( + "context" "os" "restic" "restic/backend/local" @@ -50,7 +51,7 @@ func TestRepositoryWithBackend(t testing.TB, be restic.Backend) (r restic.Reposi repo := New(be) cfg := restic.TestCreateConfig(t, testChunkerPol) - err := repo.init(test.TestPassword, cfg) + err := repo.init(context.TODO(), test.TestPassword, cfg) if err != nil { t.Fatalf("TestRepository(): initialize repo failed: %v", err) } @@ -94,7 +95,7 @@ func TestOpenLocal(t testing.TB, dir string) (r restic.Repository) { } repo := New(be) - err = repo.SearchKey(test.TestPassword, 10) + err = repo.SearchKey(context.TODO(), test.TestPassword, 10) if err != nil { t.Fatal(err) } diff --git a/src/restic/restorer.go b/src/restic/restorer.go index 56916f3ce..3b7f8fc83 100644 --- a/src/restic/restorer.go +++ b/src/restic/restorer.go @@ -1,6 +1,7 @@ package restic import ( + "context" "os" "path/filepath" @@ -30,7 +31,7 @@ func NewRestorer(repo Repository, id ID) (*Restorer, error) { var err error - r.sn, err = LoadSnapshot(repo, id) + r.sn, err = LoadSnapshot(context.TODO(), repo, id) if err != nil { return nil, err } @@ -38,8 +39,8 @@ func NewRestorer(repo Repository, id ID) (*Restorer, error) { return r, nil } -func (res *Restorer) restoreTo(dst string, dir string, treeID ID, idx *HardlinkIndex) error { - tree, err := res.repo.LoadTree(treeID) +func (res *Restorer) restoreTo(ctx context.Context, dst string, dir string, treeID ID, idx *HardlinkIndex) error { + tree, err := res.repo.LoadTree(ctx, treeID) if err != nil { return res.Error(dir, nil, err) } @@ -50,7 +51,7 @@ func (res *Restorer) restoreTo(dst string, dir string, treeID ID, idx *HardlinkI debug.Log("SelectForRestore returned %v", selectedForRestore) if selectedForRestore { - err := res.restoreNodeTo(node, dir, dst, idx) + err := res.restoreNodeTo(ctx, node, dir, dst, idx) if err != nil { return err } @@ -62,7 +63,7 @@ func (res *Restorer) restoreTo(dst string, dir string, treeID ID, idx *HardlinkI } subp := filepath.Join(dir, node.Name) - err = res.restoreTo(dst, subp, *node.Subtree, idx) + err = res.restoreTo(ctx, dst, subp, *node.Subtree, idx) if err != nil { err = res.Error(subp, node, err) if err != nil { @@ -83,11 +84,11 @@ func (res *Restorer) restoreTo(dst string, dir string, treeID ID, idx *HardlinkI return nil } -func (res *Restorer) restoreNodeTo(node *Node, dir string, dst string, idx *HardlinkIndex) error { +func (res *Restorer) restoreNodeTo(ctx context.Context, node *Node, dir string, dst string, idx *HardlinkIndex) error { debug.Log("node %v, dir %v, dst %v", node.Name, dir, dst) dstPath := filepath.Join(dst, dir, node.Name) - err := node.CreateAt(dstPath, res.repo, idx) + err := node.CreateAt(ctx, dstPath, res.repo, idx) if err != nil { debug.Log("node.CreateAt(%s) error %v", dstPath, err) } @@ -99,7 +100,7 @@ func (res *Restorer) restoreNodeTo(node *Node, dir string, dst string, idx *Hard // Create parent directories and retry err = fs.MkdirAll(filepath.Dir(dstPath), 0700) if err == nil || os.IsExist(errors.Cause(err)) { - err = node.CreateAt(dstPath, res.repo, idx) + err = node.CreateAt(ctx, dstPath, res.repo, idx) } } @@ -118,9 +119,9 @@ func (res *Restorer) restoreNodeTo(node *Node, dir string, dst string, idx *Hard // RestoreTo creates the directories and files in the snapshot below dst. // Before an item is created, res.Filter is called. -func (res *Restorer) RestoreTo(dst string) error { +func (res *Restorer) RestoreTo(ctx context.Context, dst string) error { idx := NewHardlinkIndex() - return res.restoreTo(dst, string(filepath.Separator), *res.sn.Tree, idx) + return res.restoreTo(ctx, dst, string(filepath.Separator), *res.sn.Tree, idx) } // Snapshot returns the snapshot this restorer is configured to use. diff --git a/src/restic/snapshot.go b/src/restic/snapshot.go index 343dfe411..f56d0bdd0 100644 --- a/src/restic/snapshot.go +++ b/src/restic/snapshot.go @@ -1,6 +1,7 @@ package restic import ( + "context" "fmt" "os/user" "path/filepath" @@ -51,9 +52,9 @@ func NewSnapshot(paths []string, tags []string, hostname string) (*Snapshot, err } // LoadSnapshot loads the snapshot with the id and returns it. -func LoadSnapshot(repo Repository, id ID) (*Snapshot, error) { +func LoadSnapshot(ctx context.Context, repo Repository, id ID) (*Snapshot, error) { sn := &Snapshot{id: &id} - err := repo.LoadJSONUnpacked(SnapshotFile, id, sn) + err := repo.LoadJSONUnpacked(ctx, SnapshotFile, id, sn) if err != nil { return nil, err } @@ -62,12 +63,9 @@ func LoadSnapshot(repo Repository, id ID) (*Snapshot, error) { } // LoadAllSnapshots returns a list of all snapshots in the repo. -func LoadAllSnapshots(repo Repository) (snapshots []*Snapshot, err error) { - done := make(chan struct{}) - defer close(done) - - for id := range repo.List(SnapshotFile, done) { - sn, err := LoadSnapshot(repo, id) +func LoadAllSnapshots(ctx context.Context, repo Repository) (snapshots []*Snapshot, err error) { + for id := range repo.List(ctx, SnapshotFile) { + sn, err := LoadSnapshot(ctx, repo, id) if err != nil { return nil, err } @@ -178,15 +176,15 @@ func (sn *Snapshot) SamePaths(paths []string) bool { var ErrNoSnapshotFound = errors.New("no snapshot found") // FindLatestSnapshot finds latest snapshot with optional target/directory, tags and hostname filters. -func FindLatestSnapshot(repo Repository, targets []string, tags []string, hostname string) (ID, error) { +func FindLatestSnapshot(ctx context.Context, repo Repository, targets []string, tags []string, hostname string) (ID, error) { var ( latest time.Time latestID ID found bool ) - for snapshotID := range repo.List(SnapshotFile, make(chan struct{})) { - snapshot, err := LoadSnapshot(repo, snapshotID) + for snapshotID := range repo.List(ctx, SnapshotFile) { + snapshot, err := LoadSnapshot(ctx, repo, snapshotID) if err != nil { return ID{}, errors.Errorf("Error listing snapshot: %v", err) } diff --git a/src/restic/testing.go b/src/restic/testing.go index 144f53bd1..af0a81233 100644 --- a/src/restic/testing.go +++ b/src/restic/testing.go @@ -1,6 +1,7 @@ package restic import ( + "context" "encoding/json" "fmt" "io" @@ -29,7 +30,7 @@ type fakeFileSystem struct { // saveFile reads from rd and saves the blobs in the repository. The list of // IDs is returned. -func (fs *fakeFileSystem) saveFile(rd io.Reader) (blobs IDs) { +func (fs *fakeFileSystem) saveFile(ctx context.Context, rd io.Reader) (blobs IDs) { if fs.buf == nil { fs.buf = make([]byte, chunker.MaxSize) } @@ -53,7 +54,7 @@ func (fs *fakeFileSystem) saveFile(rd io.Reader) (blobs IDs) { id := Hash(chunk.Data) if !fs.blobIsKnown(id, DataBlob) { - _, err := fs.repo.SaveBlob(DataBlob, chunk.Data, id) + _, err := fs.repo.SaveBlob(ctx, DataBlob, chunk.Data, id) if err != nil { fs.t.Fatalf("error saving chunk: %v", err) } @@ -103,7 +104,7 @@ func (fs *fakeFileSystem) blobIsKnown(id ID, t BlobType) bool { } // saveTree saves a tree of fake files in the repo and returns the ID. -func (fs *fakeFileSystem) saveTree(seed int64, depth int) ID { +func (fs *fakeFileSystem) saveTree(ctx context.Context, seed int64, depth int) ID { rnd := rand.NewSource(seed) numNodes := int(rnd.Int63() % maxNodes) @@ -113,7 +114,7 @@ func (fs *fakeFileSystem) saveTree(seed int64, depth int) ID { // randomly select the type of the node, either tree (p = 1/4) or file (p = 3/4). if depth > 1 && rnd.Int63()%4 == 0 { treeSeed := rnd.Int63() % maxSeed - id := fs.saveTree(treeSeed, depth-1) + id := fs.saveTree(ctx, treeSeed, depth-1) node := &Node{ Name: fmt.Sprintf("dir-%v", treeSeed), @@ -136,7 +137,7 @@ func (fs *fakeFileSystem) saveTree(seed int64, depth int) ID { Size: uint64(fileSize), } - node.Content = fs.saveFile(fakeFile(fs.t, fileSeed, fileSize)) + node.Content = fs.saveFile(ctx, fakeFile(fs.t, fileSeed, fileSize)) tree.Nodes = append(tree.Nodes, node) } @@ -145,7 +146,7 @@ func (fs *fakeFileSystem) saveTree(seed int64, depth int) ID { return id } - _, err := fs.repo.SaveBlob(TreeBlob, buf, id) + _, err := fs.repo.SaveBlob(ctx, TreeBlob, buf, id) if err != nil { fs.t.Fatal(err) } @@ -176,10 +177,10 @@ func TestCreateSnapshot(t testing.TB, repo Repository, at time.Time, depth int, duplication: duplication, } - treeID := fs.saveTree(seed, depth) + treeID := fs.saveTree(context.TODO(), seed, depth) snapshot.Tree = &treeID - id, err := repo.SaveJSONUnpacked(SnapshotFile, snapshot) + id, err := repo.SaveJSONUnpacked(context.TODO(), SnapshotFile, snapshot) if err != nil { t.Fatal(err) } @@ -193,7 +194,7 @@ func TestCreateSnapshot(t testing.TB, repo Repository, at time.Time, depth int, t.Fatal(err) } - err = repo.SaveIndex() + err = repo.SaveIndex(context.TODO()) if err != nil { t.Fatal(err) } diff --git a/src/restic/testing_test.go b/src/restic/testing_test.go index 86b18a001..29b426623 100644 --- a/src/restic/testing_test.go +++ b/src/restic/testing_test.go @@ -1,6 +1,7 @@ package restic_test import ( + "context" "restic" "restic/checker" "restic/repository" @@ -23,7 +24,7 @@ func TestCreateSnapshot(t *testing.T) { restic.TestCreateSnapshot(t, repo, testSnapshotTime.Add(time.Duration(i)*time.Second), testDepth, 0) } - snapshots, err := restic.LoadAllSnapshots(repo) + snapshots, err := restic.LoadAllSnapshots(context.TODO(), repo) if err != nil { t.Fatal(err) } diff --git a/src/restic/tree_test.go b/src/restic/tree_test.go index 0bf7cfddc..dbdd20d20 100644 --- a/src/restic/tree_test.go +++ b/src/restic/tree_test.go @@ -1,6 +1,7 @@ package restic_test import ( + "context" "encoding/json" "io/ioutil" "os" @@ -98,14 +99,14 @@ func TestLoadTree(t *testing.T) { // save tree tree := restic.NewTree() - id, err := repo.SaveTree(tree) + id, err := repo.SaveTree(context.TODO(), tree) OK(t, err) // save packs OK(t, repo.Flush()) // load tree again - tree2, err := repo.LoadTree(id) + tree2, err := repo.LoadTree(context.TODO(), id) OK(t, err) Assert(t, tree.Equals(tree2), diff --git a/src/restic/walk/walk.go b/src/restic/walk/walk.go index 8e8e4b536..6ebcb6629 100644 --- a/src/restic/walk/walk.go +++ b/src/restic/walk/walk.go @@ -1,6 +1,7 @@ package walk import ( + "context" "fmt" "os" "path/filepath" @@ -34,7 +35,7 @@ func NewTreeWalker(ch chan<- loadTreeJob, out chan<- TreeJob) *TreeWalker { // Walk starts walking the tree given by id. When the channel done is closed, // processing stops. -func (tw *TreeWalker) Walk(path string, id restic.ID, done chan struct{}) { +func (tw *TreeWalker) Walk(ctx context.Context, path string, id restic.ID) { debug.Log("starting on tree %v for %v", id.Str(), path) defer debug.Log("done walking tree %v for %v", id.Str(), path) @@ -48,22 +49,22 @@ func (tw *TreeWalker) Walk(path string, id restic.ID, done chan struct{}) { if res.err != nil { select { case tw.out <- TreeJob{Path: path, Error: res.err}: - case <-done: + case <-ctx.Done(): return } return } - tw.walk(path, res.tree, done) + tw.walk(ctx, path, res.tree) select { case tw.out <- TreeJob{Path: path, Tree: res.tree}: - case <-done: + case <-ctx.Done(): return } } -func (tw *TreeWalker) walk(path string, tree *restic.Tree, done chan struct{}) { +func (tw *TreeWalker) walk(ctx context.Context, path string, tree *restic.Tree) { debug.Log("start on %q", path) defer debug.Log("done for %q", path) @@ -94,7 +95,7 @@ func (tw *TreeWalker) walk(path string, tree *restic.Tree, done chan struct{}) { res := <-results[i] if res.err == nil { - tw.walk(p, res.tree, done) + tw.walk(ctx, p, res.tree) } else { fmt.Fprintf(os.Stderr, "error loading tree: %v\n", res.err) } @@ -106,7 +107,7 @@ func (tw *TreeWalker) walk(path string, tree *restic.Tree, done chan struct{}) { select { case tw.out <- job: - case <-done: + case <-ctx.Done(): return } } @@ -124,14 +125,14 @@ type loadTreeJob struct { type treeLoader func(restic.ID) (*restic.Tree, error) -func loadTreeWorker(wg *sync.WaitGroup, in <-chan loadTreeJob, load treeLoader, done <-chan struct{}) { +func loadTreeWorker(ctx context.Context, wg *sync.WaitGroup, in <-chan loadTreeJob, load treeLoader) { debug.Log("start") defer debug.Log("exit") defer wg.Done() for { select { - case <-done: + case <-ctx.Done(): debug.Log("done channel closed") return case job, ok := <-in: @@ -148,7 +149,7 @@ func loadTreeWorker(wg *sync.WaitGroup, in <-chan loadTreeJob, load treeLoader, select { case job.res <- loadTreeResult{tree, err}: debug.Log("job result sent") - case <-done: + case <-ctx.Done(): debug.Log("done channel closed before result could be sent") return } @@ -158,7 +159,7 @@ func loadTreeWorker(wg *sync.WaitGroup, in <-chan loadTreeJob, load treeLoader, // TreeLoader loads tree objects. type TreeLoader interface { - LoadTree(restic.ID) (*restic.Tree, error) + LoadTree(context.Context, restic.ID) (*restic.Tree, error) } const loadTreeWorkers = 10 @@ -166,11 +167,11 @@ const loadTreeWorkers = 10 // Tree walks the tree specified by id recursively and sends a job for each // file and directory it finds. When the channel done is closed, processing // stops. -func Tree(repo TreeLoader, id restic.ID, done chan struct{}, jobCh chan<- TreeJob) { +func Tree(ctx context.Context, repo TreeLoader, id restic.ID, jobCh chan<- TreeJob) { debug.Log("start on %v, start workers", id.Str()) load := func(id restic.ID) (*restic.Tree, error) { - tree, err := repo.LoadTree(id) + tree, err := repo.LoadTree(ctx, id) if err != nil { return nil, err } @@ -182,11 +183,11 @@ func Tree(repo TreeLoader, id restic.ID, done chan struct{}, jobCh chan<- TreeJo var wg sync.WaitGroup for i := 0; i < loadTreeWorkers; i++ { wg.Add(1) - go loadTreeWorker(&wg, ch, load, done) + go loadTreeWorker(ctx, &wg, ch, load) } tw := NewTreeWalker(ch, jobCh) - tw.Walk("", id, done) + tw.Walk(ctx, "", id) close(jobCh) close(ch) diff --git a/src/restic/walk/walk_test.go b/src/restic/walk/walk_test.go index 6a824827d..bdb5cf459 100644 --- a/src/restic/walk/walk_test.go +++ b/src/restic/walk/walk_test.go @@ -1,6 +1,7 @@ package walk_test import ( + "context" "os" "path/filepath" "strings" @@ -24,17 +25,15 @@ func TestWalkTree(t *testing.T) { // archive a few files arch := archiver.New(repo) - sn, _, err := arch.Snapshot(nil, dirs, nil, "localhost", nil) + sn, _, err := arch.Snapshot(context.TODO(), nil, dirs, nil, "localhost", nil) OK(t, err) // flush repo, write all packs OK(t, repo.Flush()) - done := make(chan struct{}) - // start tree walker treeJobs := make(chan walk.TreeJob) - go walk.Tree(repo, *sn.Tree, done, treeJobs) + go walk.Tree(context.TODO(), repo, *sn.Tree, treeJobs) // start filesystem walker fsJobs := make(chan pipe.Job) @@ -43,7 +42,7 @@ func TestWalkTree(t *testing.T) { f := func(string, os.FileInfo) bool { return true } - go pipe.Walk(dirs, f, done, fsJobs, resCh) + go pipe.Walk(context.TODO(), dirs, f, fsJobs, resCh) for { // receive fs job @@ -95,9 +94,9 @@ type delayRepo struct { delay time.Duration } -func (d delayRepo) LoadTree(id restic.ID) (*restic.Tree, error) { +func (d delayRepo) LoadTree(ctx context.Context, id restic.ID) (*restic.Tree, error) { time.Sleep(d.delay) - return d.repo.LoadTree(id) + return d.repo.LoadTree(ctx, id) } var repoFixture = filepath.Join("testdata", "walktree-test-repo.tar.gz") @@ -1345,7 +1344,7 @@ func TestDelayedWalkTree(t *testing.T) { defer cleanup() repo := repository.TestOpenLocal(t, repodir) - OK(t, repo.LoadIndex()) + OK(t, repo.LoadIndex(context.TODO())) root, err := restic.ParseID("937a2f64f736c64ee700c6ab06f840c68c94799c288146a0e81e07f4c94254da") OK(t, err) @@ -1354,7 +1353,7 @@ func TestDelayedWalkTree(t *testing.T) { // start tree walker treeJobs := make(chan walk.TreeJob) - go walk.Tree(dr, root, nil, treeJobs) + go walk.Tree(context.TODO(), dr, root, treeJobs) i := 0 for job := range treeJobs { @@ -1375,7 +1374,7 @@ func BenchmarkDelayedWalkTree(t *testing.B) { defer cleanup() repo := repository.TestOpenLocal(t, repodir) - OK(t, repo.LoadIndex()) + OK(t, repo.LoadIndex(context.TODO())) root, err := restic.ParseID("937a2f64f736c64ee700c6ab06f840c68c94799c288146a0e81e07f4c94254da") OK(t, err) @@ -1387,7 +1386,7 @@ func BenchmarkDelayedWalkTree(t *testing.B) { for i := 0; i < t.N; i++ { // start tree walker treeJobs := make(chan walk.TreeJob) - go walk.Tree(dr, root, nil, treeJobs) + go walk.Tree(context.TODO(), dr, root, treeJobs) for range treeJobs { } diff --git a/src/restic/worker/pool.go b/src/restic/worker/pool.go index 2268ef8f8..870548378 100644 --- a/src/restic/worker/pool.go +++ b/src/restic/worker/pool.go @@ -1,5 +1,7 @@ package worker +import "context" + // Job is one unit of work. It is given to a Func, and the returned result and // error are stored in Result and Error. type Job struct { @@ -9,12 +11,11 @@ type Job struct { } // Func does the actual work within a Pool. -type Func func(job Job, done <-chan struct{}) (result interface{}, err error) +type Func func(ctx context.Context, job Job) (result interface{}, err error) // Pool implements a worker pool. type Pool struct { f Func - done chan struct{} jobCh <-chan Job resCh chan<- Job @@ -25,10 +26,9 @@ type Pool struct { // New returns a new worker pool with n goroutines, each running the function // f. The workers are started immediately. -func New(n int, f Func, jobChan <-chan Job, resultChan chan<- Job) *Pool { +func New(ctx context.Context, n int, f Func, jobChan <-chan Job, resultChan chan<- Job) *Pool { p := &Pool{ f: f, - done: make(chan struct{}), workersExit: make(chan struct{}), allWorkersDone: make(chan struct{}), numWorkers: n, @@ -37,7 +37,7 @@ func New(n int, f Func, jobChan <-chan Job, resultChan chan<- Job) *Pool { } for i := 0; i < n; i++ { - go p.runWorker(i) + go p.runWorker(ctx, i) } go p.waitForExit() @@ -58,7 +58,7 @@ func (p *Pool) waitForExit() { } // runWorker runs a worker function. -func (p *Pool) runWorker(numWorker int) { +func (p *Pool) runWorker(ctx context.Context, numWorker int) { defer func() { p.workersExit <- struct{}{} }() @@ -75,7 +75,7 @@ func (p *Pool) runWorker(numWorker int) { for { select { - case <-p.done: + case <-ctx.Done(): return case job, ok = <-inCh: @@ -83,7 +83,7 @@ func (p *Pool) runWorker(numWorker int) { return } - job.Result, job.Error = p.f(job, p.done) + job.Result, job.Error = p.f(ctx, job) inCh = nil outCh = p.resCh diff --git a/src/restic/worker/pool_test.go b/src/restic/worker/pool_test.go index 9d6159b89..f36a98c2e 100644 --- a/src/restic/worker/pool_test.go +++ b/src/restic/worker/pool_test.go @@ -1,6 +1,7 @@ package worker_test import ( + "context" "testing" "restic/errors" @@ -12,7 +13,7 @@ const concurrency = 10 var errTooLarge = errors.New("too large") -func square(job worker.Job, done <-chan struct{}) (interface{}, error) { +func square(ctx context.Context, job worker.Job) (interface{}, error) { n := job.Data.(int) if n > 2000 { return nil, errTooLarge @@ -20,15 +21,15 @@ func square(job worker.Job, done <-chan struct{}) (interface{}, error) { return n * n, nil } -func newBufferedPool(bufsize int, n int, f worker.Func) (chan worker.Job, chan worker.Job, *worker.Pool) { +func newBufferedPool(ctx context.Context, bufsize int, n int, f worker.Func) (chan worker.Job, chan worker.Job, *worker.Pool) { inCh := make(chan worker.Job, bufsize) outCh := make(chan worker.Job, bufsize) - return inCh, outCh, worker.New(n, f, inCh, outCh) + return inCh, outCh, worker.New(ctx, n, f, inCh, outCh) } func TestPool(t *testing.T) { - inCh, outCh, p := newBufferedPool(200, concurrency, square) + inCh, outCh, p := newBufferedPool(context.TODO(), 200, concurrency, square) for i := 0; i < 150; i++ { inCh <- worker.Job{Data: i} @@ -53,7 +54,7 @@ func TestPool(t *testing.T) { } func TestPoolErrors(t *testing.T) { - inCh, outCh, p := newBufferedPool(200, concurrency, square) + inCh, outCh, p := newBufferedPool(context.TODO(), 200, concurrency, square) for i := 0; i < 150; i++ { inCh <- worker.Job{Data: i + 1900}