diff --git a/cmd/restic/cmd_copy.go b/cmd/restic/cmd_copy.go index b11c5e1e8..cb8296d4b 100644 --- a/cmd/restic/cmd_copy.go +++ b/cmd/restic/cmd_copy.go @@ -183,7 +183,7 @@ func copyTree(ctx context.Context, srcRepo restic.Repository, dstRepo restic.Rep visited := visitedTrees.Has(treeID) visitedTrees.Insert(treeID) return visited - }) + }, nil) wg.Go(func() error { // reused buffer diff --git a/internal/checker/checker.go b/internal/checker/checker.go index 4a1c5cb4b..e41c1f1b5 100644 --- a/internal/checker/checker.go +++ b/internal/checker/checker.go @@ -376,7 +376,7 @@ func (c *Checker) Structure(ctx context.Context, errChan chan<- error) { c.blobRefs.M.Insert(h) c.blobRefs.Unlock() return blobReferenced - }) + }, nil) defer close(errChan) for i := 0; i < defaultParallelism; i++ { diff --git a/internal/restic/find.go b/internal/restic/find.go index b797cac6b..4c72766c6 100644 --- a/internal/restic/find.go +++ b/internal/restic/find.go @@ -27,7 +27,7 @@ func FindUsedBlobs(ctx context.Context, repo TreeLoader, treeID ID, blobs BlobSe blobs.Insert(h) lock.Unlock() return blobReferenced - }) + }, nil) wg.Go(func() error { for tree := range treeStream { diff --git a/internal/restic/tree_stream.go b/internal/restic/tree_stream.go index b71f4aa18..0c2a96810 100644 --- a/internal/restic/tree_stream.go +++ b/internal/restic/tree_stream.go @@ -6,6 +6,7 @@ import ( "sync" "github.com/restic/restic/internal/debug" + "github.com/restic/restic/internal/ui/progress" "golang.org/x/sync/errgroup" ) @@ -18,14 +19,24 @@ type TreeItem struct { *Tree } +type trackedTreeItem struct { + TreeItem + rootIdx int +} + +type trackedID struct { + ID + rootIdx int +} + // loadTreeWorker loads trees from repo and sends them to out. func loadTreeWorker(ctx context.Context, repo TreeLoader, - in <-chan ID, out chan<- TreeItem) { + in <-chan trackedID, out chan<- trackedTreeItem) { for treeID := range in { - tree, err := repo.LoadTree(ctx, treeID) + tree, err := repo.LoadTree(ctx, treeID.ID) debug.Log("load tree %v (%v) returned err: %v", tree, treeID, err) - job := TreeItem{ID: treeID, Error: err, Tree: tree} + job := trackedTreeItem{TreeItem: TreeItem{ID: treeID.ID, Error: err, Tree: tree}, rootIdx: treeID.rootIdx} select { case <-ctx.Done(): @@ -35,17 +46,23 @@ func loadTreeWorker(ctx context.Context, repo TreeLoader, } } -func filterTrees(ctx context.Context, backlog IDs, loaderChan chan<- ID, - in <-chan TreeItem, out chan<- TreeItem, skip func(tree ID) bool) { +func filterTrees(ctx context.Context, trees IDs, loaderChan chan<- trackedID, + in <-chan trackedTreeItem, out chan<- TreeItem, skip func(tree ID) bool, p *progress.Counter) { var ( inCh = in outCh chan<- TreeItem - loadCh chan<- ID + loadCh chan<- trackedID job TreeItem - nextTreeID ID + nextTreeID trackedID outstandingLoadTreeJobs = 0 ) + rootCounter := make([]int, len(trees)) + backlog := make([]trackedID, 0, len(trees)) + for idx, id := range trees { + backlog = append(backlog, trackedID{ID: id, rootIdx: idx}) + rootCounter[idx] = 1 + } for { if loadCh == nil && len(backlog) > 0 { @@ -53,7 +70,11 @@ func filterTrees(ctx context.Context, backlog IDs, loaderChan chan<- ID, ln := len(backlog) - 1 nextTreeID, backlog = backlog[ln], backlog[:ln] - if skip(nextTreeID) { + if skip(nextTreeID.ID) { + rootCounter[nextTreeID.rootIdx]-- + if p != nil && rootCounter[nextTreeID.rootIdx] == 0 { + p.Add(1) + } continue } @@ -82,6 +103,7 @@ func filterTrees(ctx context.Context, backlog IDs, loaderChan chan<- ID, } outstandingLoadTreeJobs-- + rootCounter[j.rootIdx]-- debug.Log("input job tree %v", j.ID) @@ -90,7 +112,7 @@ func filterTrees(ctx context.Context, backlog IDs, loaderChan chan<- ID, } else if j.Tree == nil { debug.Log("received job with nil tree pointer: %v (ID %v)", j.Error, j.ID) // send a new job with the new error instead of the old one - j = TreeItem{ID: j.ID, Error: errors.New("tree is nil and error is nil")} + j = trackedTreeItem{TreeItem: TreeItem{ID: j.ID, Error: errors.New("tree is nil and error is nil")}, rootIdx: j.rootIdx} } else { subtrees := j.Tree.Subtrees() debug.Log("subtrees for tree %v: %v", j.ID, subtrees) @@ -104,11 +126,15 @@ func filterTrees(ctx context.Context, backlog IDs, loaderChan chan<- ID, debug.Log("tree %v has nil subtree", j.ID) continue } - backlog = append(backlog, id) + backlog = append(backlog, trackedID{ID: id, rootIdx: j.rootIdx}) + rootCounter[j.rootIdx]++ } } + if p != nil && rootCounter[j.rootIdx] == 0 { + p.Add(1) + } - job = j + job = j.TreeItem outCh = out inCh = nil @@ -122,9 +148,9 @@ func filterTrees(ctx context.Context, backlog IDs, loaderChan chan<- ID, // StreamTrees iteratively loads the given trees and their subtrees. The skip method // is guaranteed to always be called from the same goroutine. -func StreamTrees(ctx context.Context, wg *errgroup.Group, repo TreeLoader, trees IDs, skip func(tree ID) bool) <-chan TreeItem { - loaderChan := make(chan ID) - loadedTreeChan := make(chan TreeItem) +func StreamTrees(ctx context.Context, wg *errgroup.Group, repo TreeLoader, trees IDs, skip func(tree ID) bool, p *progress.Counter) <-chan TreeItem { + loaderChan := make(chan trackedID) + loadedTreeChan := make(chan trackedTreeItem) treeStream := make(chan TreeItem) var loadTreeWg sync.WaitGroup @@ -148,7 +174,7 @@ func StreamTrees(ctx context.Context, wg *errgroup.Group, repo TreeLoader, trees wg.Go(func() error { defer close(loaderChan) defer close(treeStream) - filterTrees(ctx, trees, loaderChan, loadedTreeChan, treeStream, skip) + filterTrees(ctx, trees, loaderChan, loadedTreeChan, treeStream, skip, p) return nil }) return treeStream