2
2
mirror of https://github.com/octoleo/restic.git synced 2024-11-11 15:51:02 +00:00

Merge pull request #3106 from MichaelEischer/parallel-tree-walk

Parallelize tree walk in prune and copy and add progress bar to check
This commit is contained in:
Alexander Neumann 2021-01-28 12:06:42 +01:00 committed by GitHub
commit 72eec8c0c4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 437 additions and 322 deletions

View File

@ -0,0 +1,10 @@
Enhancement: Parallelize scan of snapshot content in copy and prune
The copy and the prune commands used to traverse the directories of
snapshots one by one to find used data. This snapshot traversal is
now parallized which can speed up this step several times.
In addition the check command now reports how many snapshots have
already been processed.
https://github.com/restic/restic/pull/3106

View File

@ -240,7 +240,11 @@ func runCheck(opts CheckOptions, gopts GlobalOptions, args []string) error {
Verbosef("check snapshots, trees and blobs\n") Verbosef("check snapshots, trees and blobs\n")
errChan = make(chan error) errChan = make(chan error)
go chkr.Structure(gopts.ctx, errChan) go func() {
bar := newProgressMax(!gopts.Quiet, 0, "snapshots")
defer bar.Done()
chkr.Structure(gopts.ctx, bar, errChan)
}()
for err := range errChan { for err := range errChan {
errorsFound = true errorsFound = true

View File

@ -6,6 +6,7 @@ import (
"github.com/restic/restic/internal/debug" "github.com/restic/restic/internal/debug"
"github.com/restic/restic/internal/restic" "github.com/restic/restic/internal/restic"
"golang.org/x/sync/errgroup"
"github.com/spf13/cobra" "github.com/spf13/cobra"
) )
@ -103,12 +104,8 @@ func runCopy(opts CopyOptions, gopts GlobalOptions, args []string) error {
dstSnapshotByOriginal[*sn.ID()] = append(dstSnapshotByOriginal[*sn.ID()], sn) dstSnapshotByOriginal[*sn.ID()] = append(dstSnapshotByOriginal[*sn.ID()], sn)
} }
cloner := &treeCloner{ // remember already processed trees across all snapshots
srcRepo: srcRepo, visitedTrees := restic.NewIDSet()
dstRepo: dstRepo,
visitedTrees: restic.NewIDSet(),
buf: nil,
}
for sn := range FindFilteredSnapshots(ctx, srcRepo, opts.Hosts, opts.Tags, opts.Paths, args) { for sn := range FindFilteredSnapshots(ctx, srcRepo, opts.Hosts, opts.Tags, opts.Paths, args) {
Verbosef("\nsnapshot %s of %v at %s)\n", sn.ID().Str(), sn.Paths, sn.Time) Verbosef("\nsnapshot %s of %v at %s)\n", sn.ID().Str(), sn.Paths, sn.Time)
@ -133,7 +130,7 @@ func runCopy(opts CopyOptions, gopts GlobalOptions, args []string) error {
} }
Verbosef(" copy started, this may take a while...\n") Verbosef(" copy started, this may take a while...\n")
if err := cloner.copyTree(ctx, *sn.Tree); err != nil { if err := copyTree(ctx, srcRepo, dstRepo, visitedTrees, *sn.Tree); err != nil {
return err return err
} }
debug.Log("tree copied") debug.Log("tree copied")
@ -177,64 +174,64 @@ func similarSnapshots(sna *restic.Snapshot, snb *restic.Snapshot) bool {
return true return true
} }
type treeCloner struct { func copyTree(ctx context.Context, srcRepo restic.Repository, dstRepo restic.Repository,
srcRepo restic.Repository visitedTrees restic.IDSet, rootTreeID restic.ID) error {
dstRepo restic.Repository
visitedTrees restic.IDSet
buf []byte
}
func (t *treeCloner) copyTree(ctx context.Context, treeID restic.ID) error { wg, ctx := errgroup.WithContext(ctx)
// We have already processed this tree
if t.visitedTrees.Has(treeID) {
return nil
}
tree, err := t.srcRepo.LoadTree(ctx, treeID) treeStream := restic.StreamTrees(ctx, wg, srcRepo, restic.IDs{rootTreeID}, func(treeID restic.ID) bool {
if err != nil { visited := visitedTrees.Has(treeID)
return fmt.Errorf("LoadTree(%v) returned error %v", treeID.Str(), err) visitedTrees.Insert(treeID)
return visited
}, nil)
wg.Go(func() error {
// reused buffer
var buf []byte
for tree := range treeStream {
if tree.Error != nil {
return fmt.Errorf("LoadTree(%v) returned error %v", tree.ID.Str(), tree.Error)
} }
t.visitedTrees.Insert(treeID)
// Do we already have this tree blob? // Do we already have this tree blob?
if !t.dstRepo.Index().Has(restic.BlobHandle{ID: treeID, Type: restic.TreeBlob}) { if !dstRepo.Index().Has(restic.BlobHandle{ID: tree.ID, Type: restic.TreeBlob}) {
newTreeID, err := t.dstRepo.SaveTree(ctx, tree) newTreeID, err := dstRepo.SaveTree(ctx, tree.Tree)
if err != nil { if err != nil {
return fmt.Errorf("SaveTree(%v) returned error %v", treeID.Str(), err) return fmt.Errorf("SaveTree(%v) returned error %v", tree.ID.Str(), err)
} }
// Assurance only. // Assurance only.
if newTreeID != treeID { if newTreeID != tree.ID {
return fmt.Errorf("SaveTree(%v) returned unexpected id %s", treeID.Str(), newTreeID.Str()) return fmt.Errorf("SaveTree(%v) returned unexpected id %s", tree.ID.Str(), newTreeID.Str())
} }
} }
// TODO: parellize this stuff, likely only needed inside a tree. // TODO: parallelize blob down/upload
for _, entry := range tree.Nodes { for _, entry := range tree.Nodes {
// If it is a directory, recurse // Recursion into directories is handled by StreamTrees
if entry.Type == "dir" && entry.Subtree != nil {
if err := t.copyTree(ctx, *entry.Subtree); err != nil {
return err
}
}
// Copy the blobs for this file. // Copy the blobs for this file.
for _, blobID := range entry.Content { for _, blobID := range entry.Content {
// Do we already have this data blob? // Do we already have this data blob?
if t.dstRepo.Index().Has(restic.BlobHandle{ID: blobID, Type: restic.DataBlob}) { if dstRepo.Index().Has(restic.BlobHandle{ID: blobID, Type: restic.DataBlob}) {
continue continue
} }
debug.Log("Copying blob %s\n", blobID.Str()) debug.Log("Copying blob %s\n", blobID.Str())
t.buf, err = t.srcRepo.LoadBlob(ctx, restic.DataBlob, blobID, t.buf) var err error
buf, err = srcRepo.LoadBlob(ctx, restic.DataBlob, blobID, buf)
if err != nil { if err != nil {
return fmt.Errorf("LoadBlob(%v) returned error %v", blobID, err) return fmt.Errorf("LoadBlob(%v) returned error %v", blobID, err)
} }
_, _, err = t.dstRepo.SaveBlob(ctx, restic.DataBlob, t.buf, blobID, false) _, _, err = dstRepo.SaveBlob(ctx, restic.DataBlob, buf, blobID, false)
if err != nil { if err != nil {
return fmt.Errorf("SaveBlob(%v) returned error %v", blobID, err) return fmt.Errorf("SaveBlob(%v) returned error %v", blobID, err)
} }
} }
} }
return nil }
return nil
})
return wg.Wait()
} }

View File

@ -574,10 +574,8 @@ func getUsedBlobs(gopts GlobalOptions, repo restic.Repository, ignoreSnapshots r
bar := newProgressMax(!gopts.Quiet, uint64(len(snapshotTrees)), "snapshots") bar := newProgressMax(!gopts.Quiet, uint64(len(snapshotTrees)), "snapshots")
defer bar.Done() defer bar.Done()
for _, tree := range snapshotTrees {
debug.Log("process tree %v", tree)
err = restic.FindUsedBlobs(ctx, repo, tree, usedBlobs) err = restic.FindUsedBlobs(ctx, repo, snapshotTrees, usedBlobs, bar)
if err != nil { if err != nil {
if repo.Backend().IsNotExist(err) { if repo.Backend().IsNotExist(err) {
return nil, errors.Fatal("unable to load a tree from the repo: " + err.Error()) return nil, errors.Fatal("unable to load a tree from the repo: " + err.Error())
@ -585,9 +583,5 @@ func getUsedBlobs(gopts GlobalOptions, repo restic.Repository, ignoreSnapshots r
return nil, err return nil, err
} }
debug.Log("processed tree %v", tree)
bar.Add(1)
}
return usedBlobs, nil return usedBlobs, nil
} }

View File

@ -166,7 +166,7 @@ func statsWalkSnapshot(ctx context.Context, snapshot *restic.Snapshot, repo rest
if statsOptions.countMode == countModeRawData { if statsOptions.countMode == countModeRawData {
// count just the sizes of unique blobs; we don't need to walk the tree // count just the sizes of unique blobs; we don't need to walk the tree
// ourselves in this case, since a nifty function does it for us // ourselves in this case, since a nifty function does it for us
return restic.FindUsedBlobs(ctx, repo, *snapshot.Tree, stats.blobs) return restic.FindUsedBlobs(ctx, repo, restic.IDs{*snapshot.Tree}, stats.blobs, nil)
} }
err := walker.Walk(ctx, repo, *snapshot.Tree, restic.NewIDSet(), statsWalkTree(repo, stats)) err := walker.Walk(ctx, repo, *snapshot.Tree, restic.NewIDSet(), statsWalkTree(repo, stats))

View File

@ -33,11 +33,14 @@ func newProgressMax(show bool, max uint64, description string) *progress.Counter
} }
interval := calculateProgressInterval() interval := calculateProgressInterval()
return progress.New(interval, func(v uint64, d time.Duration, final bool) { return progress.New(interval, max, func(v uint64, max uint64, d time.Duration, final bool) {
status := fmt.Sprintf("[%s] %s %d / %d %s", var status string
formatDuration(d), if max == 0 {
formatPercent(v, max), status = fmt.Sprintf("[%s] %d %s", formatDuration(d), v, description)
v, max, description) } else {
status = fmt.Sprintf("[%s] %s %d / %d %s",
formatDuration(d), formatPercent(v, max), v, max, description)
}
if w := stdoutTerminalWidth(); w > 0 { if w := stdoutTerminalWidth(); w > 0 {
status = shortenStatus(w, status) status = shortenStatus(w, status)

View File

@ -308,200 +308,27 @@ func (e TreeError) Error() string {
return fmt.Sprintf("tree %v: %v", e.ID.Str(), e.Errors) return fmt.Sprintf("tree %v: %v", e.ID.Str(), e.Errors)
} }
type treeJob struct {
restic.ID
error
*restic.Tree
}
// loadTreeWorker loads trees from repo and sends them to out.
func loadTreeWorker(ctx context.Context, repo restic.Repository,
in <-chan restic.ID, out chan<- treeJob,
wg *sync.WaitGroup) {
defer func() {
debug.Log("exiting")
wg.Done()
}()
var (
inCh = in
outCh = out
job treeJob
)
outCh = nil
for {
select {
case <-ctx.Done():
return
case treeID, ok := <-inCh:
if !ok {
return
}
debug.Log("load tree %v", treeID)
tree, err := repo.LoadTree(ctx, treeID)
debug.Log("load tree %v (%v) returned err: %v", tree, treeID, err)
job = treeJob{ID: treeID, error: err, Tree: tree}
outCh = out
inCh = nil
case outCh <- job:
debug.Log("sent tree %v", job.ID)
outCh = nil
inCh = in
}
}
}
// checkTreeWorker checks the trees received and sends out errors to errChan. // checkTreeWorker checks the trees received and sends out errors to errChan.
func (c *Checker) checkTreeWorker(ctx context.Context, in <-chan treeJob, out chan<- error, wg *sync.WaitGroup) { func (c *Checker) checkTreeWorker(ctx context.Context, trees <-chan restic.TreeItem, out chan<- error) {
defer func() { for job := range trees {
debug.Log("exiting") debug.Log("check tree %v (tree %v, err %v)", job.ID, job.Tree, job.Error)
wg.Done()
}()
var (
inCh = in
outCh = out
treeError TreeError
)
outCh = nil
for {
select {
case <-ctx.Done():
debug.Log("done channel closed, exiting")
return
case job, ok := <-inCh:
if !ok {
debug.Log("input channel closed, exiting")
return
}
debug.Log("check tree %v (tree %v, err %v)", job.ID, job.Tree, job.error)
var errs []error var errs []error
if job.error != nil { if job.Error != nil {
errs = append(errs, job.error) errs = append(errs, job.Error)
} else { } else {
errs = c.checkTree(job.ID, job.Tree) errs = c.checkTree(job.ID, job.Tree)
} }
if len(errs) > 0 { if len(errs) == 0 {
debug.Log("checked tree %v: %v errors", job.ID, len(errs))
treeError = TreeError{ID: job.ID, Errors: errs}
outCh = out
inCh = nil
}
case outCh <- treeError:
debug.Log("tree %v: sent %d errors", treeError.ID, len(treeError.Errors))
outCh = nil
inCh = in
}
}
}
func (c *Checker) filterTrees(ctx context.Context, backlog restic.IDs, loaderChan chan<- restic.ID, in <-chan treeJob, out chan<- treeJob) {
defer func() {
debug.Log("closing output channels")
close(loaderChan)
close(out)
}()
var (
inCh = in
outCh = out
loadCh = loaderChan
job treeJob
nextTreeID restic.ID
outstandingLoadTreeJobs = 0
)
outCh = nil
loadCh = nil
for {
if loadCh == nil && len(backlog) > 0 {
// process last added ids first, that is traverse the tree in depth-first order
ln := len(backlog) - 1
nextTreeID, backlog = backlog[ln], backlog[:ln]
// use a separate flag for processed trees to ensure that check still processes trees
// even when a file references a tree blob
c.blobRefs.Lock()
h := restic.BlobHandle{ID: nextTreeID, Type: restic.TreeBlob}
blobReferenced := c.blobRefs.M.Has(h)
// noop if already referenced
c.blobRefs.M.Insert(h)
c.blobRefs.Unlock()
if blobReferenced {
continue continue
} }
treeError := TreeError{ID: job.ID, Errors: errs}
loadCh = loaderChan
}
if loadCh == nil && outCh == nil && outstandingLoadTreeJobs == 0 {
debug.Log("backlog is empty, all channels nil, exiting")
return
}
select { select {
case <-ctx.Done(): case <-ctx.Done():
return return
case out <- treeError:
case loadCh <- nextTreeID: debug.Log("tree %v: sent %d errors", treeError.ID, len(treeError.Errors))
outstandingLoadTreeJobs++
loadCh = nil
case j, ok := <-inCh:
if !ok {
debug.Log("input channel closed")
inCh = nil
in = nil
continue
}
outstandingLoadTreeJobs--
debug.Log("input job tree %v", j.ID)
if j.error != nil {
debug.Log("received job with error: %v (tree %v, ID %v)", j.error, j.Tree, j.ID)
} else if j.Tree == nil {
debug.Log("received job with nil tree pointer: %v (ID %v)", j.error, j.ID)
// send a new job with the new error instead of the old one
j = treeJob{ID: j.ID, error: errors.New("tree is nil and error is nil")}
} else {
subtrees := j.Tree.Subtrees()
debug.Log("subtrees for tree %v: %v", j.ID, subtrees)
// iterate backwards over subtree to compensate backwards traversal order of nextTreeID selection
for i := len(subtrees) - 1; i >= 0; i-- {
id := subtrees[i]
if id.IsNull() {
// We do not need to raise this error here, it is
// checked when the tree is checked. Just make sure
// that we do not add any null IDs to the backlog.
debug.Log("tree %v has nil subtree", j.ID)
continue
}
backlog = append(backlog, id)
}
}
job = j
outCh = out
inCh = nil
case outCh <- job:
debug.Log("tree sent to check: %v", job.ID)
outCh = nil
inCh = in
} }
} }
} }
@ -527,10 +354,9 @@ func loadSnapshotTreeIDs(ctx context.Context, repo restic.Repository) (ids resti
// Structure checks that for all snapshots all referenced data blobs and // Structure checks that for all snapshots all referenced data blobs and
// subtrees are available in the index. errChan is closed after all trees have // subtrees are available in the index. errChan is closed after all trees have
// been traversed. // been traversed.
func (c *Checker) Structure(ctx context.Context, errChan chan<- error) { func (c *Checker) Structure(ctx context.Context, p *progress.Counter, errChan chan<- error) {
defer close(errChan)
trees, errs := loadSnapshotTreeIDs(ctx, c.repo) trees, errs := loadSnapshotTreeIDs(ctx, c.repo)
p.SetMax(uint64(len(trees)))
debug.Log("need to check %d trees from snapshots, %d errs returned", len(trees), len(errs)) debug.Log("need to check %d trees from snapshots, %d errs returned", len(trees), len(errs))
for _, err := range errs { for _, err := range errs {
@ -541,19 +367,26 @@ func (c *Checker) Structure(ctx context.Context, errChan chan<- error) {
} }
} }
treeIDChan := make(chan restic.ID) wg, ctx := errgroup.WithContext(ctx)
treeJobChan1 := make(chan treeJob) treeStream := restic.StreamTrees(ctx, wg, c.repo, trees, func(treeID restic.ID) bool {
treeJobChan2 := make(chan treeJob) // blobRefs may be accessed in parallel by checkTree
c.blobRefs.Lock()
h := restic.BlobHandle{ID: treeID, Type: restic.TreeBlob}
blobReferenced := c.blobRefs.M.Has(h)
// noop if already referenced
c.blobRefs.M.Insert(h)
c.blobRefs.Unlock()
return blobReferenced
}, p)
var wg sync.WaitGroup defer close(errChan)
for i := 0; i < defaultParallelism; i++ { for i := 0; i < defaultParallelism; i++ {
wg.Add(2) wg.Go(func() error {
go loadTreeWorker(ctx, c.repo, treeIDChan, treeJobChan1, &wg) c.checkTreeWorker(ctx, treeStream, errChan)
go c.checkTreeWorker(ctx, treeJobChan2, errChan, &wg) return nil
})
} }
c.filterTrees(ctx, trees, treeIDChan, treeJobChan1, treeJobChan2)
wg.Wait() wg.Wait()
} }

View File

@ -43,7 +43,9 @@ func checkPacks(chkr *checker.Checker) []error {
} }
func checkStruct(chkr *checker.Checker) []error { func checkStruct(chkr *checker.Checker) []error {
return collectErrors(context.TODO(), chkr.Structure) return collectErrors(context.TODO(), func(ctx context.Context, errChan chan<- error) {
chkr.Structure(ctx, nil, errChan)
})
} }
func checkData(chkr *checker.Checker) []error { func checkData(chkr *checker.Checker) []error {

View File

@ -30,7 +30,7 @@ func TestCheckRepo(t testing.TB, repo restic.Repository) {
// structure // structure
errChan = make(chan error) errChan = make(chan error)
go chkr.Structure(context.TODO(), errChan) go chkr.Structure(context.TODO(), nil, errChan)
for err := range errChan { for err := range errChan {
t.Error(err) t.Error(err)

View File

@ -368,7 +368,7 @@ func TestIndexSave(t *testing.T) {
defer cancel() defer cancel()
errCh := make(chan error) errCh := make(chan error)
go checker.Structure(ctx, errCh) go checker.Structure(ctx, nil, errCh)
i := 0 i := 0
for err := range errCh { for err := range errCh {
t.Errorf("checker returned error: %v", err) t.Errorf("checker returned error: %v", err)

View File

@ -1,6 +1,12 @@
package restic package restic
import "context" import (
"context"
"sync"
"github.com/restic/restic/internal/ui/progress"
"golang.org/x/sync/errgroup"
)
// TreeLoader loads a tree from a repository. // TreeLoader loads a tree from a repository.
type TreeLoader interface { type TreeLoader interface {
@ -9,31 +15,39 @@ type TreeLoader interface {
// FindUsedBlobs traverses the tree ID and adds all seen blobs (trees and data // FindUsedBlobs traverses the tree ID and adds all seen blobs (trees and data
// blobs) to the set blobs. Already seen tree blobs will not be visited again. // blobs) to the set blobs. Already seen tree blobs will not be visited again.
func FindUsedBlobs(ctx context.Context, repo TreeLoader, treeID ID, blobs BlobSet) error { func FindUsedBlobs(ctx context.Context, repo TreeLoader, treeIDs IDs, blobs BlobSet, p *progress.Counter) error {
var lock sync.Mutex
wg, ctx := errgroup.WithContext(ctx)
treeStream := StreamTrees(ctx, wg, repo, treeIDs, func(treeID ID) bool {
// locking is necessary the goroutine below concurrently adds data blobs
lock.Lock()
h := BlobHandle{ID: treeID, Type: TreeBlob} h := BlobHandle{ID: treeID, Type: TreeBlob}
if blobs.Has(h) { blobReferenced := blobs.Has(h)
return nil // noop if already referenced
}
blobs.Insert(h) blobs.Insert(h)
lock.Unlock()
return blobReferenced
}, p)
tree, err := repo.LoadTree(ctx, treeID) wg.Go(func() error {
if err != nil { for tree := range treeStream {
return err if tree.Error != nil {
return tree.Error
} }
lock.Lock()
for _, node := range tree.Nodes { for _, node := range tree.Nodes {
switch node.Type { switch node.Type {
case "file": case "file":
for _, blob := range node.Content { for _, blob := range node.Content {
blobs.Insert(BlobHandle{ID: blob, Type: DataBlob}) blobs.Insert(BlobHandle{ID: blob, Type: DataBlob})
} }
case "dir":
err := FindUsedBlobs(ctx, repo, *node.Subtree, blobs)
if err != nil {
return err
} }
} }
lock.Unlock()
} }
return nil return nil
})
return wg.Wait()
} }

View File

@ -15,6 +15,8 @@ import (
"github.com/restic/restic/internal/errors" "github.com/restic/restic/internal/errors"
"github.com/restic/restic/internal/repository" "github.com/restic/restic/internal/repository"
"github.com/restic/restic/internal/restic" "github.com/restic/restic/internal/restic"
"github.com/restic/restic/internal/test"
"github.com/restic/restic/internal/ui/progress"
) )
func loadIDSet(t testing.TB, filename string) restic.BlobSet { func loadIDSet(t testing.TB, filename string) restic.BlobSet {
@ -92,9 +94,12 @@ func TestFindUsedBlobs(t *testing.T) {
snapshots = append(snapshots, sn) snapshots = append(snapshots, sn)
} }
p := progress.New(time.Second, findTestSnapshots, func(value uint64, total uint64, runtime time.Duration, final bool) {})
defer p.Done()
for i, sn := range snapshots { for i, sn := range snapshots {
usedBlobs := restic.NewBlobSet() usedBlobs := restic.NewBlobSet()
err := restic.FindUsedBlobs(context.TODO(), repo, *sn.Tree, usedBlobs) err := restic.FindUsedBlobs(context.TODO(), repo, restic.IDs{*sn.Tree}, usedBlobs, p)
if err != nil { if err != nil {
t.Errorf("FindUsedBlobs returned error: %v", err) t.Errorf("FindUsedBlobs returned error: %v", err)
continue continue
@ -105,6 +110,8 @@ func TestFindUsedBlobs(t *testing.T) {
continue continue
} }
test.Equals(t, p.Get(), uint64(i+1))
goldenFilename := filepath.Join("testdata", fmt.Sprintf("used_blobs_snapshot%d", i)) goldenFilename := filepath.Join("testdata", fmt.Sprintf("used_blobs_snapshot%d", i))
want := loadIDSet(t, goldenFilename) want := loadIDSet(t, goldenFilename)
@ -119,6 +126,40 @@ func TestFindUsedBlobs(t *testing.T) {
} }
} }
func TestMultiFindUsedBlobs(t *testing.T) {
repo, cleanup := repository.TestRepository(t)
defer cleanup()
var snapshotTrees restic.IDs
for i := 0; i < findTestSnapshots; i++ {
sn := restic.TestCreateSnapshot(t, repo, findTestTime.Add(time.Duration(i)*time.Second), findTestDepth, 0)
t.Logf("snapshot %v saved, tree %v", sn.ID().Str(), sn.Tree.Str())
snapshotTrees = append(snapshotTrees, *sn.Tree)
}
want := restic.NewBlobSet()
for i := range snapshotTrees {
goldenFilename := filepath.Join("testdata", fmt.Sprintf("used_blobs_snapshot%d", i))
want.Merge(loadIDSet(t, goldenFilename))
}
p := progress.New(time.Second, findTestSnapshots, func(value uint64, total uint64, runtime time.Duration, final bool) {})
defer p.Done()
// run twice to check progress bar handling of duplicate tree roots
usedBlobs := restic.NewBlobSet()
for i := 1; i < 3; i++ {
err := restic.FindUsedBlobs(context.TODO(), repo, snapshotTrees, usedBlobs, p)
test.OK(t, err)
test.Equals(t, p.Get(), uint64(i*len(snapshotTrees)))
if !want.Equals(usedBlobs) {
t.Errorf("wrong list of blobs returned:\n missing blobs: %v\n extra blobs: %v",
want.Sub(usedBlobs), usedBlobs.Sub(want))
}
}
}
type ForbiddenRepo struct{} type ForbiddenRepo struct{}
func (r ForbiddenRepo) LoadTree(ctx context.Context, id restic.ID) (*restic.Tree, error) { func (r ForbiddenRepo) LoadTree(ctx context.Context, id restic.ID) (*restic.Tree, error) {
@ -133,12 +174,12 @@ func TestFindUsedBlobsSkipsSeenBlobs(t *testing.T) {
t.Logf("snapshot %v saved, tree %v", snapshot.ID().Str(), snapshot.Tree.Str()) t.Logf("snapshot %v saved, tree %v", snapshot.ID().Str(), snapshot.Tree.Str())
usedBlobs := restic.NewBlobSet() usedBlobs := restic.NewBlobSet()
err := restic.FindUsedBlobs(context.TODO(), repo, *snapshot.Tree, usedBlobs) err := restic.FindUsedBlobs(context.TODO(), repo, restic.IDs{*snapshot.Tree}, usedBlobs, nil)
if err != nil { if err != nil {
t.Fatalf("FindUsedBlobs returned error: %v", err) t.Fatalf("FindUsedBlobs returned error: %v", err)
} }
err = restic.FindUsedBlobs(context.TODO(), ForbiddenRepo{}, *snapshot.Tree, usedBlobs) err = restic.FindUsedBlobs(context.TODO(), ForbiddenRepo{}, restic.IDs{*snapshot.Tree}, usedBlobs, nil)
if err != nil { if err != nil {
t.Fatalf("FindUsedBlobs returned error: %v", err) t.Fatalf("FindUsedBlobs returned error: %v", err)
} }
@ -154,7 +195,7 @@ func BenchmarkFindUsedBlobs(b *testing.B) {
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
blobs := restic.NewBlobSet() blobs := restic.NewBlobSet()
err := restic.FindUsedBlobs(context.TODO(), repo, *sn.Tree, blobs) err := restic.FindUsedBlobs(context.TODO(), repo, restic.IDs{*sn.Tree}, blobs, nil)
if err != nil { if err != nil {
b.Error(err) b.Error(err)
} }

View File

@ -0,0 +1,183 @@
package restic
import (
"context"
"errors"
"sync"
"github.com/restic/restic/internal/debug"
"github.com/restic/restic/internal/ui/progress"
"golang.org/x/sync/errgroup"
)
const streamTreeParallelism = 5
// TreeItem is used to return either an error or the tree for a tree id
type TreeItem struct {
ID
Error error
*Tree
}
type trackedTreeItem struct {
TreeItem
rootIdx int
}
type trackedID struct {
ID
rootIdx int
}
// loadTreeWorker loads trees from repo and sends them to out.
func loadTreeWorker(ctx context.Context, repo TreeLoader,
in <-chan trackedID, out chan<- trackedTreeItem) {
for treeID := range in {
tree, err := repo.LoadTree(ctx, treeID.ID)
debug.Log("load tree %v (%v) returned err: %v", tree, treeID, err)
job := trackedTreeItem{TreeItem: TreeItem{ID: treeID.ID, Error: err, Tree: tree}, rootIdx: treeID.rootIdx}
select {
case <-ctx.Done():
return
case out <- job:
}
}
}
func filterTrees(ctx context.Context, trees IDs, loaderChan chan<- trackedID,
in <-chan trackedTreeItem, out chan<- TreeItem, skip func(tree ID) bool, p *progress.Counter) {
var (
inCh = in
outCh chan<- TreeItem
loadCh chan<- trackedID
job TreeItem
nextTreeID trackedID
outstandingLoadTreeJobs = 0
)
rootCounter := make([]int, len(trees))
backlog := make([]trackedID, 0, len(trees))
for idx, id := range trees {
backlog = append(backlog, trackedID{ID: id, rootIdx: idx})
rootCounter[idx] = 1
}
for {
if loadCh == nil && len(backlog) > 0 {
// process last added ids first, that is traverse the tree in depth-first order
ln := len(backlog) - 1
nextTreeID, backlog = backlog[ln], backlog[:ln]
if skip(nextTreeID.ID) {
rootCounter[nextTreeID.rootIdx]--
if p != nil && rootCounter[nextTreeID.rootIdx] == 0 {
p.Add(1)
}
continue
}
loadCh = loaderChan
}
if loadCh == nil && outCh == nil && outstandingLoadTreeJobs == 0 {
debug.Log("backlog is empty, all channels nil, exiting")
return
}
select {
case <-ctx.Done():
return
case loadCh <- nextTreeID:
outstandingLoadTreeJobs++
loadCh = nil
case j, ok := <-inCh:
if !ok {
debug.Log("input channel closed")
inCh = nil
in = nil
continue
}
outstandingLoadTreeJobs--
rootCounter[j.rootIdx]--
debug.Log("input job tree %v", j.ID)
if j.Error != nil {
debug.Log("received job with error: %v (tree %v, ID %v)", j.Error, j.Tree, j.ID)
} else if j.Tree == nil {
debug.Log("received job with nil tree pointer: %v (ID %v)", j.Error, j.ID)
// send a new job with the new error instead of the old one
j = trackedTreeItem{TreeItem: TreeItem{ID: j.ID, Error: errors.New("tree is nil and error is nil")}, rootIdx: j.rootIdx}
} else {
subtrees := j.Tree.Subtrees()
debug.Log("subtrees for tree %v: %v", j.ID, subtrees)
// iterate backwards over subtree to compensate backwards traversal order of nextTreeID selection
for i := len(subtrees) - 1; i >= 0; i-- {
id := subtrees[i]
if id.IsNull() {
// We do not need to raise this error here, it is
// checked when the tree is checked. Just make sure
// that we do not add any null IDs to the backlog.
debug.Log("tree %v has nil subtree", j.ID)
continue
}
backlog = append(backlog, trackedID{ID: id, rootIdx: j.rootIdx})
rootCounter[j.rootIdx]++
}
}
if p != nil && rootCounter[j.rootIdx] == 0 {
p.Add(1)
}
job = j.TreeItem
outCh = out
inCh = nil
case outCh <- job:
debug.Log("tree sent to process: %v", job.ID)
outCh = nil
inCh = in
}
}
}
// StreamTrees iteratively loads the given trees and their subtrees. The skip method
// is guaranteed to always be called from the same goroutine. To shutdown the started
// goroutines, either read all items from the channel or cancel the context. Then `Wait()`
// on the errgroup until all goroutines were stopped.
func StreamTrees(ctx context.Context, wg *errgroup.Group, repo TreeLoader, trees IDs, skip func(tree ID) bool, p *progress.Counter) <-chan TreeItem {
loaderChan := make(chan trackedID)
loadedTreeChan := make(chan trackedTreeItem)
treeStream := make(chan TreeItem)
var loadTreeWg sync.WaitGroup
for i := 0; i < streamTreeParallelism; i++ {
loadTreeWg.Add(1)
wg.Go(func() error {
defer loadTreeWg.Done()
loadTreeWorker(ctx, repo, loaderChan, loadedTreeChan)
return nil
})
}
// close once all loadTreeWorkers have completed
wg.Go(func() error {
loadTreeWg.Wait()
close(loadedTreeChan)
return nil
})
wg.Go(func() error {
defer close(loaderChan)
defer close(treeStream)
filterTrees(ctx, trees, loaderChan, loadedTreeChan, treeStream, skip, p)
return nil
})
return treeStream
}

View File

@ -12,7 +12,7 @@ import (
// //
// The final argument is true if Counter.Done has been called, // The final argument is true if Counter.Done has been called,
// which means that the current call will be the last. // which means that the current call will be the last.
type Func func(value uint64, runtime time.Duration, final bool) type Func func(value uint64, total uint64, runtime time.Duration, final bool)
// A Counter tracks a running count and controls a goroutine that passes its // A Counter tracks a running count and controls a goroutine that passes its
// value periodically to a Func. // value periodically to a Func.
@ -27,16 +27,19 @@ type Counter struct {
valueMutex sync.Mutex valueMutex sync.Mutex
value uint64 value uint64
max uint64
} }
// New starts a new Counter. // New starts a new Counter.
func New(interval time.Duration, report Func) *Counter { func New(interval time.Duration, total uint64, report Func) *Counter {
c := &Counter{ c := &Counter{
report: report, report: report,
start: time.Now(), start: time.Now(),
stopped: make(chan struct{}), stopped: make(chan struct{}),
stop: make(chan struct{}), stop: make(chan struct{}),
max: total,
} }
if interval > 0 { if interval > 0 {
c.tick = time.NewTicker(interval) c.tick = time.NewTicker(interval)
} }
@ -56,6 +59,16 @@ func (c *Counter) Add(v uint64) {
c.valueMutex.Unlock() c.valueMutex.Unlock()
} }
// SetMax sets the maximum expected counter value. This method is concurrency-safe.
func (c *Counter) SetMax(max uint64) {
if c == nil {
return
}
c.valueMutex.Lock()
c.max = max
c.valueMutex.Unlock()
}
// Done tells a Counter to stop and waits for it to report its final value. // Done tells a Counter to stop and waits for it to report its final value.
func (c *Counter) Done() { func (c *Counter) Done() {
if c == nil { if c == nil {
@ -69,7 +82,8 @@ func (c *Counter) Done() {
*c = Counter{} // Prevent reuse. *c = Counter{} // Prevent reuse.
} }
func (c *Counter) get() uint64 { // Get the current Counter value. This method is concurrency-safe.
func (c *Counter) Get() uint64 {
c.valueMutex.Lock() c.valueMutex.Lock()
v := c.value v := c.value
c.valueMutex.Unlock() c.valueMutex.Unlock()
@ -77,11 +91,19 @@ func (c *Counter) get() uint64 {
return v return v
} }
func (c *Counter) getMax() uint64 {
c.valueMutex.Lock()
max := c.max
c.valueMutex.Unlock()
return max
}
func (c *Counter) run() { func (c *Counter) run() {
defer close(c.stopped) defer close(c.stopped)
defer func() { defer func() {
// Must be a func so that time.Since isn't called at defer time. // Must be a func so that time.Since isn't called at defer time.
c.report(c.get(), time.Since(c.start), true) c.report(c.Get(), c.getMax(), time.Since(c.start), true)
}() }()
var tick <-chan time.Time var tick <-chan time.Time
@ -101,6 +123,6 @@ func (c *Counter) run() {
return return
} }
c.report(c.get(), now.Sub(c.start), false) c.report(c.Get(), c.getMax(), now.Sub(c.start), false)
} }
} }

View File

@ -10,23 +10,32 @@ import (
func TestCounter(t *testing.T) { func TestCounter(t *testing.T) {
const N = 100 const N = 100
const startTotal = uint64(12345)
var ( var (
finalSeen = false finalSeen = false
increasing = true increasing = true
last uint64 last uint64
lastTotal = startTotal
ncalls int ncalls int
nmaxChange int
) )
report := func(value uint64, d time.Duration, final bool) { report := func(value uint64, total uint64, d time.Duration, final bool) {
if final {
finalSeen = true finalSeen = true
}
if value < last { if value < last {
increasing = false increasing = false
} }
last = value last = value
if total != lastTotal {
nmaxChange++
}
lastTotal = total
ncalls++ ncalls++
} }
c := progress.New(10*time.Millisecond, report) c := progress.New(10*time.Millisecond, startTotal, report)
done := make(chan struct{}) done := make(chan struct{})
go func() { go func() {
@ -35,6 +44,7 @@ func TestCounter(t *testing.T) {
time.Sleep(time.Millisecond) time.Sleep(time.Millisecond)
c.Add(1) c.Add(1)
} }
c.SetMax(42)
}() }()
<-done <-done
@ -43,6 +53,8 @@ func TestCounter(t *testing.T) {
test.Assert(t, finalSeen, "final call did not happen") test.Assert(t, finalSeen, "final call did not happen")
test.Assert(t, increasing, "values not increasing") test.Assert(t, increasing, "values not increasing")
test.Equals(t, uint64(N), last) test.Equals(t, uint64(N), last)
test.Equals(t, uint64(42), lastTotal)
test.Equals(t, int(1), nmaxChange)
t.Log("number of calls:", ncalls) t.Log("number of calls:", ncalls)
} }
@ -58,14 +70,14 @@ func TestCounterNoTick(t *testing.T) {
finalSeen := false finalSeen := false
otherSeen := false otherSeen := false
report := func(value uint64, d time.Duration, final bool) { report := func(value, total uint64, d time.Duration, final bool) {
if final { if final {
finalSeen = true finalSeen = true
} else { } else {
otherSeen = true otherSeen = true
} }
} }
c := progress.New(0, report) c := progress.New(0, 1, report)
time.Sleep(time.Millisecond) time.Sleep(time.Millisecond)
c.Done() c.Done()