2
2
mirror of https://github.com/octoleo/restic.git synced 2024-11-10 15:21:03 +00:00

archiver: remove tomb usage

This commit is contained in:
Michael Eischer 2022-05-27 19:08:50 +02:00
parent 0cb6b3d80a
commit 408ac1a0c2
8 changed files with 123 additions and 102 deletions

View File

@ -13,7 +13,7 @@ import (
"github.com/restic/restic/internal/errors" "github.com/restic/restic/internal/errors"
"github.com/restic/restic/internal/fs" "github.com/restic/restic/internal/fs"
"github.com/restic/restic/internal/restic" "github.com/restic/restic/internal/restic"
tomb "gopkg.in/tomb.v2" "golang.org/x/sync/errgroup"
) )
// SelectByNameFunc returns true for all items that should be included (files and // SelectByNameFunc returns true for all items that should be included (files and
@ -762,17 +762,23 @@ func (arch *Archiver) loadParentTree(ctx context.Context, snapshotID restic.ID)
} }
// runWorkers starts the worker pools, which are stopped when the context is cancelled. // runWorkers starts the worker pools, which are stopped when the context is cancelled.
func (arch *Archiver) runWorkers(ctx context.Context, t *tomb.Tomb) { func (arch *Archiver) runWorkers(ctx context.Context, wg *errgroup.Group) {
arch.blobSaver = NewBlobSaver(ctx, t, arch.Repo, arch.Options.SaveBlobConcurrency) arch.blobSaver = NewBlobSaver(ctx, wg, arch.Repo, arch.Options.SaveBlobConcurrency)
arch.fileSaver = NewFileSaver(ctx, t, arch.fileSaver = NewFileSaver(ctx, wg,
arch.blobSaver.Save, arch.blobSaver.Save,
arch.Repo.Config().ChunkerPolynomial, arch.Repo.Config().ChunkerPolynomial,
arch.Options.FileReadConcurrency, arch.Options.SaveBlobConcurrency) arch.Options.FileReadConcurrency, arch.Options.SaveBlobConcurrency)
arch.fileSaver.CompleteBlob = arch.CompleteBlob arch.fileSaver.CompleteBlob = arch.CompleteBlob
arch.fileSaver.NodeFromFileInfo = arch.nodeFromFileInfo arch.fileSaver.NodeFromFileInfo = arch.nodeFromFileInfo
arch.treeSaver = NewTreeSaver(ctx, t, arch.Options.SaveTreeConcurrency, arch.saveTree, arch.Error) arch.treeSaver = NewTreeSaver(ctx, wg, arch.Options.SaveTreeConcurrency, arch.saveTree, arch.Error)
}
func (arch *Archiver) stopWorkers() {
arch.blobSaver.TriggerShutdown()
arch.fileSaver.TriggerShutdown()
arch.treeSaver.TriggerShutdown()
} }
// Snapshot saves several targets and returns a snapshot. // Snapshot saves several targets and returns a snapshot.
@ -787,17 +793,16 @@ func (arch *Archiver) Snapshot(ctx context.Context, targets []string, opts Snaps
return nil, restic.ID{}, err return nil, restic.ID{}, err
} }
var t tomb.Tomb wg, wgCtx := errgroup.WithContext(ctx)
wctx := t.Context(ctx)
start := time.Now() start := time.Now()
var rootTreeID restic.ID var rootTreeID restic.ID
var stats ItemStats var stats ItemStats
t.Go(func() error { wg.Go(func() error {
arch.runWorkers(wctx, &t) arch.runWorkers(wgCtx, wg)
debug.Log("starting snapshot") debug.Log("starting snapshot")
tree, err := arch.SaveTree(wctx, "/", atree, arch.loadParentTree(wctx, opts.ParentSnapshot)) tree, err := arch.SaveTree(wgCtx, "/", atree, arch.loadParentTree(wgCtx, opts.ParentSnapshot))
if err != nil { if err != nil {
return err return err
} }
@ -806,13 +811,12 @@ func (arch *Archiver) Snapshot(ctx context.Context, targets []string, opts Snaps
return errors.New("snapshot is empty") return errors.New("snapshot is empty")
} }
rootTreeID, stats, err = arch.saveTree(wctx, tree) rootTreeID, stats, err = arch.saveTree(wgCtx, tree)
// trigger shutdown but don't set an error arch.stopWorkers()
t.Kill(nil)
return err return err
}) })
err = t.Wait() err = wg.Wait()
debug.Log("err is %v", err) debug.Log("err is %v", err)
if err != nil { if err != nil {

View File

@ -23,7 +23,7 @@ import (
"github.com/restic/restic/internal/repository" "github.com/restic/restic/internal/repository"
"github.com/restic/restic/internal/restic" "github.com/restic/restic/internal/restic"
restictest "github.com/restic/restic/internal/test" restictest "github.com/restic/restic/internal/test"
tomb "gopkg.in/tomb.v2" "golang.org/x/sync/errgroup"
) )
func prepareTempdirRepoSrc(t testing.TB, src TestDir) (tempdir string, repo restic.Repository, cleanup func()) { func prepareTempdirRepoSrc(t testing.TB, src TestDir) (tempdir string, repo restic.Repository, cleanup func()) {
@ -41,11 +41,10 @@ func prepareTempdirRepoSrc(t testing.TB, src TestDir) (tempdir string, repo rest
} }
func saveFile(t testing.TB, repo restic.Repository, filename string, filesystem fs.FS) (*restic.Node, ItemStats) { func saveFile(t testing.TB, repo restic.Repository, filename string, filesystem fs.FS) (*restic.Node, ItemStats) {
var tmb tomb.Tomb wg, ctx := errgroup.WithContext(context.TODO())
ctx := tmb.Context(context.Background())
arch := New(repo, filesystem, Options{}) arch := New(repo, filesystem, Options{})
arch.runWorkers(ctx, &tmb) arch.runWorkers(ctx, wg)
arch.Error = func(item string, fi os.FileInfo, err error) error { arch.Error = func(item string, fi os.FileInfo, err error) error {
t.Errorf("archiver error for %v: %v", item, err) t.Errorf("archiver error for %v: %v", item, err)
@ -87,14 +86,13 @@ func saveFile(t testing.TB, repo restic.Repository, filename string, filesystem
t.Fatal(res.Err()) t.Fatal(res.Err())
} }
tmb.Kill(nil) arch.stopWorkers()
err = tmb.Wait() err = repo.Flush(context.Background())
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
err = repo.Flush(context.Background()) if err := wg.Wait(); err != nil {
if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -214,14 +212,14 @@ func TestArchiverSave(t *testing.T) {
tempdir, repo, cleanup := prepareTempdirRepoSrc(t, TestDir{"file": testfile}) tempdir, repo, cleanup := prepareTempdirRepoSrc(t, TestDir{"file": testfile})
defer cleanup() defer cleanup()
var tmb tomb.Tomb wg, ctx := errgroup.WithContext(ctx)
arch := New(repo, fs.Track{FS: fs.Local{}}, Options{}) arch := New(repo, fs.Track{FS: fs.Local{}}, Options{})
arch.Error = func(item string, fi os.FileInfo, err error) error { arch.Error = func(item string, fi os.FileInfo, err error) error {
t.Errorf("archiver error for %v: %v", item, err) t.Errorf("archiver error for %v: %v", item, err)
return err return err
} }
arch.runWorkers(tmb.Context(ctx), &tmb) arch.runWorkers(ctx, wg)
node, excluded, err := arch.Save(ctx, "/", filepath.Join(tempdir, "file"), nil) node, excluded, err := arch.Save(ctx, "/", filepath.Join(tempdir, "file"), nil)
if err != nil { if err != nil {
@ -243,6 +241,7 @@ func TestArchiverSave(t *testing.T) {
stats := node.stats stats := node.stats
arch.stopWorkers()
err = repo.Flush(ctx) err = repo.Flush(ctx)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@ -281,6 +280,8 @@ func TestArchiverSaveReaderFS(t *testing.T) {
repo, cleanup := repository.TestRepository(t) repo, cleanup := repository.TestRepository(t)
defer cleanup() defer cleanup()
wg, ctx := errgroup.WithContext(ctx)
ts := time.Now() ts := time.Now()
filename := "xx" filename := "xx"
readerFs := &fs.Reader{ readerFs := &fs.Reader{
@ -290,14 +291,12 @@ func TestArchiverSaveReaderFS(t *testing.T) {
ReadCloser: ioutil.NopCloser(strings.NewReader(test.Data)), ReadCloser: ioutil.NopCloser(strings.NewReader(test.Data)),
} }
var tmb tomb.Tomb
arch := New(repo, readerFs, Options{}) arch := New(repo, readerFs, Options{})
arch.Error = func(item string, fi os.FileInfo, err error) error { arch.Error = func(item string, fi os.FileInfo, err error) error {
t.Errorf("archiver error for %v: %v", item, err) t.Errorf("archiver error for %v: %v", item, err)
return err return err
} }
arch.runWorkers(tmb.Context(ctx), &tmb) arch.runWorkers(ctx, wg)
node, excluded, err := arch.Save(ctx, "/", filename, nil) node, excluded, err := arch.Save(ctx, "/", filename, nil)
t.Logf("Save returned %v %v", node, err) t.Logf("Save returned %v %v", node, err)
@ -320,6 +319,7 @@ func TestArchiverSaveReaderFS(t *testing.T) {
stats := node.stats stats := node.stats
arch.stopWorkers()
err = repo.Flush(ctx) err = repo.Flush(ctx)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@ -826,14 +826,13 @@ func TestArchiverSaveDir(t *testing.T) {
for _, test := range tests { for _, test := range tests {
t.Run("", func(t *testing.T) { t.Run("", func(t *testing.T) {
var tmb tomb.Tomb
ctx := tmb.Context(context.Background())
tempdir, repo, cleanup := prepareTempdirRepoSrc(t, test.src) tempdir, repo, cleanup := prepareTempdirRepoSrc(t, test.src)
defer cleanup() defer cleanup()
wg, ctx := errgroup.WithContext(context.Background())
arch := New(repo, fs.Track{FS: fs.Local{}}, Options{}) arch := New(repo, fs.Track{FS: fs.Local{}}, Options{})
arch.runWorkers(ctx, &tmb) arch.runWorkers(ctx, wg)
chdir := tempdir chdir := tempdir
if test.chdir != "" { if test.chdir != "" {
@ -856,12 +855,6 @@ func TestArchiverSaveDir(t *testing.T) {
ft.Wait(ctx) ft.Wait(ctx)
node, stats := ft.Node(), ft.Stats() node, stats := ft.Node(), ft.Stats()
tmb.Kill(nil)
err = tmb.Wait()
if err != nil {
t.Fatal(err)
}
t.Logf("stats: %v", stats) t.Logf("stats: %v", stats)
if stats.DataSize != 0 { if stats.DataSize != 0 {
t.Errorf("wrong stats returned in DataSize, want 0, got %d", stats.DataSize) t.Errorf("wrong stats returned in DataSize, want 0, got %d", stats.DataSize)
@ -876,24 +869,29 @@ func TestArchiverSaveDir(t *testing.T) {
t.Errorf("wrong stats returned in TreeBlobs, want > 0, got %d", stats.TreeBlobs) t.Errorf("wrong stats returned in TreeBlobs, want > 0, got %d", stats.TreeBlobs)
} }
ctx = context.Background()
node.Name = targetNodeName node.Name = targetNodeName
tree := &restic.Tree{Nodes: []*restic.Node{node}} tree := &restic.Tree{Nodes: []*restic.Node{node}}
treeID, err := repo.SaveTree(ctx, tree) treeID, err := repo.SaveTree(ctx, tree)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
arch.stopWorkers()
err = repo.Flush(ctx) err = repo.Flush(ctx)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
err = wg.Wait()
if err != nil {
t.Fatal(err)
}
want := test.want want := test.want
if want == nil { if want == nil {
want = test.src want = test.src
} }
TestEnsureTree(ctx, t, "/", repo, treeID, want) TestEnsureTree(context.TODO(), t, "/", repo, treeID, want)
}) })
} }
} }
@ -915,11 +913,10 @@ func TestArchiverSaveDirIncremental(t *testing.T) {
// save the empty directory several times in a row, then have a look if the // save the empty directory several times in a row, then have a look if the
// archiver did save the same tree several times // archiver did save the same tree several times
for i := 0; i < 5; i++ { for i := 0; i < 5; i++ {
var tmb tomb.Tomb wg, ctx := errgroup.WithContext(context.TODO())
ctx := tmb.Context(context.Background())
arch := New(repo, fs.Track{FS: fs.Local{}}, Options{}) arch := New(repo, fs.Track{FS: fs.Local{}}, Options{})
arch.runWorkers(ctx, &tmb) arch.runWorkers(ctx, wg)
fi, err := fs.Lstat(tempdir) fi, err := fs.Lstat(tempdir)
if err != nil { if err != nil {
@ -934,8 +931,6 @@ func TestArchiverSaveDirIncremental(t *testing.T) {
ft.Wait(ctx) ft.Wait(ctx)
node, stats := ft.Node(), ft.Stats() node, stats := ft.Node(), ft.Stats()
tmb.Kill(nil)
err = tmb.Wait()
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -972,7 +967,12 @@ func TestArchiverSaveDirIncremental(t *testing.T) {
t.Logf("node subtree %v", node.Subtree) t.Logf("node subtree %v", node.Subtree)
err = repo.Flush(context.Background()) arch.stopWorkers()
err = repo.Flush(ctx)
if err != nil {
t.Fatal(err)
}
err = wg.Wait()
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -1081,9 +1081,6 @@ func TestArchiverSaveTree(t *testing.T) {
for _, test := range tests { for _, test := range tests {
t.Run("", func(t *testing.T) { t.Run("", func(t *testing.T) {
var tmb tomb.Tomb
ctx := tmb.Context(context.Background())
tempdir, repo, cleanup := prepareTempdirRepoSrc(t, test.src) tempdir, repo, cleanup := prepareTempdirRepoSrc(t, test.src)
defer cleanup() defer cleanup()
@ -1099,7 +1096,9 @@ func TestArchiverSaveTree(t *testing.T) {
stat.Add(s) stat.Add(s)
} }
arch.runWorkers(ctx, &tmb) wg, ctx := errgroup.WithContext(context.TODO())
arch.runWorkers(ctx, wg)
back := restictest.Chdir(t, tempdir) back := restictest.Chdir(t, tempdir)
defer back() defer back()
@ -1123,14 +1122,12 @@ func TestArchiverSaveTree(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
tmb.Kill(nil) arch.stopWorkers()
err = tmb.Wait() err = repo.Flush(ctx)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
err = wg.Wait()
ctx = context.Background()
err = repo.Flush(ctx)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -1139,7 +1136,7 @@ func TestArchiverSaveTree(t *testing.T) {
if want == nil { if want == nil {
want = test.src want = test.src
} }
TestEnsureTree(ctx, t, "/", repo, treeID, want) TestEnsureTree(context.TODO(), t, "/", repo, treeID, want)
bothZeroOrNeither(t, uint64(test.stat.DataBlobs), uint64(stat.DataBlobs)) bothZeroOrNeither(t, uint64(test.stat.DataBlobs), uint64(stat.DataBlobs))
bothZeroOrNeither(t, uint64(test.stat.TreeBlobs), uint64(stat.TreeBlobs)) bothZeroOrNeither(t, uint64(test.stat.TreeBlobs), uint64(stat.TreeBlobs))
bothZeroOrNeither(t, test.stat.DataSize, stat.DataSize) bothZeroOrNeither(t, test.stat.DataSize, stat.DataSize)
@ -2240,14 +2237,14 @@ func TestRacyFileSwap(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
var tmb tomb.Tomb wg, ctx := errgroup.WithContext(ctx)
arch := New(repo, fs.Track{FS: statfs}, Options{}) arch := New(repo, fs.Track{FS: statfs}, Options{})
arch.Error = func(item string, fi os.FileInfo, err error) error { arch.Error = func(item string, fi os.FileInfo, err error) error {
t.Logf("archiver error as expected for %v: %v", item, err) t.Logf("archiver error as expected for %v: %v", item, err)
return err return err
} }
arch.runWorkers(tmb.Context(ctx), &tmb) arch.runWorkers(ctx, wg)
// fs.Track will panic if the file was not closed // fs.Track will panic if the file was not closed
_, excluded, err := arch.Save(ctx, "/", tempfile, nil) _, excluded, err := arch.Save(ctx, "/", tempfile, nil)

View File

@ -5,7 +5,7 @@ import (
"github.com/restic/restic/internal/debug" "github.com/restic/restic/internal/debug"
"github.com/restic/restic/internal/restic" "github.com/restic/restic/internal/restic"
tomb "gopkg.in/tomb.v2" "golang.org/x/sync/errgroup"
) )
// Saver allows saving a blob. // Saver allows saving a blob.
@ -22,7 +22,7 @@ type BlobSaver struct {
// NewBlobSaver returns a new blob. A worker pool is started, it is stopped // NewBlobSaver returns a new blob. A worker pool is started, it is stopped
// when ctx is cancelled. // when ctx is cancelled.
func NewBlobSaver(ctx context.Context, t *tomb.Tomb, repo Saver, workers uint) *BlobSaver { func NewBlobSaver(ctx context.Context, wg *errgroup.Group, repo Saver, workers uint) *BlobSaver {
ch := make(chan saveBlobJob) ch := make(chan saveBlobJob)
s := &BlobSaver{ s := &BlobSaver{
repo: repo, repo: repo,
@ -30,14 +30,18 @@ func NewBlobSaver(ctx context.Context, t *tomb.Tomb, repo Saver, workers uint) *
} }
for i := uint(0); i < workers; i++ { for i := uint(0); i < workers; i++ {
t.Go(func() error { wg.Go(func() error {
return s.worker(t.Context(ctx), ch) return s.worker(ctx, ch)
}) })
} }
return s return s
} }
func (s *BlobSaver) TriggerShutdown() {
close(s.ch)
}
// Save stores a blob in the repo. It checks the index and the known blobs // Save stores a blob in the repo. It checks the index and the known blobs
// before saving anything. It takes ownership of the buffer passed in. // before saving anything. It takes ownership of the buffer passed in.
func (s *BlobSaver) Save(ctx context.Context, t restic.BlobType, buf *Buffer) FutureBlob { func (s *BlobSaver) Save(ctx context.Context, t restic.BlobType, buf *Buffer) FutureBlob {
@ -114,10 +118,14 @@ func (s *BlobSaver) saveBlob(ctx context.Context, t restic.BlobType, buf []byte)
func (s *BlobSaver) worker(ctx context.Context, jobs <-chan saveBlobJob) error { func (s *BlobSaver) worker(ctx context.Context, jobs <-chan saveBlobJob) error {
for { for {
var job saveBlobJob var job saveBlobJob
var ok bool
select { select {
case <-ctx.Done(): case <-ctx.Done():
return nil return nil
case job = <-jobs: case job, ok = <-jobs:
if !ok {
return nil
}
} }
res, err := s.saveBlob(ctx, job.BlobType, job.buf.Data) res, err := s.saveBlob(ctx, job.BlobType, job.buf.Data)

View File

@ -10,7 +10,7 @@ import (
"github.com/restic/restic/internal/errors" "github.com/restic/restic/internal/errors"
"github.com/restic/restic/internal/repository" "github.com/restic/restic/internal/repository"
"github.com/restic/restic/internal/restic" "github.com/restic/restic/internal/restic"
tomb "gopkg.in/tomb.v2" "golang.org/x/sync/errgroup"
) )
var errTest = errors.New("test error") var errTest = errors.New("test error")
@ -38,12 +38,12 @@ func TestBlobSaver(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
tmb, ctx := tomb.WithContext(ctx) wg, ctx := errgroup.WithContext(ctx)
saver := &saveFail{ saver := &saveFail{
idx: repository.NewMasterIndex(), idx: repository.NewMasterIndex(),
} }
b := NewBlobSaver(ctx, tmb, saver, uint(runtime.NumCPU())) b := NewBlobSaver(ctx, wg, saver, uint(runtime.NumCPU()))
var results []FutureBlob var results []FutureBlob
@ -60,9 +60,9 @@ func TestBlobSaver(t *testing.T) {
} }
} }
tmb.Kill(nil) b.TriggerShutdown()
err := tmb.Wait() err := wg.Wait()
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -84,22 +84,22 @@ func TestBlobSaverError(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
tmb, ctx := tomb.WithContext(ctx) wg, ctx := errgroup.WithContext(ctx)
saver := &saveFail{ saver := &saveFail{
idx: repository.NewMasterIndex(), idx: repository.NewMasterIndex(),
failAt: int32(test.failAt), failAt: int32(test.failAt),
} }
b := NewBlobSaver(ctx, tmb, saver, uint(runtime.NumCPU())) b := NewBlobSaver(ctx, wg, saver, uint(runtime.NumCPU()))
for i := 0; i < test.blobs; i++ { for i := 0; i < test.blobs; i++ {
buf := &Buffer{Data: []byte(fmt.Sprintf("foo%d", i))} buf := &Buffer{Data: []byte(fmt.Sprintf("foo%d", i))}
b.Save(ctx, restic.DataBlob, buf) b.Save(ctx, restic.DataBlob, buf)
} }
tmb.Kill(nil) b.TriggerShutdown()
err := tmb.Wait() err := wg.Wait()
if err == nil { if err == nil {
t.Errorf("expected error not found") t.Errorf("expected error not found")
} }

View File

@ -10,7 +10,7 @@ import (
"github.com/restic/restic/internal/errors" "github.com/restic/restic/internal/errors"
"github.com/restic/restic/internal/fs" "github.com/restic/restic/internal/fs"
"github.com/restic/restic/internal/restic" "github.com/restic/restic/internal/restic"
tomb "gopkg.in/tomb.v2" "golang.org/x/sync/errgroup"
) )
// FutureFile is returned by Save and will return the data once it // FutureFile is returned by Save and will return the data once it
@ -67,7 +67,7 @@ type FileSaver struct {
// NewFileSaver returns a new file saver. A worker pool with fileWorkers is // NewFileSaver returns a new file saver. A worker pool with fileWorkers is
// started, it is stopped when ctx is cancelled. // started, it is stopped when ctx is cancelled.
func NewFileSaver(ctx context.Context, t *tomb.Tomb, save SaveBlobFn, pol chunker.Pol, fileWorkers, blobWorkers uint) *FileSaver { func NewFileSaver(ctx context.Context, wg *errgroup.Group, save SaveBlobFn, pol chunker.Pol, fileWorkers, blobWorkers uint) *FileSaver {
ch := make(chan saveFileJob) ch := make(chan saveFileJob)
debug.Log("new file saver with %v file workers and %v blob workers", fileWorkers, blobWorkers) debug.Log("new file saver with %v file workers and %v blob workers", fileWorkers, blobWorkers)
@ -84,8 +84,8 @@ func NewFileSaver(ctx context.Context, t *tomb.Tomb, save SaveBlobFn, pol chunke
} }
for i := uint(0); i < fileWorkers; i++ { for i := uint(0); i < fileWorkers; i++ {
t.Go(func() error { wg.Go(func() error {
s.worker(t.Context(ctx), ch) s.worker(ctx, ch)
return nil return nil
}) })
} }
@ -93,6 +93,10 @@ func NewFileSaver(ctx context.Context, t *tomb.Tomb, save SaveBlobFn, pol chunke
return s return s
} }
func (s *FileSaver) TriggerShutdown() {
close(s.ch)
}
// CompleteFunc is called when the file has been saved. // CompleteFunc is called when the file has been saved.
type CompleteFunc func(*restic.Node, ItemStats) type CompleteFunc func(*restic.Node, ItemStats)
@ -115,7 +119,6 @@ func (s *FileSaver) Save(ctx context.Context, snPath string, file fs.File, fi os
debug.Log("not sending job, context is cancelled: %v", ctx.Err()) debug.Log("not sending job, context is cancelled: %v", ctx.Err())
_ = file.Close() _ = file.Close()
close(ch) close(ch)
return FutureFile{ch: ch}
} }
return FutureFile{ch: ch} return FutureFile{ch: ch}
@ -226,12 +229,15 @@ func (s *FileSaver) worker(ctx context.Context, jobs <-chan saveFileJob) {
for { for {
var job saveFileJob var job saveFileJob
var ok bool
select { select {
case <-ctx.Done(): case <-ctx.Done():
return return
case job = <-jobs: case job, ok = <-jobs:
if !ok {
return
}
} }
res := s.saveFile(ctx, chnker, job.snPath, job.file, job.fi, job.start) res := s.saveFile(ctx, chnker, job.snPath, job.file, job.fi, job.start)
if job.complete != nil { if job.complete != nil {
job.complete(res.node, res.stats) job.complete(res.node, res.stats)

View File

@ -12,7 +12,7 @@ import (
"github.com/restic/restic/internal/fs" "github.com/restic/restic/internal/fs"
"github.com/restic/restic/internal/restic" "github.com/restic/restic/internal/restic"
"github.com/restic/restic/internal/test" "github.com/restic/restic/internal/test"
tomb "gopkg.in/tomb.v2" "golang.org/x/sync/errgroup"
) )
func createTestFiles(t testing.TB, num int) (files []string, cleanup func()) { func createTestFiles(t testing.TB, num int) (files []string, cleanup func()) {
@ -30,8 +30,8 @@ func createTestFiles(t testing.TB, num int) (files []string, cleanup func()) {
return files, cleanup return files, cleanup
} }
func startFileSaver(ctx context.Context, t testing.TB) (*FileSaver, context.Context, *tomb.Tomb) { func startFileSaver(ctx context.Context, t testing.TB) (*FileSaver, context.Context, *errgroup.Group) {
tmb, ctx := tomb.WithContext(ctx) wg, ctx := errgroup.WithContext(ctx)
saveBlob := func(ctx context.Context, tpe restic.BlobType, buf *Buffer) FutureBlob { saveBlob := func(ctx context.Context, tpe restic.BlobType, buf *Buffer) FutureBlob {
ch := make(chan saveBlobResponse) ch := make(chan saveBlobResponse)
@ -45,10 +45,10 @@ func startFileSaver(ctx context.Context, t testing.TB) (*FileSaver, context.Cont
t.Fatal(err) t.Fatal(err)
} }
s := NewFileSaver(ctx, tmb, saveBlob, pol, workers, workers) s := NewFileSaver(ctx, wg, saveBlob, pol, workers, workers)
s.NodeFromFileInfo = restic.NodeFromFileInfo s.NodeFromFileInfo = restic.NodeFromFileInfo
return s, ctx, tmb return s, ctx, wg
} }
func TestFileSaver(t *testing.T) { func TestFileSaver(t *testing.T) {
@ -62,7 +62,7 @@ func TestFileSaver(t *testing.T) {
completeFn := func(*restic.Node, ItemStats) {} completeFn := func(*restic.Node, ItemStats) {}
testFs := fs.Local{} testFs := fs.Local{}
s, ctx, tmb := startFileSaver(ctx, t) s, ctx, wg := startFileSaver(ctx, t)
var results []FutureFile var results []FutureFile
@ -88,9 +88,9 @@ func TestFileSaver(t *testing.T) {
} }
} }
tmb.Kill(nil) s.TriggerShutdown()
err := tmb.Wait() err := wg.Wait()
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View File

@ -5,7 +5,7 @@ import (
"github.com/restic/restic/internal/debug" "github.com/restic/restic/internal/debug"
"github.com/restic/restic/internal/restic" "github.com/restic/restic/internal/restic"
tomb "gopkg.in/tomb.v2" "golang.org/x/sync/errgroup"
) )
// FutureTree is returned by Save and will return the data once it // FutureTree is returned by Save and will return the data once it
@ -47,7 +47,7 @@ type TreeSaver struct {
// NewTreeSaver returns a new tree saver. A worker pool with treeWorkers is // NewTreeSaver returns a new tree saver. A worker pool with treeWorkers is
// started, it is stopped when ctx is cancelled. // started, it is stopped when ctx is cancelled.
func NewTreeSaver(ctx context.Context, t *tomb.Tomb, treeWorkers uint, saveTree func(context.Context, *restic.Tree) (restic.ID, ItemStats, error), errFn ErrorFunc) *TreeSaver { func NewTreeSaver(ctx context.Context, wg *errgroup.Group, treeWorkers uint, saveTree func(context.Context, *restic.Tree) (restic.ID, ItemStats, error), errFn ErrorFunc) *TreeSaver {
ch := make(chan saveTreeJob) ch := make(chan saveTreeJob)
s := &TreeSaver{ s := &TreeSaver{
@ -57,14 +57,18 @@ func NewTreeSaver(ctx context.Context, t *tomb.Tomb, treeWorkers uint, saveTree
} }
for i := uint(0); i < treeWorkers; i++ { for i := uint(0); i < treeWorkers; i++ {
t.Go(func() error { wg.Go(func() error {
return s.worker(t.Context(ctx), ch) return s.worker(ctx, ch)
}) })
} }
return s return s
} }
func (s *TreeSaver) TriggerShutdown() {
close(s.ch)
}
// Save stores the dir d and returns the data once it has been completed. // Save stores the dir d and returns the data once it has been completed.
func (s *TreeSaver) Save(ctx context.Context, snPath string, node *restic.Node, nodes []FutureNode, complete CompleteFunc) FutureTree { func (s *TreeSaver) Save(ctx context.Context, snPath string, node *restic.Node, nodes []FutureNode, complete CompleteFunc) FutureTree {
ch := make(chan saveTreeResponse, 1) ch := make(chan saveTreeResponse, 1)
@ -80,7 +84,6 @@ func (s *TreeSaver) Save(ctx context.Context, snPath string, node *restic.Node,
case <-ctx.Done(): case <-ctx.Done():
debug.Log("not saving tree, context is cancelled") debug.Log("not saving tree, context is cancelled")
close(ch) close(ch)
return FutureTree{ch: ch}
} }
return FutureTree{ch: ch} return FutureTree{ch: ch}
@ -146,12 +149,15 @@ func (s *TreeSaver) save(ctx context.Context, snPath string, node *restic.Node,
func (s *TreeSaver) worker(ctx context.Context, jobs <-chan saveTreeJob) error { func (s *TreeSaver) worker(ctx context.Context, jobs <-chan saveTreeJob) error {
for { for {
var job saveTreeJob var job saveTreeJob
var ok bool
select { select {
case <-ctx.Done(): case <-ctx.Done():
return nil return nil
case job = <-jobs: case job, ok = <-jobs:
if !ok {
return nil
}
} }
node, stats, err := s.save(ctx, job.snPath, job.node, job.nodes) node, stats, err := s.save(ctx, job.snPath, job.node, job.nodes)
if err != nil { if err != nil {
debug.Log("error saving tree blob: %v", err) debug.Log("error saving tree blob: %v", err)

View File

@ -10,14 +10,14 @@ import (
"github.com/restic/restic/internal/errors" "github.com/restic/restic/internal/errors"
"github.com/restic/restic/internal/restic" "github.com/restic/restic/internal/restic"
tomb "gopkg.in/tomb.v2" "golang.org/x/sync/errgroup"
) )
func TestTreeSaver(t *testing.T) { func TestTreeSaver(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
tmb, ctx := tomb.WithContext(ctx) wg, ctx := errgroup.WithContext(ctx)
saveFn := func(context.Context, *restic.Tree) (restic.ID, ItemStats, error) { saveFn := func(context.Context, *restic.Tree) (restic.ID, ItemStats, error) {
return restic.NewRandomID(), ItemStats{TreeBlobs: 1, TreeSize: 123}, nil return restic.NewRandomID(), ItemStats{TreeBlobs: 1, TreeSize: 123}, nil
@ -27,7 +27,7 @@ func TestTreeSaver(t *testing.T) {
return nil return nil
} }
b := NewTreeSaver(ctx, tmb, uint(runtime.NumCPU()), saveFn, errFn) b := NewTreeSaver(ctx, wg, uint(runtime.NumCPU()), saveFn, errFn)
var results []FutureTree var results []FutureTree
@ -44,9 +44,9 @@ func TestTreeSaver(t *testing.T) {
tree.Wait(ctx) tree.Wait(ctx)
} }
tmb.Kill(nil) b.TriggerShutdown()
err := tmb.Wait() err := wg.Wait()
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -71,7 +71,7 @@ func TestTreeSaverError(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
tmb, ctx := tomb.WithContext(ctx) wg, ctx := errgroup.WithContext(ctx)
var num int32 var num int32
saveFn := func(context.Context, *restic.Tree) (restic.ID, ItemStats, error) { saveFn := func(context.Context, *restic.Tree) (restic.ID, ItemStats, error) {
@ -88,7 +88,7 @@ func TestTreeSaverError(t *testing.T) {
return nil return nil
} }
b := NewTreeSaver(ctx, tmb, uint(runtime.NumCPU()), saveFn, errFn) b := NewTreeSaver(ctx, wg, uint(runtime.NumCPU()), saveFn, errFn)
var results []FutureTree var results []FutureTree
@ -105,9 +105,9 @@ func TestTreeSaverError(t *testing.T) {
tree.Wait(ctx) tree.Wait(ctx)
} }
tmb.Kill(nil) b.TriggerShutdown()
err := tmb.Wait() err := wg.Wait()
if err == nil { if err == nil {
t.Errorf("expected error not found") t.Errorf("expected error not found")
} }