diff --git a/internal/archiver/archiver_duplication_test.go b/internal/archiver/archiver_duplication_test.go index 69a2ce21c..2538dfec0 100644 --- a/internal/archiver/archiver_duplication_test.go +++ b/internal/archiver/archiver_duplication_test.go @@ -48,7 +48,7 @@ func forgetfulBackend() restic.Backend { return nil, errors.New("not found") } - be.SaveFn = func(ctx context.Context, h restic.Handle, rd io.Reader) error { + be.SaveFn = func(ctx context.Context, h restic.Handle, rd restic.RewindReader) error { return nil } diff --git a/internal/backend/azure/azure.go b/internal/backend/azure/azure.go index c3c1fb70d..0b3f0b6ea 100644 --- a/internal/backend/azure/azure.go +++ b/internal/backend/azure/azure.go @@ -3,6 +3,7 @@ package azure import ( "context" "io" + "io/ioutil" "net/http" "os" "path" @@ -114,19 +115,8 @@ func (be *Backend) Path() string { return be.prefix } -// preventCloser wraps an io.Reader to run a function instead of the original Close() function. -type preventCloser struct { - io.Reader - f func() -} - -func (wr preventCloser) Close() error { - wr.f() - return nil -} - // Save stores data in the backend at the handle. -func (be *Backend) Save(ctx context.Context, h restic.Handle, rd io.Reader) (err error) { +func (be *Backend) Save(ctx context.Context, h restic.Handle, rd restic.RewindReader) error { if err := h.Valid(); err != nil { return err } @@ -137,18 +127,10 @@ func (be *Backend) Save(ctx context.Context, h restic.Handle, rd io.Reader) (err be.sem.GetToken() - // wrap the reader so that net/http client cannot close the reader, return - // the token instead. - rd = preventCloser{ - Reader: rd, - f: func() { - debug.Log("Close()") - }, - } - debug.Log("InsertObject(%v, %v)", be.container.Name, objName) - err = be.container.GetBlobReference(objName).CreateBlockBlobFromReader(rd, nil) + // wrap the reader so that net/http client cannot close the reader + err := be.container.GetBlobReference(objName).CreateBlockBlobFromReader(ioutil.NopCloser(rd), nil) be.sem.ReleaseToken() debug.Log("%v, err %#v", objName, err) diff --git a/internal/backend/b2/b2.go b/internal/backend/b2/b2.go index 7ad077cdf..4a2651223 100644 --- a/internal/backend/b2/b2.go +++ b/internal/backend/b2/b2.go @@ -185,7 +185,7 @@ func (be *b2Backend) openReader(ctx context.Context, h restic.Handle, length int } // Save stores data in the backend at the handle. -func (be *b2Backend) Save(ctx context.Context, h restic.Handle, rd io.Reader) error { +func (be *b2Backend) Save(ctx context.Context, h restic.Handle, rd restic.RewindReader) error { ctx, cancel := context.WithCancel(ctx) defer cancel() diff --git a/internal/backend/backend_error.go b/internal/backend/backend_error.go index ee4b68b9b..77a931858 100644 --- a/internal/backend/backend_error.go +++ b/internal/backend/backend_error.go @@ -45,7 +45,7 @@ func (be *ErrorBackend) fail(p float32) bool { } // Save stores the data in the backend under the given handle. -func (be *ErrorBackend) Save(ctx context.Context, h restic.Handle, rd io.Reader) error { +func (be *ErrorBackend) Save(ctx context.Context, h restic.Handle, rd restic.RewindReader) error { if be.fail(be.FailSave) { return errors.Errorf("Save(%v) random error induced", h) } diff --git a/internal/backend/backend_retry.go b/internal/backend/backend_retry.go index 00274e43e..9f8834892 100644 --- a/internal/backend/backend_retry.go +++ b/internal/backend/backend_retry.go @@ -8,7 +8,6 @@ import ( "github.com/cenkalti/backoff" "github.com/restic/restic/internal/debug" - "github.com/restic/restic/internal/errors" "github.com/restic/restic/internal/restic" ) @@ -47,23 +46,9 @@ func (be *RetryBackend) retry(ctx context.Context, msg string, f func() error) e } // Save stores the data in the backend under the given handle. -func (be *RetryBackend) Save(ctx context.Context, h restic.Handle, rd io.Reader) error { - seeker, ok := rd.(io.Seeker) - if !ok { - return errors.Errorf("reader %T is not a seeker", rd) - } - - pos, err := seeker.Seek(0, io.SeekCurrent) - if err != nil { - return errors.Wrap(err, "Seek") - } - - if pos != 0 { - return errors.Errorf("reader is not at the beginning (pos %v)", pos) - } - +func (be *RetryBackend) Save(ctx context.Context, h restic.Handle, rd restic.RewindReader) error { return be.retry(ctx, fmt.Sprintf("Save(%v)", h), func() error { - _, err := seeker.Seek(0, io.SeekStart) + err := rd.Rewind() if err != nil { return err } diff --git a/internal/backend/backend_retry_test.go b/internal/backend/backend_retry_test.go index 832547c77..277f96756 100644 --- a/internal/backend/backend_retry_test.go +++ b/internal/backend/backend_retry_test.go @@ -13,48 +13,11 @@ import ( "github.com/restic/restic/internal/test" ) -func TestBackendRetrySeeker(t *testing.T) { - be := &mock.Backend{ - SaveFn: func(ctx context.Context, h restic.Handle, rd io.Reader) error { - return nil - }, - } - - retryBackend := RetryBackend{ - Backend: be, - } - - data := test.Random(24, 23*14123) - - type wrapReader struct { - io.Reader - } - - var rd io.Reader - rd = wrapReader{bytes.NewReader(data)} - - err := retryBackend.Save(context.TODO(), restic.Handle{}, rd) - if err == nil { - t.Fatal("did not get expected error for retry backend with non-seeker reader") - } - - rd = bytes.NewReader(data) - _, err = io.CopyN(ioutil.Discard, rd, 5) - if err != nil { - t.Fatal(err) - } - - err = retryBackend.Save(context.TODO(), restic.Handle{}, rd) - if err == nil { - t.Fatal("did not get expected error for partial reader") - } -} - func TestBackendSaveRetry(t *testing.T) { buf := bytes.NewBuffer(nil) errcount := 0 be := &mock.Backend{ - SaveFn: func(ctx context.Context, h restic.Handle, rd io.Reader) error { + SaveFn: func(ctx context.Context, h restic.Handle, rd restic.RewindReader) error { if errcount == 0 { errcount++ _, err := io.CopyN(ioutil.Discard, rd, 120) @@ -75,7 +38,7 @@ func TestBackendSaveRetry(t *testing.T) { } data := test.Random(23, 5*1024*1024+11241) - err := retryBackend.Save(context.TODO(), restic.Handle{}, bytes.NewReader(data)) + err := retryBackend.Save(context.TODO(), restic.Handle{}, restic.NewByteReader(data)) if err != nil { t.Fatal(err) } diff --git a/internal/backend/gs/gs.go b/internal/backend/gs/gs.go index bbcf152bc..7cdced796 100644 --- a/internal/backend/gs/gs.go +++ b/internal/backend/gs/gs.go @@ -207,7 +207,7 @@ func (be *Backend) Path() string { } // Save stores data in the backend at the handle. -func (be *Backend) Save(ctx context.Context, h restic.Handle, rd io.Reader) (err error) { +func (be *Backend) Save(ctx context.Context, h restic.Handle, rd restic.RewindReader) error { if err := h.Valid(); err != nil { return err } @@ -250,6 +250,7 @@ func (be *Backend) Save(ctx context.Context, h restic.Handle, rd io.Reader) (err info, err := be.service.Objects.Insert(be.bucketName, &storage.Object{ Name: objName, + Size: uint64(rd.Length()), }).Media(rd, cs).Do() be.sem.ReleaseToken() diff --git a/internal/backend/local/local.go b/internal/backend/local/local.go index 2f82e2a42..d1b7d6788 100644 --- a/internal/backend/local/local.go +++ b/internal/backend/local/local.go @@ -98,7 +98,7 @@ func (b *Local) IsNotExist(err error) bool { } // Save stores data in the backend at the handle. -func (b *Local) Save(ctx context.Context, h restic.Handle, rd io.Reader) error { +func (b *Local) Save(ctx context.Context, h restic.Handle, rd restic.RewindReader) error { debug.Log("Save %v", h) if err := h.Valid(); err != nil { return err diff --git a/internal/backend/mem/mem_backend.go b/internal/backend/mem/mem_backend.go index a64ef774d..a8244be43 100644 --- a/internal/backend/mem/mem_backend.go +++ b/internal/backend/mem/mem_backend.go @@ -59,7 +59,7 @@ func (be *MemoryBackend) IsNotExist(err error) bool { } // Save adds new Data to the backend. -func (be *MemoryBackend) Save(ctx context.Context, h restic.Handle, rd io.Reader) error { +func (be *MemoryBackend) Save(ctx context.Context, h restic.Handle, rd restic.RewindReader) error { if err := h.Valid(); err != nil { return err } diff --git a/internal/backend/rest/rest.go b/internal/backend/rest/rest.go index 2f5b35675..58a8b83d8 100644 --- a/internal/backend/rest/rest.go +++ b/internal/backend/rest/rest.go @@ -9,6 +9,7 @@ import ( "net/http" "net/url" "path" + "strconv" "strings" "golang.org/x/net/context/ctxhttp" @@ -105,7 +106,7 @@ func (b *restBackend) Location() string { } // Save stores data in the backend at the handle. -func (b *restBackend) Save(ctx context.Context, h restic.Handle, rd io.Reader) (err error) { +func (b *restBackend) Save(ctx context.Context, h restic.Handle, rd restic.RewindReader) error { if err := h.Valid(); err != nil { return err } @@ -114,12 +115,11 @@ func (b *restBackend) Save(ctx context.Context, h restic.Handle, rd io.Reader) ( defer cancel() // make sure that client.Post() cannot close the reader by wrapping it - rd = ioutil.NopCloser(rd) - - req, err := http.NewRequest(http.MethodPost, b.Filename(h), rd) + req, err := http.NewRequest(http.MethodPost, b.Filename(h), ioutil.NopCloser(rd)) if err != nil { return errors.Wrap(err, "NewRequest") } + req.Header.Set("Content-Length", strconv.Itoa(rd.Length())) req.Header.Set("Content-Type", "application/octet-stream") req.Header.Set("Accept", contentTypeV2) diff --git a/internal/backend/s3/s3.go b/internal/backend/s3/s3.go index d36679bf7..420729b0e 100644 --- a/internal/backend/s3/s3.go +++ b/internal/backend/s3/s3.go @@ -240,7 +240,7 @@ func lenForFile(f *os.File) (int64, error) { } // Save stores data in the backend at the handle. -func (be *Backend) Save(ctx context.Context, h restic.Handle, rd io.Reader) (err error) { +func (be *Backend) Save(ctx context.Context, h restic.Handle, rd restic.RewindReader) error { debug.Log("Save %v", h) if err := h.Valid(); err != nil { @@ -252,27 +252,11 @@ func (be *Backend) Save(ctx context.Context, h restic.Handle, rd io.Reader) (err be.sem.GetToken() defer be.sem.ReleaseToken() - var size int64 = -1 - - type lenner interface { - Len() int - } - - // find size for reader - if f, ok := rd.(*os.File); ok { - size, err = lenForFile(f) - if err != nil { - return err - } - } else if l, ok := rd.(lenner); ok { - size = int64(l.Len()) - } - opts := minio.PutObjectOptions{} opts.ContentType = "application/octet-stream" - debug.Log("PutObject(%v, %v, %v)", be.cfg.Bucket, objName, size) - n, err := be.client.PutObjectWithContext(ctx, be.cfg.Bucket, objName, ioutil.NopCloser(rd), size, opts) + debug.Log("PutObject(%v, %v, %v)", be.cfg.Bucket, objName, rd.Length()) + n, err := be.client.PutObjectWithContext(ctx, be.cfg.Bucket, objName, ioutil.NopCloser(rd), int64(rd.Length()), opts) debug.Log("%v -> %v bytes, err %#v: %v", objName, n, err, err) diff --git a/internal/backend/sftp/sftp.go b/internal/backend/sftp/sftp.go index 58fe38c24..8f7855a37 100644 --- a/internal/backend/sftp/sftp.go +++ b/internal/backend/sftp/sftp.go @@ -282,7 +282,7 @@ func Join(parts ...string) string { } // Save stores data in the backend at the handle. -func (r *SFTP) Save(ctx context.Context, h restic.Handle, rd io.Reader) (err error) { +func (r *SFTP) Save(ctx context.Context, h restic.Handle, rd restic.RewindReader) error { debug.Log("Save %v", h) if err := r.clientError(); err != nil { return err diff --git a/internal/backend/swift/swift.go b/internal/backend/swift/swift.go index 115a8d0e3..8a17450d8 100644 --- a/internal/backend/swift/swift.go +++ b/internal/backend/swift/swift.go @@ -156,8 +156,8 @@ func (be *beSwift) openReader(ctx context.Context, h restic.Handle, length int, } // Save stores data in the backend at the handle. -func (be *beSwift) Save(ctx context.Context, h restic.Handle, rd io.Reader) (err error) { - if err = h.Valid(); err != nil { +func (be *beSwift) Save(ctx context.Context, h restic.Handle, rd restic.RewindReader) error { + if err := h.Valid(); err != nil { return err } @@ -171,7 +171,7 @@ func (be *beSwift) Save(ctx context.Context, h restic.Handle, rd io.Reader) (err encoding := "binary/octet-stream" debug.Log("PutObject(%v, %v, %v)", be.container, objName, encoding) - _, err = be.conn.ObjectPut(be.container, objName, rd, true, "", encoding, nil) + _, err := be.conn.ObjectPut(be.container, objName, rd, true, "", encoding, nil) debug.Log("%v, err %#v", objName, err) return errors.Wrap(err, "client.PutObject") diff --git a/internal/backend/test/benchmarks.go b/internal/backend/test/benchmarks.go index 2c3dbff2e..302768f2e 100644 --- a/internal/backend/test/benchmarks.go +++ b/internal/backend/test/benchmarks.go @@ -14,7 +14,8 @@ func saveRandomFile(t testing.TB, be restic.Backend, length int) ([]byte, restic data := test.Random(23, length) id := restic.Hash(data) handle := restic.Handle{Type: restic.DataFile, Name: id.String()} - if err := be.Save(context.TODO(), handle, bytes.NewReader(data)); err != nil { + err := be.Save(context.TODO(), handle, restic.NewByteReader(data)) + if err != nil { t.Fatalf("Save() error: %+v", err) } return data, handle @@ -148,16 +149,11 @@ func (s *Suite) BenchmarkSave(t *testing.B) { id := restic.Hash(data) handle := restic.Handle{Type: restic.DataFile, Name: id.String()} - rd := bytes.NewReader(data) - + rd := restic.NewByteReader(data) t.SetBytes(int64(length)) t.ResetTimer() for i := 0; i < t.N; i++ { - if _, err := rd.Seek(0, 0); err != nil { - t.Fatal(err) - } - if err := be.Save(context.TODO(), handle, rd); err != nil { t.Fatal(err) } diff --git a/internal/backend/test/tests.go b/internal/backend/test/tests.go index d861c9589..c1127103d 100644 --- a/internal/backend/test/tests.go +++ b/internal/backend/test/tests.go @@ -10,7 +10,6 @@ import ( "os" "reflect" "sort" - "strings" "testing" "time" @@ -85,7 +84,7 @@ func (s *Suite) TestConfig(t *testing.T) { t.Fatalf("did not get expected error for non-existing config") } - err = b.Save(context.TODO(), restic.Handle{Type: restic.ConfigFile}, strings.NewReader(testString)) + err = b.Save(context.TODO(), restic.Handle{Type: restic.ConfigFile}, restic.NewByteReader([]byte(testString))) if err != nil { t.Fatalf("Save() error: %+v", err) } @@ -135,7 +134,7 @@ func (s *Suite) TestLoad(t *testing.T) { id := restic.Hash(data) handle := restic.Handle{Type: restic.DataFile, Name: id.String()} - err = b.Save(context.TODO(), handle, bytes.NewReader(data)) + err = b.Save(context.TODO(), handle, restic.NewByteReader(data)) if err != nil { t.Fatalf("Save() error: %+v", err) } @@ -250,7 +249,7 @@ func (s *Suite) TestList(t *testing.T) { data := test.Random(rand.Int(), rand.Intn(100)+55) id := restic.Hash(data) h := restic.Handle{Type: restic.DataFile, Name: id.String()} - err := b.Save(context.TODO(), h, bytes.NewReader(data)) + err := b.Save(context.TODO(), h, restic.NewByteReader(data)) if err != nil { t.Fatal(err) } @@ -340,7 +339,7 @@ func (s *Suite) TestListCancel(t *testing.T) { data := []byte(fmt.Sprintf("random test blob %v", i)) id := restic.Hash(data) h := restic.Handle{Type: restic.DataFile, Name: id.String()} - err := b.Save(context.TODO(), h, bytes.NewReader(data)) + err := b.Save(context.TODO(), h, restic.NewByteReader(data)) if err != nil { t.Fatal(err) } @@ -443,7 +442,7 @@ func (s *Suite) TestListCancel(t *testing.T) { } type errorCloser struct { - io.Reader + io.ReadSeeker l int t testing.TB } @@ -453,10 +452,15 @@ func (ec errorCloser) Close() error { return errors.New("forbidden method close was called") } -func (ec errorCloser) Len() int { +func (ec errorCloser) Length() int { return ec.l } +func (ec errorCloser) Rewind() error { + _, err := ec.ReadSeeker.Seek(0, io.SeekStart) + return err +} + // TestSave tests saving data in the backend. func (s *Suite) TestSave(t *testing.T) { seedRand(t) @@ -480,7 +484,7 @@ func (s *Suite) TestSave(t *testing.T) { Type: restic.DataFile, Name: fmt.Sprintf("%s-%d", id, i), } - err := b.Save(context.TODO(), h, bytes.NewReader(data)) + err := b.Save(context.TODO(), h, restic.NewByteReader(data)) test.OK(t, err) buf, err := backend.LoadAll(context.TODO(), b, h) @@ -532,7 +536,7 @@ func (s *Suite) TestSave(t *testing.T) { // wrap the tempfile in an errorCloser, so we can detect if the backend // closes the reader - err = b.Save(context.TODO(), h, errorCloser{t: t, l: length, Reader: tmpfile}) + err = b.Save(context.TODO(), h, errorCloser{t: t, l: length, ReadSeeker: tmpfile}) if err != nil { t.Fatal(err) } @@ -542,25 +546,10 @@ func (s *Suite) TestSave(t *testing.T) { t.Fatalf("error removing item: %+v", err) } - // try again directly with the temp file - if _, err = tmpfile.Seek(588, io.SeekStart); err != nil { - t.Fatal(err) - } - - err = b.Save(context.TODO(), h, tmpfile) - if err != nil { - t.Fatal(err) - } - if err = tmpfile.Close(); err != nil { t.Fatal(err) } - err = b.Remove(context.TODO(), h) - if err != nil { - t.Fatalf("error removing item: %+v", err) - } - if err = os.Remove(tmpfile.Name()); err != nil { t.Fatal(err) } @@ -585,7 +574,7 @@ func (s *Suite) TestSaveFilenames(t *testing.T) { for i, test := range filenameTests { h := restic.Handle{Name: test.name, Type: restic.DataFile} - err := b.Save(context.TODO(), h, strings.NewReader(test.data)) + err := b.Save(context.TODO(), h, restic.NewByteReader([]byte(test.data))) if err != nil { t.Errorf("test %d failed: Save() returned %+v", i, err) continue @@ -622,7 +611,7 @@ var testStrings = []struct { func store(t testing.TB, b restic.Backend, tpe restic.FileType, data []byte) restic.Handle { id := restic.Hash(data) h := restic.Handle{Name: id.String(), Type: tpe} - err := b.Save(context.TODO(), h, bytes.NewReader(data)) + err := b.Save(context.TODO(), h, restic.NewByteReader([]byte(data))) test.OK(t, err) return h } @@ -776,7 +765,7 @@ func (s *Suite) TestBackend(t *testing.T) { test.Assert(t, !ok, "removed blob still present") // create blob - err = b.Save(context.TODO(), h, strings.NewReader(ts.data)) + err = b.Save(context.TODO(), h, restic.NewByteReader([]byte(ts.data))) test.OK(t, err) // list items diff --git a/internal/backend/utils_test.go b/internal/backend/utils_test.go index ed7488a57..74929fd0b 100644 --- a/internal/backend/utils_test.go +++ b/internal/backend/utils_test.go @@ -24,7 +24,8 @@ func TestLoadAll(t *testing.T) { data := rtest.Random(23+i, rand.Intn(MiB)+500*KiB) id := restic.Hash(data) - err := b.Save(context.TODO(), restic.Handle{Name: id.String(), Type: restic.DataFile}, bytes.NewReader(data)) + h := restic.Handle{Name: id.String(), Type: restic.DataFile} + err := b.Save(context.TODO(), h, restic.NewByteReader(data)) rtest.OK(t, err) buf, err := backend.LoadAll(context.TODO(), b, restic.Handle{Type: restic.DataFile, Name: id.String()}) @@ -49,7 +50,8 @@ func TestLoadSmallBuffer(t *testing.T) { data := rtest.Random(23+i, rand.Intn(MiB)+500*KiB) id := restic.Hash(data) - err := b.Save(context.TODO(), restic.Handle{Name: id.String(), Type: restic.DataFile}, bytes.NewReader(data)) + h := restic.Handle{Name: id.String(), Type: restic.DataFile} + err := b.Save(context.TODO(), h, restic.NewByteReader(data)) rtest.OK(t, err) buf, err := backend.LoadAll(context.TODO(), b, restic.Handle{Type: restic.DataFile, Name: id.String()}) @@ -74,7 +76,8 @@ func TestLoadLargeBuffer(t *testing.T) { data := rtest.Random(23+i, rand.Intn(MiB)+500*KiB) id := restic.Hash(data) - err := b.Save(context.TODO(), restic.Handle{Name: id.String(), Type: restic.DataFile}, bytes.NewReader(data)) + h := restic.Handle{Name: id.String(), Type: restic.DataFile} + err := b.Save(context.TODO(), h, restic.NewByteReader(data)) rtest.OK(t, err) buf, err := backend.LoadAll(context.TODO(), b, restic.Handle{Type: restic.DataFile, Name: id.String()}) diff --git a/internal/cache/backend.go b/internal/cache/backend.go index 08639e462..b6c163b29 100644 --- a/internal/cache/backend.go +++ b/internal/cache/backend.go @@ -5,7 +5,6 @@ import ( "io" "sync" - "github.com/pkg/errors" "github.com/restic/restic/internal/debug" "github.com/restic/restic/internal/restic" ) @@ -50,35 +49,29 @@ var autoCacheTypes = map[restic.FileType]struct{}{ } // Save stores a new file in the backend and the cache. -func (b *Backend) Save(ctx context.Context, h restic.Handle, rd io.Reader) (err error) { +func (b *Backend) Save(ctx context.Context, h restic.Handle, rd restic.RewindReader) error { if _, ok := autoCacheTypes[h.Type]; !ok { return b.Backend.Save(ctx, h, rd) } debug.Log("Save(%v): auto-store in the cache", h) - seeker, ok := rd.(io.Seeker) - if !ok { - return errors.New("reader is not a seeker") - } - - pos, err := seeker.Seek(0, io.SeekCurrent) + // make sure the reader is at the start + err := rd.Rewind() if err != nil { - return errors.Wrapf(err, "Seek") - } - - if pos != 0 { - return errors.Errorf("reader is not rewind (pos %d)", pos) + return err } + // first, save in the backend err = b.Backend.Save(ctx, h, rd) if err != nil { return err } - _, err = seeker.Seek(pos, io.SeekStart) + // next, save in the cache + err = rd.Rewind() if err != nil { - return errors.Wrapf(err, "Seek") + return err } err = b.Cache.Save(h, rd) diff --git a/internal/cache/backend_test.go b/internal/cache/backend_test.go index dbcd53326..993319b0d 100644 --- a/internal/cache/backend_test.go +++ b/internal/cache/backend_test.go @@ -28,7 +28,7 @@ func loadAndCompare(t testing.TB, be restic.Backend, h restic.Handle, data []byt } func save(t testing.TB, be restic.Backend, h restic.Handle, data []byte) { - err := be.Save(context.TODO(), h, bytes.NewReader(data)) + err := be.Save(context.TODO(), h, restic.NewByteReader(data)) if err != nil { t.Fatal(err) } diff --git a/internal/checker/checker_test.go b/internal/checker/checker_test.go index f1b0fe938..601407636 100644 --- a/internal/checker/checker_test.go +++ b/internal/checker/checker_test.go @@ -3,7 +3,9 @@ package checker_test import ( "context" "io" + "io/ioutil" "math/rand" + "os" "path/filepath" "sort" "testing" @@ -195,17 +197,47 @@ func TestModifiedIndex(t *testing.T) { Type: restic.IndexFile, Name: "90f838b4ac28735fda8644fe6a08dbc742e57aaf81b30977b4fefa357010eafd", } - err := repo.Backend().Load(context.TODO(), h, 0, 0, func(rd io.Reader) error { - // save the index again with a modified name so that the hash doesn't match - // the content any more - h2 := restic.Handle{ - Type: restic.IndexFile, - Name: "80f838b4ac28735fda8644fe6a08dbc742e57aaf81b30977b4fefa357010eafd", + + tmpfile, err := ioutil.TempFile("", "restic-test-mod-index-") + if err != nil { + t.Fatal(err) + } + defer func() { + err := tmpfile.Close() + if err != nil { + t.Fatal(err) } - return repo.Backend().Save(context.TODO(), h2, rd) + + err = os.Remove(tmpfile.Name()) + if err != nil { + t.Fatal(err) + } + }() + + // read the file from the backend + err = repo.Backend().Load(context.TODO(), h, 0, 0, func(rd io.Reader) error { + _, err := io.Copy(tmpfile, rd) + return err }) test.OK(t, err) + // save the index again with a modified name so that the hash doesn't match + // the content any more + h2 := restic.Handle{ + Type: restic.IndexFile, + Name: "80f838b4ac28735fda8644fe6a08dbc742e57aaf81b30977b4fefa357010eafd", + } + + rd, err := restic.NewFileReader(tmpfile) + if err != nil { + t.Fatal(err) + } + + err = repo.Backend().Save(context.TODO(), h2, rd) + if err != nil { + t.Fatal(err) + } + chkr := checker.New(repo) hints, errs := chkr.LoadIndex(context.TODO()) if len(errs) == 0 { diff --git a/internal/limiter/limiter_backend.go b/internal/limiter/limiter_backend.go index 963a084dd..b2351a8fd 100644 --- a/internal/limiter/limiter_backend.go +++ b/internal/limiter/limiter_backend.go @@ -21,8 +21,23 @@ type rateLimitedBackend struct { limiter Limiter } -func (r rateLimitedBackend) Save(ctx context.Context, h restic.Handle, rd io.Reader) error { - return r.Backend.Save(ctx, h, r.limiter.Upstream(rd)) +func (r rateLimitedBackend) Save(ctx context.Context, h restic.Handle, rd restic.RewindReader) error { + limited := limitedRewindReader{ + RewindReader: rd, + limited: r.limiter.Upstream(rd), + } + + return r.Backend.Save(ctx, h, limited) +} + +type limitedRewindReader struct { + restic.RewindReader + + limited io.Reader +} + +func (l limitedRewindReader) Read(b []byte) (int, error) { + return l.limited.Read(b) } func (r rateLimitedBackend) Load(ctx context.Context, h restic.Handle, length int, offset int64, consumer func(rd io.Reader) error) error { diff --git a/internal/mock/backend.go b/internal/mock/backend.go index 14288c3f6..930fdb3ee 100644 --- a/internal/mock/backend.go +++ b/internal/mock/backend.go @@ -12,7 +12,7 @@ import ( type Backend struct { CloseFn func() error IsNotExistFn func(err error) bool - SaveFn func(ctx context.Context, h restic.Handle, rd io.Reader) error + SaveFn func(ctx context.Context, h restic.Handle, rd restic.RewindReader) error OpenReaderFn func(ctx context.Context, h restic.Handle, length int, offset int64) (io.ReadCloser, error) StatFn func(ctx context.Context, h restic.Handle) (restic.FileInfo, error) ListFn func(ctx context.Context, t restic.FileType, fn func(restic.FileInfo) error) error @@ -56,7 +56,7 @@ func (m *Backend) IsNotExist(err error) bool { } // Save data in the backend. -func (m *Backend) Save(ctx context.Context, h restic.Handle, rd io.Reader) error { +func (m *Backend) Save(ctx context.Context, h restic.Handle, rd restic.RewindReader) error { if m.SaveFn == nil { return errors.New("not implemented") } diff --git a/internal/pack/pack_test.go b/internal/pack/pack_test.go index a61b28a2b..12e3600bb 100644 --- a/internal/pack/pack_test.go +++ b/internal/pack/pack_test.go @@ -127,7 +127,7 @@ func TestUnpackReadSeeker(t *testing.T) { id := restic.Hash(packData) handle := restic.Handle{Type: restic.DataFile, Name: id.String()} - rtest.OK(t, b.Save(context.TODO(), handle, bytes.NewReader(packData))) + rtest.OK(t, b.Save(context.TODO(), handle, restic.NewByteReader(packData))) verifyBlobs(t, bufs, k, restic.ReaderAt(b, handle), packSize) } @@ -140,6 +140,6 @@ func TestShortPack(t *testing.T) { id := restic.Hash(packData) handle := restic.Handle{Type: restic.DataFile, Name: id.String()} - rtest.OK(t, b.Save(context.TODO(), handle, bytes.NewReader(packData))) + rtest.OK(t, b.Save(context.TODO(), handle, restic.NewByteReader(packData))) verifyBlobs(t, bufs, k, restic.ReaderAt(b, handle), packSize) } diff --git a/internal/repository/key.go b/internal/repository/key.go index 63c35b8e8..44e3d24e0 100644 --- a/internal/repository/key.go +++ b/internal/repository/key.go @@ -1,7 +1,6 @@ package repository import ( - "bytes" "context" "encoding/json" "fmt" @@ -250,7 +249,7 @@ func AddKey(ctx context.Context, s *Repository, password string, template *crypt Name: restic.Hash(buf).String(), } - err = s.be.Save(ctx, h, bytes.NewReader(buf)) + err = s.be.Save(ctx, h, restic.NewByteReader(buf)) if err != nil { return nil, err } diff --git a/internal/repository/packer_manager.go b/internal/repository/packer_manager.go index d4bba2cf3..4884e0885 100644 --- a/internal/repository/packer_manager.go +++ b/internal/repository/packer_manager.go @@ -3,7 +3,6 @@ package repository import ( "context" "crypto/sha256" - "io" "os" "sync" @@ -19,7 +18,7 @@ import ( // Saver implements saving data in a backend. type Saver interface { - Save(context.Context, restic.Handle, io.Reader) error + Save(context.Context, restic.Handle, restic.RewindReader) error } // Packer holds a pack.Packer together with a hash writer. @@ -96,15 +95,15 @@ func (r *Repository) savePacker(ctx context.Context, t restic.BlobType, p *Packe return err } - _, err = p.tmpfile.Seek(0, 0) - if err != nil { - return errors.Wrap(err, "Seek") - } - id := restic.IDFromHash(p.hw.Sum(nil)) h := restic.Handle{Type: restic.DataFile, Name: id.String()} - err = r.be.Save(ctx, h, p.tmpfile) + rd, err := restic.NewFileReader(p.tmpfile) + if err != nil { + return err + } + + err = r.be.Save(ctx, h, rd) if err != nil { debug.Log("Save(%v) error: %v", h, err) return err diff --git a/internal/repository/packer_manager_test.go b/internal/repository/packer_manager_test.go index aface097b..e1f067f7c 100644 --- a/internal/repository/packer_manager_test.go +++ b/internal/repository/packer_manager_test.go @@ -50,11 +50,17 @@ func randomID(rd io.Reader) restic.ID { const maxBlobSize = 1 << 20 -func saveFile(t testing.TB, be Saver, f *os.File, id restic.ID) { +func saveFile(t testing.TB, be Saver, length int, f *os.File, id restic.ID) { h := restic.Handle{Type: restic.DataFile, Name: id.String()} t.Logf("save file %v", h) - if err := be.Save(context.TODO(), h, f); err != nil { + rd, err := restic.NewFileReader(f) + if err != nil { + t.Fatal(err) + } + + err = be.Save(context.TODO(), h, rd) + if err != nil { t.Fatal(err) } @@ -101,12 +107,8 @@ func fillPacks(t testing.TB, rnd *randReader, be Saver, pm *packerManager, buf [ t.Fatal(err) } - if _, err = packer.tmpfile.Seek(0, 0); err != nil { - t.Fatal(err) - } - packID := restic.IDFromHash(packer.hw.Sum(nil)) - saveFile(t, be, packer.tmpfile, packID) + saveFile(t, be, int(packer.Size()), packer.tmpfile, packID) } return bytes @@ -122,7 +124,7 @@ func flushRemainingPacks(t testing.TB, rnd *randReader, be Saver, pm *packerMana bytes += int(n) packID := restic.IDFromHash(packer.hw.Sum(nil)) - saveFile(t, be, packer.tmpfile, packID) + saveFile(t, be, int(packer.Size()), packer.tmpfile, packID) } } @@ -147,7 +149,7 @@ func BenchmarkPackerManager(t *testing.B) { rnd := newRandReader(rand.NewSource(23)) be := &mock.Backend{ - SaveFn: func(context.Context, restic.Handle, io.Reader) error { return nil }, + SaveFn: func(context.Context, restic.Handle, restic.RewindReader) error { return nil }, } blobBuf := make([]byte, maxBlobSize) diff --git a/internal/repository/repository.go b/internal/repository/repository.go index 3fafe0efb..0c1236f42 100644 --- a/internal/repository/repository.go +++ b/internal/repository/repository.go @@ -282,7 +282,7 @@ func (r *Repository) SaveUnpacked(ctx context.Context, t restic.FileType, p []by id = restic.Hash(ciphertext) h := restic.Handle{Type: t, Name: id.String()} - err = r.be.Save(ctx, h, bytes.NewReader(ciphertext)) + err = r.be.Save(ctx, h, restic.NewByteReader(ciphertext)) if err != nil { debug.Log("error saving blob %v: %v", h, err) return restic.ID{}, err @@ -456,11 +456,7 @@ func (r *Repository) LoadIndex(ctx context.Context) error { } } - if err := <-errCh; err != nil { - return err - } - - return nil + return <-errCh } // LoadIndex loads the index id from backend and returns it. diff --git a/internal/restic/backend.go b/internal/restic/backend.go index c5dc25087..b2fd46b97 100644 --- a/internal/restic/backend.go +++ b/internal/restic/backend.go @@ -20,8 +20,8 @@ type Backend interface { // Close the backend Close() error - // Save stores the data in the backend under the given handle. - Save(ctx context.Context, h Handle, rd io.Reader) error + // Save stores the data from rd under the given handle. + Save(ctx context.Context, h Handle, rd RewindReader) error // Load runs fn with a reader that yields the contents of the file at h at the // given offset. If length is larger than zero, only a portion of the file diff --git a/internal/restic/rewind_reader.go b/internal/restic/rewind_reader.go new file mode 100644 index 000000000..aa851d99f --- /dev/null +++ b/internal/restic/rewind_reader.go @@ -0,0 +1,90 @@ +package restic + +import ( + "bytes" + "io" + + "github.com/restic/restic/internal/errors" +) + +// RewindReader allows resetting the Reader to the beginning of the data. +type RewindReader interface { + io.Reader + + // Rewind rewinds the reader so the same data can be read again from the + // start. + Rewind() error + + // Length returns the number of bytes that can be read from the Reader + // after calling Rewind. + Length() int +} + +// ByteReader implements a RewindReader for a byte slice. +type ByteReader struct { + *bytes.Reader + Len int +} + +// Rewind restarts the reader from the beginning of the data. +func (b *ByteReader) Rewind() error { + _, err := b.Reader.Seek(0, io.SeekStart) + return err +} + +// Length returns the number of bytes read from the reader after Rewind is +// called. +func (b *ByteReader) Length() int { + return b.Len +} + +// statically ensure that *ByteReader implements RewindReader. +var _ RewindReader = &ByteReader{} + +// NewByteReader prepares a ByteReader that can then be used to read buf. +func NewByteReader(buf []byte) *ByteReader { + return &ByteReader{ + Reader: bytes.NewReader(buf), + Len: len(buf), + } +} + +// statically ensure that *FileReader implements RewindReader. +var _ RewindReader = &FileReader{} + +// FileReader implements a RewindReader for an open file. +type FileReader struct { + io.ReadSeeker + Len int +} + +// Rewind seeks to the beginning of the file. +func (f *FileReader) Rewind() error { + _, err := f.ReadSeeker.Seek(0, io.SeekStart) + return errors.Wrap(err, "Seek") +} + +// Length returns the length of the file. +func (f *FileReader) Length() int { + return f.Len +} + +// NewFileReader wraps f in a *FileReader. +func NewFileReader(f io.ReadSeeker) (*FileReader, error) { + pos, err := f.Seek(0, io.SeekEnd) + if err != nil { + return nil, errors.Wrap(err, "Seek") + } + + fr := &FileReader{ + ReadSeeker: f, + Len: int(pos), + } + + err = fr.Rewind() + if err != nil { + return nil, err + } + + return fr, nil +} diff --git a/internal/restic/rewind_reader_test.go b/internal/restic/rewind_reader_test.go new file mode 100644 index 000000000..c3c6001f8 --- /dev/null +++ b/internal/restic/rewind_reader_test.go @@ -0,0 +1,154 @@ +package restic + +import ( + "bytes" + "io" + "io/ioutil" + "math/rand" + "os" + "path/filepath" + "testing" + "time" + + "github.com/restic/restic/internal/test" +) + +func TestByteReader(t *testing.T) { + buf := []byte("foobar") + fn := func() RewindReader { + return NewByteReader(buf) + } + testRewindReader(t, fn, buf) +} + +func TestFileReader(t *testing.T) { + buf := []byte("foobar") + + d, cleanup := test.TempDir(t) + defer cleanup() + + filename := filepath.Join(d, "file-reader-test") + err := ioutil.WriteFile(filename, []byte("foobar"), 0600) + if err != nil { + t.Fatal(err) + } + + f, err := os.Open(filename) + if err != nil { + t.Fatal(err) + } + + defer func() { + err := f.Close() + if err != nil { + t.Fatal(err) + } + }() + + fn := func() RewindReader { + rd, err := NewFileReader(f) + if err != nil { + t.Fatal(err) + } + return rd + } + + testRewindReader(t, fn, buf) +} + +func testRewindReader(t *testing.T, fn func() RewindReader, data []byte) { + seed := time.Now().Unix() + t.Logf("seed is %d", seed) + rnd := rand.New(rand.NewSource(seed)) + + type ReaderTestFunc func(t testing.TB, r RewindReader, data []byte) + var tests = []ReaderTestFunc{ + func(t testing.TB, rd RewindReader, data []byte) { + if rd.Length() != len(data) { + t.Fatalf("wrong length returned, want %d, got %d", len(data), rd.Length()) + } + + buf := make([]byte, len(data)) + _, err := io.ReadFull(rd, buf) + if err != nil { + t.Fatal(err) + } + + if !bytes.Equal(buf, data) { + t.Fatalf("wrong data returned") + } + + if rd.Length() != len(data) { + t.Fatalf("wrong length returned, want %d, got %d", len(data), rd.Length()) + } + + err = rd.Rewind() + if err != nil { + t.Fatal(err) + } + + if rd.Length() != len(data) { + t.Fatalf("wrong length returned, want %d, got %d", len(data), rd.Length()) + } + + buf2 := make([]byte, len(data)) + _, err = io.ReadFull(rd, buf2) + if err != nil { + t.Fatal(err) + } + + if !bytes.Equal(buf2, data) { + t.Fatalf("wrong data returned") + } + + if rd.Length() != len(data) { + t.Fatalf("wrong length returned, want %d, got %d", len(data), rd.Length()) + } + }, + func(t testing.TB, rd RewindReader, data []byte) { + // read first bytes + buf := make([]byte, rnd.Intn(len(data))) + _, err := io.ReadFull(rd, buf) + if err != nil { + t.Fatal(err) + } + + if !bytes.Equal(buf, data[:len(buf)]) { + t.Fatalf("wrong data returned") + } + + err = rd.Rewind() + if err != nil { + t.Fatal(err) + } + + buf2 := make([]byte, rnd.Intn(len(data))) + _, err = io.ReadFull(rd, buf2) + if err != nil { + t.Fatal(err) + } + + if !bytes.Equal(buf2, data[:len(buf2)]) { + t.Fatalf("wrong data returned") + } + + // read remainder + buf3 := make([]byte, len(data)-len(buf2)) + _, err = io.ReadFull(rd, buf3) + if err != nil { + t.Fatal(err) + } + + if !bytes.Equal(buf3, data[len(buf2):]) { + t.Fatalf("wrong data returned") + } + }, + } + + for _, test := range tests { + t.Run("", func(t *testing.T) { + rd := fn() + test(t, rd, data) + }) + } +}