From 99f7fd74e3f45e518031f7c0302e3583452b0714 Mon Sep 17 00:00:00 2001 From: Alexander Neumann Date: Sat, 3 Mar 2018 14:20:54 +0100 Subject: [PATCH] backend: Improve Save() As mentioned in issue [#1560](https://github.com/restic/restic/pull/1560#issuecomment-364689346) this changes the signature for `backend.Save()`. It now takes a parameter of interface type `RewindReader`, so that the backend implementations or our `RetryBackend` middleware can reset the reader to the beginning and then retry an upload operation. The `RewindReader` interface also provides a `Length()` method, which is used in the backend to get the size of the data to be saved. This removes several ugly hacks we had to do to pull the size back out of the `io.Reader` passed to `Save()` before. In the `s3` and `rest` backend this is actively used. --- .../archiver/archiver_duplication_test.go | 2 +- internal/backend/azure/azure.go | 26 +-- internal/backend/b2/b2.go | 2 +- internal/backend/backend_error.go | 2 +- internal/backend/backend_retry.go | 19 +-- internal/backend/backend_retry_test.go | 41 +---- internal/backend/gs/gs.go | 3 +- internal/backend/local/local.go | 2 +- internal/backend/mem/mem_backend.go | 2 +- internal/backend/rest/rest.go | 8 +- internal/backend/s3/s3.go | 22 +-- internal/backend/sftp/sftp.go | 2 +- internal/backend/swift/swift.go | 6 +- internal/backend/test/benchmarks.go | 10 +- internal/backend/test/tests.go | 43 ++--- internal/backend/utils_test.go | 9 +- internal/cache/backend.go | 23 +-- internal/cache/backend_test.go | 2 +- internal/checker/checker_test.go | 46 +++++- internal/limiter/limiter_backend.go | 19 ++- internal/mock/backend.go | 4 +- internal/pack/pack_test.go | 4 +- internal/repository/key.go | 3 +- internal/repository/packer_manager.go | 15 +- internal/repository/packer_manager_test.go | 20 ++- internal/repository/repository.go | 8 +- internal/restic/backend.go | 4 +- internal/restic/rewind_reader.go | 90 ++++++++++ internal/restic/rewind_reader_test.go | 154 ++++++++++++++++++ 29 files changed, 387 insertions(+), 204 deletions(-) create mode 100644 internal/restic/rewind_reader.go create mode 100644 internal/restic/rewind_reader_test.go diff --git a/internal/archiver/archiver_duplication_test.go b/internal/archiver/archiver_duplication_test.go index 69a2ce21c..2538dfec0 100644 --- a/internal/archiver/archiver_duplication_test.go +++ b/internal/archiver/archiver_duplication_test.go @@ -48,7 +48,7 @@ func forgetfulBackend() restic.Backend { return nil, errors.New("not found") } - be.SaveFn = func(ctx context.Context, h restic.Handle, rd io.Reader) error { + be.SaveFn = func(ctx context.Context, h restic.Handle, rd restic.RewindReader) error { return nil } diff --git a/internal/backend/azure/azure.go b/internal/backend/azure/azure.go index c3c1fb70d..0b3f0b6ea 100644 --- a/internal/backend/azure/azure.go +++ b/internal/backend/azure/azure.go @@ -3,6 +3,7 @@ package azure import ( "context" "io" + "io/ioutil" "net/http" "os" "path" @@ -114,19 +115,8 @@ func (be *Backend) Path() string { return be.prefix } -// preventCloser wraps an io.Reader to run a function instead of the original Close() function. -type preventCloser struct { - io.Reader - f func() -} - -func (wr preventCloser) Close() error { - wr.f() - return nil -} - // Save stores data in the backend at the handle. -func (be *Backend) Save(ctx context.Context, h restic.Handle, rd io.Reader) (err error) { +func (be *Backend) Save(ctx context.Context, h restic.Handle, rd restic.RewindReader) error { if err := h.Valid(); err != nil { return err } @@ -137,18 +127,10 @@ func (be *Backend) Save(ctx context.Context, h restic.Handle, rd io.Reader) (err be.sem.GetToken() - // wrap the reader so that net/http client cannot close the reader, return - // the token instead. - rd = preventCloser{ - Reader: rd, - f: func() { - debug.Log("Close()") - }, - } - debug.Log("InsertObject(%v, %v)", be.container.Name, objName) - err = be.container.GetBlobReference(objName).CreateBlockBlobFromReader(rd, nil) + // wrap the reader so that net/http client cannot close the reader + err := be.container.GetBlobReference(objName).CreateBlockBlobFromReader(ioutil.NopCloser(rd), nil) be.sem.ReleaseToken() debug.Log("%v, err %#v", objName, err) diff --git a/internal/backend/b2/b2.go b/internal/backend/b2/b2.go index 7ad077cdf..4a2651223 100644 --- a/internal/backend/b2/b2.go +++ b/internal/backend/b2/b2.go @@ -185,7 +185,7 @@ func (be *b2Backend) openReader(ctx context.Context, h restic.Handle, length int } // Save stores data in the backend at the handle. -func (be *b2Backend) Save(ctx context.Context, h restic.Handle, rd io.Reader) error { +func (be *b2Backend) Save(ctx context.Context, h restic.Handle, rd restic.RewindReader) error { ctx, cancel := context.WithCancel(ctx) defer cancel() diff --git a/internal/backend/backend_error.go b/internal/backend/backend_error.go index ee4b68b9b..77a931858 100644 --- a/internal/backend/backend_error.go +++ b/internal/backend/backend_error.go @@ -45,7 +45,7 @@ func (be *ErrorBackend) fail(p float32) bool { } // Save stores the data in the backend under the given handle. -func (be *ErrorBackend) Save(ctx context.Context, h restic.Handle, rd io.Reader) error { +func (be *ErrorBackend) Save(ctx context.Context, h restic.Handle, rd restic.RewindReader) error { if be.fail(be.FailSave) { return errors.Errorf("Save(%v) random error induced", h) } diff --git a/internal/backend/backend_retry.go b/internal/backend/backend_retry.go index 00274e43e..9f8834892 100644 --- a/internal/backend/backend_retry.go +++ b/internal/backend/backend_retry.go @@ -8,7 +8,6 @@ import ( "github.com/cenkalti/backoff" "github.com/restic/restic/internal/debug" - "github.com/restic/restic/internal/errors" "github.com/restic/restic/internal/restic" ) @@ -47,23 +46,9 @@ func (be *RetryBackend) retry(ctx context.Context, msg string, f func() error) e } // Save stores the data in the backend under the given handle. -func (be *RetryBackend) Save(ctx context.Context, h restic.Handle, rd io.Reader) error { - seeker, ok := rd.(io.Seeker) - if !ok { - return errors.Errorf("reader %T is not a seeker", rd) - } - - pos, err := seeker.Seek(0, io.SeekCurrent) - if err != nil { - return errors.Wrap(err, "Seek") - } - - if pos != 0 { - return errors.Errorf("reader is not at the beginning (pos %v)", pos) - } - +func (be *RetryBackend) Save(ctx context.Context, h restic.Handle, rd restic.RewindReader) error { return be.retry(ctx, fmt.Sprintf("Save(%v)", h), func() error { - _, err := seeker.Seek(0, io.SeekStart) + err := rd.Rewind() if err != nil { return err } diff --git a/internal/backend/backend_retry_test.go b/internal/backend/backend_retry_test.go index 832547c77..277f96756 100644 --- a/internal/backend/backend_retry_test.go +++ b/internal/backend/backend_retry_test.go @@ -13,48 +13,11 @@ import ( "github.com/restic/restic/internal/test" ) -func TestBackendRetrySeeker(t *testing.T) { - be := &mock.Backend{ - SaveFn: func(ctx context.Context, h restic.Handle, rd io.Reader) error { - return nil - }, - } - - retryBackend := RetryBackend{ - Backend: be, - } - - data := test.Random(24, 23*14123) - - type wrapReader struct { - io.Reader - } - - var rd io.Reader - rd = wrapReader{bytes.NewReader(data)} - - err := retryBackend.Save(context.TODO(), restic.Handle{}, rd) - if err == nil { - t.Fatal("did not get expected error for retry backend with non-seeker reader") - } - - rd = bytes.NewReader(data) - _, err = io.CopyN(ioutil.Discard, rd, 5) - if err != nil { - t.Fatal(err) - } - - err = retryBackend.Save(context.TODO(), restic.Handle{}, rd) - if err == nil { - t.Fatal("did not get expected error for partial reader") - } -} - func TestBackendSaveRetry(t *testing.T) { buf := bytes.NewBuffer(nil) errcount := 0 be := &mock.Backend{ - SaveFn: func(ctx context.Context, h restic.Handle, rd io.Reader) error { + SaveFn: func(ctx context.Context, h restic.Handle, rd restic.RewindReader) error { if errcount == 0 { errcount++ _, err := io.CopyN(ioutil.Discard, rd, 120) @@ -75,7 +38,7 @@ func TestBackendSaveRetry(t *testing.T) { } data := test.Random(23, 5*1024*1024+11241) - err := retryBackend.Save(context.TODO(), restic.Handle{}, bytes.NewReader(data)) + err := retryBackend.Save(context.TODO(), restic.Handle{}, restic.NewByteReader(data)) if err != nil { t.Fatal(err) } diff --git a/internal/backend/gs/gs.go b/internal/backend/gs/gs.go index bbcf152bc..7cdced796 100644 --- a/internal/backend/gs/gs.go +++ b/internal/backend/gs/gs.go @@ -207,7 +207,7 @@ func (be *Backend) Path() string { } // Save stores data in the backend at the handle. -func (be *Backend) Save(ctx context.Context, h restic.Handle, rd io.Reader) (err error) { +func (be *Backend) Save(ctx context.Context, h restic.Handle, rd restic.RewindReader) error { if err := h.Valid(); err != nil { return err } @@ -250,6 +250,7 @@ func (be *Backend) Save(ctx context.Context, h restic.Handle, rd io.Reader) (err info, err := be.service.Objects.Insert(be.bucketName, &storage.Object{ Name: objName, + Size: uint64(rd.Length()), }).Media(rd, cs).Do() be.sem.ReleaseToken() diff --git a/internal/backend/local/local.go b/internal/backend/local/local.go index 2f82e2a42..d1b7d6788 100644 --- a/internal/backend/local/local.go +++ b/internal/backend/local/local.go @@ -98,7 +98,7 @@ func (b *Local) IsNotExist(err error) bool { } // Save stores data in the backend at the handle. -func (b *Local) Save(ctx context.Context, h restic.Handle, rd io.Reader) error { +func (b *Local) Save(ctx context.Context, h restic.Handle, rd restic.RewindReader) error { debug.Log("Save %v", h) if err := h.Valid(); err != nil { return err diff --git a/internal/backend/mem/mem_backend.go b/internal/backend/mem/mem_backend.go index a64ef774d..a8244be43 100644 --- a/internal/backend/mem/mem_backend.go +++ b/internal/backend/mem/mem_backend.go @@ -59,7 +59,7 @@ func (be *MemoryBackend) IsNotExist(err error) bool { } // Save adds new Data to the backend. -func (be *MemoryBackend) Save(ctx context.Context, h restic.Handle, rd io.Reader) error { +func (be *MemoryBackend) Save(ctx context.Context, h restic.Handle, rd restic.RewindReader) error { if err := h.Valid(); err != nil { return err } diff --git a/internal/backend/rest/rest.go b/internal/backend/rest/rest.go index 2f5b35675..58a8b83d8 100644 --- a/internal/backend/rest/rest.go +++ b/internal/backend/rest/rest.go @@ -9,6 +9,7 @@ import ( "net/http" "net/url" "path" + "strconv" "strings" "golang.org/x/net/context/ctxhttp" @@ -105,7 +106,7 @@ func (b *restBackend) Location() string { } // Save stores data in the backend at the handle. -func (b *restBackend) Save(ctx context.Context, h restic.Handle, rd io.Reader) (err error) { +func (b *restBackend) Save(ctx context.Context, h restic.Handle, rd restic.RewindReader) error { if err := h.Valid(); err != nil { return err } @@ -114,12 +115,11 @@ func (b *restBackend) Save(ctx context.Context, h restic.Handle, rd io.Reader) ( defer cancel() // make sure that client.Post() cannot close the reader by wrapping it - rd = ioutil.NopCloser(rd) - - req, err := http.NewRequest(http.MethodPost, b.Filename(h), rd) + req, err := http.NewRequest(http.MethodPost, b.Filename(h), ioutil.NopCloser(rd)) if err != nil { return errors.Wrap(err, "NewRequest") } + req.Header.Set("Content-Length", strconv.Itoa(rd.Length())) req.Header.Set("Content-Type", "application/octet-stream") req.Header.Set("Accept", contentTypeV2) diff --git a/internal/backend/s3/s3.go b/internal/backend/s3/s3.go index d36679bf7..420729b0e 100644 --- a/internal/backend/s3/s3.go +++ b/internal/backend/s3/s3.go @@ -240,7 +240,7 @@ func lenForFile(f *os.File) (int64, error) { } // Save stores data in the backend at the handle. -func (be *Backend) Save(ctx context.Context, h restic.Handle, rd io.Reader) (err error) { +func (be *Backend) Save(ctx context.Context, h restic.Handle, rd restic.RewindReader) error { debug.Log("Save %v", h) if err := h.Valid(); err != nil { @@ -252,27 +252,11 @@ func (be *Backend) Save(ctx context.Context, h restic.Handle, rd io.Reader) (err be.sem.GetToken() defer be.sem.ReleaseToken() - var size int64 = -1 - - type lenner interface { - Len() int - } - - // find size for reader - if f, ok := rd.(*os.File); ok { - size, err = lenForFile(f) - if err != nil { - return err - } - } else if l, ok := rd.(lenner); ok { - size = int64(l.Len()) - } - opts := minio.PutObjectOptions{} opts.ContentType = "application/octet-stream" - debug.Log("PutObject(%v, %v, %v)", be.cfg.Bucket, objName, size) - n, err := be.client.PutObjectWithContext(ctx, be.cfg.Bucket, objName, ioutil.NopCloser(rd), size, opts) + debug.Log("PutObject(%v, %v, %v)", be.cfg.Bucket, objName, rd.Length()) + n, err := be.client.PutObjectWithContext(ctx, be.cfg.Bucket, objName, ioutil.NopCloser(rd), int64(rd.Length()), opts) debug.Log("%v -> %v bytes, err %#v: %v", objName, n, err, err) diff --git a/internal/backend/sftp/sftp.go b/internal/backend/sftp/sftp.go index 58fe38c24..8f7855a37 100644 --- a/internal/backend/sftp/sftp.go +++ b/internal/backend/sftp/sftp.go @@ -282,7 +282,7 @@ func Join(parts ...string) string { } // Save stores data in the backend at the handle. -func (r *SFTP) Save(ctx context.Context, h restic.Handle, rd io.Reader) (err error) { +func (r *SFTP) Save(ctx context.Context, h restic.Handle, rd restic.RewindReader) error { debug.Log("Save %v", h) if err := r.clientError(); err != nil { return err diff --git a/internal/backend/swift/swift.go b/internal/backend/swift/swift.go index 115a8d0e3..8a17450d8 100644 --- a/internal/backend/swift/swift.go +++ b/internal/backend/swift/swift.go @@ -156,8 +156,8 @@ func (be *beSwift) openReader(ctx context.Context, h restic.Handle, length int, } // Save stores data in the backend at the handle. -func (be *beSwift) Save(ctx context.Context, h restic.Handle, rd io.Reader) (err error) { - if err = h.Valid(); err != nil { +func (be *beSwift) Save(ctx context.Context, h restic.Handle, rd restic.RewindReader) error { + if err := h.Valid(); err != nil { return err } @@ -171,7 +171,7 @@ func (be *beSwift) Save(ctx context.Context, h restic.Handle, rd io.Reader) (err encoding := "binary/octet-stream" debug.Log("PutObject(%v, %v, %v)", be.container, objName, encoding) - _, err = be.conn.ObjectPut(be.container, objName, rd, true, "", encoding, nil) + _, err := be.conn.ObjectPut(be.container, objName, rd, true, "", encoding, nil) debug.Log("%v, err %#v", objName, err) return errors.Wrap(err, "client.PutObject") diff --git a/internal/backend/test/benchmarks.go b/internal/backend/test/benchmarks.go index 2c3dbff2e..302768f2e 100644 --- a/internal/backend/test/benchmarks.go +++ b/internal/backend/test/benchmarks.go @@ -14,7 +14,8 @@ func saveRandomFile(t testing.TB, be restic.Backend, length int) ([]byte, restic data := test.Random(23, length) id := restic.Hash(data) handle := restic.Handle{Type: restic.DataFile, Name: id.String()} - if err := be.Save(context.TODO(), handle, bytes.NewReader(data)); err != nil { + err := be.Save(context.TODO(), handle, restic.NewByteReader(data)) + if err != nil { t.Fatalf("Save() error: %+v", err) } return data, handle @@ -148,16 +149,11 @@ func (s *Suite) BenchmarkSave(t *testing.B) { id := restic.Hash(data) handle := restic.Handle{Type: restic.DataFile, Name: id.String()} - rd := bytes.NewReader(data) - + rd := restic.NewByteReader(data) t.SetBytes(int64(length)) t.ResetTimer() for i := 0; i < t.N; i++ { - if _, err := rd.Seek(0, 0); err != nil { - t.Fatal(err) - } - if err := be.Save(context.TODO(), handle, rd); err != nil { t.Fatal(err) } diff --git a/internal/backend/test/tests.go b/internal/backend/test/tests.go index d861c9589..c1127103d 100644 --- a/internal/backend/test/tests.go +++ b/internal/backend/test/tests.go @@ -10,7 +10,6 @@ import ( "os" "reflect" "sort" - "strings" "testing" "time" @@ -85,7 +84,7 @@ func (s *Suite) TestConfig(t *testing.T) { t.Fatalf("did not get expected error for non-existing config") } - err = b.Save(context.TODO(), restic.Handle{Type: restic.ConfigFile}, strings.NewReader(testString)) + err = b.Save(context.TODO(), restic.Handle{Type: restic.ConfigFile}, restic.NewByteReader([]byte(testString))) if err != nil { t.Fatalf("Save() error: %+v", err) } @@ -135,7 +134,7 @@ func (s *Suite) TestLoad(t *testing.T) { id := restic.Hash(data) handle := restic.Handle{Type: restic.DataFile, Name: id.String()} - err = b.Save(context.TODO(), handle, bytes.NewReader(data)) + err = b.Save(context.TODO(), handle, restic.NewByteReader(data)) if err != nil { t.Fatalf("Save() error: %+v", err) } @@ -250,7 +249,7 @@ func (s *Suite) TestList(t *testing.T) { data := test.Random(rand.Int(), rand.Intn(100)+55) id := restic.Hash(data) h := restic.Handle{Type: restic.DataFile, Name: id.String()} - err := b.Save(context.TODO(), h, bytes.NewReader(data)) + err := b.Save(context.TODO(), h, restic.NewByteReader(data)) if err != nil { t.Fatal(err) } @@ -340,7 +339,7 @@ func (s *Suite) TestListCancel(t *testing.T) { data := []byte(fmt.Sprintf("random test blob %v", i)) id := restic.Hash(data) h := restic.Handle{Type: restic.DataFile, Name: id.String()} - err := b.Save(context.TODO(), h, bytes.NewReader(data)) + err := b.Save(context.TODO(), h, restic.NewByteReader(data)) if err != nil { t.Fatal(err) } @@ -443,7 +442,7 @@ func (s *Suite) TestListCancel(t *testing.T) { } type errorCloser struct { - io.Reader + io.ReadSeeker l int t testing.TB } @@ -453,10 +452,15 @@ func (ec errorCloser) Close() error { return errors.New("forbidden method close was called") } -func (ec errorCloser) Len() int { +func (ec errorCloser) Length() int { return ec.l } +func (ec errorCloser) Rewind() error { + _, err := ec.ReadSeeker.Seek(0, io.SeekStart) + return err +} + // TestSave tests saving data in the backend. func (s *Suite) TestSave(t *testing.T) { seedRand(t) @@ -480,7 +484,7 @@ func (s *Suite) TestSave(t *testing.T) { Type: restic.DataFile, Name: fmt.Sprintf("%s-%d", id, i), } - err := b.Save(context.TODO(), h, bytes.NewReader(data)) + err := b.Save(context.TODO(), h, restic.NewByteReader(data)) test.OK(t, err) buf, err := backend.LoadAll(context.TODO(), b, h) @@ -532,7 +536,7 @@ func (s *Suite) TestSave(t *testing.T) { // wrap the tempfile in an errorCloser, so we can detect if the backend // closes the reader - err = b.Save(context.TODO(), h, errorCloser{t: t, l: length, Reader: tmpfile}) + err = b.Save(context.TODO(), h, errorCloser{t: t, l: length, ReadSeeker: tmpfile}) if err != nil { t.Fatal(err) } @@ -542,25 +546,10 @@ func (s *Suite) TestSave(t *testing.T) { t.Fatalf("error removing item: %+v", err) } - // try again directly with the temp file - if _, err = tmpfile.Seek(588, io.SeekStart); err != nil { - t.Fatal(err) - } - - err = b.Save(context.TODO(), h, tmpfile) - if err != nil { - t.Fatal(err) - } - if err = tmpfile.Close(); err != nil { t.Fatal(err) } - err = b.Remove(context.TODO(), h) - if err != nil { - t.Fatalf("error removing item: %+v", err) - } - if err = os.Remove(tmpfile.Name()); err != nil { t.Fatal(err) } @@ -585,7 +574,7 @@ func (s *Suite) TestSaveFilenames(t *testing.T) { for i, test := range filenameTests { h := restic.Handle{Name: test.name, Type: restic.DataFile} - err := b.Save(context.TODO(), h, strings.NewReader(test.data)) + err := b.Save(context.TODO(), h, restic.NewByteReader([]byte(test.data))) if err != nil { t.Errorf("test %d failed: Save() returned %+v", i, err) continue @@ -622,7 +611,7 @@ var testStrings = []struct { func store(t testing.TB, b restic.Backend, tpe restic.FileType, data []byte) restic.Handle { id := restic.Hash(data) h := restic.Handle{Name: id.String(), Type: tpe} - err := b.Save(context.TODO(), h, bytes.NewReader(data)) + err := b.Save(context.TODO(), h, restic.NewByteReader([]byte(data))) test.OK(t, err) return h } @@ -776,7 +765,7 @@ func (s *Suite) TestBackend(t *testing.T) { test.Assert(t, !ok, "removed blob still present") // create blob - err = b.Save(context.TODO(), h, strings.NewReader(ts.data)) + err = b.Save(context.TODO(), h, restic.NewByteReader([]byte(ts.data))) test.OK(t, err) // list items diff --git a/internal/backend/utils_test.go b/internal/backend/utils_test.go index ed7488a57..74929fd0b 100644 --- a/internal/backend/utils_test.go +++ b/internal/backend/utils_test.go @@ -24,7 +24,8 @@ func TestLoadAll(t *testing.T) { data := rtest.Random(23+i, rand.Intn(MiB)+500*KiB) id := restic.Hash(data) - err := b.Save(context.TODO(), restic.Handle{Name: id.String(), Type: restic.DataFile}, bytes.NewReader(data)) + h := restic.Handle{Name: id.String(), Type: restic.DataFile} + err := b.Save(context.TODO(), h, restic.NewByteReader(data)) rtest.OK(t, err) buf, err := backend.LoadAll(context.TODO(), b, restic.Handle{Type: restic.DataFile, Name: id.String()}) @@ -49,7 +50,8 @@ func TestLoadSmallBuffer(t *testing.T) { data := rtest.Random(23+i, rand.Intn(MiB)+500*KiB) id := restic.Hash(data) - err := b.Save(context.TODO(), restic.Handle{Name: id.String(), Type: restic.DataFile}, bytes.NewReader(data)) + h := restic.Handle{Name: id.String(), Type: restic.DataFile} + err := b.Save(context.TODO(), h, restic.NewByteReader(data)) rtest.OK(t, err) buf, err := backend.LoadAll(context.TODO(), b, restic.Handle{Type: restic.DataFile, Name: id.String()}) @@ -74,7 +76,8 @@ func TestLoadLargeBuffer(t *testing.T) { data := rtest.Random(23+i, rand.Intn(MiB)+500*KiB) id := restic.Hash(data) - err := b.Save(context.TODO(), restic.Handle{Name: id.String(), Type: restic.DataFile}, bytes.NewReader(data)) + h := restic.Handle{Name: id.String(), Type: restic.DataFile} + err := b.Save(context.TODO(), h, restic.NewByteReader(data)) rtest.OK(t, err) buf, err := backend.LoadAll(context.TODO(), b, restic.Handle{Type: restic.DataFile, Name: id.String()}) diff --git a/internal/cache/backend.go b/internal/cache/backend.go index 08639e462..b6c163b29 100644 --- a/internal/cache/backend.go +++ b/internal/cache/backend.go @@ -5,7 +5,6 @@ import ( "io" "sync" - "github.com/pkg/errors" "github.com/restic/restic/internal/debug" "github.com/restic/restic/internal/restic" ) @@ -50,35 +49,29 @@ var autoCacheTypes = map[restic.FileType]struct{}{ } // Save stores a new file in the backend and the cache. -func (b *Backend) Save(ctx context.Context, h restic.Handle, rd io.Reader) (err error) { +func (b *Backend) Save(ctx context.Context, h restic.Handle, rd restic.RewindReader) error { if _, ok := autoCacheTypes[h.Type]; !ok { return b.Backend.Save(ctx, h, rd) } debug.Log("Save(%v): auto-store in the cache", h) - seeker, ok := rd.(io.Seeker) - if !ok { - return errors.New("reader is not a seeker") - } - - pos, err := seeker.Seek(0, io.SeekCurrent) + // make sure the reader is at the start + err := rd.Rewind() if err != nil { - return errors.Wrapf(err, "Seek") - } - - if pos != 0 { - return errors.Errorf("reader is not rewind (pos %d)", pos) + return err } + // first, save in the backend err = b.Backend.Save(ctx, h, rd) if err != nil { return err } - _, err = seeker.Seek(pos, io.SeekStart) + // next, save in the cache + err = rd.Rewind() if err != nil { - return errors.Wrapf(err, "Seek") + return err } err = b.Cache.Save(h, rd) diff --git a/internal/cache/backend_test.go b/internal/cache/backend_test.go index dbcd53326..993319b0d 100644 --- a/internal/cache/backend_test.go +++ b/internal/cache/backend_test.go @@ -28,7 +28,7 @@ func loadAndCompare(t testing.TB, be restic.Backend, h restic.Handle, data []byt } func save(t testing.TB, be restic.Backend, h restic.Handle, data []byte) { - err := be.Save(context.TODO(), h, bytes.NewReader(data)) + err := be.Save(context.TODO(), h, restic.NewByteReader(data)) if err != nil { t.Fatal(err) } diff --git a/internal/checker/checker_test.go b/internal/checker/checker_test.go index f1b0fe938..601407636 100644 --- a/internal/checker/checker_test.go +++ b/internal/checker/checker_test.go @@ -3,7 +3,9 @@ package checker_test import ( "context" "io" + "io/ioutil" "math/rand" + "os" "path/filepath" "sort" "testing" @@ -195,17 +197,47 @@ func TestModifiedIndex(t *testing.T) { Type: restic.IndexFile, Name: "90f838b4ac28735fda8644fe6a08dbc742e57aaf81b30977b4fefa357010eafd", } - err := repo.Backend().Load(context.TODO(), h, 0, 0, func(rd io.Reader) error { - // save the index again with a modified name so that the hash doesn't match - // the content any more - h2 := restic.Handle{ - Type: restic.IndexFile, - Name: "80f838b4ac28735fda8644fe6a08dbc742e57aaf81b30977b4fefa357010eafd", + + tmpfile, err := ioutil.TempFile("", "restic-test-mod-index-") + if err != nil { + t.Fatal(err) + } + defer func() { + err := tmpfile.Close() + if err != nil { + t.Fatal(err) } - return repo.Backend().Save(context.TODO(), h2, rd) + + err = os.Remove(tmpfile.Name()) + if err != nil { + t.Fatal(err) + } + }() + + // read the file from the backend + err = repo.Backend().Load(context.TODO(), h, 0, 0, func(rd io.Reader) error { + _, err := io.Copy(tmpfile, rd) + return err }) test.OK(t, err) + // save the index again with a modified name so that the hash doesn't match + // the content any more + h2 := restic.Handle{ + Type: restic.IndexFile, + Name: "80f838b4ac28735fda8644fe6a08dbc742e57aaf81b30977b4fefa357010eafd", + } + + rd, err := restic.NewFileReader(tmpfile) + if err != nil { + t.Fatal(err) + } + + err = repo.Backend().Save(context.TODO(), h2, rd) + if err != nil { + t.Fatal(err) + } + chkr := checker.New(repo) hints, errs := chkr.LoadIndex(context.TODO()) if len(errs) == 0 { diff --git a/internal/limiter/limiter_backend.go b/internal/limiter/limiter_backend.go index 963a084dd..b2351a8fd 100644 --- a/internal/limiter/limiter_backend.go +++ b/internal/limiter/limiter_backend.go @@ -21,8 +21,23 @@ type rateLimitedBackend struct { limiter Limiter } -func (r rateLimitedBackend) Save(ctx context.Context, h restic.Handle, rd io.Reader) error { - return r.Backend.Save(ctx, h, r.limiter.Upstream(rd)) +func (r rateLimitedBackend) Save(ctx context.Context, h restic.Handle, rd restic.RewindReader) error { + limited := limitedRewindReader{ + RewindReader: rd, + limited: r.limiter.Upstream(rd), + } + + return r.Backend.Save(ctx, h, limited) +} + +type limitedRewindReader struct { + restic.RewindReader + + limited io.Reader +} + +func (l limitedRewindReader) Read(b []byte) (int, error) { + return l.limited.Read(b) } func (r rateLimitedBackend) Load(ctx context.Context, h restic.Handle, length int, offset int64, consumer func(rd io.Reader) error) error { diff --git a/internal/mock/backend.go b/internal/mock/backend.go index 14288c3f6..930fdb3ee 100644 --- a/internal/mock/backend.go +++ b/internal/mock/backend.go @@ -12,7 +12,7 @@ import ( type Backend struct { CloseFn func() error IsNotExistFn func(err error) bool - SaveFn func(ctx context.Context, h restic.Handle, rd io.Reader) error + SaveFn func(ctx context.Context, h restic.Handle, rd restic.RewindReader) error OpenReaderFn func(ctx context.Context, h restic.Handle, length int, offset int64) (io.ReadCloser, error) StatFn func(ctx context.Context, h restic.Handle) (restic.FileInfo, error) ListFn func(ctx context.Context, t restic.FileType, fn func(restic.FileInfo) error) error @@ -56,7 +56,7 @@ func (m *Backend) IsNotExist(err error) bool { } // Save data in the backend. -func (m *Backend) Save(ctx context.Context, h restic.Handle, rd io.Reader) error { +func (m *Backend) Save(ctx context.Context, h restic.Handle, rd restic.RewindReader) error { if m.SaveFn == nil { return errors.New("not implemented") } diff --git a/internal/pack/pack_test.go b/internal/pack/pack_test.go index a61b28a2b..12e3600bb 100644 --- a/internal/pack/pack_test.go +++ b/internal/pack/pack_test.go @@ -127,7 +127,7 @@ func TestUnpackReadSeeker(t *testing.T) { id := restic.Hash(packData) handle := restic.Handle{Type: restic.DataFile, Name: id.String()} - rtest.OK(t, b.Save(context.TODO(), handle, bytes.NewReader(packData))) + rtest.OK(t, b.Save(context.TODO(), handle, restic.NewByteReader(packData))) verifyBlobs(t, bufs, k, restic.ReaderAt(b, handle), packSize) } @@ -140,6 +140,6 @@ func TestShortPack(t *testing.T) { id := restic.Hash(packData) handle := restic.Handle{Type: restic.DataFile, Name: id.String()} - rtest.OK(t, b.Save(context.TODO(), handle, bytes.NewReader(packData))) + rtest.OK(t, b.Save(context.TODO(), handle, restic.NewByteReader(packData))) verifyBlobs(t, bufs, k, restic.ReaderAt(b, handle), packSize) } diff --git a/internal/repository/key.go b/internal/repository/key.go index 63c35b8e8..44e3d24e0 100644 --- a/internal/repository/key.go +++ b/internal/repository/key.go @@ -1,7 +1,6 @@ package repository import ( - "bytes" "context" "encoding/json" "fmt" @@ -250,7 +249,7 @@ func AddKey(ctx context.Context, s *Repository, password string, template *crypt Name: restic.Hash(buf).String(), } - err = s.be.Save(ctx, h, bytes.NewReader(buf)) + err = s.be.Save(ctx, h, restic.NewByteReader(buf)) if err != nil { return nil, err } diff --git a/internal/repository/packer_manager.go b/internal/repository/packer_manager.go index d4bba2cf3..4884e0885 100644 --- a/internal/repository/packer_manager.go +++ b/internal/repository/packer_manager.go @@ -3,7 +3,6 @@ package repository import ( "context" "crypto/sha256" - "io" "os" "sync" @@ -19,7 +18,7 @@ import ( // Saver implements saving data in a backend. type Saver interface { - Save(context.Context, restic.Handle, io.Reader) error + Save(context.Context, restic.Handle, restic.RewindReader) error } // Packer holds a pack.Packer together with a hash writer. @@ -96,15 +95,15 @@ func (r *Repository) savePacker(ctx context.Context, t restic.BlobType, p *Packe return err } - _, err = p.tmpfile.Seek(0, 0) - if err != nil { - return errors.Wrap(err, "Seek") - } - id := restic.IDFromHash(p.hw.Sum(nil)) h := restic.Handle{Type: restic.DataFile, Name: id.String()} - err = r.be.Save(ctx, h, p.tmpfile) + rd, err := restic.NewFileReader(p.tmpfile) + if err != nil { + return err + } + + err = r.be.Save(ctx, h, rd) if err != nil { debug.Log("Save(%v) error: %v", h, err) return err diff --git a/internal/repository/packer_manager_test.go b/internal/repository/packer_manager_test.go index aface097b..e1f067f7c 100644 --- a/internal/repository/packer_manager_test.go +++ b/internal/repository/packer_manager_test.go @@ -50,11 +50,17 @@ func randomID(rd io.Reader) restic.ID { const maxBlobSize = 1 << 20 -func saveFile(t testing.TB, be Saver, f *os.File, id restic.ID) { +func saveFile(t testing.TB, be Saver, length int, f *os.File, id restic.ID) { h := restic.Handle{Type: restic.DataFile, Name: id.String()} t.Logf("save file %v", h) - if err := be.Save(context.TODO(), h, f); err != nil { + rd, err := restic.NewFileReader(f) + if err != nil { + t.Fatal(err) + } + + err = be.Save(context.TODO(), h, rd) + if err != nil { t.Fatal(err) } @@ -101,12 +107,8 @@ func fillPacks(t testing.TB, rnd *randReader, be Saver, pm *packerManager, buf [ t.Fatal(err) } - if _, err = packer.tmpfile.Seek(0, 0); err != nil { - t.Fatal(err) - } - packID := restic.IDFromHash(packer.hw.Sum(nil)) - saveFile(t, be, packer.tmpfile, packID) + saveFile(t, be, int(packer.Size()), packer.tmpfile, packID) } return bytes @@ -122,7 +124,7 @@ func flushRemainingPacks(t testing.TB, rnd *randReader, be Saver, pm *packerMana bytes += int(n) packID := restic.IDFromHash(packer.hw.Sum(nil)) - saveFile(t, be, packer.tmpfile, packID) + saveFile(t, be, int(packer.Size()), packer.tmpfile, packID) } } @@ -147,7 +149,7 @@ func BenchmarkPackerManager(t *testing.B) { rnd := newRandReader(rand.NewSource(23)) be := &mock.Backend{ - SaveFn: func(context.Context, restic.Handle, io.Reader) error { return nil }, + SaveFn: func(context.Context, restic.Handle, restic.RewindReader) error { return nil }, } blobBuf := make([]byte, maxBlobSize) diff --git a/internal/repository/repository.go b/internal/repository/repository.go index 3fafe0efb..0c1236f42 100644 --- a/internal/repository/repository.go +++ b/internal/repository/repository.go @@ -282,7 +282,7 @@ func (r *Repository) SaveUnpacked(ctx context.Context, t restic.FileType, p []by id = restic.Hash(ciphertext) h := restic.Handle{Type: t, Name: id.String()} - err = r.be.Save(ctx, h, bytes.NewReader(ciphertext)) + err = r.be.Save(ctx, h, restic.NewByteReader(ciphertext)) if err != nil { debug.Log("error saving blob %v: %v", h, err) return restic.ID{}, err @@ -456,11 +456,7 @@ func (r *Repository) LoadIndex(ctx context.Context) error { } } - if err := <-errCh; err != nil { - return err - } - - return nil + return <-errCh } // LoadIndex loads the index id from backend and returns it. diff --git a/internal/restic/backend.go b/internal/restic/backend.go index c5dc25087..b2fd46b97 100644 --- a/internal/restic/backend.go +++ b/internal/restic/backend.go @@ -20,8 +20,8 @@ type Backend interface { // Close the backend Close() error - // Save stores the data in the backend under the given handle. - Save(ctx context.Context, h Handle, rd io.Reader) error + // Save stores the data from rd under the given handle. + Save(ctx context.Context, h Handle, rd RewindReader) error // Load runs fn with a reader that yields the contents of the file at h at the // given offset. If length is larger than zero, only a portion of the file diff --git a/internal/restic/rewind_reader.go b/internal/restic/rewind_reader.go new file mode 100644 index 000000000..aa851d99f --- /dev/null +++ b/internal/restic/rewind_reader.go @@ -0,0 +1,90 @@ +package restic + +import ( + "bytes" + "io" + + "github.com/restic/restic/internal/errors" +) + +// RewindReader allows resetting the Reader to the beginning of the data. +type RewindReader interface { + io.Reader + + // Rewind rewinds the reader so the same data can be read again from the + // start. + Rewind() error + + // Length returns the number of bytes that can be read from the Reader + // after calling Rewind. + Length() int +} + +// ByteReader implements a RewindReader for a byte slice. +type ByteReader struct { + *bytes.Reader + Len int +} + +// Rewind restarts the reader from the beginning of the data. +func (b *ByteReader) Rewind() error { + _, err := b.Reader.Seek(0, io.SeekStart) + return err +} + +// Length returns the number of bytes read from the reader after Rewind is +// called. +func (b *ByteReader) Length() int { + return b.Len +} + +// statically ensure that *ByteReader implements RewindReader. +var _ RewindReader = &ByteReader{} + +// NewByteReader prepares a ByteReader that can then be used to read buf. +func NewByteReader(buf []byte) *ByteReader { + return &ByteReader{ + Reader: bytes.NewReader(buf), + Len: len(buf), + } +} + +// statically ensure that *FileReader implements RewindReader. +var _ RewindReader = &FileReader{} + +// FileReader implements a RewindReader for an open file. +type FileReader struct { + io.ReadSeeker + Len int +} + +// Rewind seeks to the beginning of the file. +func (f *FileReader) Rewind() error { + _, err := f.ReadSeeker.Seek(0, io.SeekStart) + return errors.Wrap(err, "Seek") +} + +// Length returns the length of the file. +func (f *FileReader) Length() int { + return f.Len +} + +// NewFileReader wraps f in a *FileReader. +func NewFileReader(f io.ReadSeeker) (*FileReader, error) { + pos, err := f.Seek(0, io.SeekEnd) + if err != nil { + return nil, errors.Wrap(err, "Seek") + } + + fr := &FileReader{ + ReadSeeker: f, + Len: int(pos), + } + + err = fr.Rewind() + if err != nil { + return nil, err + } + + return fr, nil +} diff --git a/internal/restic/rewind_reader_test.go b/internal/restic/rewind_reader_test.go new file mode 100644 index 000000000..c3c6001f8 --- /dev/null +++ b/internal/restic/rewind_reader_test.go @@ -0,0 +1,154 @@ +package restic + +import ( + "bytes" + "io" + "io/ioutil" + "math/rand" + "os" + "path/filepath" + "testing" + "time" + + "github.com/restic/restic/internal/test" +) + +func TestByteReader(t *testing.T) { + buf := []byte("foobar") + fn := func() RewindReader { + return NewByteReader(buf) + } + testRewindReader(t, fn, buf) +} + +func TestFileReader(t *testing.T) { + buf := []byte("foobar") + + d, cleanup := test.TempDir(t) + defer cleanup() + + filename := filepath.Join(d, "file-reader-test") + err := ioutil.WriteFile(filename, []byte("foobar"), 0600) + if err != nil { + t.Fatal(err) + } + + f, err := os.Open(filename) + if err != nil { + t.Fatal(err) + } + + defer func() { + err := f.Close() + if err != nil { + t.Fatal(err) + } + }() + + fn := func() RewindReader { + rd, err := NewFileReader(f) + if err != nil { + t.Fatal(err) + } + return rd + } + + testRewindReader(t, fn, buf) +} + +func testRewindReader(t *testing.T, fn func() RewindReader, data []byte) { + seed := time.Now().Unix() + t.Logf("seed is %d", seed) + rnd := rand.New(rand.NewSource(seed)) + + type ReaderTestFunc func(t testing.TB, r RewindReader, data []byte) + var tests = []ReaderTestFunc{ + func(t testing.TB, rd RewindReader, data []byte) { + if rd.Length() != len(data) { + t.Fatalf("wrong length returned, want %d, got %d", len(data), rd.Length()) + } + + buf := make([]byte, len(data)) + _, err := io.ReadFull(rd, buf) + if err != nil { + t.Fatal(err) + } + + if !bytes.Equal(buf, data) { + t.Fatalf("wrong data returned") + } + + if rd.Length() != len(data) { + t.Fatalf("wrong length returned, want %d, got %d", len(data), rd.Length()) + } + + err = rd.Rewind() + if err != nil { + t.Fatal(err) + } + + if rd.Length() != len(data) { + t.Fatalf("wrong length returned, want %d, got %d", len(data), rd.Length()) + } + + buf2 := make([]byte, len(data)) + _, err = io.ReadFull(rd, buf2) + if err != nil { + t.Fatal(err) + } + + if !bytes.Equal(buf2, data) { + t.Fatalf("wrong data returned") + } + + if rd.Length() != len(data) { + t.Fatalf("wrong length returned, want %d, got %d", len(data), rd.Length()) + } + }, + func(t testing.TB, rd RewindReader, data []byte) { + // read first bytes + buf := make([]byte, rnd.Intn(len(data))) + _, err := io.ReadFull(rd, buf) + if err != nil { + t.Fatal(err) + } + + if !bytes.Equal(buf, data[:len(buf)]) { + t.Fatalf("wrong data returned") + } + + err = rd.Rewind() + if err != nil { + t.Fatal(err) + } + + buf2 := make([]byte, rnd.Intn(len(data))) + _, err = io.ReadFull(rd, buf2) + if err != nil { + t.Fatal(err) + } + + if !bytes.Equal(buf2, data[:len(buf2)]) { + t.Fatalf("wrong data returned") + } + + // read remainder + buf3 := make([]byte, len(data)-len(buf2)) + _, err = io.ReadFull(rd, buf3) + if err != nil { + t.Fatal(err) + } + + if !bytes.Equal(buf3, data[len(buf2):]) { + t.Fatalf("wrong data returned") + } + }, + } + + for _, test := range tests { + t.Run("", func(t *testing.T) { + rd := fn() + test(t, rd, data) + }) + } +}