2
2
mirror of https://github.com/octoleo/restic.git synced 2024-11-29 16:23:59 +00:00

Add context to restic packages

This commit is contained in:
Alexander Neumann 2017-06-04 11:16:55 +02:00
parent 16fcd07110
commit cf497c2728
50 changed files with 432 additions and 422 deletions

View File

@ -2,6 +2,7 @@ package main
import ( import (
"bufio" "bufio"
"context"
"fmt" "fmt"
"io" "io"
"os" "os"
@ -263,7 +264,7 @@ func readBackupFromStdin(opts BackupOptions, gopts GlobalOptions, args []string)
return err return err
} }
err = repo.LoadIndex() err = repo.LoadIndex(context.TODO())
if err != nil { if err != nil {
return err return err
} }
@ -274,7 +275,7 @@ func readBackupFromStdin(opts BackupOptions, gopts GlobalOptions, args []string)
Hostname: opts.Hostname, Hostname: opts.Hostname,
} }
_, id, err := r.Archive(opts.StdinFilename, os.Stdin, newArchiveStdinProgress(gopts)) _, id, err := r.Archive(context.TODO(), opts.StdinFilename, os.Stdin, newArchiveStdinProgress(gopts))
if err != nil { if err != nil {
return err return err
} }
@ -372,7 +373,7 @@ func runBackup(opts BackupOptions, gopts GlobalOptions, args []string) error {
return err return err
} }
err = repo.LoadIndex() err = repo.LoadIndex(context.TODO())
if err != nil { if err != nil {
return err return err
} }
@ -391,7 +392,7 @@ func runBackup(opts BackupOptions, gopts GlobalOptions, args []string) error {
// Find last snapshot to set it as parent, if not already set // Find last snapshot to set it as parent, if not already set
if !opts.Force && parentSnapshotID == nil { if !opts.Force && parentSnapshotID == nil {
id, err := restic.FindLatestSnapshot(repo, target, opts.Tags, opts.Hostname) id, err := restic.FindLatestSnapshot(context.TODO(), repo, target, opts.Tags, opts.Hostname)
if err == nil { if err == nil {
parentSnapshotID = &id parentSnapshotID = &id
} else if err != restic.ErrNoSnapshotFound { } else if err != restic.ErrNoSnapshotFound {
@ -489,7 +490,7 @@ func runBackup(opts BackupOptions, gopts GlobalOptions, args []string) error {
Warnf("%s\rwarning for %s: %v\n", ClearLine(), dir, err) Warnf("%s\rwarning for %s: %v\n", ClearLine(), dir, err)
} }
_, id, err := arch.Snapshot(newArchiveProgress(gopts, stat), target, opts.Tags, opts.Hostname, parentSnapshotID) _, id, err := arch.Snapshot(context.TODO(), newArchiveProgress(gopts, stat), target, opts.Tags, opts.Hostname, parentSnapshotID)
if err != nil { if err != nil {
return err return err
} }

View File

@ -1,6 +1,7 @@
package main package main
import ( import (
"context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"os" "os"
@ -73,7 +74,7 @@ func runCat(gopts GlobalOptions, args []string) error {
fmt.Println(string(buf)) fmt.Println(string(buf))
return nil return nil
case "index": case "index":
buf, err := repo.LoadAndDecrypt(restic.IndexFile, id) buf, err := repo.LoadAndDecrypt(context.TODO(), restic.IndexFile, id)
if err != nil { if err != nil {
return err return err
} }
@ -83,7 +84,7 @@ func runCat(gopts GlobalOptions, args []string) error {
case "snapshot": case "snapshot":
sn := &restic.Snapshot{} sn := &restic.Snapshot{}
err = repo.LoadJSONUnpacked(restic.SnapshotFile, id, sn) err = repo.LoadJSONUnpacked(context.TODO(), restic.SnapshotFile, id, sn)
if err != nil { if err != nil {
return err return err
} }
@ -98,7 +99,7 @@ func runCat(gopts GlobalOptions, args []string) error {
return nil return nil
case "key": case "key":
h := restic.Handle{Type: restic.KeyFile, Name: id.String()} h := restic.Handle{Type: restic.KeyFile, Name: id.String()}
buf, err := backend.LoadAll(repo.Backend(), h) buf, err := backend.LoadAll(context.TODO(), repo.Backend(), h)
if err != nil { if err != nil {
return err return err
} }
@ -125,7 +126,7 @@ func runCat(gopts GlobalOptions, args []string) error {
fmt.Println(string(buf)) fmt.Println(string(buf))
return nil return nil
case "lock": case "lock":
lock, err := restic.LoadLock(repo, id) lock, err := restic.LoadLock(context.TODO(), repo, id)
if err != nil { if err != nil {
return err return err
} }
@ -141,7 +142,7 @@ func runCat(gopts GlobalOptions, args []string) error {
} }
// load index, handle all the other types // load index, handle all the other types
err = repo.LoadIndex() err = repo.LoadIndex(context.TODO())
if err != nil { if err != nil {
return err return err
} }
@ -149,7 +150,7 @@ func runCat(gopts GlobalOptions, args []string) error {
switch tpe { switch tpe {
case "pack": case "pack":
h := restic.Handle{Type: restic.DataFile, Name: id.String()} h := restic.Handle{Type: restic.DataFile, Name: id.String()}
buf, err := backend.LoadAll(repo.Backend(), h) buf, err := backend.LoadAll(context.TODO(), repo.Backend(), h)
if err != nil { if err != nil {
return err return err
} }
@ -171,7 +172,7 @@ func runCat(gopts GlobalOptions, args []string) error {
blob := list[0] blob := list[0]
buf := make([]byte, blob.Length) buf := make([]byte, blob.Length)
n, err := repo.LoadBlob(t, id, buf) n, err := repo.LoadBlob(context.TODO(), t, id, buf)
if err != nil { if err != nil {
return err return err
} }

View File

@ -1,6 +1,7 @@
package main package main
import ( import (
"context"
"fmt" "fmt"
"os" "os"
"time" "time"
@ -92,7 +93,7 @@ func runCheck(opts CheckOptions, gopts GlobalOptions, args []string) error {
chkr := checker.New(repo) chkr := checker.New(repo)
Verbosef("Load indexes\n") Verbosef("Load indexes\n")
hints, errs := chkr.LoadIndex() hints, errs := chkr.LoadIndex(context.TODO())
dupFound := false dupFound := false
for _, hint := range hints { for _, hint := range hints {
@ -113,14 +114,11 @@ func runCheck(opts CheckOptions, gopts GlobalOptions, args []string) error {
return errors.Fatal("LoadIndex returned errors") return errors.Fatal("LoadIndex returned errors")
} }
done := make(chan struct{})
defer close(done)
errorsFound := false errorsFound := false
errChan := make(chan error) errChan := make(chan error)
Verbosef("Check all packs\n") Verbosef("Check all packs\n")
go chkr.Packs(errChan, done) go chkr.Packs(context.TODO(), errChan)
for err := range errChan { for err := range errChan {
errorsFound = true errorsFound = true
@ -129,7 +127,7 @@ func runCheck(opts CheckOptions, gopts GlobalOptions, args []string) error {
Verbosef("Check snapshots, trees and blobs\n") Verbosef("Check snapshots, trees and blobs\n")
errChan = make(chan error) errChan = make(chan error)
go chkr.Structure(errChan, done) go chkr.Structure(context.TODO(), errChan)
for err := range errChan { for err := range errChan {
errorsFound = true errorsFound = true
@ -156,7 +154,7 @@ func runCheck(opts CheckOptions, gopts GlobalOptions, args []string) error {
p := newReadProgress(gopts, restic.Stat{Blobs: chkr.CountPacks()}) p := newReadProgress(gopts, restic.Stat{Blobs: chkr.CountPacks()})
errChan := make(chan error) errChan := make(chan error)
go chkr.ReadData(p, errChan, done) go chkr.ReadData(context.TODO(), p, errChan)
for err := range errChan { for err := range errChan {
errorsFound = true errorsFound = true

View File

@ -187,7 +187,7 @@ func (f *Finder) findInTree(treeID restic.ID, prefix string) error {
debug.Log("%v checking tree %v\n", prefix, treeID.Str()) debug.Log("%v checking tree %v\n", prefix, treeID.Str())
tree, err := f.repo.LoadTree(treeID) tree, err := f.repo.LoadTree(context.TODO(), treeID)
if err != nil { if err != nil {
return err return err
} }
@ -283,7 +283,7 @@ func runFind(opts FindOptions, gopts GlobalOptions, args []string) error {
} }
} }
if err = repo.LoadIndex(); err != nil { if err = repo.LoadIndex(context.TODO()); err != nil {
return err return err
} }

View File

@ -97,7 +97,7 @@ func runForget(opts ForgetOptions, gopts GlobalOptions, args []string) error {
// When explicit snapshots args are given, remove them immediately. // When explicit snapshots args are given, remove them immediately.
if !opts.DryRun { if !opts.DryRun {
h := restic.Handle{Type: restic.SnapshotFile, Name: sn.ID().String()} h := restic.Handle{Type: restic.SnapshotFile, Name: sn.ID().String()}
if err = repo.Backend().Remove(h); err != nil { if err = repo.Backend().Remove(context.TODO(), h); err != nil {
return err return err
} }
Verbosef("removed snapshot %v\n", sn.ID().Str()) Verbosef("removed snapshot %v\n", sn.ID().Str())
@ -167,7 +167,7 @@ func runForget(opts ForgetOptions, gopts GlobalOptions, args []string) error {
if !opts.DryRun { if !opts.DryRun {
for _, sn := range remove { for _, sn := range remove {
h := restic.Handle{Type: restic.SnapshotFile, Name: sn.ID().String()} h := restic.Handle{Type: restic.SnapshotFile, Name: sn.ID().String()}
err = repo.Backend().Remove(h) err = repo.Backend().Remove(context.TODO(), h)
if err != nil { if err != nil {
return err return err
} }

View File

@ -1,6 +1,7 @@
package main package main
import ( import (
"context"
"restic/errors" "restic/errors"
"restic/repository" "restic/repository"
@ -43,7 +44,7 @@ func runInit(gopts GlobalOptions, args []string) error {
s := repository.New(be) s := repository.New(be)
err = s.Init(gopts.password) err = s.Init(context.TODO(), gopts.password)
if err != nil { if err != nil {
return errors.Fatalf("create key in backend at %s failed: %v\n", gopts.Repo, err) return errors.Fatalf("create key in backend at %s failed: %v\n", gopts.Repo, err)
} }

View File

@ -30,8 +30,8 @@ func listKeys(ctx context.Context, s *repository.Repository) error {
tab.Header = fmt.Sprintf(" %-10s %-10s %-10s %s", "ID", "User", "Host", "Created") tab.Header = fmt.Sprintf(" %-10s %-10s %-10s %s", "ID", "User", "Host", "Created")
tab.RowFormat = "%s%-10s %-10s %-10s %s" tab.RowFormat = "%s%-10s %-10s %-10s %s"
for id := range s.List(restic.KeyFile, ctx.Done()) { for id := range s.List(ctx, restic.KeyFile) {
k, err := repository.LoadKey(s, id.String()) k, err := repository.LoadKey(ctx, s, id.String())
if err != nil { if err != nil {
Warnf("LoadKey() failed: %v\n", err) Warnf("LoadKey() failed: %v\n", err)
continue continue
@ -69,7 +69,7 @@ func addKey(gopts GlobalOptions, repo *repository.Repository) error {
return err return err
} }
id, err := repository.AddKey(repo, pw, repo.Key()) id, err := repository.AddKey(context.TODO(), repo, pw, repo.Key())
if err != nil { if err != nil {
return errors.Fatalf("creating new key failed: %v\n", err) return errors.Fatalf("creating new key failed: %v\n", err)
} }
@ -85,7 +85,7 @@ func deleteKey(repo *repository.Repository, name string) error {
} }
h := restic.Handle{Type: restic.KeyFile, Name: name} h := restic.Handle{Type: restic.KeyFile, Name: name}
err := repo.Backend().Remove(h) err := repo.Backend().Remove(context.TODO(), h)
if err != nil { if err != nil {
return err return err
} }
@ -100,13 +100,13 @@ func changePassword(gopts GlobalOptions, repo *repository.Repository) error {
return err return err
} }
id, err := repository.AddKey(repo, pw, repo.Key()) id, err := repository.AddKey(context.TODO(), repo, pw, repo.Key())
if err != nil { if err != nil {
return errors.Fatalf("creating new key failed: %v\n", err) return errors.Fatalf("creating new key failed: %v\n", err)
} }
h := restic.Handle{Type: restic.KeyFile, Name: repo.KeyName()} h := restic.Handle{Type: restic.KeyFile, Name: repo.KeyName()}
err = repo.Backend().Remove(h) err = repo.Backend().Remove(context.TODO(), h)
if err != nil { if err != nil {
return err return err
} }

View File

@ -1,6 +1,7 @@
package main package main
import ( import (
"context"
"fmt" "fmt"
"restic" "restic"
"restic/errors" "restic/errors"
@ -55,7 +56,7 @@ func runList(opts GlobalOptions, args []string) error {
case "locks": case "locks":
t = restic.LockFile t = restic.LockFile
case "blobs": case "blobs":
idx, err := index.Load(repo, nil) idx, err := index.Load(context.TODO(), repo, nil)
if err != nil { if err != nil {
return err return err
} }
@ -71,7 +72,7 @@ func runList(opts GlobalOptions, args []string) error {
return errors.Fatal("invalid type") return errors.Fatal("invalid type")
} }
for id := range repo.List(t, nil) { for id := range repo.List(context.TODO(), t) {
Printf("%s\n", id) Printf("%s\n", id)
} }

View File

@ -46,7 +46,7 @@ func init() {
} }
func printTree(repo *repository.Repository, id *restic.ID, prefix string) error { func printTree(repo *repository.Repository, id *restic.ID, prefix string) error {
tree, err := repo.LoadTree(*id) tree, err := repo.LoadTree(context.TODO(), *id)
if err != nil { if err != nil {
return err return err
} }
@ -74,7 +74,7 @@ func runLs(opts LsOptions, gopts GlobalOptions, args []string) error {
return err return err
} }
if err = repo.LoadIndex(); err != nil { if err = repo.LoadIndex(context.TODO()); err != nil {
return err return err
} }

View File

@ -4,6 +4,7 @@
package main package main
import ( import (
"context"
"os" "os"
"github.com/spf13/cobra" "github.com/spf13/cobra"
@ -64,7 +65,7 @@ func mount(opts MountOptions, gopts GlobalOptions, mountpoint string) error {
return err return err
} }
err = repo.LoadIndex() err = repo.LoadIndex(context.TODO())
if err != nil { if err != nil {
return err return err
} }

View File

@ -1,7 +1,6 @@
package main package main
import ( import (
"context"
"fmt" "fmt"
"restic" "restic"
"restic/debug" "restic/debug"
@ -76,14 +75,13 @@ func runPrune(gopts GlobalOptions) error {
} }
func pruneRepository(gopts GlobalOptions, repo restic.Repository) error { func pruneRepository(gopts GlobalOptions, repo restic.Repository) error {
err := repo.LoadIndex() ctx := gopts.ctx
err := repo.LoadIndex(ctx)
if err != nil { if err != nil {
return err return err
} }
ctx, cancel := context.WithCancel(gopts.ctx)
defer cancel()
var stats struct { var stats struct {
blobs int blobs int
packs int packs int
@ -92,14 +90,14 @@ func pruneRepository(gopts GlobalOptions, repo restic.Repository) error {
} }
Verbosef("counting files in repo\n") Verbosef("counting files in repo\n")
for range repo.List(restic.DataFile, ctx.Done()) { for range repo.List(ctx, restic.DataFile) {
stats.packs++ stats.packs++
} }
Verbosef("building new index for repo\n") Verbosef("building new index for repo\n")
bar := newProgressMax(!gopts.Quiet, uint64(stats.packs), "packs") bar := newProgressMax(!gopts.Quiet, uint64(stats.packs), "packs")
idx, err := index.New(repo, bar) idx, err := index.New(ctx, repo, bar)
if err != nil { if err != nil {
return err return err
} }
@ -135,7 +133,7 @@ func pruneRepository(gopts GlobalOptions, repo restic.Repository) error {
Verbosef("load all snapshots\n") Verbosef("load all snapshots\n")
// find referenced blobs // find referenced blobs
snapshots, err := restic.LoadAllSnapshots(repo) snapshots, err := restic.LoadAllSnapshots(ctx, repo)
if err != nil { if err != nil {
return err return err
} }
@ -152,7 +150,7 @@ func pruneRepository(gopts GlobalOptions, repo restic.Repository) error {
for _, sn := range snapshots { for _, sn := range snapshots {
debug.Log("process snapshot %v", sn.ID().Str()) debug.Log("process snapshot %v", sn.ID().Str())
err = restic.FindUsedBlobs(repo, *sn.Tree, usedBlobs, seenBlobs) err = restic.FindUsedBlobs(ctx, repo, *sn.Tree, usedBlobs, seenBlobs)
if err != nil { if err != nil {
return err return err
} }
@ -217,7 +215,7 @@ func pruneRepository(gopts GlobalOptions, repo restic.Repository) error {
if len(rewritePacks) != 0 { if len(rewritePacks) != 0 {
bar = newProgressMax(!gopts.Quiet, uint64(len(rewritePacks)), "packs rewritten") bar = newProgressMax(!gopts.Quiet, uint64(len(rewritePacks)), "packs rewritten")
bar.Start() bar.Start()
err = repository.Repack(repo, rewritePacks, usedBlobs, bar) err = repository.Repack(ctx, repo, rewritePacks, usedBlobs, bar)
if err != nil { if err != nil {
return err return err
} }
@ -229,7 +227,7 @@ func pruneRepository(gopts GlobalOptions, repo restic.Repository) error {
bar.Start() bar.Start()
for packID := range removePacks { for packID := range removePacks {
h := restic.Handle{Type: restic.DataFile, Name: packID.String()} h := restic.Handle{Type: restic.DataFile, Name: packID.String()}
err = repo.Backend().Remove(h) err = repo.Backend().Remove(ctx, h)
if err != nil { if err != nil {
Warnf("unable to remove file %v from the repository\n", packID.Str()) Warnf("unable to remove file %v from the repository\n", packID.Str())
} }

View File

@ -45,12 +45,12 @@ func rebuildIndex(ctx context.Context, repo restic.Repository) error {
Verbosef("counting files in repo\n") Verbosef("counting files in repo\n")
var packs uint64 var packs uint64
for range repo.List(restic.DataFile, ctx.Done()) { for range repo.List(ctx, restic.DataFile) {
packs++ packs++
} }
bar := newProgressMax(!globalOptions.Quiet, packs, "packs") bar := newProgressMax(!globalOptions.Quiet, packs, "packs")
idx, err := index.New(repo, bar) idx, err := index.New(ctx, repo, bar)
if err != nil { if err != nil {
return err return err
} }
@ -58,11 +58,11 @@ func rebuildIndex(ctx context.Context, repo restic.Repository) error {
Verbosef("finding old index files\n") Verbosef("finding old index files\n")
var supersedes restic.IDs var supersedes restic.IDs
for id := range repo.List(restic.IndexFile, ctx.Done()) { for id := range repo.List(ctx, restic.IndexFile) {
supersedes = append(supersedes, id) supersedes = append(supersedes, id)
} }
id, err := idx.Save(repo, supersedes) id, err := idx.Save(ctx, repo, supersedes)
if err != nil { if err != nil {
return err return err
} }
@ -72,7 +72,7 @@ func rebuildIndex(ctx context.Context, repo restic.Repository) error {
Verbosef("remove %d old index files\n", len(supersedes)) Verbosef("remove %d old index files\n", len(supersedes))
for _, id := range supersedes { for _, id := range supersedes {
if err := repo.Backend().Remove(restic.Handle{ if err := repo.Backend().Remove(ctx, restic.Handle{
Type: restic.IndexFile, Type: restic.IndexFile,
Name: id.String(), Name: id.String(),
}); err != nil { }); err != nil {

View File

@ -50,6 +50,8 @@ func init() {
} }
func runRestore(opts RestoreOptions, gopts GlobalOptions, args []string) error { func runRestore(opts RestoreOptions, gopts GlobalOptions, args []string) error {
ctx := gopts.ctx
if len(args) != 1 { if len(args) != 1 {
return errors.Fatal("no snapshot ID specified") return errors.Fatal("no snapshot ID specified")
} }
@ -79,7 +81,7 @@ func runRestore(opts RestoreOptions, gopts GlobalOptions, args []string) error {
} }
} }
err = repo.LoadIndex() err = repo.LoadIndex(ctx)
if err != nil { if err != nil {
return err return err
} }
@ -87,7 +89,7 @@ func runRestore(opts RestoreOptions, gopts GlobalOptions, args []string) error {
var id restic.ID var id restic.ID
if snapshotIDString == "latest" { if snapshotIDString == "latest" {
id, err = restic.FindLatestSnapshot(repo, opts.Paths, opts.Tags, opts.Host) id, err = restic.FindLatestSnapshot(ctx, repo, opts.Paths, opts.Tags, opts.Host)
if err != nil { if err != nil {
Exitf(1, "latest snapshot for criteria not found: %v Paths:%v Host:%v", err, opts.Paths, opts.Host) Exitf(1, "latest snapshot for criteria not found: %v Paths:%v Host:%v", err, opts.Paths, opts.Host)
} }
@ -136,7 +138,7 @@ func runRestore(opts RestoreOptions, gopts GlobalOptions, args []string) error {
Verbosef("restoring %s to %s\n", res.Snapshot(), opts.Target) Verbosef("restoring %s to %s\n", res.Snapshot(), opts.Target)
err = res.RestoreTo(opts.Target) err = res.RestoreTo(ctx, opts.Target)
if totalErrors > 0 { if totalErrors > 0 {
Printf("There were %d errors\n", totalErrors) Printf("There were %d errors\n", totalErrors)
} }

View File

@ -76,7 +76,7 @@ func changeTags(repo *repository.Repository, sn *restic.Snapshot, setTags, addTa
} }
// Save the new snapshot. // Save the new snapshot.
id, err := repo.SaveJSONUnpacked(restic.SnapshotFile, sn) id, err := repo.SaveJSONUnpacked(context.TODO(), restic.SnapshotFile, sn)
if err != nil { if err != nil {
return false, err return false, err
} }
@ -89,7 +89,7 @@ func changeTags(repo *repository.Repository, sn *restic.Snapshot, setTags, addTa
// Remove the old snapshot. // Remove the old snapshot.
h := restic.Handle{Type: restic.SnapshotFile, Name: sn.ID().String()} h := restic.Handle{Type: restic.SnapshotFile, Name: sn.ID().String()}
if err = repo.Backend().Remove(h); err != nil { if err = repo.Backend().Remove(context.TODO(), h); err != nil {
return false, err return false, err
} }

View File

@ -1,6 +1,7 @@
package main package main
import ( import (
"context"
"restic" "restic"
"github.com/spf13/cobra" "github.com/spf13/cobra"
@ -41,7 +42,7 @@ func runUnlock(opts UnlockOptions, gopts GlobalOptions) error {
fn = restic.RemoveAllLocks fn = restic.RemoveAllLocks
} }
err = fn(repo) err = fn(context.TODO(), repo)
if err != nil { if err != nil {
return err return err
} }

View File

@ -22,7 +22,7 @@ func FindFilteredSnapshots(ctx context.Context, repo *repository.Repository, hos
// Process all snapshot IDs given as arguments. // Process all snapshot IDs given as arguments.
for _, s := range snapshotIDs { for _, s := range snapshotIDs {
if s == "latest" { if s == "latest" {
id, err = restic.FindLatestSnapshot(repo, paths, tags, host) id, err = restic.FindLatestSnapshot(ctx, repo, paths, tags, host)
if err != nil { if err != nil {
Warnf("Ignoring %q, no snapshot matched given filter (Paths:%v Tags:%v Host:%v)\n", s, paths, tags, host) Warnf("Ignoring %q, no snapshot matched given filter (Paths:%v Tags:%v Host:%v)\n", s, paths, tags, host)
usedFilter = true usedFilter = true
@ -44,7 +44,7 @@ func FindFilteredSnapshots(ctx context.Context, repo *repository.Repository, hos
} }
for _, id := range ids.Uniq() { for _, id := range ids.Uniq() {
sn, err := restic.LoadSnapshot(repo, id) sn, err := restic.LoadSnapshot(ctx, repo, id)
if err != nil { if err != nil {
Warnf("Ignoring %q, could not load snapshot: %v\n", id, err) Warnf("Ignoring %q, could not load snapshot: %v\n", id, err)
continue continue
@ -58,8 +58,8 @@ func FindFilteredSnapshots(ctx context.Context, repo *repository.Repository, hos
return return
} }
for id := range repo.List(restic.SnapshotFile, ctx.Done()) { for id := range repo.List(ctx, restic.SnapshotFile) {
sn, err := restic.LoadSnapshot(repo, id) sn, err := restic.LoadSnapshot(ctx, repo, id)
if err != nil { if err != nil {
Warnf("Ignoring %q, could not load snapshot: %v\n", id, err) Warnf("Ignoring %q, could not load snapshot: %v\n", id, err)
continue continue

View File

@ -310,7 +310,7 @@ func OpenRepository(opts GlobalOptions) (*repository.Repository, error) {
} }
} }
err = s.SearchKey(opts.password, maxKeys) err = s.SearchKey(context.TODO(), opts.password, maxKeys)
if err != nil { if err != nil {
return nil, errors.Fatalf("unable to open repo: %v", err) return nil, errors.Fatalf("unable to open repo: %v", err)
} }
@ -440,7 +440,7 @@ func open(s string, opts options.Options) (restic.Backend, error) {
} }
// check if config is there // check if config is there
fi, err := be.Stat(restic.Handle{Type: restic.ConfigFile}) fi, err := be.Stat(context.TODO(), restic.Handle{Type: restic.ConfigFile})
if err != nil { if err != nil {
return nil, errors.Fatalf("unable to open config file: %v\nIs there a repository at the following location?\n%v", err, s) return nil, errors.Fatalf("unable to open config file: %v\nIs there a repository at the following location?\n%v", err, s)
} }

View File

@ -1,6 +1,7 @@
package main package main
import ( import (
"context"
"fmt" "fmt"
"os" "os"
"sync" "sync"
@ -32,7 +33,7 @@ func lockRepository(repo *repository.Repository, exclusive bool) (*restic.Lock,
lockFn = restic.NewExclusiveLock lockFn = restic.NewExclusiveLock
} }
lock, err := lockFn(repo) lock, err := lockFn(context.TODO(), repo)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -75,7 +76,7 @@ func refreshLocks(wg *sync.WaitGroup, done <-chan struct{}) {
debug.Log("refreshing locks") debug.Log("refreshing locks")
globalLocks.Lock() globalLocks.Lock()
for _, lock := range globalLocks.locks { for _, lock := range globalLocks.locks {
err := lock.Refresh() err := lock.Refresh(context.TODO())
if err != nil { if err != nil {
fmt.Fprintf(os.Stderr, "unable to refresh lock: %v\n", err) fmt.Fprintf(os.Stderr, "unable to refresh lock: %v\n", err)
} }

View File

@ -1,6 +1,7 @@
package archiver package archiver
import ( import (
"context"
"io" "io"
"restic" "restic"
"restic/debug" "restic/debug"
@ -20,7 +21,7 @@ type Reader struct {
} }
// Archive reads data from the reader and saves it to the repo. // Archive reads data from the reader and saves it to the repo.
func (r *Reader) Archive(name string, rd io.Reader, p *restic.Progress) (*restic.Snapshot, restic.ID, error) { func (r *Reader) Archive(ctx context.Context, name string, rd io.Reader, p *restic.Progress) (*restic.Snapshot, restic.ID, error) {
if name == "" { if name == "" {
return nil, restic.ID{}, errors.New("no filename given") return nil, restic.ID{}, errors.New("no filename given")
} }
@ -53,7 +54,7 @@ func (r *Reader) Archive(name string, rd io.Reader, p *restic.Progress) (*restic
id := restic.Hash(chunk.Data) id := restic.Hash(chunk.Data)
if !repo.Index().Has(id, restic.DataBlob) { if !repo.Index().Has(id, restic.DataBlob) {
_, err := repo.SaveBlob(restic.DataBlob, chunk.Data, id) _, err := repo.SaveBlob(ctx, restic.DataBlob, chunk.Data, id)
if err != nil { if err != nil {
return nil, restic.ID{}, err return nil, restic.ID{}, err
} }
@ -87,14 +88,14 @@ func (r *Reader) Archive(name string, rd io.Reader, p *restic.Progress) (*restic
}, },
} }
treeID, err := repo.SaveTree(tree) treeID, err := repo.SaveTree(ctx, tree)
if err != nil { if err != nil {
return nil, restic.ID{}, err return nil, restic.ID{}, err
} }
sn.Tree = &treeID sn.Tree = &treeID
debug.Log("tree saved as %v", treeID.Str()) debug.Log("tree saved as %v", treeID.Str())
id, err := repo.SaveJSONUnpacked(restic.SnapshotFile, sn) id, err := repo.SaveJSONUnpacked(ctx, restic.SnapshotFile, sn)
if err != nil { if err != nil {
return nil, restic.ID{}, err return nil, restic.ID{}, err
} }
@ -106,7 +107,7 @@ func (r *Reader) Archive(name string, rd io.Reader, p *restic.Progress) (*restic
return nil, restic.ID{}, err return nil, restic.ID{}, err
} }
err = repo.SaveIndex() err = repo.SaveIndex(ctx)
if err != nil { if err != nil {
return nil, restic.ID{}, err return nil, restic.ID{}, err
} }

View File

@ -1,6 +1,7 @@
package archiver package archiver
import ( import (
"context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
@ -92,7 +93,7 @@ func (arch *Archiver) isKnownBlob(id restic.ID, t restic.BlobType) bool {
} }
// Save stores a blob read from rd in the repository. // Save stores a blob read from rd in the repository.
func (arch *Archiver) Save(t restic.BlobType, data []byte, id restic.ID) error { func (arch *Archiver) Save(ctx context.Context, t restic.BlobType, data []byte, id restic.ID) error {
debug.Log("Save(%v, %v)\n", t, id.Str()) debug.Log("Save(%v, %v)\n", t, id.Str())
if arch.isKnownBlob(id, restic.DataBlob) { if arch.isKnownBlob(id, restic.DataBlob) {
@ -100,7 +101,7 @@ func (arch *Archiver) Save(t restic.BlobType, data []byte, id restic.ID) error {
return nil return nil
} }
_, err := arch.repo.SaveBlob(t, data, id) _, err := arch.repo.SaveBlob(ctx, t, data, id)
if err != nil { if err != nil {
debug.Log("Save(%v, %v): error %v\n", t, id.Str(), err) debug.Log("Save(%v, %v): error %v\n", t, id.Str(), err)
return err return err
@ -111,7 +112,7 @@ func (arch *Archiver) Save(t restic.BlobType, data []byte, id restic.ID) error {
} }
// SaveTreeJSON stores a tree in the repository. // SaveTreeJSON stores a tree in the repository.
func (arch *Archiver) SaveTreeJSON(tree *restic.Tree) (restic.ID, error) { func (arch *Archiver) SaveTreeJSON(ctx context.Context, tree *restic.Tree) (restic.ID, error) {
data, err := json.Marshal(tree) data, err := json.Marshal(tree)
if err != nil { if err != nil {
return restic.ID{}, errors.Wrap(err, "Marshal") return restic.ID{}, errors.Wrap(err, "Marshal")
@ -124,7 +125,7 @@ func (arch *Archiver) SaveTreeJSON(tree *restic.Tree) (restic.ID, error) {
return id, nil return id, nil
} }
return arch.repo.SaveBlob(restic.TreeBlob, data, id) return arch.repo.SaveBlob(ctx, restic.TreeBlob, data, id)
} }
func (arch *Archiver) reloadFileIfChanged(node *restic.Node, file fs.File) (*restic.Node, error) { func (arch *Archiver) reloadFileIfChanged(node *restic.Node, file fs.File) (*restic.Node, error) {
@ -153,11 +154,11 @@ type saveResult struct {
bytes uint64 bytes uint64
} }
func (arch *Archiver) saveChunk(chunk chunker.Chunk, p *restic.Progress, token struct{}, file fs.File, resultChannel chan<- saveResult) { func (arch *Archiver) saveChunk(ctx context.Context, chunk chunker.Chunk, p *restic.Progress, token struct{}, file fs.File, resultChannel chan<- saveResult) {
defer freeBuf(chunk.Data) defer freeBuf(chunk.Data)
id := restic.Hash(chunk.Data) id := restic.Hash(chunk.Data)
err := arch.Save(restic.DataBlob, chunk.Data, id) err := arch.Save(ctx, restic.DataBlob, chunk.Data, id)
// TODO handle error // TODO handle error
if err != nil { if err != nil {
panic(err) panic(err)
@ -206,7 +207,7 @@ func updateNodeContent(node *restic.Node, results []saveResult) error {
// SaveFile stores the content of the file on the backend as a Blob by calling // SaveFile stores the content of the file on the backend as a Blob by calling
// Save for each chunk. // Save for each chunk.
func (arch *Archiver) SaveFile(p *restic.Progress, node *restic.Node) (*restic.Node, error) { func (arch *Archiver) SaveFile(ctx context.Context, p *restic.Progress, node *restic.Node) (*restic.Node, error) {
file, err := fs.Open(node.Path) file, err := fs.Open(node.Path)
defer file.Close() defer file.Close()
if err != nil { if err != nil {
@ -234,7 +235,7 @@ func (arch *Archiver) SaveFile(p *restic.Progress, node *restic.Node) (*restic.N
} }
resCh := make(chan saveResult, 1) resCh := make(chan saveResult, 1)
go arch.saveChunk(chunk, p, <-arch.blobToken, file, resCh) go arch.saveChunk(ctx, chunk, p, <-arch.blobToken, file, resCh)
resultChannels = append(resultChannels, resCh) resultChannels = append(resultChannels, resCh)
} }
@ -247,7 +248,7 @@ func (arch *Archiver) SaveFile(p *restic.Progress, node *restic.Node) (*restic.N
return node, err return node, err
} }
func (arch *Archiver) fileWorker(wg *sync.WaitGroup, p *restic.Progress, done <-chan struct{}, entCh <-chan pipe.Entry) { func (arch *Archiver) fileWorker(ctx context.Context, wg *sync.WaitGroup, p *restic.Progress, entCh <-chan pipe.Entry) {
defer func() { defer func() {
debug.Log("done") debug.Log("done")
wg.Done() wg.Done()
@ -305,7 +306,7 @@ func (arch *Archiver) fileWorker(wg *sync.WaitGroup, p *restic.Progress, done <-
// otherwise read file normally // otherwise read file normally
if node.Type == "file" && len(node.Content) == 0 { if node.Type == "file" && len(node.Content) == 0 {
debug.Log(" read and save %v", e.Path()) debug.Log(" read and save %v", e.Path())
node, err = arch.SaveFile(p, node) node, err = arch.SaveFile(ctx, p, node)
if err != nil { if err != nil {
fmt.Fprintf(os.Stderr, "error for %v: %v\n", node.Path, err) fmt.Fprintf(os.Stderr, "error for %v: %v\n", node.Path, err)
arch.Warn(e.Path(), nil, err) arch.Warn(e.Path(), nil, err)
@ -322,14 +323,14 @@ func (arch *Archiver) fileWorker(wg *sync.WaitGroup, p *restic.Progress, done <-
debug.Log(" processed %v, %d blobs", e.Path(), len(node.Content)) debug.Log(" processed %v, %d blobs", e.Path(), len(node.Content))
e.Result() <- node e.Result() <- node
p.Report(restic.Stat{Files: 1}) p.Report(restic.Stat{Files: 1})
case <-done: case <-ctx.Done():
// pipeline was cancelled // pipeline was cancelled
return return
} }
} }
} }
func (arch *Archiver) dirWorker(wg *sync.WaitGroup, p *restic.Progress, done <-chan struct{}, dirCh <-chan pipe.Dir) { func (arch *Archiver) dirWorker(ctx context.Context, wg *sync.WaitGroup, p *restic.Progress, dirCh <-chan pipe.Dir) {
debug.Log("start") debug.Log("start")
defer func() { defer func() {
debug.Log("done") debug.Log("done")
@ -398,7 +399,7 @@ func (arch *Archiver) dirWorker(wg *sync.WaitGroup, p *restic.Progress, done <-c
node.Error = err.Error() node.Error = err.Error()
} }
id, err := arch.SaveTreeJSON(tree) id, err := arch.SaveTreeJSON(ctx, tree)
if err != nil { if err != nil {
panic(err) panic(err)
} }
@ -415,7 +416,7 @@ func (arch *Archiver) dirWorker(wg *sync.WaitGroup, p *restic.Progress, done <-c
if dir.Path() != "" { if dir.Path() != "" {
p.Report(restic.Stat{Dirs: 1}) p.Report(restic.Stat{Dirs: 1})
} }
case <-done: case <-ctx.Done():
// pipeline was cancelled // pipeline was cancelled
return return
} }
@ -427,7 +428,7 @@ type archivePipe struct {
New <-chan pipe.Job New <-chan pipe.Job
} }
func copyJobs(done <-chan struct{}, in <-chan pipe.Job, out chan<- pipe.Job) { func copyJobs(ctx context.Context, in <-chan pipe.Job, out chan<- pipe.Job) {
var ( var (
// disable sending on the outCh until we received a job // disable sending on the outCh until we received a job
outCh chan<- pipe.Job outCh chan<- pipe.Job
@ -439,7 +440,7 @@ func copyJobs(done <-chan struct{}, in <-chan pipe.Job, out chan<- pipe.Job) {
for { for {
select { select {
case <-done: case <-ctx.Done():
return return
case job, ok = <-inCh: case job, ok = <-inCh:
if !ok { if !ok {
@ -462,7 +463,7 @@ type archiveJob struct {
new pipe.Job new pipe.Job
} }
func (a *archivePipe) compare(done <-chan struct{}, out chan<- pipe.Job) { func (a *archivePipe) compare(ctx context.Context, out chan<- pipe.Job) {
defer func() { defer func() {
close(out) close(out)
debug.Log("done") debug.Log("done")
@ -488,7 +489,7 @@ func (a *archivePipe) compare(done <-chan struct{}, out chan<- pipe.Job) {
out <- archiveJob{new: newJob}.Copy() out <- archiveJob{new: newJob}.Copy()
} }
copyJobs(done, a.New, out) copyJobs(ctx, a.New, out)
return return
} }
@ -585,7 +586,7 @@ func (j archiveJob) Copy() pipe.Job {
const saveIndexTime = 30 * time.Second const saveIndexTime = 30 * time.Second
// saveIndexes regularly queries the master index for full indexes and saves them. // saveIndexes regularly queries the master index for full indexes and saves them.
func (arch *Archiver) saveIndexes(wg *sync.WaitGroup, done <-chan struct{}) { func (arch *Archiver) saveIndexes(ctx context.Context, wg *sync.WaitGroup) {
defer wg.Done() defer wg.Done()
ticker := time.NewTicker(saveIndexTime) ticker := time.NewTicker(saveIndexTime)
@ -593,11 +594,11 @@ func (arch *Archiver) saveIndexes(wg *sync.WaitGroup, done <-chan struct{}) {
for { for {
select { select {
case <-done: case <-ctx.Done():
return return
case <-ticker.C: case <-ticker.C:
debug.Log("saving full indexes") debug.Log("saving full indexes")
err := arch.repo.SaveFullIndex() err := arch.repo.SaveFullIndex(ctx)
if err != nil { if err != nil {
debug.Log("save indexes returned an error: %v", err) debug.Log("save indexes returned an error: %v", err)
fmt.Fprintf(os.Stderr, "error saving preliminary index: %v\n", err) fmt.Fprintf(os.Stderr, "error saving preliminary index: %v\n", err)
@ -634,7 +635,7 @@ func (p baseNameSlice) Swap(i, j int) { p[i], p[j] = p[j], p[i] }
// Snapshot creates a snapshot of the given paths. If parentrestic.ID is set, this is // Snapshot creates a snapshot of the given paths. If parentrestic.ID is set, this is
// used to compare the files to the ones archived at the time this snapshot was // used to compare the files to the ones archived at the time this snapshot was
// taken. // taken.
func (arch *Archiver) Snapshot(p *restic.Progress, paths, tags []string, hostname string, parentID *restic.ID) (*restic.Snapshot, restic.ID, error) { func (arch *Archiver) Snapshot(ctx context.Context, p *restic.Progress, paths, tags []string, hostname string, parentID *restic.ID) (*restic.Snapshot, restic.ID, error) {
paths = unique(paths) paths = unique(paths)
sort.Sort(baseNameSlice(paths)) sort.Sort(baseNameSlice(paths))
@ -643,7 +644,6 @@ func (arch *Archiver) Snapshot(p *restic.Progress, paths, tags []string, hostnam
debug.RunHook("Archiver.Snapshot", nil) debug.RunHook("Archiver.Snapshot", nil)
// signal the whole pipeline to stop // signal the whole pipeline to stop
done := make(chan struct{})
var err error var err error
p.Start() p.Start()
@ -663,14 +663,14 @@ func (arch *Archiver) Snapshot(p *restic.Progress, paths, tags []string, hostnam
sn.Parent = parentID sn.Parent = parentID
// load parent snapshot // load parent snapshot
parent, err := restic.LoadSnapshot(arch.repo, *parentID) parent, err := restic.LoadSnapshot(ctx, arch.repo, *parentID)
if err != nil { if err != nil {
return nil, restic.ID{}, err return nil, restic.ID{}, err
} }
// start walker on old tree // start walker on old tree
ch := make(chan walk.TreeJob) ch := make(chan walk.TreeJob)
go walk.Tree(arch.repo, *parent.Tree, done, ch) go walk.Tree(ctx, arch.repo, *parent.Tree, ch)
jobs.Old = ch jobs.Old = ch
} else { } else {
// use closed channel // use closed channel
@ -683,13 +683,13 @@ func (arch *Archiver) Snapshot(p *restic.Progress, paths, tags []string, hostnam
pipeCh := make(chan pipe.Job) pipeCh := make(chan pipe.Job)
resCh := make(chan pipe.Result, 1) resCh := make(chan pipe.Result, 1)
go func() { go func() {
pipe.Walk(paths, arch.SelectFilter, done, pipeCh, resCh) pipe.Walk(ctx, paths, arch.SelectFilter, pipeCh, resCh)
debug.Log("pipe.Walk done") debug.Log("pipe.Walk done")
}() }()
jobs.New = pipeCh jobs.New = pipeCh
ch := make(chan pipe.Job) ch := make(chan pipe.Job)
go jobs.compare(done, ch) go jobs.compare(ctx, ch)
var wg sync.WaitGroup var wg sync.WaitGroup
entCh := make(chan pipe.Entry) entCh := make(chan pipe.Entry)
@ -708,22 +708,22 @@ func (arch *Archiver) Snapshot(p *restic.Progress, paths, tags []string, hostnam
// run workers // run workers
for i := 0; i < maxConcurrency; i++ { for i := 0; i < maxConcurrency; i++ {
wg.Add(2) wg.Add(2)
go arch.fileWorker(&wg, p, done, entCh) go arch.fileWorker(ctx, &wg, p, entCh)
go arch.dirWorker(&wg, p, done, dirCh) go arch.dirWorker(ctx, &wg, p, dirCh)
} }
// run index saver // run index saver
var wgIndexSaver sync.WaitGroup var wgIndexSaver sync.WaitGroup
stopIndexSaver := make(chan struct{}) indexCtx, indexCancel := context.WithCancel(ctx)
wgIndexSaver.Add(1) wgIndexSaver.Add(1)
go arch.saveIndexes(&wgIndexSaver, stopIndexSaver) go arch.saveIndexes(indexCtx, &wgIndexSaver)
// wait for all workers to terminate // wait for all workers to terminate
debug.Log("wait for workers") debug.Log("wait for workers")
wg.Wait() wg.Wait()
// stop index saver // stop index saver
close(stopIndexSaver) indexCancel()
wgIndexSaver.Wait() wgIndexSaver.Wait()
debug.Log("workers terminated") debug.Log("workers terminated")
@ -740,7 +740,7 @@ func (arch *Archiver) Snapshot(p *restic.Progress, paths, tags []string, hostnam
sn.Tree = root.Subtree sn.Tree = root.Subtree
// load top-level tree again to see if it is empty // load top-level tree again to see if it is empty
toptree, err := arch.repo.LoadTree(*root.Subtree) toptree, err := arch.repo.LoadTree(ctx, *root.Subtree)
if err != nil { if err != nil {
return nil, restic.ID{}, err return nil, restic.ID{}, err
} }
@ -750,7 +750,7 @@ func (arch *Archiver) Snapshot(p *restic.Progress, paths, tags []string, hostnam
} }
// save index // save index
err = arch.repo.SaveIndex() err = arch.repo.SaveIndex(ctx)
if err != nil { if err != nil {
debug.Log("error saving index: %v", err) debug.Log("error saving index: %v", err)
return nil, restic.ID{}, err return nil, restic.ID{}, err
@ -759,7 +759,7 @@ func (arch *Archiver) Snapshot(p *restic.Progress, paths, tags []string, hostnam
debug.Log("saved indexes") debug.Log("saved indexes")
// save snapshot // save snapshot
id, err := arch.repo.SaveJSONUnpacked(restic.SnapshotFile, sn) id, err := arch.repo.SaveJSONUnpacked(ctx, restic.SnapshotFile, sn)
if err != nil { if err != nil {
return nil, restic.ID{}, err return nil, restic.ID{}, err
} }

View File

@ -1,6 +1,7 @@
package archiver package archiver
import ( import (
"context"
"restic" "restic"
"testing" "testing"
) )
@ -8,7 +9,7 @@ import (
// TestSnapshot creates a new snapshot of path. // TestSnapshot creates a new snapshot of path.
func TestSnapshot(t testing.TB, repo restic.Repository, path string, parent *restic.ID) *restic.Snapshot { func TestSnapshot(t testing.TB, repo restic.Repository, path string, parent *restic.ID) *restic.Snapshot {
arch := New(repo) arch := New(repo)
sn, _, err := arch.Snapshot(nil, []string{path}, []string{"test"}, "localhost", parent) sn, _, err := arch.Snapshot(context.TODO(), nil, []string{path}, []string{"test"}, "localhost", parent)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View File

@ -1,6 +1,7 @@
package checker package checker
import ( import (
"context"
"crypto/sha256" "crypto/sha256"
"fmt" "fmt"
"io" "io"
@ -76,7 +77,7 @@ func (err ErrOldIndexFormat) Error() string {
} }
// LoadIndex loads all index files. // LoadIndex loads all index files.
func (c *Checker) LoadIndex() (hints []error, errs []error) { func (c *Checker) LoadIndex(ctx context.Context) (hints []error, errs []error) {
debug.Log("Start") debug.Log("Start")
type indexRes struct { type indexRes struct {
Index *repository.Index Index *repository.Index
@ -86,21 +87,21 @@ func (c *Checker) LoadIndex() (hints []error, errs []error) {
indexCh := make(chan indexRes) indexCh := make(chan indexRes)
worker := func(id restic.ID, done <-chan struct{}) error { worker := func(ctx context.Context, id restic.ID) error {
debug.Log("worker got index %v", id) debug.Log("worker got index %v", id)
idx, err := repository.LoadIndexWithDecoder(c.repo, id, repository.DecodeIndex) idx, err := repository.LoadIndexWithDecoder(ctx, c.repo, id, repository.DecodeIndex)
if errors.Cause(err) == repository.ErrOldIndexFormat { if errors.Cause(err) == repository.ErrOldIndexFormat {
debug.Log("index %v has old format", id.Str()) debug.Log("index %v has old format", id.Str())
hints = append(hints, ErrOldIndexFormat{id}) hints = append(hints, ErrOldIndexFormat{id})
idx, err = repository.LoadIndexWithDecoder(c.repo, id, repository.DecodeOldIndex) idx, err = repository.LoadIndexWithDecoder(ctx, c.repo, id, repository.DecodeOldIndex)
} }
err = errors.Wrapf(err, "error loading index %v", id.Str()) err = errors.Wrapf(err, "error loading index %v", id.Str())
select { select {
case indexCh <- indexRes{Index: idx, ID: id.String(), err: err}: case indexCh <- indexRes{Index: idx, ID: id.String(), err: err}:
case <-done: case <-ctx.Done():
} }
return nil return nil
@ -109,7 +110,7 @@ func (c *Checker) LoadIndex() (hints []error, errs []error) {
go func() { go func() {
defer close(indexCh) defer close(indexCh)
debug.Log("start loading indexes in parallel") debug.Log("start loading indexes in parallel")
err := repository.FilesInParallel(c.repo.Backend(), restic.IndexFile, defaultParallelism, err := repository.FilesInParallel(ctx, c.repo.Backend(), restic.IndexFile, defaultParallelism,
repository.ParallelWorkFuncParseID(worker)) repository.ParallelWorkFuncParseID(worker))
debug.Log("loading indexes finished, error: %v", err) debug.Log("loading indexes finished, error: %v", err)
if err != nil { if err != nil {
@ -183,7 +184,7 @@ func (e PackError) Error() string {
return "pack " + e.ID.String() + ": " + e.Err.Error() return "pack " + e.ID.String() + ": " + e.Err.Error()
} }
func packIDTester(repo restic.Repository, inChan <-chan restic.ID, errChan chan<- error, wg *sync.WaitGroup, done <-chan struct{}) { func packIDTester(ctx context.Context, repo restic.Repository, inChan <-chan restic.ID, errChan chan<- error, wg *sync.WaitGroup) {
debug.Log("worker start") debug.Log("worker start")
defer debug.Log("worker done") defer debug.Log("worker done")
@ -191,7 +192,7 @@ func packIDTester(repo restic.Repository, inChan <-chan restic.ID, errChan chan<
for id := range inChan { for id := range inChan {
h := restic.Handle{Type: restic.DataFile, Name: id.String()} h := restic.Handle{Type: restic.DataFile, Name: id.String()}
ok, err := repo.Backend().Test(h) ok, err := repo.Backend().Test(ctx, h)
if err != nil { if err != nil {
err = PackError{ID: id, Err: err} err = PackError{ID: id, Err: err}
} else { } else {
@ -203,7 +204,7 @@ func packIDTester(repo restic.Repository, inChan <-chan restic.ID, errChan chan<
if err != nil { if err != nil {
debug.Log("error checking for pack %s: %v", id.Str(), err) debug.Log("error checking for pack %s: %v", id.Str(), err)
select { select {
case <-done: case <-ctx.Done():
return return
case errChan <- err: case errChan <- err:
} }
@ -218,7 +219,7 @@ func packIDTester(repo restic.Repository, inChan <-chan restic.ID, errChan chan<
// Packs checks that all packs referenced in the index are still available and // Packs checks that all packs referenced in the index are still available and
// there are no packs that aren't in an index. errChan is closed after all // there are no packs that aren't in an index. errChan is closed after all
// packs have been checked. // packs have been checked.
func (c *Checker) Packs(errChan chan<- error, done <-chan struct{}) { func (c *Checker) Packs(ctx context.Context, errChan chan<- error) {
defer close(errChan) defer close(errChan)
debug.Log("checking for %d packs", len(c.packs)) debug.Log("checking for %d packs", len(c.packs))
@ -229,7 +230,7 @@ func (c *Checker) Packs(errChan chan<- error, done <-chan struct{}) {
IDChan := make(chan restic.ID) IDChan := make(chan restic.ID)
for i := 0; i < defaultParallelism; i++ { for i := 0; i < defaultParallelism; i++ {
workerWG.Add(1) workerWG.Add(1)
go packIDTester(c.repo, IDChan, errChan, &workerWG, done) go packIDTester(ctx, c.repo, IDChan, errChan, &workerWG)
} }
for id := range c.packs { for id := range c.packs {
@ -242,12 +243,12 @@ func (c *Checker) Packs(errChan chan<- error, done <-chan struct{}) {
workerWG.Wait() workerWG.Wait()
debug.Log("workers terminated") debug.Log("workers terminated")
for id := range c.repo.List(restic.DataFile, done) { for id := range c.repo.List(ctx, restic.DataFile) {
debug.Log("check data blob %v", id.Str()) debug.Log("check data blob %v", id.Str())
if !seenPacks.Has(id) { if !seenPacks.Has(id) {
c.orphanedPacks = append(c.orphanedPacks, id) c.orphanedPacks = append(c.orphanedPacks, id)
select { select {
case <-done: case <-ctx.Done():
return return
case errChan <- PackError{ID: id, Orphaned: true, Err: errors.New("not referenced in any index")}: case errChan <- PackError{ID: id, Orphaned: true, Err: errors.New("not referenced in any index")}:
} }
@ -277,8 +278,8 @@ func (e Error) Error() string {
return e.Err.Error() return e.Err.Error()
} }
func loadTreeFromSnapshot(repo restic.Repository, id restic.ID) (restic.ID, error) { func loadTreeFromSnapshot(ctx context.Context, repo restic.Repository, id restic.ID) (restic.ID, error) {
sn, err := restic.LoadSnapshot(repo, id) sn, err := restic.LoadSnapshot(ctx, repo, id)
if err != nil { if err != nil {
debug.Log("error loading snapshot %v: %v", id.Str(), err) debug.Log("error loading snapshot %v: %v", id.Str(), err)
return restic.ID{}, err return restic.ID{}, err
@ -293,7 +294,7 @@ func loadTreeFromSnapshot(repo restic.Repository, id restic.ID) (restic.ID, erro
} }
// loadSnapshotTreeIDs loads all snapshots from backend and returns the tree IDs. // loadSnapshotTreeIDs loads all snapshots from backend and returns the tree IDs.
func loadSnapshotTreeIDs(repo restic.Repository) (restic.IDs, []error) { func loadSnapshotTreeIDs(ctx context.Context, repo restic.Repository) (restic.IDs, []error) {
var trees struct { var trees struct {
IDs restic.IDs IDs restic.IDs
sync.Mutex sync.Mutex
@ -304,7 +305,7 @@ func loadSnapshotTreeIDs(repo restic.Repository) (restic.IDs, []error) {
sync.Mutex sync.Mutex
} }
snapshotWorker := func(strID string, done <-chan struct{}) error { snapshotWorker := func(ctx context.Context, strID string) error {
id, err := restic.ParseID(strID) id, err := restic.ParseID(strID)
if err != nil { if err != nil {
return err return err
@ -312,7 +313,7 @@ func loadSnapshotTreeIDs(repo restic.Repository) (restic.IDs, []error) {
debug.Log("load snapshot %v", id.Str()) debug.Log("load snapshot %v", id.Str())
treeID, err := loadTreeFromSnapshot(repo, id) treeID, err := loadTreeFromSnapshot(ctx, repo, id)
if err != nil { if err != nil {
errs.Lock() errs.Lock()
errs.errs = append(errs.errs, err) errs.errs = append(errs.errs, err)
@ -328,7 +329,7 @@ func loadSnapshotTreeIDs(repo restic.Repository) (restic.IDs, []error) {
return nil return nil
} }
err := repository.FilesInParallel(repo.Backend(), restic.SnapshotFile, defaultParallelism, snapshotWorker) err := repository.FilesInParallel(ctx, repo.Backend(), restic.SnapshotFile, defaultParallelism, snapshotWorker)
if err != nil { if err != nil {
errs.errs = append(errs.errs, err) errs.errs = append(errs.errs, err)
} }
@ -353,9 +354,9 @@ type treeJob struct {
} }
// loadTreeWorker loads trees from repo and sends them to out. // loadTreeWorker loads trees from repo and sends them to out.
func loadTreeWorker(repo restic.Repository, func loadTreeWorker(ctx context.Context, repo restic.Repository,
in <-chan restic.ID, out chan<- treeJob, in <-chan restic.ID, out chan<- treeJob,
done <-chan struct{}, wg *sync.WaitGroup) { wg *sync.WaitGroup) {
defer func() { defer func() {
debug.Log("exiting") debug.Log("exiting")
@ -371,7 +372,7 @@ func loadTreeWorker(repo restic.Repository,
outCh = nil outCh = nil
for { for {
select { select {
case <-done: case <-ctx.Done():
return return
case treeID, ok := <-inCh: case treeID, ok := <-inCh:
@ -380,7 +381,7 @@ func loadTreeWorker(repo restic.Repository,
} }
debug.Log("load tree %v", treeID.Str()) debug.Log("load tree %v", treeID.Str())
tree, err := repo.LoadTree(treeID) tree, err := repo.LoadTree(ctx, treeID)
debug.Log("load tree %v (%v) returned err: %v", tree, treeID.Str(), err) debug.Log("load tree %v (%v) returned err: %v", tree, treeID.Str(), err)
job = treeJob{ID: treeID, error: err, Tree: tree} job = treeJob{ID: treeID, error: err, Tree: tree}
outCh = out outCh = out
@ -395,7 +396,7 @@ func loadTreeWorker(repo restic.Repository,
} }
// checkTreeWorker checks the trees received and sends out errors to errChan. // checkTreeWorker checks the trees received and sends out errors to errChan.
func (c *Checker) checkTreeWorker(in <-chan treeJob, out chan<- error, done <-chan struct{}, wg *sync.WaitGroup) { func (c *Checker) checkTreeWorker(ctx context.Context, in <-chan treeJob, out chan<- error, wg *sync.WaitGroup) {
defer func() { defer func() {
debug.Log("exiting") debug.Log("exiting")
wg.Done() wg.Done()
@ -410,7 +411,7 @@ func (c *Checker) checkTreeWorker(in <-chan treeJob, out chan<- error, done <-ch
outCh = nil outCh = nil
for { for {
select { select {
case <-done: case <-ctx.Done():
debug.Log("done channel closed, exiting") debug.Log("done channel closed, exiting")
return return
@ -458,7 +459,7 @@ func (c *Checker) checkTreeWorker(in <-chan treeJob, out chan<- error, done <-ch
} }
} }
func filterTrees(backlog restic.IDs, loaderChan chan<- restic.ID, in <-chan treeJob, out chan<- treeJob, done <-chan struct{}) { func filterTrees(ctx context.Context, backlog restic.IDs, loaderChan chan<- restic.ID, in <-chan treeJob, out chan<- treeJob) {
defer func() { defer func() {
debug.Log("closing output channels") debug.Log("closing output channels")
close(loaderChan) close(loaderChan)
@ -489,7 +490,7 @@ func filterTrees(backlog restic.IDs, loaderChan chan<- restic.ID, in <-chan tree
} }
select { select {
case <-done: case <-ctx.Done():
return return
case loadCh <- nextTreeID: case loadCh <- nextTreeID:
@ -549,15 +550,15 @@ func filterTrees(backlog restic.IDs, loaderChan chan<- restic.ID, in <-chan tree
// Structure checks that for all snapshots all referenced data blobs and // Structure checks that for all snapshots all referenced data blobs and
// subtrees are available in the index. errChan is closed after all trees have // subtrees are available in the index. errChan is closed after all trees have
// been traversed. // been traversed.
func (c *Checker) Structure(errChan chan<- error, done <-chan struct{}) { func (c *Checker) Structure(ctx context.Context, errChan chan<- error) {
defer close(errChan) defer close(errChan)
trees, errs := loadSnapshotTreeIDs(c.repo) trees, errs := loadSnapshotTreeIDs(ctx, c.repo)
debug.Log("need to check %d trees from snapshots, %d errs returned", len(trees), len(errs)) debug.Log("need to check %d trees from snapshots, %d errs returned", len(trees), len(errs))
for _, err := range errs { for _, err := range errs {
select { select {
case <-done: case <-ctx.Done():
return return
case errChan <- err: case errChan <- err:
} }
@ -570,11 +571,11 @@ func (c *Checker) Structure(errChan chan<- error, done <-chan struct{}) {
var wg sync.WaitGroup var wg sync.WaitGroup
for i := 0; i < defaultParallelism; i++ { for i := 0; i < defaultParallelism; i++ {
wg.Add(2) wg.Add(2)
go loadTreeWorker(c.repo, treeIDChan, treeJobChan1, done, &wg) go loadTreeWorker(ctx, c.repo, treeIDChan, treeJobChan1, &wg)
go c.checkTreeWorker(treeJobChan2, errChan, done, &wg) go c.checkTreeWorker(ctx, treeJobChan2, errChan, &wg)
} }
filterTrees(trees, treeIDChan, treeJobChan1, treeJobChan2, done) filterTrees(ctx, trees, treeIDChan, treeJobChan1, treeJobChan2)
wg.Wait() wg.Wait()
} }
@ -659,11 +660,11 @@ func (c *Checker) CountPacks() uint64 {
} }
// checkPack reads a pack and checks the integrity of all blobs. // checkPack reads a pack and checks the integrity of all blobs.
func checkPack(r restic.Repository, id restic.ID) error { func checkPack(ctx context.Context, r restic.Repository, id restic.ID) error {
debug.Log("checking pack %v", id.Str()) debug.Log("checking pack %v", id.Str())
h := restic.Handle{Type: restic.DataFile, Name: id.String()} h := restic.Handle{Type: restic.DataFile, Name: id.String()}
rd, err := r.Backend().Load(h, 0, 0) rd, err := r.Backend().Load(ctx, h, 0, 0)
if err != nil { if err != nil {
return err return err
} }
@ -748,7 +749,7 @@ func checkPack(r restic.Repository, id restic.ID) error {
} }
// ReadData loads all data from the repository and checks the integrity. // ReadData loads all data from the repository and checks the integrity.
func (c *Checker) ReadData(p *restic.Progress, errChan chan<- error, done <-chan struct{}) { func (c *Checker) ReadData(ctx context.Context, p *restic.Progress, errChan chan<- error) {
defer close(errChan) defer close(errChan)
p.Start() p.Start()
@ -761,7 +762,7 @@ func (c *Checker) ReadData(p *restic.Progress, errChan chan<- error, done <-chan
var ok bool var ok bool
select { select {
case <-done: case <-ctx.Done():
return return
case id, ok = <-in: case id, ok = <-in:
if !ok { if !ok {
@ -769,21 +770,21 @@ func (c *Checker) ReadData(p *restic.Progress, errChan chan<- error, done <-chan
} }
} }
err := checkPack(c.repo, id) err := checkPack(ctx, c.repo, id)
p.Report(restic.Stat{Blobs: 1}) p.Report(restic.Stat{Blobs: 1})
if err == nil { if err == nil {
continue continue
} }
select { select {
case <-done: case <-ctx.Done():
return return
case errChan <- err: case errChan <- err:
} }
} }
} }
ch := c.repo.List(restic.DataFile, done) ch := c.repo.List(ctx, restic.DataFile)
var wg sync.WaitGroup var wg sync.WaitGroup
for i := 0; i < defaultParallelism; i++ { for i := 0; i < defaultParallelism; i++ {

View File

@ -1,6 +1,7 @@
package checker package checker
import ( import (
"context"
"restic" "restic"
"testing" "testing"
) )
@ -9,7 +10,7 @@ import (
func TestCheckRepo(t testing.TB, repo restic.Repository) { func TestCheckRepo(t testing.TB, repo restic.Repository) {
chkr := New(repo) chkr := New(repo)
hints, errs := chkr.LoadIndex() hints, errs := chkr.LoadIndex(context.TODO())
if len(errs) != 0 { if len(errs) != 0 {
t.Fatalf("errors loading index: %v", errs) t.Fatalf("errors loading index: %v", errs)
} }
@ -18,12 +19,9 @@ func TestCheckRepo(t testing.TB, repo restic.Repository) {
t.Fatalf("errors loading index: %v", hints) t.Fatalf("errors loading index: %v", hints)
} }
done := make(chan struct{})
defer close(done)
// packs // packs
errChan := make(chan error) errChan := make(chan error)
go chkr.Packs(errChan, done) go chkr.Packs(context.TODO(), errChan)
for err := range errChan { for err := range errChan {
t.Error(err) t.Error(err)
@ -31,7 +29,7 @@ func TestCheckRepo(t testing.TB, repo restic.Repository) {
// structure // structure
errChan = make(chan error) errChan = make(chan error)
go chkr.Structure(errChan, done) go chkr.Structure(context.TODO(), errChan)
for err := range errChan { for err := range errChan {
t.Error(err) t.Error(err)
@ -45,7 +43,7 @@ func TestCheckRepo(t testing.TB, repo restic.Repository) {
// read data // read data
errChan = make(chan error) errChan = make(chan error)
go chkr.ReadData(nil, errChan, done) go chkr.ReadData(context.TODO(), nil, errChan)
for err := range errChan { for err := range errChan {
t.Error(err) t.Error(err)

View File

@ -1,6 +1,7 @@
package restic package restic
import ( import (
"context"
"testing" "testing"
"restic/errors" "restic/errors"
@ -23,7 +24,7 @@ const RepoVersion = 1
// JSONUnpackedLoader loads unpacked JSON. // JSONUnpackedLoader loads unpacked JSON.
type JSONUnpackedLoader interface { type JSONUnpackedLoader interface {
LoadJSONUnpacked(FileType, ID, interface{}) error LoadJSONUnpacked(context.Context, FileType, ID, interface{}) error
} }
// CreateConfig creates a config file with a randomly selected polynomial and // CreateConfig creates a config file with a randomly selected polynomial and
@ -57,12 +58,12 @@ func TestCreateConfig(t testing.TB, pol chunker.Pol) (cfg Config) {
} }
// LoadConfig returns loads, checks and returns the config for a repository. // LoadConfig returns loads, checks and returns the config for a repository.
func LoadConfig(r JSONUnpackedLoader) (Config, error) { func LoadConfig(ctx context.Context, r JSONUnpackedLoader) (Config, error) {
var ( var (
cfg Config cfg Config
) )
err := r.LoadJSONUnpacked(ConfigFile, ID{}, &cfg) err := r.LoadJSONUnpacked(ctx, ConfigFile, ID{}, &cfg)
if err != nil { if err != nil {
return Config{}, err return Config{}, err
} }

View File

@ -1,12 +1,14 @@
package restic package restic
import "context"
// FindUsedBlobs traverses the tree ID and adds all seen blobs (trees and data // FindUsedBlobs traverses the tree ID and adds all seen blobs (trees and data
// blobs) to the set blobs. The tree blobs in the `seen` BlobSet will not be visited // blobs) to the set blobs. The tree blobs in the `seen` BlobSet will not be visited
// again. // again.
func FindUsedBlobs(repo Repository, treeID ID, blobs BlobSet, seen BlobSet) error { func FindUsedBlobs(ctx context.Context, repo Repository, treeID ID, blobs BlobSet, seen BlobSet) error {
blobs.Insert(BlobHandle{ID: treeID, Type: TreeBlob}) blobs.Insert(BlobHandle{ID: treeID, Type: TreeBlob})
tree, err := repo.LoadTree(treeID) tree, err := repo.LoadTree(ctx, treeID)
if err != nil { if err != nil {
return err return err
} }
@ -26,7 +28,7 @@ func FindUsedBlobs(repo Repository, treeID ID, blobs BlobSet, seen BlobSet) erro
seen.Insert(h) seen.Insert(h)
err := FindUsedBlobs(repo, subtreeID, blobs, seen) err := FindUsedBlobs(ctx, repo, subtreeID, blobs, seen)
if err != nil { if err != nil {
return err return err
} }

View File

@ -26,9 +26,9 @@ type dir struct {
ownerIsRoot bool ownerIsRoot bool
} }
func newDir(repo restic.Repository, node *restic.Node, ownerIsRoot bool) (*dir, error) { func newDir(ctx context.Context, repo restic.Repository, node *restic.Node, ownerIsRoot bool) (*dir, error) {
debug.Log("new dir for %v (%v)", node.Name, node.Subtree.Str()) debug.Log("new dir for %v (%v)", node.Name, node.Subtree.Str())
tree, err := repo.LoadTree(*node.Subtree) tree, err := repo.LoadTree(ctx, *node.Subtree)
if err != nil { if err != nil {
debug.Log(" error loading tree %v: %v", node.Subtree.Str(), err) debug.Log(" error loading tree %v: %v", node.Subtree.Str(), err)
return nil, err return nil, err
@ -49,7 +49,7 @@ func newDir(repo restic.Repository, node *restic.Node, ownerIsRoot bool) (*dir,
// replaceSpecialNodes replaces nodes with name "." and "/" by their contents. // replaceSpecialNodes replaces nodes with name "." and "/" by their contents.
// Otherwise, the node is returned. // Otherwise, the node is returned.
func replaceSpecialNodes(repo restic.Repository, node *restic.Node) ([]*restic.Node, error) { func replaceSpecialNodes(ctx context.Context, repo restic.Repository, node *restic.Node) ([]*restic.Node, error) {
if node.Type != "dir" || node.Subtree == nil { if node.Type != "dir" || node.Subtree == nil {
return []*restic.Node{node}, nil return []*restic.Node{node}, nil
} }
@ -58,7 +58,7 @@ func replaceSpecialNodes(repo restic.Repository, node *restic.Node) ([]*restic.N
return []*restic.Node{node}, nil return []*restic.Node{node}, nil
} }
tree, err := repo.LoadTree(*node.Subtree) tree, err := repo.LoadTree(ctx, *node.Subtree)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -66,16 +66,16 @@ func replaceSpecialNodes(repo restic.Repository, node *restic.Node) ([]*restic.N
return tree.Nodes, nil return tree.Nodes, nil
} }
func newDirFromSnapshot(repo restic.Repository, snapshot SnapshotWithId, ownerIsRoot bool) (*dir, error) { func newDirFromSnapshot(ctx context.Context, repo restic.Repository, snapshot SnapshotWithId, ownerIsRoot bool) (*dir, error) {
debug.Log("new dir for snapshot %v (%v)", snapshot.ID.Str(), snapshot.Tree.Str()) debug.Log("new dir for snapshot %v (%v)", snapshot.ID.Str(), snapshot.Tree.Str())
tree, err := repo.LoadTree(*snapshot.Tree) tree, err := repo.LoadTree(ctx, *snapshot.Tree)
if err != nil { if err != nil {
debug.Log(" loadTree(%v) failed: %v", snapshot.ID.Str(), err) debug.Log(" loadTree(%v) failed: %v", snapshot.ID.Str(), err)
return nil, err return nil, err
} }
items := make(map[string]*restic.Node) items := make(map[string]*restic.Node)
for _, n := range tree.Nodes { for _, n := range tree.Nodes {
nodes, err := replaceSpecialNodes(repo, n) nodes, err := replaceSpecialNodes(ctx, repo, n)
if err != nil { if err != nil {
debug.Log(" replaceSpecialNodes(%v) failed: %v", n, err) debug.Log(" replaceSpecialNodes(%v) failed: %v", n, err)
return nil, err return nil, err
@ -167,7 +167,7 @@ func (d *dir) Lookup(ctx context.Context, name string) (fs.Node, error) {
} }
switch node.Type { switch node.Type {
case "dir": case "dir":
return newDir(d.repo, node, d.ownerIsRoot) return newDir(ctx, d.repo, node, d.ownerIsRoot)
case "file": case "file":
return newFile(d.repo, node, d.ownerIsRoot) return newFile(d.repo, node, d.ownerIsRoot)
case "symlink": case "symlink":

View File

@ -9,6 +9,8 @@ import (
"restic" "restic"
"restic/debug" "restic/debug"
scontext "context"
"bazil.org/fuse" "bazil.org/fuse"
"bazil.org/fuse/fs" "bazil.org/fuse/fs"
"golang.org/x/net/context" "golang.org/x/net/context"
@ -25,7 +27,7 @@ var _ = fs.HandleReleaser(&file{})
// for fuse operations. // for fuse operations.
type BlobLoader interface { type BlobLoader interface {
LookupBlobSize(restic.ID, restic.BlobType) (uint, error) LookupBlobSize(restic.ID, restic.BlobType) (uint, error)
LoadBlob(restic.BlobType, restic.ID, []byte) (int, error) LoadBlob(scontext.Context, restic.BlobType, restic.ID, []byte) (int, error)
} }
type file struct { type file struct {
@ -88,7 +90,7 @@ func (f *file) Attr(ctx context.Context, a *fuse.Attr) error {
} }
func (f *file) getBlobAt(i int) (blob []byte, err error) { func (f *file) getBlobAt(ctx context.Context, i int) (blob []byte, err error) {
debug.Log("getBlobAt(%v, %v)", f.node.Name, i) debug.Log("getBlobAt(%v, %v)", f.node.Name, i)
if f.blobs[i] != nil { if f.blobs[i] != nil {
return f.blobs[i], nil return f.blobs[i], nil
@ -100,7 +102,7 @@ func (f *file) getBlobAt(i int) (blob []byte, err error) {
} }
buf := restic.NewBlobBuffer(f.sizes[i]) buf := restic.NewBlobBuffer(f.sizes[i])
n, err := f.repo.LoadBlob(restic.DataBlob, f.node.Content[i], buf) n, err := f.repo.LoadBlob(ctx, restic.DataBlob, f.node.Content[i], buf)
if err != nil { if err != nil {
debug.Log("LoadBlob(%v, %v) failed: %v", f.node.Name, f.node.Content[i], err) debug.Log("LoadBlob(%v, %v) failed: %v", f.node.Name, f.node.Content[i], err)
return nil, err return nil, err
@ -137,7 +139,7 @@ func (f *file) Read(ctx context.Context, req *fuse.ReadRequest, resp *fuse.ReadR
readBytes := 0 readBytes := 0
remainingBytes := req.Size remainingBytes := req.Size
for i := startContent; remainingBytes > 0 && i < len(f.sizes); i++ { for i := startContent; remainingBytes > 0 && i < len(f.sizes); i++ {
blob, err := f.getBlobAt(i) blob, err := f.getBlobAt(ctx, i)
if err != nil { if err != nil {
return err return err
} }

View File

@ -73,14 +73,14 @@ func (sn *SnapshotsDir) updateCache(ctx context.Context) error {
sn.Lock() sn.Lock()
defer sn.Unlock() defer sn.Unlock()
for id := range sn.repo.List(restic.SnapshotFile, ctx.Done()) { for id := range sn.repo.List(ctx, restic.SnapshotFile) {
if sn.processed.Has(id) { if sn.processed.Has(id) {
debug.Log("skipping snapshot %v, already in list", id.Str()) debug.Log("skipping snapshot %v, already in list", id.Str())
continue continue
} }
debug.Log("found snapshot id %v", id.Str()) debug.Log("found snapshot id %v", id.Str())
snapshot, err := restic.LoadSnapshot(sn.repo, id) snapshot, err := restic.LoadSnapshot(ctx, sn.repo, id)
if err != nil { if err != nil {
return err return err
} }
@ -158,5 +158,5 @@ func (sn *SnapshotsDir) Lookup(ctx context.Context, name string) (fs.Node, error
} }
} }
return newDirFromSnapshot(sn.repo, snapshot, sn.ownerIsRoot) return newDirFromSnapshot(ctx, sn.repo, snapshot, sn.ownerIsRoot)
} }

View File

@ -2,6 +2,7 @@
package index package index
import ( import (
"context"
"fmt" "fmt"
"os" "os"
"restic" "restic"
@ -33,15 +34,12 @@ func newIndex() *Index {
} }
// New creates a new index for repo from scratch. // New creates a new index for repo from scratch.
func New(repo restic.Repository, p *restic.Progress) (*Index, error) { func New(ctx context.Context, repo restic.Repository, p *restic.Progress) (*Index, error) {
done := make(chan struct{})
defer close(done)
p.Start() p.Start()
defer p.Done() defer p.Done()
ch := make(chan worker.Job) ch := make(chan worker.Job)
go list.AllPacks(repo, ch, done) go list.AllPacks(ctx, repo, ch)
idx := newIndex() idx := newIndex()
@ -84,11 +82,11 @@ type indexJSON struct {
Packs []*packJSON `json:"packs"` Packs []*packJSON `json:"packs"`
} }
func loadIndexJSON(repo restic.Repository, id restic.ID) (*indexJSON, error) { func loadIndexJSON(ctx context.Context, repo restic.Repository, id restic.ID) (*indexJSON, error) {
debug.Log("process index %v\n", id.Str()) debug.Log("process index %v\n", id.Str())
var idx indexJSON var idx indexJSON
err := repo.LoadJSONUnpacked(restic.IndexFile, id, &idx) err := repo.LoadJSONUnpacked(ctx, restic.IndexFile, id, &idx)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -97,25 +95,22 @@ func loadIndexJSON(repo restic.Repository, id restic.ID) (*indexJSON, error) {
} }
// Load creates an index by loading all index files from the repo. // Load creates an index by loading all index files from the repo.
func Load(repo restic.Repository, p *restic.Progress) (*Index, error) { func Load(ctx context.Context, repo restic.Repository, p *restic.Progress) (*Index, error) {
debug.Log("loading indexes") debug.Log("loading indexes")
p.Start() p.Start()
defer p.Done() defer p.Done()
done := make(chan struct{})
defer close(done)
supersedes := make(map[restic.ID]restic.IDSet) supersedes := make(map[restic.ID]restic.IDSet)
results := make(map[restic.ID]map[restic.ID]Pack) results := make(map[restic.ID]map[restic.ID]Pack)
index := newIndex() index := newIndex()
for id := range repo.List(restic.IndexFile, done) { for id := range repo.List(ctx, restic.IndexFile) {
p.Report(restic.Stat{Blobs: 1}) p.Report(restic.Stat{Blobs: 1})
debug.Log("Load index %v", id.Str()) debug.Log("Load index %v", id.Str())
idx, err := loadIndexJSON(repo, id) idx, err := loadIndexJSON(ctx, repo, id)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -250,17 +245,17 @@ func (idx *Index) FindBlob(h restic.BlobHandle) (result []Location, err error) {
} }
// Save writes the complete index to the repo. // Save writes the complete index to the repo.
func (idx *Index) Save(repo restic.Repository, supersedes restic.IDs) (restic.ID, error) { func (idx *Index) Save(ctx context.Context, repo restic.Repository, supersedes restic.IDs) (restic.ID, error) {
packs := make(map[restic.ID][]restic.Blob, len(idx.Packs)) packs := make(map[restic.ID][]restic.Blob, len(idx.Packs))
for id, p := range idx.Packs { for id, p := range idx.Packs {
packs[id] = p.Entries packs[id] = p.Entries
} }
return Save(repo, packs, supersedes) return Save(ctx, repo, packs, supersedes)
} }
// Save writes a new index containing the given packs. // Save writes a new index containing the given packs.
func Save(repo restic.Repository, packs map[restic.ID][]restic.Blob, supersedes restic.IDs) (restic.ID, error) { func Save(ctx context.Context, repo restic.Repository, packs map[restic.ID][]restic.Blob, supersedes restic.IDs) (restic.ID, error) {
idx := &indexJSON{ idx := &indexJSON{
Supersedes: supersedes, Supersedes: supersedes,
Packs: make([]*packJSON, 0, len(packs)), Packs: make([]*packJSON, 0, len(packs)),
@ -285,5 +280,5 @@ func Save(repo restic.Repository, packs map[restic.ID][]restic.Blob, supersedes
idx.Packs = append(idx.Packs, p) idx.Packs = append(idx.Packs, p)
} }
return repo.SaveJSONUnpacked(restic.IndexFile, idx) return repo.SaveJSONUnpacked(ctx, restic.IndexFile, idx)
} }

View File

@ -1,6 +1,7 @@
package list package list
import ( import (
"context"
"restic" "restic"
"restic/worker" "restic/worker"
) )
@ -9,8 +10,8 @@ const listPackWorkers = 10
// Lister combines lists packs in a repo and blobs in a pack. // Lister combines lists packs in a repo and blobs in a pack.
type Lister interface { type Lister interface {
List(restic.FileType, <-chan struct{}) <-chan restic.ID List(context.Context, restic.FileType) <-chan restic.ID
ListPack(restic.ID) ([]restic.Blob, int64, error) ListPack(context.Context, restic.ID) ([]restic.Blob, int64, error)
} }
// Result is returned in the channel from LoadBlobsFromAllPacks. // Result is returned in the channel from LoadBlobsFromAllPacks.
@ -36,10 +37,10 @@ func (l Result) Entries() []restic.Blob {
} }
// AllPacks sends the contents of all packs to ch. // AllPacks sends the contents of all packs to ch.
func AllPacks(repo Lister, ch chan<- worker.Job, done <-chan struct{}) { func AllPacks(ctx context.Context, repo Lister, ch chan<- worker.Job) {
f := func(job worker.Job, done <-chan struct{}) (interface{}, error) { f := func(ctx context.Context, job worker.Job) (interface{}, error) {
packID := job.Data.(restic.ID) packID := job.Data.(restic.ID)
entries, size, err := repo.ListPack(packID) entries, size, err := repo.ListPack(ctx, packID)
return Result{ return Result{
packID: packID, packID: packID,
@ -49,14 +50,14 @@ func AllPacks(repo Lister, ch chan<- worker.Job, done <-chan struct{}) {
} }
jobCh := make(chan worker.Job) jobCh := make(chan worker.Job)
wp := worker.New(listPackWorkers, f, jobCh, ch) wp := worker.New(ctx, listPackWorkers, f, jobCh, ch)
go func() { go func() {
defer close(jobCh) defer close(jobCh)
for id := range repo.List(restic.DataFile, done) { for id := range repo.List(ctx, restic.DataFile) {
select { select {
case jobCh <- worker.Job{Data: id}: case jobCh <- worker.Job{Data: id}:
case <-done: case <-ctx.Done():
return return
} }
} }

View File

@ -59,15 +59,15 @@ func IsAlreadyLocked(err error) bool {
// NewLock returns a new, non-exclusive lock for the repository. If an // NewLock returns a new, non-exclusive lock for the repository. If an
// exclusive lock is already held by another process, ErrAlreadyLocked is // exclusive lock is already held by another process, ErrAlreadyLocked is
// returned. // returned.
func NewLock(repo Repository) (*Lock, error) { func NewLock(ctx context.Context, repo Repository) (*Lock, error) {
return newLock(repo, false) return newLock(ctx, repo, false)
} }
// NewExclusiveLock returns a new, exclusive lock for the repository. If // NewExclusiveLock returns a new, exclusive lock for the repository. If
// another lock (normal and exclusive) is already held by another process, // another lock (normal and exclusive) is already held by another process,
// ErrAlreadyLocked is returned. // ErrAlreadyLocked is returned.
func NewExclusiveLock(repo Repository) (*Lock, error) { func NewExclusiveLock(ctx context.Context, repo Repository) (*Lock, error) {
return newLock(repo, true) return newLock(ctx, repo, true)
} }
var waitBeforeLockCheck = 200 * time.Millisecond var waitBeforeLockCheck = 200 * time.Millisecond
@ -78,7 +78,7 @@ func TestSetLockTimeout(t testing.TB, d time.Duration) {
waitBeforeLockCheck = d waitBeforeLockCheck = d
} }
func newLock(repo Repository, excl bool) (*Lock, error) { func newLock(ctx context.Context, repo Repository, excl bool) (*Lock, error) {
lock := &Lock{ lock := &Lock{
Time: time.Now(), Time: time.Now(),
PID: os.Getpid(), PID: os.Getpid(),
@ -95,11 +95,11 @@ func newLock(repo Repository, excl bool) (*Lock, error) {
return nil, err return nil, err
} }
if err = lock.checkForOtherLocks(); err != nil { if err = lock.checkForOtherLocks(ctx); err != nil {
return nil, err return nil, err
} }
lockID, err := lock.createLock() lockID, err := lock.createLock(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -108,7 +108,7 @@ func newLock(repo Repository, excl bool) (*Lock, error) {
time.Sleep(waitBeforeLockCheck) time.Sleep(waitBeforeLockCheck)
if err = lock.checkForOtherLocks(); err != nil { if err = lock.checkForOtherLocks(ctx); err != nil {
lock.Unlock() lock.Unlock()
return nil, err return nil, err
} }
@ -133,8 +133,8 @@ func (l *Lock) fillUserInfo() error {
// if there are any other locks, regardless if exclusive or not. If a // if there are any other locks, regardless if exclusive or not. If a
// non-exclusive lock is to be created, an error is only returned when an // non-exclusive lock is to be created, an error is only returned when an
// exclusive lock is found. // exclusive lock is found.
func (l *Lock) checkForOtherLocks() error { func (l *Lock) checkForOtherLocks(ctx context.Context) error {
return eachLock(l.repo, func(id ID, lock *Lock, err error) error { return eachLock(ctx, l.repo, func(id ID, lock *Lock, err error) error {
if l.lockID != nil && id.Equal(*l.lockID) { if l.lockID != nil && id.Equal(*l.lockID) {
return nil return nil
} }
@ -156,12 +156,9 @@ func (l *Lock) checkForOtherLocks() error {
}) })
} }
func eachLock(repo Repository, f func(ID, *Lock, error) error) error { func eachLock(ctx context.Context, repo Repository, f func(ID, *Lock, error) error) error {
done := make(chan struct{}) for id := range repo.List(ctx, LockFile) {
defer close(done) lock, err := LoadLock(ctx, repo, id)
for id := range repo.List(LockFile, done) {
lock, err := LoadLock(repo, id)
err = f(id, lock, err) err = f(id, lock, err)
if err != nil { if err != nil {
return err return err
@ -172,8 +169,8 @@ func eachLock(repo Repository, f func(ID, *Lock, error) error) error {
} }
// createLock acquires the lock by creating a file in the repository. // createLock acquires the lock by creating a file in the repository.
func (l *Lock) createLock() (ID, error) { func (l *Lock) createLock(ctx context.Context) (ID, error) {
id, err := l.repo.SaveJSONUnpacked(LockFile, l) id, err := l.repo.SaveJSONUnpacked(ctx, LockFile, l)
if err != nil { if err != nil {
return ID{}, err return ID{}, err
} }
@ -228,9 +225,9 @@ func (l *Lock) Stale() bool {
// Refresh refreshes the lock by creating a new file in the backend with a new // Refresh refreshes the lock by creating a new file in the backend with a new
// timestamp. Afterwards the old lock is removed. // timestamp. Afterwards the old lock is removed.
func (l *Lock) Refresh() error { func (l *Lock) Refresh(ctx context.Context) error {
debug.Log("refreshing lock %v", l.lockID.Str()) debug.Log("refreshing lock %v", l.lockID.Str())
id, err := l.createLock() id, err := l.createLock(ctx)
if err != nil { if err != nil {
return err return err
} }
@ -271,9 +268,9 @@ func init() {
} }
// LoadLock loads and unserializes a lock from a repository. // LoadLock loads and unserializes a lock from a repository.
func LoadLock(repo Repository, id ID) (*Lock, error) { func LoadLock(ctx context.Context, repo Repository, id ID) (*Lock, error) {
lock := &Lock{} lock := &Lock{}
if err := repo.LoadJSONUnpacked(LockFile, id, lock); err != nil { if err := repo.LoadJSONUnpacked(ctx, LockFile, id, lock); err != nil {
return nil, err return nil, err
} }
lock.lockID = &id lock.lockID = &id
@ -282,8 +279,8 @@ func LoadLock(repo Repository, id ID) (*Lock, error) {
} }
// RemoveStaleLocks deletes all locks detected as stale from the repository. // RemoveStaleLocks deletes all locks detected as stale from the repository.
func RemoveStaleLocks(repo Repository) error { func RemoveStaleLocks(ctx context.Context, repo Repository) error {
return eachLock(repo, func(id ID, lock *Lock, err error) error { return eachLock(ctx, repo, func(id ID, lock *Lock, err error) error {
// ignore locks that cannot be loaded // ignore locks that cannot be loaded
if err != nil { if err != nil {
return nil return nil
@ -298,8 +295,8 @@ func RemoveStaleLocks(repo Repository) error {
} }
// RemoveAllLocks removes all locks forcefully. // RemoveAllLocks removes all locks forcefully.
func RemoveAllLocks(repo Repository) error { func RemoveAllLocks(ctx context.Context, repo Repository) error {
return eachLock(repo, func(id ID, lock *Lock, err error) error { return eachLock(ctx, repo, func(id ID, lock *Lock, err error) error {
return repo.Backend().Remove(context.TODO(), Handle{Type: LockFile, Name: id.String()}) return repo.Backend().Remove(context.TODO(), Handle{Type: LockFile, Name: id.String()})
}) })
} }

View File

@ -1,6 +1,7 @@
package mock package mock
import ( import (
"context"
"io" "io"
"restic" "restic"
@ -10,13 +11,13 @@ import (
// Backend implements a mock backend. // Backend implements a mock backend.
type Backend struct { type Backend struct {
CloseFn func() error CloseFn func() error
SaveFn func(h restic.Handle, rd io.Reader) error SaveFn func(ctx context.Context, h restic.Handle, rd io.Reader) error
LoadFn func(h restic.Handle, length int, offset int64) (io.ReadCloser, error) LoadFn func(ctx context.Context, h restic.Handle, length int, offset int64) (io.ReadCloser, error)
StatFn func(h restic.Handle) (restic.FileInfo, error) StatFn func(ctx context.Context, h restic.Handle) (restic.FileInfo, error)
ListFn func(restic.FileType, <-chan struct{}) <-chan string ListFn func(ctx context.Context, t restic.FileType) <-chan string
RemoveFn func(h restic.Handle) error RemoveFn func(ctx context.Context, h restic.Handle) error
TestFn func(h restic.Handle) (bool, error) TestFn func(ctx context.Context, h restic.Handle) (bool, error)
DeleteFn func() error DeleteFn func(ctx context.Context) error
LocationFn func() string LocationFn func() string
} }
@ -39,68 +40,68 @@ func (m *Backend) Location() string {
} }
// Save data in the backend. // Save data in the backend.
func (m *Backend) Save(h restic.Handle, rd io.Reader) error { func (m *Backend) Save(ctx context.Context, h restic.Handle, rd io.Reader) error {
if m.SaveFn == nil { if m.SaveFn == nil {
return errors.New("not implemented") return errors.New("not implemented")
} }
return m.SaveFn(h, rd) return m.SaveFn(ctx, h, rd)
} }
// Load loads data from the backend. // Load loads data from the backend.
func (m *Backend) Load(h restic.Handle, length int, offset int64) (io.ReadCloser, error) { func (m *Backend) Load(ctx context.Context, h restic.Handle, length int, offset int64) (io.ReadCloser, error) {
if m.LoadFn == nil { if m.LoadFn == nil {
return nil, errors.New("not implemented") return nil, errors.New("not implemented")
} }
return m.LoadFn(h, length, offset) return m.LoadFn(ctx, h, length, offset)
} }
// Stat an object in the backend. // Stat an object in the backend.
func (m *Backend) Stat(h restic.Handle) (restic.FileInfo, error) { func (m *Backend) Stat(ctx context.Context, h restic.Handle) (restic.FileInfo, error) {
if m.StatFn == nil { if m.StatFn == nil {
return restic.FileInfo{}, errors.New("not implemented") return restic.FileInfo{}, errors.New("not implemented")
} }
return m.StatFn(h) return m.StatFn(ctx, h)
} }
// List items of type t. // List items of type t.
func (m *Backend) List(t restic.FileType, done <-chan struct{}) <-chan string { func (m *Backend) List(ctx context.Context, t restic.FileType) <-chan string {
if m.ListFn == nil { if m.ListFn == nil {
ch := make(chan string) ch := make(chan string)
close(ch) close(ch)
return ch return ch
} }
return m.ListFn(t, done) return m.ListFn(ctx, t)
} }
// Remove data from the backend. // Remove data from the backend.
func (m *Backend) Remove(h restic.Handle) error { func (m *Backend) Remove(ctx context.Context, h restic.Handle) error {
if m.RemoveFn == nil { if m.RemoveFn == nil {
return errors.New("not implemented") return errors.New("not implemented")
} }
return m.RemoveFn(h) return m.RemoveFn(ctx, h)
} }
// Test for the existence of a specific item. // Test for the existence of a specific item.
func (m *Backend) Test(h restic.Handle) (bool, error) { func (m *Backend) Test(ctx context.Context, h restic.Handle) (bool, error) {
if m.TestFn == nil { if m.TestFn == nil {
return false, errors.New("not implemented") return false, errors.New("not implemented")
} }
return m.TestFn(h) return m.TestFn(ctx, h)
} }
// Delete all data. // Delete all data.
func (m *Backend) Delete() error { func (m *Backend) Delete(ctx context.Context) error {
if m.DeleteFn == nil { if m.DeleteFn == nil {
return errors.New("not implemented") return errors.New("not implemented")
} }
return m.DeleteFn() return m.DeleteFn(ctx)
} }
// Make sure that Backend implements the backend interface. // Make sure that Backend implements the backend interface.

View File

@ -1,6 +1,7 @@
package restic package restic
import ( import (
"context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"os" "os"
@ -116,7 +117,7 @@ func (node Node) GetExtendedAttribute(a string) []byte {
} }
// CreateAt creates the node at the given path and restores all the meta data. // CreateAt creates the node at the given path and restores all the meta data.
func (node *Node) CreateAt(path string, repo Repository, idx *HardlinkIndex) error { func (node *Node) CreateAt(ctx context.Context, path string, repo Repository, idx *HardlinkIndex) error {
debug.Log("create node %v at %v", node.Name, path) debug.Log("create node %v at %v", node.Name, path)
switch node.Type { switch node.Type {
@ -125,7 +126,7 @@ func (node *Node) CreateAt(path string, repo Repository, idx *HardlinkIndex) err
return err return err
} }
case "file": case "file":
if err := node.createFileAt(path, repo, idx); err != nil { if err := node.createFileAt(ctx, path, repo, idx); err != nil {
return err return err
} }
case "symlink": case "symlink":
@ -228,7 +229,7 @@ func (node Node) createDirAt(path string) error {
return nil return nil
} }
func (node Node) createFileAt(path string, repo Repository, idx *HardlinkIndex) error { func (node Node) createFileAt(ctx context.Context, path string, repo Repository, idx *HardlinkIndex) error {
if node.Links > 1 && idx.Has(node.Inode, node.DeviceID) { if node.Links > 1 && idx.Has(node.Inode, node.DeviceID) {
if err := fs.Remove(path); !os.IsNotExist(err) { if err := fs.Remove(path); !os.IsNotExist(err) {
return errors.Wrap(err, "RemoveCreateHardlink") return errors.Wrap(err, "RemoveCreateHardlink")
@ -259,7 +260,7 @@ func (node Node) createFileAt(path string, repo Repository, idx *HardlinkIndex)
buf = NewBlobBuffer(int(size)) buf = NewBlobBuffer(int(size))
} }
n, err := repo.LoadBlob(DataBlob, id, buf) n, err := repo.LoadBlob(ctx, DataBlob, id, buf)
if err != nil { if err != nil {
return err return err
} }

View File

@ -2,6 +2,7 @@ package pack_test
import ( import (
"bytes" "bytes"
"context"
"crypto/rand" "crypto/rand"
"crypto/sha256" "crypto/sha256"
"encoding/binary" "encoding/binary"
@ -126,7 +127,7 @@ func TestUnpackReadSeeker(t *testing.T) {
id := restic.Hash(packData) id := restic.Hash(packData)
handle := restic.Handle{Type: restic.DataFile, Name: id.String()} handle := restic.Handle{Type: restic.DataFile, Name: id.String()}
OK(t, b.Save(handle, bytes.NewReader(packData))) OK(t, b.Save(context.TODO(), handle, bytes.NewReader(packData)))
verifyBlobs(t, bufs, k, restic.ReaderAt(b, handle), packSize) verifyBlobs(t, bufs, k, restic.ReaderAt(b, handle), packSize)
} }
@ -139,6 +140,6 @@ func TestShortPack(t *testing.T) {
id := restic.Hash(packData) id := restic.Hash(packData)
handle := restic.Handle{Type: restic.DataFile, Name: id.String()} handle := restic.Handle{Type: restic.DataFile, Name: id.String()}
OK(t, b.Save(handle, bytes.NewReader(packData))) OK(t, b.Save(context.TODO(), handle, bytes.NewReader(packData)))
verifyBlobs(t, bufs, k, restic.ReaderAt(b, handle), packSize) verifyBlobs(t, bufs, k, restic.ReaderAt(b, handle), packSize)
} }

View File

@ -1,6 +1,7 @@
package pipe package pipe
import ( import (
"context"
"fmt" "fmt"
"os" "os"
"path/filepath" "path/filepath"
@ -78,7 +79,7 @@ func readDirNames(dirname string) ([]string, error) {
// dirs). If false is returned, files are ignored and dirs are not even walked. // dirs). If false is returned, files are ignored and dirs are not even walked.
type SelectFunc func(item string, fi os.FileInfo) bool type SelectFunc func(item string, fi os.FileInfo) bool
func walk(basedir, dir string, selectFunc SelectFunc, done <-chan struct{}, jobs chan<- Job, res chan<- Result) (excluded bool) { func walk(ctx context.Context, basedir, dir string, selectFunc SelectFunc, jobs chan<- Job, res chan<- Result) (excluded bool) {
debug.Log("start on %q, basedir %q", dir, basedir) debug.Log("start on %q, basedir %q", dir, basedir)
relpath, err := filepath.Rel(basedir, dir) relpath, err := filepath.Rel(basedir, dir)
@ -92,7 +93,7 @@ func walk(basedir, dir string, selectFunc SelectFunc, done <-chan struct{}, jobs
debug.Log("error for %v: %v, res %p", dir, err, res) debug.Log("error for %v: %v, res %p", dir, err, res)
select { select {
case jobs <- Dir{basedir: basedir, path: relpath, info: info, error: err, result: res}: case jobs <- Dir{basedir: basedir, path: relpath, info: info, error: err, result: res}:
case <-done: case <-ctx.Done():
} }
return return
} }
@ -107,7 +108,7 @@ func walk(basedir, dir string, selectFunc SelectFunc, done <-chan struct{}, jobs
debug.Log("sending file job for %v, res %p", dir, res) debug.Log("sending file job for %v, res %p", dir, res)
select { select {
case jobs <- Entry{info: info, basedir: basedir, path: relpath, result: res}: case jobs <- Entry{info: info, basedir: basedir, path: relpath, result: res}:
case <-done: case <-ctx.Done():
} }
return return
} }
@ -117,7 +118,7 @@ func walk(basedir, dir string, selectFunc SelectFunc, done <-chan struct{}, jobs
if err != nil { if err != nil {
debug.Log("Readdirnames(%v) returned error: %v, res %p", dir, err, res) debug.Log("Readdirnames(%v) returned error: %v, res %p", dir, err, res)
select { select {
case <-done: case <-ctx.Done():
case jobs <- Dir{basedir: basedir, path: relpath, info: info, error: err, result: res}: case jobs <- Dir{basedir: basedir, path: relpath, info: info, error: err, result: res}:
} }
return return
@ -146,7 +147,7 @@ func walk(basedir, dir string, selectFunc SelectFunc, done <-chan struct{}, jobs
debug.Log("sending file job for %v, err %v, res %p", subpath, err, res) debug.Log("sending file job for %v, err %v, res %p", subpath, err, res)
select { select {
case jobs <- Entry{info: fi, error: statErr, basedir: basedir, path: filepath.Join(relpath, name), result: ch}: case jobs <- Entry{info: fi, error: statErr, basedir: basedir, path: filepath.Join(relpath, name), result: ch}:
case <-done: case <-ctx.Done():
return return
} }
continue continue
@ -156,13 +157,13 @@ func walk(basedir, dir string, selectFunc SelectFunc, done <-chan struct{}, jobs
// between walk and open // between walk and open
debug.RunHook("pipe.walk2", filepath.Join(relpath, name)) debug.RunHook("pipe.walk2", filepath.Join(relpath, name))
walk(basedir, subpath, selectFunc, done, jobs, ch) walk(ctx, basedir, subpath, selectFunc, jobs, ch)
} }
debug.Log("sending dirjob for %q, basedir %q, res %p", dir, basedir, res) debug.Log("sending dirjob for %q, basedir %q, res %p", dir, basedir, res)
select { select {
case jobs <- Dir{basedir: basedir, path: relpath, info: info, Entries: entries, result: res}: case jobs <- Dir{basedir: basedir, path: relpath, info: info, Entries: entries, result: res}:
case <-done: case <-ctx.Done():
} }
return return
@ -191,7 +192,7 @@ func cleanupPath(path string) ([]string, error) {
// Walk sends a Job for each file and directory it finds below the paths. When // Walk sends a Job for each file and directory it finds below the paths. When
// the channel done is closed, processing stops. // the channel done is closed, processing stops.
func Walk(walkPaths []string, selectFunc SelectFunc, done chan struct{}, jobs chan<- Job, res chan<- Result) { func Walk(ctx context.Context, walkPaths []string, selectFunc SelectFunc, jobs chan<- Job, res chan<- Result) {
var paths []string var paths []string
for _, p := range walkPaths { for _, p := range walkPaths {
@ -215,7 +216,7 @@ func Walk(walkPaths []string, selectFunc SelectFunc, done chan struct{}, jobs ch
for _, path := range paths { for _, path := range paths {
debug.Log("start walker for %v", path) debug.Log("start walker for %v", path)
ch := make(chan Result, 1) ch := make(chan Result, 1)
excluded := walk(filepath.Dir(path), path, selectFunc, done, jobs, ch) excluded := walk(ctx, filepath.Dir(path), path, selectFunc, jobs, ch)
if excluded { if excluded {
debug.Log("walker for %v done, it was excluded by the filter", path) debug.Log("walker for %v done, it was excluded by the filter", path)
@ -228,7 +229,7 @@ func Walk(walkPaths []string, selectFunc SelectFunc, done chan struct{}, jobs ch
debug.Log("sending root node, res %p", res) debug.Log("sending root node, res %p", res)
select { select {
case <-done: case <-ctx.Done():
return return
case jobs <- Dir{Entries: entries, result: res}: case jobs <- Dir{Entries: entries, result: res}:
} }

View File

@ -12,7 +12,7 @@ type backendReaderAt struct {
} }
func (brd backendReaderAt) ReadAt(p []byte, offset int64) (n int, err error) { func (brd backendReaderAt) ReadAt(p []byte, offset int64) (n int, err error) {
return ReadAt(brd.be, brd.h, offset, p) return ReadAt(context.TODO(), brd.be, brd.h, offset, p)
} }
// ReaderAt returns an io.ReaderAt for a file in the backend. // ReaderAt returns an io.ReaderAt for a file in the backend.
@ -21,9 +21,9 @@ func ReaderAt(be Backend, h Handle) io.ReaderAt {
} }
// ReadAt reads from the backend handle h at the given position. // ReadAt reads from the backend handle h at the given position.
func ReadAt(be Backend, h Handle, offset int64, p []byte) (n int, err error) { func ReadAt(ctx context.Context, be Backend, h Handle, offset int64, p []byte) (n int, err error) {
debug.Log("ReadAt(%v) at %v, len %v", h, offset, len(p)) debug.Log("ReadAt(%v) at %v, len %v", h, offset, len(p))
rd, err := be.Load(context.TODO(), h, len(p), offset) rd, err := be.Load(ctx, h, len(p), offset)
if err != nil { if err != nil {
return 0, err return 0, err
} }

View File

@ -17,30 +17,30 @@ type Repository interface {
SetIndex(Index) SetIndex(Index)
Index() Index Index() Index
SaveFullIndex() error SaveFullIndex(context.Context) error
SaveIndex() error SaveIndex(context.Context) error
LoadIndex() error LoadIndex(context.Context) error
Config() Config Config() Config
LookupBlobSize(ID, BlobType) (uint, error) LookupBlobSize(ID, BlobType) (uint, error)
List(FileType, <-chan struct{}) <-chan ID List(context.Context, FileType) <-chan ID
ListPack(ID) ([]Blob, int64, error) ListPack(context.Context, ID) ([]Blob, int64, error)
Flush() error Flush() error
SaveUnpacked(FileType, []byte) (ID, error) SaveUnpacked(context.Context, FileType, []byte) (ID, error)
SaveJSONUnpacked(FileType, interface{}) (ID, error) SaveJSONUnpacked(context.Context, FileType, interface{}) (ID, error)
LoadJSONUnpacked(FileType, ID, interface{}) error LoadJSONUnpacked(context.Context, FileType, ID, interface{}) error
LoadAndDecrypt(FileType, ID) ([]byte, error) LoadAndDecrypt(context.Context, FileType, ID) ([]byte, error)
LoadBlob(BlobType, ID, []byte) (int, error) LoadBlob(context.Context, BlobType, ID, []byte) (int, error)
SaveBlob(BlobType, []byte, ID) (ID, error) SaveBlob(context.Context, BlobType, []byte, ID) (ID, error)
LoadTree(ID) (*Tree, error) LoadTree(context.Context, ID) (*Tree, error)
SaveTree(t *Tree) (ID, error) SaveTree(context.Context, *Tree) (ID, error)
} }
// Deleter removes all data stored in a backend/repo. // Deleter removes all data stored in a backend/repo.

View File

@ -1,6 +1,7 @@
package repository package repository
import ( import (
"context"
"encoding/json" "encoding/json"
"io" "io"
"restic" "restic"
@ -519,10 +520,10 @@ func DecodeOldIndex(buf []byte) (idx *Index, err error) {
} }
// LoadIndexWithDecoder loads the index and decodes it with fn. // LoadIndexWithDecoder loads the index and decodes it with fn.
func LoadIndexWithDecoder(repo restic.Repository, id restic.ID, fn func([]byte) (*Index, error)) (idx *Index, err error) { func LoadIndexWithDecoder(ctx context.Context, repo restic.Repository, id restic.ID, fn func([]byte) (*Index, error)) (idx *Index, err error) {
debug.Log("Loading index %v", id.Str()) debug.Log("Loading index %v", id.Str())
buf, err := repo.LoadAndDecrypt(restic.IndexFile, id) buf, err := repo.LoadAndDecrypt(ctx, restic.IndexFile, id)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -2,6 +2,7 @@ package repository
import ( import (
"bytes" "bytes"
"context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"os" "os"
@ -58,12 +59,12 @@ var (
// createMasterKey creates a new master key in the given backend and encrypts // createMasterKey creates a new master key in the given backend and encrypts
// it with the password. // it with the password.
func createMasterKey(s *Repository, password string) (*Key, error) { func createMasterKey(s *Repository, password string) (*Key, error) {
return AddKey(s, password, nil) return AddKey(context.TODO(), s, password, nil)
} }
// OpenKey tries do decrypt the key specified by name with the given password. // OpenKey tries do decrypt the key specified by name with the given password.
func OpenKey(s *Repository, name string, password string) (*Key, error) { func OpenKey(ctx context.Context, s *Repository, name string, password string) (*Key, error) {
k, err := LoadKey(s, name) k, err := LoadKey(ctx, s, name)
if err != nil { if err != nil {
debug.Log("LoadKey(%v) returned error %v", name[:12], err) debug.Log("LoadKey(%v) returned error %v", name[:12], err)
return nil, err return nil, err
@ -113,19 +114,17 @@ func OpenKey(s *Repository, name string, password string) (*Key, error) {
// given password. If none could be found, ErrNoKeyFound is returned. When // given password. If none could be found, ErrNoKeyFound is returned. When
// maxKeys is reached, ErrMaxKeysReached is returned. When setting maxKeys to // maxKeys is reached, ErrMaxKeysReached is returned. When setting maxKeys to
// zero, all keys in the repo are checked. // zero, all keys in the repo are checked.
func SearchKey(s *Repository, password string, maxKeys int) (*Key, error) { func SearchKey(ctx context.Context, s *Repository, password string, maxKeys int) (*Key, error) {
checked := 0 checked := 0
// try at most maxKeysForSearch keys in repo // try at most maxKeysForSearch keys in repo
done := make(chan struct{}) for name := range s.Backend().List(ctx, restic.KeyFile) {
defer close(done)
for name := range s.Backend().List(restic.KeyFile, done) {
if maxKeys > 0 && checked > maxKeys { if maxKeys > 0 && checked > maxKeys {
return nil, ErrMaxKeysReached return nil, ErrMaxKeysReached
} }
debug.Log("trying key %v", name[:12]) debug.Log("trying key %v", name[:12])
key, err := OpenKey(s, name, password) key, err := OpenKey(ctx, s, name, password)
if err != nil { if err != nil {
debug.Log("key %v returned error %v", name[:12], err) debug.Log("key %v returned error %v", name[:12], err)
@ -145,9 +144,9 @@ func SearchKey(s *Repository, password string, maxKeys int) (*Key, error) {
} }
// LoadKey loads a key from the backend. // LoadKey loads a key from the backend.
func LoadKey(s *Repository, name string) (k *Key, err error) { func LoadKey(ctx context.Context, s *Repository, name string) (k *Key, err error) {
h := restic.Handle{Type: restic.KeyFile, Name: name} h := restic.Handle{Type: restic.KeyFile, Name: name}
data, err := backend.LoadAll(s.be, h) data, err := backend.LoadAll(ctx, s.be, h)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -162,7 +161,7 @@ func LoadKey(s *Repository, name string) (k *Key, err error) {
} }
// AddKey adds a new key to an already existing repository. // AddKey adds a new key to an already existing repository.
func AddKey(s *Repository, password string, template *crypto.Key) (*Key, error) { func AddKey(ctx context.Context, s *Repository, password string, template *crypto.Key) (*Key, error) {
// make sure we have valid KDF parameters // make sure we have valid KDF parameters
if KDFParams == nil { if KDFParams == nil {
p, err := crypto.Calibrate(KDFTimeout, KDFMemory) p, err := crypto.Calibrate(KDFTimeout, KDFMemory)
@ -233,7 +232,7 @@ func AddKey(s *Repository, password string, template *crypto.Key) (*Key, error)
Name: restic.Hash(buf).String(), Name: restic.Hash(buf).String(),
} }
err = s.be.Save(h, bytes.NewReader(buf)) err = s.be.Save(ctx, h, bytes.NewReader(buf))
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -1,6 +1,7 @@
package repository package repository
import ( import (
"context"
"crypto/sha256" "crypto/sha256"
"io" "io"
"os" "os"
@ -18,7 +19,7 @@ import (
// Saver implements saving data in a backend. // Saver implements saving data in a backend.
type Saver interface { type Saver interface {
Save(restic.Handle, io.Reader) error Save(context.Context, restic.Handle, io.Reader) error
} }
// Packer holds a pack.Packer together with a hash writer. // Packer holds a pack.Packer together with a hash writer.
@ -118,7 +119,7 @@ func (r *Repository) savePacker(p *Packer) error {
id := restic.IDFromHash(p.hw.Sum(nil)) id := restic.IDFromHash(p.hw.Sum(nil))
h := restic.Handle{Type: restic.DataFile, Name: id.String()} h := restic.Handle{Type: restic.DataFile, Name: id.String()}
err = r.be.Save(h, p.tmpfile) err = r.be.Save(context.TODO(), h, p.tmpfile)
if err != nil { if err != nil {
debug.Log("Save(%v) error: %v", h, err) debug.Log("Save(%v) error: %v", h, err)
return err return err

View File

@ -1,6 +1,7 @@
package repository package repository
import ( import (
"context"
"restic" "restic"
"sync" "sync"
@ -18,24 +19,19 @@ func closeIfOpen(ch chan struct{}) {
} }
// ParallelWorkFunc gets one file ID to work on. If an error is returned, // ParallelWorkFunc gets one file ID to work on. If an error is returned,
// processing stops. If done is closed, the function should return. // processing stops. When the contect is cancelled the function should return.
type ParallelWorkFunc func(id string, done <-chan struct{}) error type ParallelWorkFunc func(ctx context.Context, id string) error
// ParallelIDWorkFunc gets one restic.ID to work on. If an error is returned, // ParallelIDWorkFunc gets one restic.ID to work on. If an error is returned,
// processing stops. If done is closed, the function should return. // processing stops. When the context is cancelled the function should return.
type ParallelIDWorkFunc func(id restic.ID, done <-chan struct{}) error type ParallelIDWorkFunc func(ctx context.Context, id restic.ID) error
// FilesInParallel runs n workers of f in parallel, on the IDs that // FilesInParallel runs n workers of f in parallel, on the IDs that
// repo.List(t) yield. If f returns an error, the process is aborted and the // repo.List(t) yield. If f returns an error, the process is aborted and the
// first error is returned. // first error is returned.
func FilesInParallel(repo restic.Lister, t restic.FileType, n uint, f ParallelWorkFunc) error { func FilesInParallel(ctx context.Context, repo restic.Lister, t restic.FileType, n uint, f ParallelWorkFunc) error {
done := make(chan struct{})
defer closeIfOpen(done)
wg := &sync.WaitGroup{} wg := &sync.WaitGroup{}
ch := repo.List(ctx, t)
ch := repo.List(t, done)
errors := make(chan error, n) errors := make(chan error, n)
for i := 0; uint(i) < n; i++ { for i := 0; uint(i) < n; i++ {
@ -50,13 +46,12 @@ func FilesInParallel(repo restic.Lister, t restic.FileType, n uint, f ParallelWo
return return
} }
err := f(id, done) err := f(ctx, id)
if err != nil { if err != nil {
closeIfOpen(done)
errors <- err errors <- err
return return
} }
case <-done: case <-ctx.Done():
return return
} }
} }
@ -79,13 +74,13 @@ func FilesInParallel(repo restic.Lister, t restic.FileType, n uint, f ParallelWo
// function that takes a string. Filenames that do not parse as a restic.ID // function that takes a string. Filenames that do not parse as a restic.ID
// are ignored. // are ignored.
func ParallelWorkFuncParseID(f ParallelIDWorkFunc) ParallelWorkFunc { func ParallelWorkFuncParseID(f ParallelIDWorkFunc) ParallelWorkFunc {
return func(s string, done <-chan struct{}) error { return func(ctx context.Context, s string) error {
id, err := restic.ParseID(s) id, err := restic.ParseID(s)
if err != nil { if err != nil {
debug.Log("invalid ID %q: %v", id, err) debug.Log("invalid ID %q: %v", id, err)
return err return err
} }
return f(id, done) return f(ctx, id)
} }
} }

View File

@ -1,6 +1,7 @@
package repository package repository
import ( import (
"context"
"crypto/sha256" "crypto/sha256"
"io" "io"
"restic" "restic"
@ -17,7 +18,7 @@ import (
// these packs. Each pack is loaded and the blobs listed in keepBlobs is saved // these packs. Each pack is loaded and the blobs listed in keepBlobs is saved
// into a new pack. Afterwards, the packs are removed. This operation requires // into a new pack. Afterwards, the packs are removed. This operation requires
// an exclusive lock on the repo. // an exclusive lock on the repo.
func Repack(repo restic.Repository, packs restic.IDSet, keepBlobs restic.BlobSet, p *restic.Progress) (err error) { func Repack(ctx context.Context, repo restic.Repository, packs restic.IDSet, keepBlobs restic.BlobSet, p *restic.Progress) (err error) {
debug.Log("repacking %d packs while keeping %d blobs", len(packs), len(keepBlobs)) debug.Log("repacking %d packs while keeping %d blobs", len(packs), len(keepBlobs))
for packID := range packs { for packID := range packs {
@ -29,7 +30,7 @@ func Repack(repo restic.Repository, packs restic.IDSet, keepBlobs restic.BlobSet
return errors.Wrap(err, "TempFile") return errors.Wrap(err, "TempFile")
} }
beRd, err := repo.Backend().Load(h, 0, 0) beRd, err := repo.Backend().Load(ctx, h, 0, 0)
if err != nil { if err != nil {
return err return err
} }
@ -100,7 +101,7 @@ func Repack(repo restic.Repository, packs restic.IDSet, keepBlobs restic.BlobSet
h, tempfile.Name(), id) h, tempfile.Name(), id)
} }
_, err = repo.SaveBlob(entry.Type, buf, entry.ID) _, err = repo.SaveBlob(ctx, entry.Type, buf, entry.ID)
if err != nil { if err != nil {
return err return err
} }
@ -128,7 +129,7 @@ func Repack(repo restic.Repository, packs restic.IDSet, keepBlobs restic.BlobSet
for packID := range packs { for packID := range packs {
h := restic.Handle{Type: restic.DataFile, Name: packID.String()} h := restic.Handle{Type: restic.DataFile, Name: packID.String()}
err := repo.Backend().Remove(h) err := repo.Backend().Remove(ctx, h)
if err != nil { if err != nil {
debug.Log("error removing pack %v: %v", packID.Str(), err) debug.Log("error removing pack %v: %v", packID.Str(), err)
return err return err

View File

@ -2,6 +2,7 @@ package repository
import ( import (
"bytes" "bytes"
"context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"os" "os"
@ -50,11 +51,11 @@ func (r *Repository) PrefixLength(t restic.FileType) (int, error) {
// LoadAndDecrypt loads and decrypts data identified by t and id from the // LoadAndDecrypt loads and decrypts data identified by t and id from the
// backend. // backend.
func (r *Repository) LoadAndDecrypt(t restic.FileType, id restic.ID) ([]byte, error) { func (r *Repository) LoadAndDecrypt(ctx context.Context, t restic.FileType, id restic.ID) ([]byte, error) {
debug.Log("load %v with id %v", t, id.Str()) debug.Log("load %v with id %v", t, id.Str())
h := restic.Handle{Type: t, Name: id.String()} h := restic.Handle{Type: t, Name: id.String()}
buf, err := backend.LoadAll(r.be, h) buf, err := backend.LoadAll(ctx, r.be, h)
if err != nil { if err != nil {
debug.Log("error loading %v: %v", h, err) debug.Log("error loading %v: %v", h, err)
return nil, err return nil, err
@ -76,7 +77,7 @@ func (r *Repository) LoadAndDecrypt(t restic.FileType, id restic.ID) ([]byte, er
// loadBlob tries to load and decrypt content identified by t and id from a // loadBlob tries to load and decrypt content identified by t and id from a
// pack from the backend, the result is stored in plaintextBuf, which must be // pack from the backend, the result is stored in plaintextBuf, which must be
// large enough to hold the complete blob. // large enough to hold the complete blob.
func (r *Repository) loadBlob(id restic.ID, t restic.BlobType, plaintextBuf []byte) (int, error) { func (r *Repository) loadBlob(ctx context.Context, id restic.ID, t restic.BlobType, plaintextBuf []byte) (int, error) {
debug.Log("load %v with id %v (buf len %v, cap %d)", t, id.Str(), len(plaintextBuf), cap(plaintextBuf)) debug.Log("load %v with id %v (buf len %v, cap %d)", t, id.Str(), len(plaintextBuf), cap(plaintextBuf))
// lookup packs // lookup packs
@ -103,7 +104,7 @@ func (r *Repository) loadBlob(id restic.ID, t restic.BlobType, plaintextBuf []by
plaintextBuf = plaintextBuf[:blob.Length] plaintextBuf = plaintextBuf[:blob.Length]
n, err := restic.ReadAt(r.be, h, int64(blob.Offset), plaintextBuf) n, err := restic.ReadAt(ctx, r.be, h, int64(blob.Offset), plaintextBuf)
if err != nil { if err != nil {
debug.Log("error loading blob %v: %v", blob, err) debug.Log("error loading blob %v: %v", blob, err)
lastError = err lastError = err
@ -143,8 +144,8 @@ func (r *Repository) loadBlob(id restic.ID, t restic.BlobType, plaintextBuf []by
// LoadJSONUnpacked decrypts the data and afterwards calls json.Unmarshal on // LoadJSONUnpacked decrypts the data and afterwards calls json.Unmarshal on
// the item. // the item.
func (r *Repository) LoadJSONUnpacked(t restic.FileType, id restic.ID, item interface{}) (err error) { func (r *Repository) LoadJSONUnpacked(ctx context.Context, t restic.FileType, id restic.ID, item interface{}) (err error) {
buf, err := r.LoadAndDecrypt(t, id) buf, err := r.LoadAndDecrypt(ctx, t, id)
if err != nil { if err != nil {
return err return err
} }
@ -159,7 +160,7 @@ func (r *Repository) LookupBlobSize(id restic.ID, tpe restic.BlobType) (uint, er
// SaveAndEncrypt encrypts data and stores it to the backend as type t. If data // SaveAndEncrypt encrypts data and stores it to the backend as type t. If data
// is small enough, it will be packed together with other small blobs. // is small enough, it will be packed together with other small blobs.
func (r *Repository) SaveAndEncrypt(t restic.BlobType, data []byte, id *restic.ID) (restic.ID, error) { func (r *Repository) SaveAndEncrypt(ctx context.Context, t restic.BlobType, data []byte, id *restic.ID) (restic.ID, error) {
if id == nil { if id == nil {
// compute plaintext hash // compute plaintext hash
hashedID := restic.Hash(data) hashedID := restic.Hash(data)
@ -204,19 +205,19 @@ func (r *Repository) SaveAndEncrypt(t restic.BlobType, data []byte, id *restic.I
// SaveJSONUnpacked serialises item as JSON and encrypts and saves it in the // SaveJSONUnpacked serialises item as JSON and encrypts and saves it in the
// backend as type t, without a pack. It returns the storage hash. // backend as type t, without a pack. It returns the storage hash.
func (r *Repository) SaveJSONUnpacked(t restic.FileType, item interface{}) (restic.ID, error) { func (r *Repository) SaveJSONUnpacked(ctx context.Context, t restic.FileType, item interface{}) (restic.ID, error) {
debug.Log("save new blob %v", t) debug.Log("save new blob %v", t)
plaintext, err := json.Marshal(item) plaintext, err := json.Marshal(item)
if err != nil { if err != nil {
return restic.ID{}, errors.Wrap(err, "json.Marshal") return restic.ID{}, errors.Wrap(err, "json.Marshal")
} }
return r.SaveUnpacked(t, plaintext) return r.SaveUnpacked(ctx, t, plaintext)
} }
// SaveUnpacked encrypts data and stores it in the backend. Returned is the // SaveUnpacked encrypts data and stores it in the backend. Returned is the
// storage hash. // storage hash.
func (r *Repository) SaveUnpacked(t restic.FileType, p []byte) (id restic.ID, err error) { func (r *Repository) SaveUnpacked(ctx context.Context, t restic.FileType, p []byte) (id restic.ID, err error) {
ciphertext := restic.NewBlobBuffer(len(p)) ciphertext := restic.NewBlobBuffer(len(p))
ciphertext, err = r.Encrypt(ciphertext, p) ciphertext, err = r.Encrypt(ciphertext, p)
if err != nil { if err != nil {
@ -226,7 +227,7 @@ func (r *Repository) SaveUnpacked(t restic.FileType, p []byte) (id restic.ID, er
id = restic.Hash(ciphertext) id = restic.Hash(ciphertext)
h := restic.Handle{Type: t, Name: id.String()} h := restic.Handle{Type: t, Name: id.String()}
err = r.be.Save(h, bytes.NewReader(ciphertext)) err = r.be.Save(ctx, h, bytes.NewReader(ciphertext))
if err != nil { if err != nil {
debug.Log("error saving blob %v: %v", h, err) debug.Log("error saving blob %v: %v", h, err)
return restic.ID{}, err return restic.ID{}, err
@ -269,7 +270,7 @@ func (r *Repository) SetIndex(i restic.Index) {
} }
// SaveIndex saves an index in the repository. // SaveIndex saves an index in the repository.
func SaveIndex(repo restic.Repository, index *Index) (restic.ID, error) { func SaveIndex(ctx context.Context, repo restic.Repository, index *Index) (restic.ID, error) {
buf := bytes.NewBuffer(nil) buf := bytes.NewBuffer(nil)
err := index.Finalize(buf) err := index.Finalize(buf)
@ -277,15 +278,15 @@ func SaveIndex(repo restic.Repository, index *Index) (restic.ID, error) {
return restic.ID{}, err return restic.ID{}, err
} }
return repo.SaveUnpacked(restic.IndexFile, buf.Bytes()) return repo.SaveUnpacked(ctx, restic.IndexFile, buf.Bytes())
} }
// saveIndex saves all indexes in the backend. // saveIndex saves all indexes in the backend.
func (r *Repository) saveIndex(indexes ...*Index) error { func (r *Repository) saveIndex(ctx context.Context, indexes ...*Index) error {
for i, idx := range indexes { for i, idx := range indexes {
debug.Log("Saving index %d", i) debug.Log("Saving index %d", i)
sid, err := SaveIndex(r, idx) sid, err := SaveIndex(ctx, r, idx)
if err != nil { if err != nil {
return err return err
} }
@ -297,34 +298,34 @@ func (r *Repository) saveIndex(indexes ...*Index) error {
} }
// SaveIndex saves all new indexes in the backend. // SaveIndex saves all new indexes in the backend.
func (r *Repository) SaveIndex() error { func (r *Repository) SaveIndex(ctx context.Context) error {
return r.saveIndex(r.idx.NotFinalIndexes()...) return r.saveIndex(ctx, r.idx.NotFinalIndexes()...)
} }
// SaveFullIndex saves all full indexes in the backend. // SaveFullIndex saves all full indexes in the backend.
func (r *Repository) SaveFullIndex() error { func (r *Repository) SaveFullIndex(ctx context.Context) error {
return r.saveIndex(r.idx.FullIndexes()...) return r.saveIndex(ctx, r.idx.FullIndexes()...)
} }
const loadIndexParallelism = 20 const loadIndexParallelism = 20
// LoadIndex loads all index files from the backend in parallel and stores them // LoadIndex loads all index files from the backend in parallel and stores them
// in the master index. The first error that occurred is returned. // in the master index. The first error that occurred is returned.
func (r *Repository) LoadIndex() error { func (r *Repository) LoadIndex(ctx context.Context) error {
debug.Log("Loading index") debug.Log("Loading index")
errCh := make(chan error, 1) errCh := make(chan error, 1)
indexes := make(chan *Index) indexes := make(chan *Index)
worker := func(id restic.ID, done <-chan struct{}) error { worker := func(ctx context.Context, id restic.ID) error {
idx, err := LoadIndex(r, id) idx, err := LoadIndex(ctx, r, id)
if err != nil { if err != nil {
return err return err
} }
select { select {
case indexes <- idx: case indexes <- idx:
case <-done: case <-ctx.Done():
} }
return nil return nil
@ -332,7 +333,7 @@ func (r *Repository) LoadIndex() error {
go func() { go func() {
defer close(indexes) defer close(indexes)
errCh <- FilesInParallel(r.be, restic.IndexFile, loadIndexParallelism, errCh <- FilesInParallel(ctx, r.be, restic.IndexFile, loadIndexParallelism,
ParallelWorkFuncParseID(worker)) ParallelWorkFuncParseID(worker))
}() }()
@ -348,15 +349,15 @@ func (r *Repository) LoadIndex() error {
} }
// LoadIndex loads the index id from backend and returns it. // LoadIndex loads the index id from backend and returns it.
func LoadIndex(repo restic.Repository, id restic.ID) (*Index, error) { func LoadIndex(ctx context.Context, repo restic.Repository, id restic.ID) (*Index, error) {
idx, err := LoadIndexWithDecoder(repo, id, DecodeIndex) idx, err := LoadIndexWithDecoder(ctx, repo, id, DecodeIndex)
if err == nil { if err == nil {
return idx, nil return idx, nil
} }
if errors.Cause(err) == ErrOldIndexFormat { if errors.Cause(err) == ErrOldIndexFormat {
fmt.Fprintf(os.Stderr, "index %v has old format\n", id.Str()) fmt.Fprintf(os.Stderr, "index %v has old format\n", id.Str())
return LoadIndexWithDecoder(repo, id, DecodeOldIndex) return LoadIndexWithDecoder(ctx, repo, id, DecodeOldIndex)
} }
return nil, err return nil, err
@ -364,8 +365,8 @@ func LoadIndex(repo restic.Repository, id restic.ID) (*Index, error) {
// SearchKey finds a key with the supplied password, afterwards the config is // SearchKey finds a key with the supplied password, afterwards the config is
// read and parsed. It tries at most maxKeys key files in the repo. // read and parsed. It tries at most maxKeys key files in the repo.
func (r *Repository) SearchKey(password string, maxKeys int) error { func (r *Repository) SearchKey(ctx context.Context, password string, maxKeys int) error {
key, err := SearchKey(r, password, maxKeys) key, err := SearchKey(ctx, r, password, maxKeys)
if err != nil { if err != nil {
return err return err
} }
@ -373,14 +374,14 @@ func (r *Repository) SearchKey(password string, maxKeys int) error {
r.key = key.master r.key = key.master
r.packerManager.key = key.master r.packerManager.key = key.master
r.keyName = key.Name() r.keyName = key.Name()
r.cfg, err = restic.LoadConfig(r) r.cfg, err = restic.LoadConfig(ctx, r)
return err return err
} }
// Init creates a new master key with the supplied password, initializes and // Init creates a new master key with the supplied password, initializes and
// saves the repository config. // saves the repository config.
func (r *Repository) Init(password string) error { func (r *Repository) Init(ctx context.Context, password string) error {
has, err := r.be.Test(restic.Handle{Type: restic.ConfigFile}) has, err := r.be.Test(ctx, restic.Handle{Type: restic.ConfigFile})
if err != nil { if err != nil {
return err return err
} }
@ -393,12 +394,12 @@ func (r *Repository) Init(password string) error {
return err return err
} }
return r.init(password, cfg) return r.init(ctx, password, cfg)
} }
// init creates a new master key with the supplied password and uses it to save // init creates a new master key with the supplied password and uses it to save
// the config into the repo. // the config into the repo.
func (r *Repository) init(password string, cfg restic.Config) error { func (r *Repository) init(ctx context.Context, password string, cfg restic.Config) error {
key, err := createMasterKey(r, password) key, err := createMasterKey(r, password)
if err != nil { if err != nil {
return err return err
@ -408,7 +409,7 @@ func (r *Repository) init(password string, cfg restic.Config) error {
r.packerManager.key = key.master r.packerManager.key = key.master
r.keyName = key.Name() r.keyName = key.Name()
r.cfg = cfg r.cfg = cfg
_, err = r.SaveJSONUnpacked(restic.ConfigFile, cfg) _, err = r.SaveJSONUnpacked(ctx, restic.ConfigFile, cfg)
return err return err
} }
@ -443,15 +444,15 @@ func (r *Repository) KeyName() string {
} }
// List returns a channel that yields all IDs of type t in the backend. // List returns a channel that yields all IDs of type t in the backend.
func (r *Repository) List(t restic.FileType, done <-chan struct{}) <-chan restic.ID { func (r *Repository) List(ctx context.Context, t restic.FileType) <-chan restic.ID {
out := make(chan restic.ID) out := make(chan restic.ID)
go func() { go func() {
defer close(out) defer close(out)
for strID := range r.be.List(t, done) { for strID := range r.be.List(ctx, t) {
if id, err := restic.ParseID(strID); err == nil { if id, err := restic.ParseID(strID); err == nil {
select { select {
case out <- id: case out <- id:
case <-done: case <-ctx.Done():
return return
} }
} }
@ -462,10 +463,10 @@ func (r *Repository) List(t restic.FileType, done <-chan struct{}) <-chan restic
// ListPack returns the list of blobs saved in the pack id and the length of // ListPack returns the list of blobs saved in the pack id and the length of
// the file as stored in the backend. // the file as stored in the backend.
func (r *Repository) ListPack(id restic.ID) ([]restic.Blob, int64, error) { func (r *Repository) ListPack(ctx context.Context, id restic.ID) ([]restic.Blob, int64, error) {
h := restic.Handle{Type: restic.DataFile, Name: id.String()} h := restic.Handle{Type: restic.DataFile, Name: id.String()}
blobInfo, err := r.Backend().Stat(h) blobInfo, err := r.Backend().Stat(ctx, h)
if err != nil { if err != nil {
return nil, 0, err return nil, 0, err
} }
@ -480,9 +481,9 @@ func (r *Repository) ListPack(id restic.ID) ([]restic.Blob, int64, error) {
// Delete calls backend.Delete() if implemented, and returns an error // Delete calls backend.Delete() if implemented, and returns an error
// otherwise. // otherwise.
func (r *Repository) Delete() error { func (r *Repository) Delete(ctx context.Context) error {
if b, ok := r.be.(restic.Deleter); ok { if b, ok := r.be.(restic.Deleter); ok {
return b.Delete() return b.Delete(ctx)
} }
return errors.New("Delete() called for backend that does not implement this method") return errors.New("Delete() called for backend that does not implement this method")
@ -496,7 +497,7 @@ func (r *Repository) Close() error {
// LoadBlob loads a blob of type t from the repository to the buffer. buf must // LoadBlob loads a blob of type t from the repository to the buffer. buf must
// be large enough to hold the encrypted blob, since it is used as scratch // be large enough to hold the encrypted blob, since it is used as scratch
// space. // space.
func (r *Repository) LoadBlob(t restic.BlobType, id restic.ID, buf []byte) (int, error) { func (r *Repository) LoadBlob(ctx context.Context, t restic.BlobType, id restic.ID, buf []byte) (int, error) {
debug.Log("load blob %v into buf (len %v, cap %v)", id.Str(), len(buf), cap(buf)) debug.Log("load blob %v into buf (len %v, cap %v)", id.Str(), len(buf), cap(buf))
size, err := r.idx.LookupSize(id, t) size, err := r.idx.LookupSize(id, t)
if err != nil { if err != nil {
@ -507,7 +508,7 @@ func (r *Repository) LoadBlob(t restic.BlobType, id restic.ID, buf []byte) (int,
return 0, errors.Errorf("buffer is too small for data blob (%d < %d)", cap(buf), restic.CiphertextLength(int(size))) return 0, errors.Errorf("buffer is too small for data blob (%d < %d)", cap(buf), restic.CiphertextLength(int(size)))
} }
n, err := r.loadBlob(id, t, buf) n, err := r.loadBlob(ctx, id, t, buf)
if err != nil { if err != nil {
return 0, err return 0, err
} }
@ -520,16 +521,16 @@ func (r *Repository) LoadBlob(t restic.BlobType, id restic.ID, buf []byte) (int,
// SaveBlob saves a blob of type t into the repository. If id is the null id, it // SaveBlob saves a blob of type t into the repository. If id is the null id, it
// will be computed and returned. // will be computed and returned.
func (r *Repository) SaveBlob(t restic.BlobType, buf []byte, id restic.ID) (restic.ID, error) { func (r *Repository) SaveBlob(ctx context.Context, t restic.BlobType, buf []byte, id restic.ID) (restic.ID, error) {
var i *restic.ID var i *restic.ID
if !id.IsNull() { if !id.IsNull() {
i = &id i = &id
} }
return r.SaveAndEncrypt(t, buf, i) return r.SaveAndEncrypt(ctx, t, buf, i)
} }
// LoadTree loads a tree from the repository. // LoadTree loads a tree from the repository.
func (r *Repository) LoadTree(id restic.ID) (*restic.Tree, error) { func (r *Repository) LoadTree(ctx context.Context, id restic.ID) (*restic.Tree, error) {
debug.Log("load tree %v", id.Str()) debug.Log("load tree %v", id.Str())
size, err := r.idx.LookupSize(id, restic.TreeBlob) size, err := r.idx.LookupSize(id, restic.TreeBlob)
@ -540,7 +541,7 @@ func (r *Repository) LoadTree(id restic.ID) (*restic.Tree, error) {
debug.Log("size is %d, create buffer", size) debug.Log("size is %d, create buffer", size)
buf := restic.NewBlobBuffer(int(size)) buf := restic.NewBlobBuffer(int(size))
n, err := r.loadBlob(id, restic.TreeBlob, buf) n, err := r.loadBlob(ctx, id, restic.TreeBlob, buf)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -558,7 +559,7 @@ func (r *Repository) LoadTree(id restic.ID) (*restic.Tree, error) {
// SaveTree stores a tree into the repository and returns the ID. The ID is // SaveTree stores a tree into the repository and returns the ID. The ID is
// checked against the index. The tree is only stored when the index does not // checked against the index. The tree is only stored when the index does not
// contain the ID. // contain the ID.
func (r *Repository) SaveTree(t *restic.Tree) (restic.ID, error) { func (r *Repository) SaveTree(ctx context.Context, t *restic.Tree) (restic.ID, error) {
buf, err := json.Marshal(t) buf, err := json.Marshal(t)
if err != nil { if err != nil {
return restic.ID{}, errors.Wrap(err, "MarshalJSON") return restic.ID{}, errors.Wrap(err, "MarshalJSON")
@ -573,6 +574,6 @@ func (r *Repository) SaveTree(t *restic.Tree) (restic.ID, error) {
return id, nil return id, nil
} }
_, err = r.SaveBlob(restic.TreeBlob, buf, id) _, err = r.SaveBlob(ctx, restic.TreeBlob, buf, id)
return id, err return id, err
} }

View File

@ -1,6 +1,7 @@
package repository package repository
import ( import (
"context"
"os" "os"
"restic" "restic"
"restic/backend/local" "restic/backend/local"
@ -50,7 +51,7 @@ func TestRepositoryWithBackend(t testing.TB, be restic.Backend) (r restic.Reposi
repo := New(be) repo := New(be)
cfg := restic.TestCreateConfig(t, testChunkerPol) cfg := restic.TestCreateConfig(t, testChunkerPol)
err := repo.init(test.TestPassword, cfg) err := repo.init(context.TODO(), test.TestPassword, cfg)
if err != nil { if err != nil {
t.Fatalf("TestRepository(): initialize repo failed: %v", err) t.Fatalf("TestRepository(): initialize repo failed: %v", err)
} }
@ -94,7 +95,7 @@ func TestOpenLocal(t testing.TB, dir string) (r restic.Repository) {
} }
repo := New(be) repo := New(be)
err = repo.SearchKey(test.TestPassword, 10) err = repo.SearchKey(context.TODO(), test.TestPassword, 10)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View File

@ -1,6 +1,7 @@
package restic package restic
import ( import (
"context"
"os" "os"
"path/filepath" "path/filepath"
@ -30,7 +31,7 @@ func NewRestorer(repo Repository, id ID) (*Restorer, error) {
var err error var err error
r.sn, err = LoadSnapshot(repo, id) r.sn, err = LoadSnapshot(context.TODO(), repo, id)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -38,8 +39,8 @@ func NewRestorer(repo Repository, id ID) (*Restorer, error) {
return r, nil return r, nil
} }
func (res *Restorer) restoreTo(dst string, dir string, treeID ID, idx *HardlinkIndex) error { func (res *Restorer) restoreTo(ctx context.Context, dst string, dir string, treeID ID, idx *HardlinkIndex) error {
tree, err := res.repo.LoadTree(treeID) tree, err := res.repo.LoadTree(ctx, treeID)
if err != nil { if err != nil {
return res.Error(dir, nil, err) return res.Error(dir, nil, err)
} }
@ -50,7 +51,7 @@ func (res *Restorer) restoreTo(dst string, dir string, treeID ID, idx *HardlinkI
debug.Log("SelectForRestore returned %v", selectedForRestore) debug.Log("SelectForRestore returned %v", selectedForRestore)
if selectedForRestore { if selectedForRestore {
err := res.restoreNodeTo(node, dir, dst, idx) err := res.restoreNodeTo(ctx, node, dir, dst, idx)
if err != nil { if err != nil {
return err return err
} }
@ -62,7 +63,7 @@ func (res *Restorer) restoreTo(dst string, dir string, treeID ID, idx *HardlinkI
} }
subp := filepath.Join(dir, node.Name) subp := filepath.Join(dir, node.Name)
err = res.restoreTo(dst, subp, *node.Subtree, idx) err = res.restoreTo(ctx, dst, subp, *node.Subtree, idx)
if err != nil { if err != nil {
err = res.Error(subp, node, err) err = res.Error(subp, node, err)
if err != nil { if err != nil {
@ -83,11 +84,11 @@ func (res *Restorer) restoreTo(dst string, dir string, treeID ID, idx *HardlinkI
return nil return nil
} }
func (res *Restorer) restoreNodeTo(node *Node, dir string, dst string, idx *HardlinkIndex) error { func (res *Restorer) restoreNodeTo(ctx context.Context, node *Node, dir string, dst string, idx *HardlinkIndex) error {
debug.Log("node %v, dir %v, dst %v", node.Name, dir, dst) debug.Log("node %v, dir %v, dst %v", node.Name, dir, dst)
dstPath := filepath.Join(dst, dir, node.Name) dstPath := filepath.Join(dst, dir, node.Name)
err := node.CreateAt(dstPath, res.repo, idx) err := node.CreateAt(ctx, dstPath, res.repo, idx)
if err != nil { if err != nil {
debug.Log("node.CreateAt(%s) error %v", dstPath, err) debug.Log("node.CreateAt(%s) error %v", dstPath, err)
} }
@ -99,7 +100,7 @@ func (res *Restorer) restoreNodeTo(node *Node, dir string, dst string, idx *Hard
// Create parent directories and retry // Create parent directories and retry
err = fs.MkdirAll(filepath.Dir(dstPath), 0700) err = fs.MkdirAll(filepath.Dir(dstPath), 0700)
if err == nil || os.IsExist(errors.Cause(err)) { if err == nil || os.IsExist(errors.Cause(err)) {
err = node.CreateAt(dstPath, res.repo, idx) err = node.CreateAt(ctx, dstPath, res.repo, idx)
} }
} }
@ -118,9 +119,9 @@ func (res *Restorer) restoreNodeTo(node *Node, dir string, dst string, idx *Hard
// RestoreTo creates the directories and files in the snapshot below dst. // RestoreTo creates the directories and files in the snapshot below dst.
// Before an item is created, res.Filter is called. // Before an item is created, res.Filter is called.
func (res *Restorer) RestoreTo(dst string) error { func (res *Restorer) RestoreTo(ctx context.Context, dst string) error {
idx := NewHardlinkIndex() idx := NewHardlinkIndex()
return res.restoreTo(dst, string(filepath.Separator), *res.sn.Tree, idx) return res.restoreTo(ctx, dst, string(filepath.Separator), *res.sn.Tree, idx)
} }
// Snapshot returns the snapshot this restorer is configured to use. // Snapshot returns the snapshot this restorer is configured to use.

View File

@ -1,6 +1,7 @@
package restic package restic
import ( import (
"context"
"fmt" "fmt"
"os/user" "os/user"
"path/filepath" "path/filepath"
@ -51,9 +52,9 @@ func NewSnapshot(paths []string, tags []string, hostname string) (*Snapshot, err
} }
// LoadSnapshot loads the snapshot with the id and returns it. // LoadSnapshot loads the snapshot with the id and returns it.
func LoadSnapshot(repo Repository, id ID) (*Snapshot, error) { func LoadSnapshot(ctx context.Context, repo Repository, id ID) (*Snapshot, error) {
sn := &Snapshot{id: &id} sn := &Snapshot{id: &id}
err := repo.LoadJSONUnpacked(SnapshotFile, id, sn) err := repo.LoadJSONUnpacked(ctx, SnapshotFile, id, sn)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -62,12 +63,9 @@ func LoadSnapshot(repo Repository, id ID) (*Snapshot, error) {
} }
// LoadAllSnapshots returns a list of all snapshots in the repo. // LoadAllSnapshots returns a list of all snapshots in the repo.
func LoadAllSnapshots(repo Repository) (snapshots []*Snapshot, err error) { func LoadAllSnapshots(ctx context.Context, repo Repository) (snapshots []*Snapshot, err error) {
done := make(chan struct{}) for id := range repo.List(ctx, SnapshotFile) {
defer close(done) sn, err := LoadSnapshot(ctx, repo, id)
for id := range repo.List(SnapshotFile, done) {
sn, err := LoadSnapshot(repo, id)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -178,15 +176,15 @@ func (sn *Snapshot) SamePaths(paths []string) bool {
var ErrNoSnapshotFound = errors.New("no snapshot found") var ErrNoSnapshotFound = errors.New("no snapshot found")
// FindLatestSnapshot finds latest snapshot with optional target/directory, tags and hostname filters. // FindLatestSnapshot finds latest snapshot with optional target/directory, tags and hostname filters.
func FindLatestSnapshot(repo Repository, targets []string, tags []string, hostname string) (ID, error) { func FindLatestSnapshot(ctx context.Context, repo Repository, targets []string, tags []string, hostname string) (ID, error) {
var ( var (
latest time.Time latest time.Time
latestID ID latestID ID
found bool found bool
) )
for snapshotID := range repo.List(SnapshotFile, make(chan struct{})) { for snapshotID := range repo.List(ctx, SnapshotFile) {
snapshot, err := LoadSnapshot(repo, snapshotID) snapshot, err := LoadSnapshot(ctx, repo, snapshotID)
if err != nil { if err != nil {
return ID{}, errors.Errorf("Error listing snapshot: %v", err) return ID{}, errors.Errorf("Error listing snapshot: %v", err)
} }

View File

@ -1,6 +1,7 @@
package restic package restic
import ( import (
"context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
@ -29,7 +30,7 @@ type fakeFileSystem struct {
// saveFile reads from rd and saves the blobs in the repository. The list of // saveFile reads from rd and saves the blobs in the repository. The list of
// IDs is returned. // IDs is returned.
func (fs *fakeFileSystem) saveFile(rd io.Reader) (blobs IDs) { func (fs *fakeFileSystem) saveFile(ctx context.Context, rd io.Reader) (blobs IDs) {
if fs.buf == nil { if fs.buf == nil {
fs.buf = make([]byte, chunker.MaxSize) fs.buf = make([]byte, chunker.MaxSize)
} }
@ -53,7 +54,7 @@ func (fs *fakeFileSystem) saveFile(rd io.Reader) (blobs IDs) {
id := Hash(chunk.Data) id := Hash(chunk.Data)
if !fs.blobIsKnown(id, DataBlob) { if !fs.blobIsKnown(id, DataBlob) {
_, err := fs.repo.SaveBlob(DataBlob, chunk.Data, id) _, err := fs.repo.SaveBlob(ctx, DataBlob, chunk.Data, id)
if err != nil { if err != nil {
fs.t.Fatalf("error saving chunk: %v", err) fs.t.Fatalf("error saving chunk: %v", err)
} }
@ -103,7 +104,7 @@ func (fs *fakeFileSystem) blobIsKnown(id ID, t BlobType) bool {
} }
// saveTree saves a tree of fake files in the repo and returns the ID. // saveTree saves a tree of fake files in the repo and returns the ID.
func (fs *fakeFileSystem) saveTree(seed int64, depth int) ID { func (fs *fakeFileSystem) saveTree(ctx context.Context, seed int64, depth int) ID {
rnd := rand.NewSource(seed) rnd := rand.NewSource(seed)
numNodes := int(rnd.Int63() % maxNodes) numNodes := int(rnd.Int63() % maxNodes)
@ -113,7 +114,7 @@ func (fs *fakeFileSystem) saveTree(seed int64, depth int) ID {
// randomly select the type of the node, either tree (p = 1/4) or file (p = 3/4). // randomly select the type of the node, either tree (p = 1/4) or file (p = 3/4).
if depth > 1 && rnd.Int63()%4 == 0 { if depth > 1 && rnd.Int63()%4 == 0 {
treeSeed := rnd.Int63() % maxSeed treeSeed := rnd.Int63() % maxSeed
id := fs.saveTree(treeSeed, depth-1) id := fs.saveTree(ctx, treeSeed, depth-1)
node := &Node{ node := &Node{
Name: fmt.Sprintf("dir-%v", treeSeed), Name: fmt.Sprintf("dir-%v", treeSeed),
@ -136,7 +137,7 @@ func (fs *fakeFileSystem) saveTree(seed int64, depth int) ID {
Size: uint64(fileSize), Size: uint64(fileSize),
} }
node.Content = fs.saveFile(fakeFile(fs.t, fileSeed, fileSize)) node.Content = fs.saveFile(ctx, fakeFile(fs.t, fileSeed, fileSize))
tree.Nodes = append(tree.Nodes, node) tree.Nodes = append(tree.Nodes, node)
} }
@ -145,7 +146,7 @@ func (fs *fakeFileSystem) saveTree(seed int64, depth int) ID {
return id return id
} }
_, err := fs.repo.SaveBlob(TreeBlob, buf, id) _, err := fs.repo.SaveBlob(ctx, TreeBlob, buf, id)
if err != nil { if err != nil {
fs.t.Fatal(err) fs.t.Fatal(err)
} }
@ -176,10 +177,10 @@ func TestCreateSnapshot(t testing.TB, repo Repository, at time.Time, depth int,
duplication: duplication, duplication: duplication,
} }
treeID := fs.saveTree(seed, depth) treeID := fs.saveTree(context.TODO(), seed, depth)
snapshot.Tree = &treeID snapshot.Tree = &treeID
id, err := repo.SaveJSONUnpacked(SnapshotFile, snapshot) id, err := repo.SaveJSONUnpacked(context.TODO(), SnapshotFile, snapshot)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -193,7 +194,7 @@ func TestCreateSnapshot(t testing.TB, repo Repository, at time.Time, depth int,
t.Fatal(err) t.Fatal(err)
} }
err = repo.SaveIndex() err = repo.SaveIndex(context.TODO())
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View File

@ -1,6 +1,7 @@
package walk package walk
import ( import (
"context"
"fmt" "fmt"
"os" "os"
"path/filepath" "path/filepath"
@ -34,7 +35,7 @@ func NewTreeWalker(ch chan<- loadTreeJob, out chan<- TreeJob) *TreeWalker {
// Walk starts walking the tree given by id. When the channel done is closed, // Walk starts walking the tree given by id. When the channel done is closed,
// processing stops. // processing stops.
func (tw *TreeWalker) Walk(path string, id restic.ID, done chan struct{}) { func (tw *TreeWalker) Walk(ctx context.Context, path string, id restic.ID) {
debug.Log("starting on tree %v for %v", id.Str(), path) debug.Log("starting on tree %v for %v", id.Str(), path)
defer debug.Log("done walking tree %v for %v", id.Str(), path) defer debug.Log("done walking tree %v for %v", id.Str(), path)
@ -48,22 +49,22 @@ func (tw *TreeWalker) Walk(path string, id restic.ID, done chan struct{}) {
if res.err != nil { if res.err != nil {
select { select {
case tw.out <- TreeJob{Path: path, Error: res.err}: case tw.out <- TreeJob{Path: path, Error: res.err}:
case <-done: case <-ctx.Done():
return return
} }
return return
} }
tw.walk(path, res.tree, done) tw.walk(ctx, path, res.tree)
select { select {
case tw.out <- TreeJob{Path: path, Tree: res.tree}: case tw.out <- TreeJob{Path: path, Tree: res.tree}:
case <-done: case <-ctx.Done():
return return
} }
} }
func (tw *TreeWalker) walk(path string, tree *restic.Tree, done chan struct{}) { func (tw *TreeWalker) walk(ctx context.Context, path string, tree *restic.Tree) {
debug.Log("start on %q", path) debug.Log("start on %q", path)
defer debug.Log("done for %q", path) defer debug.Log("done for %q", path)
@ -94,7 +95,7 @@ func (tw *TreeWalker) walk(path string, tree *restic.Tree, done chan struct{}) {
res := <-results[i] res := <-results[i]
if res.err == nil { if res.err == nil {
tw.walk(p, res.tree, done) tw.walk(ctx, p, res.tree)
} else { } else {
fmt.Fprintf(os.Stderr, "error loading tree: %v\n", res.err) fmt.Fprintf(os.Stderr, "error loading tree: %v\n", res.err)
} }
@ -106,7 +107,7 @@ func (tw *TreeWalker) walk(path string, tree *restic.Tree, done chan struct{}) {
select { select {
case tw.out <- job: case tw.out <- job:
case <-done: case <-ctx.Done():
return return
} }
} }
@ -124,14 +125,14 @@ type loadTreeJob struct {
type treeLoader func(restic.ID) (*restic.Tree, error) type treeLoader func(restic.ID) (*restic.Tree, error)
func loadTreeWorker(wg *sync.WaitGroup, in <-chan loadTreeJob, load treeLoader, done <-chan struct{}) { func loadTreeWorker(ctx context.Context, wg *sync.WaitGroup, in <-chan loadTreeJob, load treeLoader) {
debug.Log("start") debug.Log("start")
defer debug.Log("exit") defer debug.Log("exit")
defer wg.Done() defer wg.Done()
for { for {
select { select {
case <-done: case <-ctx.Done():
debug.Log("done channel closed") debug.Log("done channel closed")
return return
case job, ok := <-in: case job, ok := <-in:
@ -148,7 +149,7 @@ func loadTreeWorker(wg *sync.WaitGroup, in <-chan loadTreeJob, load treeLoader,
select { select {
case job.res <- loadTreeResult{tree, err}: case job.res <- loadTreeResult{tree, err}:
debug.Log("job result sent") debug.Log("job result sent")
case <-done: case <-ctx.Done():
debug.Log("done channel closed before result could be sent") debug.Log("done channel closed before result could be sent")
return return
} }
@ -158,7 +159,7 @@ func loadTreeWorker(wg *sync.WaitGroup, in <-chan loadTreeJob, load treeLoader,
// TreeLoader loads tree objects. // TreeLoader loads tree objects.
type TreeLoader interface { type TreeLoader interface {
LoadTree(restic.ID) (*restic.Tree, error) LoadTree(context.Context, restic.ID) (*restic.Tree, error)
} }
const loadTreeWorkers = 10 const loadTreeWorkers = 10
@ -166,11 +167,11 @@ const loadTreeWorkers = 10
// Tree walks the tree specified by id recursively and sends a job for each // Tree walks the tree specified by id recursively and sends a job for each
// file and directory it finds. When the channel done is closed, processing // file and directory it finds. When the channel done is closed, processing
// stops. // stops.
func Tree(repo TreeLoader, id restic.ID, done chan struct{}, jobCh chan<- TreeJob) { func Tree(ctx context.Context, repo TreeLoader, id restic.ID, jobCh chan<- TreeJob) {
debug.Log("start on %v, start workers", id.Str()) debug.Log("start on %v, start workers", id.Str())
load := func(id restic.ID) (*restic.Tree, error) { load := func(id restic.ID) (*restic.Tree, error) {
tree, err := repo.LoadTree(id) tree, err := repo.LoadTree(ctx, id)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -182,11 +183,11 @@ func Tree(repo TreeLoader, id restic.ID, done chan struct{}, jobCh chan<- TreeJo
var wg sync.WaitGroup var wg sync.WaitGroup
for i := 0; i < loadTreeWorkers; i++ { for i := 0; i < loadTreeWorkers; i++ {
wg.Add(1) wg.Add(1)
go loadTreeWorker(&wg, ch, load, done) go loadTreeWorker(ctx, &wg, ch, load)
} }
tw := NewTreeWalker(ch, jobCh) tw := NewTreeWalker(ch, jobCh)
tw.Walk("", id, done) tw.Walk(ctx, "", id)
close(jobCh) close(jobCh)
close(ch) close(ch)

View File

@ -1,5 +1,7 @@
package worker package worker
import "context"
// Job is one unit of work. It is given to a Func, and the returned result and // Job is one unit of work. It is given to a Func, and the returned result and
// error are stored in Result and Error. // error are stored in Result and Error.
type Job struct { type Job struct {
@ -9,12 +11,12 @@ type Job struct {
} }
// Func does the actual work within a Pool. // Func does the actual work within a Pool.
type Func func(job Job, done <-chan struct{}) (result interface{}, err error) type Func func(ctx context.Context, job Job) (result interface{}, err error)
// Pool implements a worker pool. // Pool implements a worker pool.
type Pool struct { type Pool struct {
f Func f Func
done chan struct{} ctx context.Context
jobCh <-chan Job jobCh <-chan Job
resCh chan<- Job resCh chan<- Job
@ -25,10 +27,9 @@ type Pool struct {
// New returns a new worker pool with n goroutines, each running the function // New returns a new worker pool with n goroutines, each running the function
// f. The workers are started immediately. // f. The workers are started immediately.
func New(n int, f Func, jobChan <-chan Job, resultChan chan<- Job) *Pool { func New(ctx context.Context, n int, f Func, jobChan <-chan Job, resultChan chan<- Job) *Pool {
p := &Pool{ p := &Pool{
f: f, f: f,
done: make(chan struct{}),
workersExit: make(chan struct{}), workersExit: make(chan struct{}),
allWorkersDone: make(chan struct{}), allWorkersDone: make(chan struct{}),
numWorkers: n, numWorkers: n,
@ -75,7 +76,7 @@ func (p *Pool) runWorker(numWorker int) {
for { for {
select { select {
case <-p.done: case <-p.ctx.Done():
return return
case job, ok = <-inCh: case job, ok = <-inCh:
@ -83,7 +84,7 @@ func (p *Pool) runWorker(numWorker int) {
return return
} }
job.Result, job.Error = p.f(job, p.done) job.Result, job.Error = p.f(p.ctx, job)
inCh = nil inCh = nil
outCh = p.resCh outCh = p.resCh

View File

@ -1,6 +1,7 @@
package worker_test package worker_test
import ( import (
"context"
"testing" "testing"
"restic/errors" "restic/errors"
@ -12,7 +13,7 @@ const concurrency = 10
var errTooLarge = errors.New("too large") var errTooLarge = errors.New("too large")
func square(job worker.Job, done <-chan struct{}) (interface{}, error) { func square(ctx context.Context, job worker.Job) (interface{}, error) {
n := job.Data.(int) n := job.Data.(int)
if n > 2000 { if n > 2000 {
return nil, errTooLarge return nil, errTooLarge
@ -20,15 +21,15 @@ func square(job worker.Job, done <-chan struct{}) (interface{}, error) {
return n * n, nil return n * n, nil
} }
func newBufferedPool(bufsize int, n int, f worker.Func) (chan worker.Job, chan worker.Job, *worker.Pool) { func newBufferedPool(ctx context.Context, bufsize int, n int, f worker.Func) (chan worker.Job, chan worker.Job, *worker.Pool) {
inCh := make(chan worker.Job, bufsize) inCh := make(chan worker.Job, bufsize)
outCh := make(chan worker.Job, bufsize) outCh := make(chan worker.Job, bufsize)
return inCh, outCh, worker.New(n, f, inCh, outCh) return inCh, outCh, worker.New(ctx, n, f, inCh, outCh)
} }
func TestPool(t *testing.T) { func TestPool(t *testing.T) {
inCh, outCh, p := newBufferedPool(200, concurrency, square) inCh, outCh, p := newBufferedPool(context.TODO(), 200, concurrency, square)
for i := 0; i < 150; i++ { for i := 0; i < 150; i++ {
inCh <- worker.Job{Data: i} inCh <- worker.Job{Data: i}
@ -53,7 +54,7 @@ func TestPool(t *testing.T) {
} }
func TestPoolErrors(t *testing.T) { func TestPoolErrors(t *testing.T) {
inCh, outCh, p := newBufferedPool(200, concurrency, square) inCh, outCh, p := newBufferedPool(context.TODO(), 200, concurrency, square)
for i := 0; i < 150; i++ { for i := 0; i < 150; i++ {
inCh <- worker.Job{Data: i + 1900} inCh <- worker.Job{Data: i + 1900}