mirror of
https://github.com/octoleo/restic.git
synced 2024-11-11 15:51:02 +00:00
Merge pull request #3106 from MichaelEischer/parallel-tree-walk
Parallelize tree walk in prune and copy and add progress bar to check
This commit is contained in:
commit
72eec8c0c4
10
changelog/unreleased/pull-3106
Normal file
10
changelog/unreleased/pull-3106
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
Enhancement: Parallelize scan of snapshot content in copy and prune
|
||||||
|
|
||||||
|
The copy and the prune commands used to traverse the directories of
|
||||||
|
snapshots one by one to find used data. This snapshot traversal is
|
||||||
|
now parallized which can speed up this step several times.
|
||||||
|
|
||||||
|
In addition the check command now reports how many snapshots have
|
||||||
|
already been processed.
|
||||||
|
|
||||||
|
https://github.com/restic/restic/pull/3106
|
@ -240,7 +240,11 @@ func runCheck(opts CheckOptions, gopts GlobalOptions, args []string) error {
|
|||||||
|
|
||||||
Verbosef("check snapshots, trees and blobs\n")
|
Verbosef("check snapshots, trees and blobs\n")
|
||||||
errChan = make(chan error)
|
errChan = make(chan error)
|
||||||
go chkr.Structure(gopts.ctx, errChan)
|
go func() {
|
||||||
|
bar := newProgressMax(!gopts.Quiet, 0, "snapshots")
|
||||||
|
defer bar.Done()
|
||||||
|
chkr.Structure(gopts.ctx, bar, errChan)
|
||||||
|
}()
|
||||||
|
|
||||||
for err := range errChan {
|
for err := range errChan {
|
||||||
errorsFound = true
|
errorsFound = true
|
||||||
|
@ -6,6 +6,7 @@ import (
|
|||||||
|
|
||||||
"github.com/restic/restic/internal/debug"
|
"github.com/restic/restic/internal/debug"
|
||||||
"github.com/restic/restic/internal/restic"
|
"github.com/restic/restic/internal/restic"
|
||||||
|
"golang.org/x/sync/errgroup"
|
||||||
|
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
)
|
)
|
||||||
@ -103,12 +104,8 @@ func runCopy(opts CopyOptions, gopts GlobalOptions, args []string) error {
|
|||||||
dstSnapshotByOriginal[*sn.ID()] = append(dstSnapshotByOriginal[*sn.ID()], sn)
|
dstSnapshotByOriginal[*sn.ID()] = append(dstSnapshotByOriginal[*sn.ID()], sn)
|
||||||
}
|
}
|
||||||
|
|
||||||
cloner := &treeCloner{
|
// remember already processed trees across all snapshots
|
||||||
srcRepo: srcRepo,
|
visitedTrees := restic.NewIDSet()
|
||||||
dstRepo: dstRepo,
|
|
||||||
visitedTrees: restic.NewIDSet(),
|
|
||||||
buf: nil,
|
|
||||||
}
|
|
||||||
|
|
||||||
for sn := range FindFilteredSnapshots(ctx, srcRepo, opts.Hosts, opts.Tags, opts.Paths, args) {
|
for sn := range FindFilteredSnapshots(ctx, srcRepo, opts.Hosts, opts.Tags, opts.Paths, args) {
|
||||||
Verbosef("\nsnapshot %s of %v at %s)\n", sn.ID().Str(), sn.Paths, sn.Time)
|
Verbosef("\nsnapshot %s of %v at %s)\n", sn.ID().Str(), sn.Paths, sn.Time)
|
||||||
@ -133,7 +130,7 @@ func runCopy(opts CopyOptions, gopts GlobalOptions, args []string) error {
|
|||||||
}
|
}
|
||||||
Verbosef(" copy started, this may take a while...\n")
|
Verbosef(" copy started, this may take a while...\n")
|
||||||
|
|
||||||
if err := cloner.copyTree(ctx, *sn.Tree); err != nil {
|
if err := copyTree(ctx, srcRepo, dstRepo, visitedTrees, *sn.Tree); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
debug.Log("tree copied")
|
debug.Log("tree copied")
|
||||||
@ -177,64 +174,64 @@ func similarSnapshots(sna *restic.Snapshot, snb *restic.Snapshot) bool {
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
type treeCloner struct {
|
func copyTree(ctx context.Context, srcRepo restic.Repository, dstRepo restic.Repository,
|
||||||
srcRepo restic.Repository
|
visitedTrees restic.IDSet, rootTreeID restic.ID) error {
|
||||||
dstRepo restic.Repository
|
|
||||||
visitedTrees restic.IDSet
|
|
||||||
buf []byte
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *treeCloner) copyTree(ctx context.Context, treeID restic.ID) error {
|
wg, ctx := errgroup.WithContext(ctx)
|
||||||
// We have already processed this tree
|
|
||||||
if t.visitedTrees.Has(treeID) {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
tree, err := t.srcRepo.LoadTree(ctx, treeID)
|
treeStream := restic.StreamTrees(ctx, wg, srcRepo, restic.IDs{rootTreeID}, func(treeID restic.ID) bool {
|
||||||
if err != nil {
|
visited := visitedTrees.Has(treeID)
|
||||||
return fmt.Errorf("LoadTree(%v) returned error %v", treeID.Str(), err)
|
visitedTrees.Insert(treeID)
|
||||||
|
return visited
|
||||||
|
}, nil)
|
||||||
|
|
||||||
|
wg.Go(func() error {
|
||||||
|
// reused buffer
|
||||||
|
var buf []byte
|
||||||
|
|
||||||
|
for tree := range treeStream {
|
||||||
|
if tree.Error != nil {
|
||||||
|
return fmt.Errorf("LoadTree(%v) returned error %v", tree.ID.Str(), tree.Error)
|
||||||
}
|
}
|
||||||
t.visitedTrees.Insert(treeID)
|
|
||||||
|
|
||||||
// Do we already have this tree blob?
|
// Do we already have this tree blob?
|
||||||
if !t.dstRepo.Index().Has(restic.BlobHandle{ID: treeID, Type: restic.TreeBlob}) {
|
if !dstRepo.Index().Has(restic.BlobHandle{ID: tree.ID, Type: restic.TreeBlob}) {
|
||||||
newTreeID, err := t.dstRepo.SaveTree(ctx, tree)
|
newTreeID, err := dstRepo.SaveTree(ctx, tree.Tree)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("SaveTree(%v) returned error %v", treeID.Str(), err)
|
return fmt.Errorf("SaveTree(%v) returned error %v", tree.ID.Str(), err)
|
||||||
}
|
}
|
||||||
// Assurance only.
|
// Assurance only.
|
||||||
if newTreeID != treeID {
|
if newTreeID != tree.ID {
|
||||||
return fmt.Errorf("SaveTree(%v) returned unexpected id %s", treeID.Str(), newTreeID.Str())
|
return fmt.Errorf("SaveTree(%v) returned unexpected id %s", tree.ID.Str(), newTreeID.Str())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: parellize this stuff, likely only needed inside a tree.
|
// TODO: parallelize blob down/upload
|
||||||
|
|
||||||
for _, entry := range tree.Nodes {
|
for _, entry := range tree.Nodes {
|
||||||
// If it is a directory, recurse
|
// Recursion into directories is handled by StreamTrees
|
||||||
if entry.Type == "dir" && entry.Subtree != nil {
|
|
||||||
if err := t.copyTree(ctx, *entry.Subtree); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// Copy the blobs for this file.
|
// Copy the blobs for this file.
|
||||||
for _, blobID := range entry.Content {
|
for _, blobID := range entry.Content {
|
||||||
// Do we already have this data blob?
|
// Do we already have this data blob?
|
||||||
if t.dstRepo.Index().Has(restic.BlobHandle{ID: blobID, Type: restic.DataBlob}) {
|
if dstRepo.Index().Has(restic.BlobHandle{ID: blobID, Type: restic.DataBlob}) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
debug.Log("Copying blob %s\n", blobID.Str())
|
debug.Log("Copying blob %s\n", blobID.Str())
|
||||||
t.buf, err = t.srcRepo.LoadBlob(ctx, restic.DataBlob, blobID, t.buf)
|
var err error
|
||||||
|
buf, err = srcRepo.LoadBlob(ctx, restic.DataBlob, blobID, buf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("LoadBlob(%v) returned error %v", blobID, err)
|
return fmt.Errorf("LoadBlob(%v) returned error %v", blobID, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
_, _, err = t.dstRepo.SaveBlob(ctx, restic.DataBlob, t.buf, blobID, false)
|
_, _, err = dstRepo.SaveBlob(ctx, restic.DataBlob, buf, blobID, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("SaveBlob(%v) returned error %v", blobID, err)
|
return fmt.Errorf("SaveBlob(%v) returned error %v", blobID, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
return wg.Wait()
|
||||||
}
|
}
|
||||||
|
@ -574,10 +574,8 @@ func getUsedBlobs(gopts GlobalOptions, repo restic.Repository, ignoreSnapshots r
|
|||||||
|
|
||||||
bar := newProgressMax(!gopts.Quiet, uint64(len(snapshotTrees)), "snapshots")
|
bar := newProgressMax(!gopts.Quiet, uint64(len(snapshotTrees)), "snapshots")
|
||||||
defer bar.Done()
|
defer bar.Done()
|
||||||
for _, tree := range snapshotTrees {
|
|
||||||
debug.Log("process tree %v", tree)
|
|
||||||
|
|
||||||
err = restic.FindUsedBlobs(ctx, repo, tree, usedBlobs)
|
err = restic.FindUsedBlobs(ctx, repo, snapshotTrees, usedBlobs, bar)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if repo.Backend().IsNotExist(err) {
|
if repo.Backend().IsNotExist(err) {
|
||||||
return nil, errors.Fatal("unable to load a tree from the repo: " + err.Error())
|
return nil, errors.Fatal("unable to load a tree from the repo: " + err.Error())
|
||||||
@ -585,9 +583,5 @@ func getUsedBlobs(gopts GlobalOptions, repo restic.Repository, ignoreSnapshots r
|
|||||||
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
debug.Log("processed tree %v", tree)
|
|
||||||
bar.Add(1)
|
|
||||||
}
|
|
||||||
return usedBlobs, nil
|
return usedBlobs, nil
|
||||||
}
|
}
|
||||||
|
@ -166,7 +166,7 @@ func statsWalkSnapshot(ctx context.Context, snapshot *restic.Snapshot, repo rest
|
|||||||
if statsOptions.countMode == countModeRawData {
|
if statsOptions.countMode == countModeRawData {
|
||||||
// count just the sizes of unique blobs; we don't need to walk the tree
|
// count just the sizes of unique blobs; we don't need to walk the tree
|
||||||
// ourselves in this case, since a nifty function does it for us
|
// ourselves in this case, since a nifty function does it for us
|
||||||
return restic.FindUsedBlobs(ctx, repo, *snapshot.Tree, stats.blobs)
|
return restic.FindUsedBlobs(ctx, repo, restic.IDs{*snapshot.Tree}, stats.blobs, nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
err := walker.Walk(ctx, repo, *snapshot.Tree, restic.NewIDSet(), statsWalkTree(repo, stats))
|
err := walker.Walk(ctx, repo, *snapshot.Tree, restic.NewIDSet(), statsWalkTree(repo, stats))
|
||||||
|
@ -33,11 +33,14 @@ func newProgressMax(show bool, max uint64, description string) *progress.Counter
|
|||||||
}
|
}
|
||||||
interval := calculateProgressInterval()
|
interval := calculateProgressInterval()
|
||||||
|
|
||||||
return progress.New(interval, func(v uint64, d time.Duration, final bool) {
|
return progress.New(interval, max, func(v uint64, max uint64, d time.Duration, final bool) {
|
||||||
status := fmt.Sprintf("[%s] %s %d / %d %s",
|
var status string
|
||||||
formatDuration(d),
|
if max == 0 {
|
||||||
formatPercent(v, max),
|
status = fmt.Sprintf("[%s] %d %s", formatDuration(d), v, description)
|
||||||
v, max, description)
|
} else {
|
||||||
|
status = fmt.Sprintf("[%s] %s %d / %d %s",
|
||||||
|
formatDuration(d), formatPercent(v, max), v, max, description)
|
||||||
|
}
|
||||||
|
|
||||||
if w := stdoutTerminalWidth(); w > 0 {
|
if w := stdoutTerminalWidth(); w > 0 {
|
||||||
status = shortenStatus(w, status)
|
status = shortenStatus(w, status)
|
||||||
|
@ -308,200 +308,27 @@ func (e TreeError) Error() string {
|
|||||||
return fmt.Sprintf("tree %v: %v", e.ID.Str(), e.Errors)
|
return fmt.Sprintf("tree %v: %v", e.ID.Str(), e.Errors)
|
||||||
}
|
}
|
||||||
|
|
||||||
type treeJob struct {
|
|
||||||
restic.ID
|
|
||||||
error
|
|
||||||
*restic.Tree
|
|
||||||
}
|
|
||||||
|
|
||||||
// loadTreeWorker loads trees from repo and sends them to out.
|
|
||||||
func loadTreeWorker(ctx context.Context, repo restic.Repository,
|
|
||||||
in <-chan restic.ID, out chan<- treeJob,
|
|
||||||
wg *sync.WaitGroup) {
|
|
||||||
|
|
||||||
defer func() {
|
|
||||||
debug.Log("exiting")
|
|
||||||
wg.Done()
|
|
||||||
}()
|
|
||||||
|
|
||||||
var (
|
|
||||||
inCh = in
|
|
||||||
outCh = out
|
|
||||||
job treeJob
|
|
||||||
)
|
|
||||||
|
|
||||||
outCh = nil
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
return
|
|
||||||
|
|
||||||
case treeID, ok := <-inCh:
|
|
||||||
if !ok {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
debug.Log("load tree %v", treeID)
|
|
||||||
|
|
||||||
tree, err := repo.LoadTree(ctx, treeID)
|
|
||||||
debug.Log("load tree %v (%v) returned err: %v", tree, treeID, err)
|
|
||||||
job = treeJob{ID: treeID, error: err, Tree: tree}
|
|
||||||
outCh = out
|
|
||||||
inCh = nil
|
|
||||||
|
|
||||||
case outCh <- job:
|
|
||||||
debug.Log("sent tree %v", job.ID)
|
|
||||||
outCh = nil
|
|
||||||
inCh = in
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// checkTreeWorker checks the trees received and sends out errors to errChan.
|
// checkTreeWorker checks the trees received and sends out errors to errChan.
|
||||||
func (c *Checker) checkTreeWorker(ctx context.Context, in <-chan treeJob, out chan<- error, wg *sync.WaitGroup) {
|
func (c *Checker) checkTreeWorker(ctx context.Context, trees <-chan restic.TreeItem, out chan<- error) {
|
||||||
defer func() {
|
for job := range trees {
|
||||||
debug.Log("exiting")
|
debug.Log("check tree %v (tree %v, err %v)", job.ID, job.Tree, job.Error)
|
||||||
wg.Done()
|
|
||||||
}()
|
|
||||||
|
|
||||||
var (
|
|
||||||
inCh = in
|
|
||||||
outCh = out
|
|
||||||
treeError TreeError
|
|
||||||
)
|
|
||||||
|
|
||||||
outCh = nil
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
debug.Log("done channel closed, exiting")
|
|
||||||
return
|
|
||||||
|
|
||||||
case job, ok := <-inCh:
|
|
||||||
if !ok {
|
|
||||||
debug.Log("input channel closed, exiting")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
debug.Log("check tree %v (tree %v, err %v)", job.ID, job.Tree, job.error)
|
|
||||||
|
|
||||||
var errs []error
|
var errs []error
|
||||||
if job.error != nil {
|
if job.Error != nil {
|
||||||
errs = append(errs, job.error)
|
errs = append(errs, job.Error)
|
||||||
} else {
|
} else {
|
||||||
errs = c.checkTree(job.ID, job.Tree)
|
errs = c.checkTree(job.ID, job.Tree)
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(errs) > 0 {
|
if len(errs) == 0 {
|
||||||
debug.Log("checked tree %v: %v errors", job.ID, len(errs))
|
|
||||||
treeError = TreeError{ID: job.ID, Errors: errs}
|
|
||||||
outCh = out
|
|
||||||
inCh = nil
|
|
||||||
}
|
|
||||||
|
|
||||||
case outCh <- treeError:
|
|
||||||
debug.Log("tree %v: sent %d errors", treeError.ID, len(treeError.Errors))
|
|
||||||
outCh = nil
|
|
||||||
inCh = in
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Checker) filterTrees(ctx context.Context, backlog restic.IDs, loaderChan chan<- restic.ID, in <-chan treeJob, out chan<- treeJob) {
|
|
||||||
defer func() {
|
|
||||||
debug.Log("closing output channels")
|
|
||||||
close(loaderChan)
|
|
||||||
close(out)
|
|
||||||
}()
|
|
||||||
|
|
||||||
var (
|
|
||||||
inCh = in
|
|
||||||
outCh = out
|
|
||||||
loadCh = loaderChan
|
|
||||||
job treeJob
|
|
||||||
nextTreeID restic.ID
|
|
||||||
outstandingLoadTreeJobs = 0
|
|
||||||
)
|
|
||||||
|
|
||||||
outCh = nil
|
|
||||||
loadCh = nil
|
|
||||||
|
|
||||||
for {
|
|
||||||
if loadCh == nil && len(backlog) > 0 {
|
|
||||||
// process last added ids first, that is traverse the tree in depth-first order
|
|
||||||
ln := len(backlog) - 1
|
|
||||||
nextTreeID, backlog = backlog[ln], backlog[:ln]
|
|
||||||
|
|
||||||
// use a separate flag for processed trees to ensure that check still processes trees
|
|
||||||
// even when a file references a tree blob
|
|
||||||
c.blobRefs.Lock()
|
|
||||||
h := restic.BlobHandle{ID: nextTreeID, Type: restic.TreeBlob}
|
|
||||||
blobReferenced := c.blobRefs.M.Has(h)
|
|
||||||
// noop if already referenced
|
|
||||||
c.blobRefs.M.Insert(h)
|
|
||||||
c.blobRefs.Unlock()
|
|
||||||
if blobReferenced {
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
treeError := TreeError{ID: job.ID, Errors: errs}
|
||||||
loadCh = loaderChan
|
|
||||||
}
|
|
||||||
|
|
||||||
if loadCh == nil && outCh == nil && outstandingLoadTreeJobs == 0 {
|
|
||||||
debug.Log("backlog is empty, all channels nil, exiting")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
return
|
return
|
||||||
|
case out <- treeError:
|
||||||
case loadCh <- nextTreeID:
|
debug.Log("tree %v: sent %d errors", treeError.ID, len(treeError.Errors))
|
||||||
outstandingLoadTreeJobs++
|
|
||||||
loadCh = nil
|
|
||||||
|
|
||||||
case j, ok := <-inCh:
|
|
||||||
if !ok {
|
|
||||||
debug.Log("input channel closed")
|
|
||||||
inCh = nil
|
|
||||||
in = nil
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
outstandingLoadTreeJobs--
|
|
||||||
|
|
||||||
debug.Log("input job tree %v", j.ID)
|
|
||||||
|
|
||||||
if j.error != nil {
|
|
||||||
debug.Log("received job with error: %v (tree %v, ID %v)", j.error, j.Tree, j.ID)
|
|
||||||
} else if j.Tree == nil {
|
|
||||||
debug.Log("received job with nil tree pointer: %v (ID %v)", j.error, j.ID)
|
|
||||||
// send a new job with the new error instead of the old one
|
|
||||||
j = treeJob{ID: j.ID, error: errors.New("tree is nil and error is nil")}
|
|
||||||
} else {
|
|
||||||
subtrees := j.Tree.Subtrees()
|
|
||||||
debug.Log("subtrees for tree %v: %v", j.ID, subtrees)
|
|
||||||
// iterate backwards over subtree to compensate backwards traversal order of nextTreeID selection
|
|
||||||
for i := len(subtrees) - 1; i >= 0; i-- {
|
|
||||||
id := subtrees[i]
|
|
||||||
if id.IsNull() {
|
|
||||||
// We do not need to raise this error here, it is
|
|
||||||
// checked when the tree is checked. Just make sure
|
|
||||||
// that we do not add any null IDs to the backlog.
|
|
||||||
debug.Log("tree %v has nil subtree", j.ID)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
backlog = append(backlog, id)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
job = j
|
|
||||||
outCh = out
|
|
||||||
inCh = nil
|
|
||||||
|
|
||||||
case outCh <- job:
|
|
||||||
debug.Log("tree sent to check: %v", job.ID)
|
|
||||||
outCh = nil
|
|
||||||
inCh = in
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -527,10 +354,9 @@ func loadSnapshotTreeIDs(ctx context.Context, repo restic.Repository) (ids resti
|
|||||||
// Structure checks that for all snapshots all referenced data blobs and
|
// Structure checks that for all snapshots all referenced data blobs and
|
||||||
// subtrees are available in the index. errChan is closed after all trees have
|
// subtrees are available in the index. errChan is closed after all trees have
|
||||||
// been traversed.
|
// been traversed.
|
||||||
func (c *Checker) Structure(ctx context.Context, errChan chan<- error) {
|
func (c *Checker) Structure(ctx context.Context, p *progress.Counter, errChan chan<- error) {
|
||||||
defer close(errChan)
|
|
||||||
|
|
||||||
trees, errs := loadSnapshotTreeIDs(ctx, c.repo)
|
trees, errs := loadSnapshotTreeIDs(ctx, c.repo)
|
||||||
|
p.SetMax(uint64(len(trees)))
|
||||||
debug.Log("need to check %d trees from snapshots, %d errs returned", len(trees), len(errs))
|
debug.Log("need to check %d trees from snapshots, %d errs returned", len(trees), len(errs))
|
||||||
|
|
||||||
for _, err := range errs {
|
for _, err := range errs {
|
||||||
@ -541,19 +367,26 @@ func (c *Checker) Structure(ctx context.Context, errChan chan<- error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
treeIDChan := make(chan restic.ID)
|
wg, ctx := errgroup.WithContext(ctx)
|
||||||
treeJobChan1 := make(chan treeJob)
|
treeStream := restic.StreamTrees(ctx, wg, c.repo, trees, func(treeID restic.ID) bool {
|
||||||
treeJobChan2 := make(chan treeJob)
|
// blobRefs may be accessed in parallel by checkTree
|
||||||
|
c.blobRefs.Lock()
|
||||||
|
h := restic.BlobHandle{ID: treeID, Type: restic.TreeBlob}
|
||||||
|
blobReferenced := c.blobRefs.M.Has(h)
|
||||||
|
// noop if already referenced
|
||||||
|
c.blobRefs.M.Insert(h)
|
||||||
|
c.blobRefs.Unlock()
|
||||||
|
return blobReferenced
|
||||||
|
}, p)
|
||||||
|
|
||||||
var wg sync.WaitGroup
|
defer close(errChan)
|
||||||
for i := 0; i < defaultParallelism; i++ {
|
for i := 0; i < defaultParallelism; i++ {
|
||||||
wg.Add(2)
|
wg.Go(func() error {
|
||||||
go loadTreeWorker(ctx, c.repo, treeIDChan, treeJobChan1, &wg)
|
c.checkTreeWorker(ctx, treeStream, errChan)
|
||||||
go c.checkTreeWorker(ctx, treeJobChan2, errChan, &wg)
|
return nil
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
c.filterTrees(ctx, trees, treeIDChan, treeJobChan1, treeJobChan2)
|
|
||||||
|
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -43,7 +43,9 @@ func checkPacks(chkr *checker.Checker) []error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func checkStruct(chkr *checker.Checker) []error {
|
func checkStruct(chkr *checker.Checker) []error {
|
||||||
return collectErrors(context.TODO(), chkr.Structure)
|
return collectErrors(context.TODO(), func(ctx context.Context, errChan chan<- error) {
|
||||||
|
chkr.Structure(ctx, nil, errChan)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func checkData(chkr *checker.Checker) []error {
|
func checkData(chkr *checker.Checker) []error {
|
||||||
|
@ -30,7 +30,7 @@ func TestCheckRepo(t testing.TB, repo restic.Repository) {
|
|||||||
|
|
||||||
// structure
|
// structure
|
||||||
errChan = make(chan error)
|
errChan = make(chan error)
|
||||||
go chkr.Structure(context.TODO(), errChan)
|
go chkr.Structure(context.TODO(), nil, errChan)
|
||||||
|
|
||||||
for err := range errChan {
|
for err := range errChan {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
|
@ -368,7 +368,7 @@ func TestIndexSave(t *testing.T) {
|
|||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
errCh := make(chan error)
|
errCh := make(chan error)
|
||||||
go checker.Structure(ctx, errCh)
|
go checker.Structure(ctx, nil, errCh)
|
||||||
i := 0
|
i := 0
|
||||||
for err := range errCh {
|
for err := range errCh {
|
||||||
t.Errorf("checker returned error: %v", err)
|
t.Errorf("checker returned error: %v", err)
|
||||||
|
@ -1,6 +1,12 @@
|
|||||||
package restic
|
package restic
|
||||||
|
|
||||||
import "context"
|
import (
|
||||||
|
"context"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/restic/restic/internal/ui/progress"
|
||||||
|
"golang.org/x/sync/errgroup"
|
||||||
|
)
|
||||||
|
|
||||||
// TreeLoader loads a tree from a repository.
|
// TreeLoader loads a tree from a repository.
|
||||||
type TreeLoader interface {
|
type TreeLoader interface {
|
||||||
@ -9,31 +15,39 @@ type TreeLoader interface {
|
|||||||
|
|
||||||
// FindUsedBlobs traverses the tree ID and adds all seen blobs (trees and data
|
// FindUsedBlobs traverses the tree ID and adds all seen blobs (trees and data
|
||||||
// blobs) to the set blobs. Already seen tree blobs will not be visited again.
|
// blobs) to the set blobs. Already seen tree blobs will not be visited again.
|
||||||
func FindUsedBlobs(ctx context.Context, repo TreeLoader, treeID ID, blobs BlobSet) error {
|
func FindUsedBlobs(ctx context.Context, repo TreeLoader, treeIDs IDs, blobs BlobSet, p *progress.Counter) error {
|
||||||
|
var lock sync.Mutex
|
||||||
|
|
||||||
|
wg, ctx := errgroup.WithContext(ctx)
|
||||||
|
treeStream := StreamTrees(ctx, wg, repo, treeIDs, func(treeID ID) bool {
|
||||||
|
// locking is necessary the goroutine below concurrently adds data blobs
|
||||||
|
lock.Lock()
|
||||||
h := BlobHandle{ID: treeID, Type: TreeBlob}
|
h := BlobHandle{ID: treeID, Type: TreeBlob}
|
||||||
if blobs.Has(h) {
|
blobReferenced := blobs.Has(h)
|
||||||
return nil
|
// noop if already referenced
|
||||||
}
|
|
||||||
blobs.Insert(h)
|
blobs.Insert(h)
|
||||||
|
lock.Unlock()
|
||||||
|
return blobReferenced
|
||||||
|
}, p)
|
||||||
|
|
||||||
tree, err := repo.LoadTree(ctx, treeID)
|
wg.Go(func() error {
|
||||||
if err != nil {
|
for tree := range treeStream {
|
||||||
return err
|
if tree.Error != nil {
|
||||||
|
return tree.Error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
lock.Lock()
|
||||||
for _, node := range tree.Nodes {
|
for _, node := range tree.Nodes {
|
||||||
switch node.Type {
|
switch node.Type {
|
||||||
case "file":
|
case "file":
|
||||||
for _, blob := range node.Content {
|
for _, blob := range node.Content {
|
||||||
blobs.Insert(BlobHandle{ID: blob, Type: DataBlob})
|
blobs.Insert(BlobHandle{ID: blob, Type: DataBlob})
|
||||||
}
|
}
|
||||||
case "dir":
|
|
||||||
err := FindUsedBlobs(ctx, repo, *node.Subtree, blobs)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
lock.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
})
|
||||||
|
return wg.Wait()
|
||||||
}
|
}
|
||||||
|
@ -15,6 +15,8 @@ import (
|
|||||||
"github.com/restic/restic/internal/errors"
|
"github.com/restic/restic/internal/errors"
|
||||||
"github.com/restic/restic/internal/repository"
|
"github.com/restic/restic/internal/repository"
|
||||||
"github.com/restic/restic/internal/restic"
|
"github.com/restic/restic/internal/restic"
|
||||||
|
"github.com/restic/restic/internal/test"
|
||||||
|
"github.com/restic/restic/internal/ui/progress"
|
||||||
)
|
)
|
||||||
|
|
||||||
func loadIDSet(t testing.TB, filename string) restic.BlobSet {
|
func loadIDSet(t testing.TB, filename string) restic.BlobSet {
|
||||||
@ -92,9 +94,12 @@ func TestFindUsedBlobs(t *testing.T) {
|
|||||||
snapshots = append(snapshots, sn)
|
snapshots = append(snapshots, sn)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
p := progress.New(time.Second, findTestSnapshots, func(value uint64, total uint64, runtime time.Duration, final bool) {})
|
||||||
|
defer p.Done()
|
||||||
|
|
||||||
for i, sn := range snapshots {
|
for i, sn := range snapshots {
|
||||||
usedBlobs := restic.NewBlobSet()
|
usedBlobs := restic.NewBlobSet()
|
||||||
err := restic.FindUsedBlobs(context.TODO(), repo, *sn.Tree, usedBlobs)
|
err := restic.FindUsedBlobs(context.TODO(), repo, restic.IDs{*sn.Tree}, usedBlobs, p)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("FindUsedBlobs returned error: %v", err)
|
t.Errorf("FindUsedBlobs returned error: %v", err)
|
||||||
continue
|
continue
|
||||||
@ -105,6 +110,8 @@ func TestFindUsedBlobs(t *testing.T) {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
test.Equals(t, p.Get(), uint64(i+1))
|
||||||
|
|
||||||
goldenFilename := filepath.Join("testdata", fmt.Sprintf("used_blobs_snapshot%d", i))
|
goldenFilename := filepath.Join("testdata", fmt.Sprintf("used_blobs_snapshot%d", i))
|
||||||
want := loadIDSet(t, goldenFilename)
|
want := loadIDSet(t, goldenFilename)
|
||||||
|
|
||||||
@ -119,6 +126,40 @@ func TestFindUsedBlobs(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestMultiFindUsedBlobs(t *testing.T) {
|
||||||
|
repo, cleanup := repository.TestRepository(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
var snapshotTrees restic.IDs
|
||||||
|
for i := 0; i < findTestSnapshots; i++ {
|
||||||
|
sn := restic.TestCreateSnapshot(t, repo, findTestTime.Add(time.Duration(i)*time.Second), findTestDepth, 0)
|
||||||
|
t.Logf("snapshot %v saved, tree %v", sn.ID().Str(), sn.Tree.Str())
|
||||||
|
snapshotTrees = append(snapshotTrees, *sn.Tree)
|
||||||
|
}
|
||||||
|
|
||||||
|
want := restic.NewBlobSet()
|
||||||
|
for i := range snapshotTrees {
|
||||||
|
goldenFilename := filepath.Join("testdata", fmt.Sprintf("used_blobs_snapshot%d", i))
|
||||||
|
want.Merge(loadIDSet(t, goldenFilename))
|
||||||
|
}
|
||||||
|
|
||||||
|
p := progress.New(time.Second, findTestSnapshots, func(value uint64, total uint64, runtime time.Duration, final bool) {})
|
||||||
|
defer p.Done()
|
||||||
|
|
||||||
|
// run twice to check progress bar handling of duplicate tree roots
|
||||||
|
usedBlobs := restic.NewBlobSet()
|
||||||
|
for i := 1; i < 3; i++ {
|
||||||
|
err := restic.FindUsedBlobs(context.TODO(), repo, snapshotTrees, usedBlobs, p)
|
||||||
|
test.OK(t, err)
|
||||||
|
test.Equals(t, p.Get(), uint64(i*len(snapshotTrees)))
|
||||||
|
|
||||||
|
if !want.Equals(usedBlobs) {
|
||||||
|
t.Errorf("wrong list of blobs returned:\n missing blobs: %v\n extra blobs: %v",
|
||||||
|
want.Sub(usedBlobs), usedBlobs.Sub(want))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
type ForbiddenRepo struct{}
|
type ForbiddenRepo struct{}
|
||||||
|
|
||||||
func (r ForbiddenRepo) LoadTree(ctx context.Context, id restic.ID) (*restic.Tree, error) {
|
func (r ForbiddenRepo) LoadTree(ctx context.Context, id restic.ID) (*restic.Tree, error) {
|
||||||
@ -133,12 +174,12 @@ func TestFindUsedBlobsSkipsSeenBlobs(t *testing.T) {
|
|||||||
t.Logf("snapshot %v saved, tree %v", snapshot.ID().Str(), snapshot.Tree.Str())
|
t.Logf("snapshot %v saved, tree %v", snapshot.ID().Str(), snapshot.Tree.Str())
|
||||||
|
|
||||||
usedBlobs := restic.NewBlobSet()
|
usedBlobs := restic.NewBlobSet()
|
||||||
err := restic.FindUsedBlobs(context.TODO(), repo, *snapshot.Tree, usedBlobs)
|
err := restic.FindUsedBlobs(context.TODO(), repo, restic.IDs{*snapshot.Tree}, usedBlobs, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("FindUsedBlobs returned error: %v", err)
|
t.Fatalf("FindUsedBlobs returned error: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = restic.FindUsedBlobs(context.TODO(), ForbiddenRepo{}, *snapshot.Tree, usedBlobs)
|
err = restic.FindUsedBlobs(context.TODO(), ForbiddenRepo{}, restic.IDs{*snapshot.Tree}, usedBlobs, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("FindUsedBlobs returned error: %v", err)
|
t.Fatalf("FindUsedBlobs returned error: %v", err)
|
||||||
}
|
}
|
||||||
@ -154,7 +195,7 @@ func BenchmarkFindUsedBlobs(b *testing.B) {
|
|||||||
|
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
blobs := restic.NewBlobSet()
|
blobs := restic.NewBlobSet()
|
||||||
err := restic.FindUsedBlobs(context.TODO(), repo, *sn.Tree, blobs)
|
err := restic.FindUsedBlobs(context.TODO(), repo, restic.IDs{*sn.Tree}, blobs, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
b.Error(err)
|
b.Error(err)
|
||||||
}
|
}
|
||||||
|
183
internal/restic/tree_stream.go
Normal file
183
internal/restic/tree_stream.go
Normal file
@ -0,0 +1,183 @@
|
|||||||
|
package restic
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/restic/restic/internal/debug"
|
||||||
|
"github.com/restic/restic/internal/ui/progress"
|
||||||
|
"golang.org/x/sync/errgroup"
|
||||||
|
)
|
||||||
|
|
||||||
|
const streamTreeParallelism = 5
|
||||||
|
|
||||||
|
// TreeItem is used to return either an error or the tree for a tree id
|
||||||
|
type TreeItem struct {
|
||||||
|
ID
|
||||||
|
Error error
|
||||||
|
*Tree
|
||||||
|
}
|
||||||
|
|
||||||
|
type trackedTreeItem struct {
|
||||||
|
TreeItem
|
||||||
|
rootIdx int
|
||||||
|
}
|
||||||
|
|
||||||
|
type trackedID struct {
|
||||||
|
ID
|
||||||
|
rootIdx int
|
||||||
|
}
|
||||||
|
|
||||||
|
// loadTreeWorker loads trees from repo and sends them to out.
|
||||||
|
func loadTreeWorker(ctx context.Context, repo TreeLoader,
|
||||||
|
in <-chan trackedID, out chan<- trackedTreeItem) {
|
||||||
|
|
||||||
|
for treeID := range in {
|
||||||
|
tree, err := repo.LoadTree(ctx, treeID.ID)
|
||||||
|
debug.Log("load tree %v (%v) returned err: %v", tree, treeID, err)
|
||||||
|
job := trackedTreeItem{TreeItem: TreeItem{ID: treeID.ID, Error: err, Tree: tree}, rootIdx: treeID.rootIdx}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
case out <- job:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func filterTrees(ctx context.Context, trees IDs, loaderChan chan<- trackedID,
|
||||||
|
in <-chan trackedTreeItem, out chan<- TreeItem, skip func(tree ID) bool, p *progress.Counter) {
|
||||||
|
|
||||||
|
var (
|
||||||
|
inCh = in
|
||||||
|
outCh chan<- TreeItem
|
||||||
|
loadCh chan<- trackedID
|
||||||
|
job TreeItem
|
||||||
|
nextTreeID trackedID
|
||||||
|
outstandingLoadTreeJobs = 0
|
||||||
|
)
|
||||||
|
rootCounter := make([]int, len(trees))
|
||||||
|
backlog := make([]trackedID, 0, len(trees))
|
||||||
|
for idx, id := range trees {
|
||||||
|
backlog = append(backlog, trackedID{ID: id, rootIdx: idx})
|
||||||
|
rootCounter[idx] = 1
|
||||||
|
}
|
||||||
|
|
||||||
|
for {
|
||||||
|
if loadCh == nil && len(backlog) > 0 {
|
||||||
|
// process last added ids first, that is traverse the tree in depth-first order
|
||||||
|
ln := len(backlog) - 1
|
||||||
|
nextTreeID, backlog = backlog[ln], backlog[:ln]
|
||||||
|
|
||||||
|
if skip(nextTreeID.ID) {
|
||||||
|
rootCounter[nextTreeID.rootIdx]--
|
||||||
|
if p != nil && rootCounter[nextTreeID.rootIdx] == 0 {
|
||||||
|
p.Add(1)
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
loadCh = loaderChan
|
||||||
|
}
|
||||||
|
|
||||||
|
if loadCh == nil && outCh == nil && outstandingLoadTreeJobs == 0 {
|
||||||
|
debug.Log("backlog is empty, all channels nil, exiting")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
|
||||||
|
case loadCh <- nextTreeID:
|
||||||
|
outstandingLoadTreeJobs++
|
||||||
|
loadCh = nil
|
||||||
|
|
||||||
|
case j, ok := <-inCh:
|
||||||
|
if !ok {
|
||||||
|
debug.Log("input channel closed")
|
||||||
|
inCh = nil
|
||||||
|
in = nil
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
outstandingLoadTreeJobs--
|
||||||
|
rootCounter[j.rootIdx]--
|
||||||
|
|
||||||
|
debug.Log("input job tree %v", j.ID)
|
||||||
|
|
||||||
|
if j.Error != nil {
|
||||||
|
debug.Log("received job with error: %v (tree %v, ID %v)", j.Error, j.Tree, j.ID)
|
||||||
|
} else if j.Tree == nil {
|
||||||
|
debug.Log("received job with nil tree pointer: %v (ID %v)", j.Error, j.ID)
|
||||||
|
// send a new job with the new error instead of the old one
|
||||||
|
j = trackedTreeItem{TreeItem: TreeItem{ID: j.ID, Error: errors.New("tree is nil and error is nil")}, rootIdx: j.rootIdx}
|
||||||
|
} else {
|
||||||
|
subtrees := j.Tree.Subtrees()
|
||||||
|
debug.Log("subtrees for tree %v: %v", j.ID, subtrees)
|
||||||
|
// iterate backwards over subtree to compensate backwards traversal order of nextTreeID selection
|
||||||
|
for i := len(subtrees) - 1; i >= 0; i-- {
|
||||||
|
id := subtrees[i]
|
||||||
|
if id.IsNull() {
|
||||||
|
// We do not need to raise this error here, it is
|
||||||
|
// checked when the tree is checked. Just make sure
|
||||||
|
// that we do not add any null IDs to the backlog.
|
||||||
|
debug.Log("tree %v has nil subtree", j.ID)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
backlog = append(backlog, trackedID{ID: id, rootIdx: j.rootIdx})
|
||||||
|
rootCounter[j.rootIdx]++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if p != nil && rootCounter[j.rootIdx] == 0 {
|
||||||
|
p.Add(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
job = j.TreeItem
|
||||||
|
outCh = out
|
||||||
|
inCh = nil
|
||||||
|
|
||||||
|
case outCh <- job:
|
||||||
|
debug.Log("tree sent to process: %v", job.ID)
|
||||||
|
outCh = nil
|
||||||
|
inCh = in
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// StreamTrees iteratively loads the given trees and their subtrees. The skip method
|
||||||
|
// is guaranteed to always be called from the same goroutine. To shutdown the started
|
||||||
|
// goroutines, either read all items from the channel or cancel the context. Then `Wait()`
|
||||||
|
// on the errgroup until all goroutines were stopped.
|
||||||
|
func StreamTrees(ctx context.Context, wg *errgroup.Group, repo TreeLoader, trees IDs, skip func(tree ID) bool, p *progress.Counter) <-chan TreeItem {
|
||||||
|
loaderChan := make(chan trackedID)
|
||||||
|
loadedTreeChan := make(chan trackedTreeItem)
|
||||||
|
treeStream := make(chan TreeItem)
|
||||||
|
|
||||||
|
var loadTreeWg sync.WaitGroup
|
||||||
|
|
||||||
|
for i := 0; i < streamTreeParallelism; i++ {
|
||||||
|
loadTreeWg.Add(1)
|
||||||
|
wg.Go(func() error {
|
||||||
|
defer loadTreeWg.Done()
|
||||||
|
loadTreeWorker(ctx, repo, loaderChan, loadedTreeChan)
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// close once all loadTreeWorkers have completed
|
||||||
|
wg.Go(func() error {
|
||||||
|
loadTreeWg.Wait()
|
||||||
|
close(loadedTreeChan)
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
wg.Go(func() error {
|
||||||
|
defer close(loaderChan)
|
||||||
|
defer close(treeStream)
|
||||||
|
filterTrees(ctx, trees, loaderChan, loadedTreeChan, treeStream, skip, p)
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
return treeStream
|
||||||
|
}
|
@ -12,7 +12,7 @@ import (
|
|||||||
//
|
//
|
||||||
// The final argument is true if Counter.Done has been called,
|
// The final argument is true if Counter.Done has been called,
|
||||||
// which means that the current call will be the last.
|
// which means that the current call will be the last.
|
||||||
type Func func(value uint64, runtime time.Duration, final bool)
|
type Func func(value uint64, total uint64, runtime time.Duration, final bool)
|
||||||
|
|
||||||
// A Counter tracks a running count and controls a goroutine that passes its
|
// A Counter tracks a running count and controls a goroutine that passes its
|
||||||
// value periodically to a Func.
|
// value periodically to a Func.
|
||||||
@ -27,16 +27,19 @@ type Counter struct {
|
|||||||
|
|
||||||
valueMutex sync.Mutex
|
valueMutex sync.Mutex
|
||||||
value uint64
|
value uint64
|
||||||
|
max uint64
|
||||||
}
|
}
|
||||||
|
|
||||||
// New starts a new Counter.
|
// New starts a new Counter.
|
||||||
func New(interval time.Duration, report Func) *Counter {
|
func New(interval time.Duration, total uint64, report Func) *Counter {
|
||||||
c := &Counter{
|
c := &Counter{
|
||||||
report: report,
|
report: report,
|
||||||
start: time.Now(),
|
start: time.Now(),
|
||||||
stopped: make(chan struct{}),
|
stopped: make(chan struct{}),
|
||||||
stop: make(chan struct{}),
|
stop: make(chan struct{}),
|
||||||
|
max: total,
|
||||||
}
|
}
|
||||||
|
|
||||||
if interval > 0 {
|
if interval > 0 {
|
||||||
c.tick = time.NewTicker(interval)
|
c.tick = time.NewTicker(interval)
|
||||||
}
|
}
|
||||||
@ -56,6 +59,16 @@ func (c *Counter) Add(v uint64) {
|
|||||||
c.valueMutex.Unlock()
|
c.valueMutex.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetMax sets the maximum expected counter value. This method is concurrency-safe.
|
||||||
|
func (c *Counter) SetMax(max uint64) {
|
||||||
|
if c == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.valueMutex.Lock()
|
||||||
|
c.max = max
|
||||||
|
c.valueMutex.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
// Done tells a Counter to stop and waits for it to report its final value.
|
// Done tells a Counter to stop and waits for it to report its final value.
|
||||||
func (c *Counter) Done() {
|
func (c *Counter) Done() {
|
||||||
if c == nil {
|
if c == nil {
|
||||||
@ -69,7 +82,8 @@ func (c *Counter) Done() {
|
|||||||
*c = Counter{} // Prevent reuse.
|
*c = Counter{} // Prevent reuse.
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Counter) get() uint64 {
|
// Get the current Counter value. This method is concurrency-safe.
|
||||||
|
func (c *Counter) Get() uint64 {
|
||||||
c.valueMutex.Lock()
|
c.valueMutex.Lock()
|
||||||
v := c.value
|
v := c.value
|
||||||
c.valueMutex.Unlock()
|
c.valueMutex.Unlock()
|
||||||
@ -77,11 +91,19 @@ func (c *Counter) get() uint64 {
|
|||||||
return v
|
return v
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *Counter) getMax() uint64 {
|
||||||
|
c.valueMutex.Lock()
|
||||||
|
max := c.max
|
||||||
|
c.valueMutex.Unlock()
|
||||||
|
|
||||||
|
return max
|
||||||
|
}
|
||||||
|
|
||||||
func (c *Counter) run() {
|
func (c *Counter) run() {
|
||||||
defer close(c.stopped)
|
defer close(c.stopped)
|
||||||
defer func() {
|
defer func() {
|
||||||
// Must be a func so that time.Since isn't called at defer time.
|
// Must be a func so that time.Since isn't called at defer time.
|
||||||
c.report(c.get(), time.Since(c.start), true)
|
c.report(c.Get(), c.getMax(), time.Since(c.start), true)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
var tick <-chan time.Time
|
var tick <-chan time.Time
|
||||||
@ -101,6 +123,6 @@ func (c *Counter) run() {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
c.report(c.get(), now.Sub(c.start), false)
|
c.report(c.Get(), c.getMax(), now.Sub(c.start), false)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -10,23 +10,32 @@ import (
|
|||||||
|
|
||||||
func TestCounter(t *testing.T) {
|
func TestCounter(t *testing.T) {
|
||||||
const N = 100
|
const N = 100
|
||||||
|
const startTotal = uint64(12345)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
finalSeen = false
|
finalSeen = false
|
||||||
increasing = true
|
increasing = true
|
||||||
last uint64
|
last uint64
|
||||||
|
lastTotal = startTotal
|
||||||
ncalls int
|
ncalls int
|
||||||
|
nmaxChange int
|
||||||
)
|
)
|
||||||
|
|
||||||
report := func(value uint64, d time.Duration, final bool) {
|
report := func(value uint64, total uint64, d time.Duration, final bool) {
|
||||||
|
if final {
|
||||||
finalSeen = true
|
finalSeen = true
|
||||||
|
}
|
||||||
if value < last {
|
if value < last {
|
||||||
increasing = false
|
increasing = false
|
||||||
}
|
}
|
||||||
last = value
|
last = value
|
||||||
|
if total != lastTotal {
|
||||||
|
nmaxChange++
|
||||||
|
}
|
||||||
|
lastTotal = total
|
||||||
ncalls++
|
ncalls++
|
||||||
}
|
}
|
||||||
c := progress.New(10*time.Millisecond, report)
|
c := progress.New(10*time.Millisecond, startTotal, report)
|
||||||
|
|
||||||
done := make(chan struct{})
|
done := make(chan struct{})
|
||||||
go func() {
|
go func() {
|
||||||
@ -35,6 +44,7 @@ func TestCounter(t *testing.T) {
|
|||||||
time.Sleep(time.Millisecond)
|
time.Sleep(time.Millisecond)
|
||||||
c.Add(1)
|
c.Add(1)
|
||||||
}
|
}
|
||||||
|
c.SetMax(42)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
<-done
|
<-done
|
||||||
@ -43,6 +53,8 @@ func TestCounter(t *testing.T) {
|
|||||||
test.Assert(t, finalSeen, "final call did not happen")
|
test.Assert(t, finalSeen, "final call did not happen")
|
||||||
test.Assert(t, increasing, "values not increasing")
|
test.Assert(t, increasing, "values not increasing")
|
||||||
test.Equals(t, uint64(N), last)
|
test.Equals(t, uint64(N), last)
|
||||||
|
test.Equals(t, uint64(42), lastTotal)
|
||||||
|
test.Equals(t, int(1), nmaxChange)
|
||||||
|
|
||||||
t.Log("number of calls:", ncalls)
|
t.Log("number of calls:", ncalls)
|
||||||
}
|
}
|
||||||
@ -58,14 +70,14 @@ func TestCounterNoTick(t *testing.T) {
|
|||||||
finalSeen := false
|
finalSeen := false
|
||||||
otherSeen := false
|
otherSeen := false
|
||||||
|
|
||||||
report := func(value uint64, d time.Duration, final bool) {
|
report := func(value, total uint64, d time.Duration, final bool) {
|
||||||
if final {
|
if final {
|
||||||
finalSeen = true
|
finalSeen = true
|
||||||
} else {
|
} else {
|
||||||
otherSeen = true
|
otherSeen = true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
c := progress.New(0, report)
|
c := progress.New(0, 1, report)
|
||||||
time.Sleep(time.Millisecond)
|
time.Sleep(time.Millisecond)
|
||||||
c.Done()
|
c.Done()
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user