backend: Improve Save()

As mentioned in issue [#1560](https://github.com/restic/restic/pull/1560#issuecomment-364689346)
this changes the signature for `backend.Save()`. It now takes a
parameter of interface type `RewindReader`, so that the backend
implementations or our `RetryBackend` middleware can reset the reader to
the beginning and then retry an upload operation.

The `RewindReader` interface also provides a `Length()` method, which is
used in the backend to get the size of the data to be saved. This
removes several ugly hacks we had to do to pull the size back out of the
`io.Reader` passed to `Save()` before. In the `s3` and `rest` backend
this is actively used.
This commit is contained in:
Alexander Neumann 2018-03-03 14:20:54 +01:00
parent 58306bfabb
commit 99f7fd74e3
29 changed files with 387 additions and 204 deletions

View File

@ -48,7 +48,7 @@ func forgetfulBackend() restic.Backend {
return nil, errors.New("not found")
}
be.SaveFn = func(ctx context.Context, h restic.Handle, rd io.Reader) error {
be.SaveFn = func(ctx context.Context, h restic.Handle, rd restic.RewindReader) error {
return nil
}

View File

@ -3,6 +3,7 @@ package azure
import (
"context"
"io"
"io/ioutil"
"net/http"
"os"
"path"
@ -114,19 +115,8 @@ func (be *Backend) Path() string {
return be.prefix
}
// preventCloser wraps an io.Reader to run a function instead of the original Close() function.
type preventCloser struct {
io.Reader
f func()
}
func (wr preventCloser) Close() error {
wr.f()
return nil
}
// Save stores data in the backend at the handle.
func (be *Backend) Save(ctx context.Context, h restic.Handle, rd io.Reader) (err error) {
func (be *Backend) Save(ctx context.Context, h restic.Handle, rd restic.RewindReader) error {
if err := h.Valid(); err != nil {
return err
}
@ -137,18 +127,10 @@ func (be *Backend) Save(ctx context.Context, h restic.Handle, rd io.Reader) (err
be.sem.GetToken()
// wrap the reader so that net/http client cannot close the reader, return
// the token instead.
rd = preventCloser{
Reader: rd,
f: func() {
debug.Log("Close()")
},
}
debug.Log("InsertObject(%v, %v)", be.container.Name, objName)
err = be.container.GetBlobReference(objName).CreateBlockBlobFromReader(rd, nil)
// wrap the reader so that net/http client cannot close the reader
err := be.container.GetBlobReference(objName).CreateBlockBlobFromReader(ioutil.NopCloser(rd), nil)
be.sem.ReleaseToken()
debug.Log("%v, err %#v", objName, err)

View File

@ -185,7 +185,7 @@ func (be *b2Backend) openReader(ctx context.Context, h restic.Handle, length int
}
// Save stores data in the backend at the handle.
func (be *b2Backend) Save(ctx context.Context, h restic.Handle, rd io.Reader) error {
func (be *b2Backend) Save(ctx context.Context, h restic.Handle, rd restic.RewindReader) error {
ctx, cancel := context.WithCancel(ctx)
defer cancel()

View File

@ -45,7 +45,7 @@ func (be *ErrorBackend) fail(p float32) bool {
}
// Save stores the data in the backend under the given handle.
func (be *ErrorBackend) Save(ctx context.Context, h restic.Handle, rd io.Reader) error {
func (be *ErrorBackend) Save(ctx context.Context, h restic.Handle, rd restic.RewindReader) error {
if be.fail(be.FailSave) {
return errors.Errorf("Save(%v) random error induced", h)
}

View File

@ -8,7 +8,6 @@ import (
"github.com/cenkalti/backoff"
"github.com/restic/restic/internal/debug"
"github.com/restic/restic/internal/errors"
"github.com/restic/restic/internal/restic"
)
@ -47,23 +46,9 @@ func (be *RetryBackend) retry(ctx context.Context, msg string, f func() error) e
}
// Save stores the data in the backend under the given handle.
func (be *RetryBackend) Save(ctx context.Context, h restic.Handle, rd io.Reader) error {
seeker, ok := rd.(io.Seeker)
if !ok {
return errors.Errorf("reader %T is not a seeker", rd)
}
pos, err := seeker.Seek(0, io.SeekCurrent)
if err != nil {
return errors.Wrap(err, "Seek")
}
if pos != 0 {
return errors.Errorf("reader is not at the beginning (pos %v)", pos)
}
func (be *RetryBackend) Save(ctx context.Context, h restic.Handle, rd restic.RewindReader) error {
return be.retry(ctx, fmt.Sprintf("Save(%v)", h), func() error {
_, err := seeker.Seek(0, io.SeekStart)
err := rd.Rewind()
if err != nil {
return err
}

View File

@ -13,48 +13,11 @@ import (
"github.com/restic/restic/internal/test"
)
func TestBackendRetrySeeker(t *testing.T) {
be := &mock.Backend{
SaveFn: func(ctx context.Context, h restic.Handle, rd io.Reader) error {
return nil
},
}
retryBackend := RetryBackend{
Backend: be,
}
data := test.Random(24, 23*14123)
type wrapReader struct {
io.Reader
}
var rd io.Reader
rd = wrapReader{bytes.NewReader(data)}
err := retryBackend.Save(context.TODO(), restic.Handle{}, rd)
if err == nil {
t.Fatal("did not get expected error for retry backend with non-seeker reader")
}
rd = bytes.NewReader(data)
_, err = io.CopyN(ioutil.Discard, rd, 5)
if err != nil {
t.Fatal(err)
}
err = retryBackend.Save(context.TODO(), restic.Handle{}, rd)
if err == nil {
t.Fatal("did not get expected error for partial reader")
}
}
func TestBackendSaveRetry(t *testing.T) {
buf := bytes.NewBuffer(nil)
errcount := 0
be := &mock.Backend{
SaveFn: func(ctx context.Context, h restic.Handle, rd io.Reader) error {
SaveFn: func(ctx context.Context, h restic.Handle, rd restic.RewindReader) error {
if errcount == 0 {
errcount++
_, err := io.CopyN(ioutil.Discard, rd, 120)
@ -75,7 +38,7 @@ func TestBackendSaveRetry(t *testing.T) {
}
data := test.Random(23, 5*1024*1024+11241)
err := retryBackend.Save(context.TODO(), restic.Handle{}, bytes.NewReader(data))
err := retryBackend.Save(context.TODO(), restic.Handle{}, restic.NewByteReader(data))
if err != nil {
t.Fatal(err)
}

View File

@ -207,7 +207,7 @@ func (be *Backend) Path() string {
}
// Save stores data in the backend at the handle.
func (be *Backend) Save(ctx context.Context, h restic.Handle, rd io.Reader) (err error) {
func (be *Backend) Save(ctx context.Context, h restic.Handle, rd restic.RewindReader) error {
if err := h.Valid(); err != nil {
return err
}
@ -250,6 +250,7 @@ func (be *Backend) Save(ctx context.Context, h restic.Handle, rd io.Reader) (err
info, err := be.service.Objects.Insert(be.bucketName,
&storage.Object{
Name: objName,
Size: uint64(rd.Length()),
}).Media(rd, cs).Do()
be.sem.ReleaseToken()

View File

@ -98,7 +98,7 @@ func (b *Local) IsNotExist(err error) bool {
}
// Save stores data in the backend at the handle.
func (b *Local) Save(ctx context.Context, h restic.Handle, rd io.Reader) error {
func (b *Local) Save(ctx context.Context, h restic.Handle, rd restic.RewindReader) error {
debug.Log("Save %v", h)
if err := h.Valid(); err != nil {
return err

View File

@ -59,7 +59,7 @@ func (be *MemoryBackend) IsNotExist(err error) bool {
}
// Save adds new Data to the backend.
func (be *MemoryBackend) Save(ctx context.Context, h restic.Handle, rd io.Reader) error {
func (be *MemoryBackend) Save(ctx context.Context, h restic.Handle, rd restic.RewindReader) error {
if err := h.Valid(); err != nil {
return err
}

View File

@ -9,6 +9,7 @@ import (
"net/http"
"net/url"
"path"
"strconv"
"strings"
"golang.org/x/net/context/ctxhttp"
@ -105,7 +106,7 @@ func (b *restBackend) Location() string {
}
// Save stores data in the backend at the handle.
func (b *restBackend) Save(ctx context.Context, h restic.Handle, rd io.Reader) (err error) {
func (b *restBackend) Save(ctx context.Context, h restic.Handle, rd restic.RewindReader) error {
if err := h.Valid(); err != nil {
return err
}
@ -114,12 +115,11 @@ func (b *restBackend) Save(ctx context.Context, h restic.Handle, rd io.Reader) (
defer cancel()
// make sure that client.Post() cannot close the reader by wrapping it
rd = ioutil.NopCloser(rd)
req, err := http.NewRequest(http.MethodPost, b.Filename(h), rd)
req, err := http.NewRequest(http.MethodPost, b.Filename(h), ioutil.NopCloser(rd))
if err != nil {
return errors.Wrap(err, "NewRequest")
}
req.Header.Set("Content-Length", strconv.Itoa(rd.Length()))
req.Header.Set("Content-Type", "application/octet-stream")
req.Header.Set("Accept", contentTypeV2)

View File

@ -240,7 +240,7 @@ func lenForFile(f *os.File) (int64, error) {
}
// Save stores data in the backend at the handle.
func (be *Backend) Save(ctx context.Context, h restic.Handle, rd io.Reader) (err error) {
func (be *Backend) Save(ctx context.Context, h restic.Handle, rd restic.RewindReader) error {
debug.Log("Save %v", h)
if err := h.Valid(); err != nil {
@ -252,27 +252,11 @@ func (be *Backend) Save(ctx context.Context, h restic.Handle, rd io.Reader) (err
be.sem.GetToken()
defer be.sem.ReleaseToken()
var size int64 = -1
type lenner interface {
Len() int
}
// find size for reader
if f, ok := rd.(*os.File); ok {
size, err = lenForFile(f)
if err != nil {
return err
}
} else if l, ok := rd.(lenner); ok {
size = int64(l.Len())
}
opts := minio.PutObjectOptions{}
opts.ContentType = "application/octet-stream"
debug.Log("PutObject(%v, %v, %v)", be.cfg.Bucket, objName, size)
n, err := be.client.PutObjectWithContext(ctx, be.cfg.Bucket, objName, ioutil.NopCloser(rd), size, opts)
debug.Log("PutObject(%v, %v, %v)", be.cfg.Bucket, objName, rd.Length())
n, err := be.client.PutObjectWithContext(ctx, be.cfg.Bucket, objName, ioutil.NopCloser(rd), int64(rd.Length()), opts)
debug.Log("%v -> %v bytes, err %#v: %v", objName, n, err, err)

View File

@ -282,7 +282,7 @@ func Join(parts ...string) string {
}
// Save stores data in the backend at the handle.
func (r *SFTP) Save(ctx context.Context, h restic.Handle, rd io.Reader) (err error) {
func (r *SFTP) Save(ctx context.Context, h restic.Handle, rd restic.RewindReader) error {
debug.Log("Save %v", h)
if err := r.clientError(); err != nil {
return err

View File

@ -156,8 +156,8 @@ func (be *beSwift) openReader(ctx context.Context, h restic.Handle, length int,
}
// Save stores data in the backend at the handle.
func (be *beSwift) Save(ctx context.Context, h restic.Handle, rd io.Reader) (err error) {
if err = h.Valid(); err != nil {
func (be *beSwift) Save(ctx context.Context, h restic.Handle, rd restic.RewindReader) error {
if err := h.Valid(); err != nil {
return err
}
@ -171,7 +171,7 @@ func (be *beSwift) Save(ctx context.Context, h restic.Handle, rd io.Reader) (err
encoding := "binary/octet-stream"
debug.Log("PutObject(%v, %v, %v)", be.container, objName, encoding)
_, err = be.conn.ObjectPut(be.container, objName, rd, true, "", encoding, nil)
_, err := be.conn.ObjectPut(be.container, objName, rd, true, "", encoding, nil)
debug.Log("%v, err %#v", objName, err)
return errors.Wrap(err, "client.PutObject")

View File

@ -14,7 +14,8 @@ func saveRandomFile(t testing.TB, be restic.Backend, length int) ([]byte, restic
data := test.Random(23, length)
id := restic.Hash(data)
handle := restic.Handle{Type: restic.DataFile, Name: id.String()}
if err := be.Save(context.TODO(), handle, bytes.NewReader(data)); err != nil {
err := be.Save(context.TODO(), handle, restic.NewByteReader(data))
if err != nil {
t.Fatalf("Save() error: %+v", err)
}
return data, handle
@ -148,16 +149,11 @@ func (s *Suite) BenchmarkSave(t *testing.B) {
id := restic.Hash(data)
handle := restic.Handle{Type: restic.DataFile, Name: id.String()}
rd := bytes.NewReader(data)
rd := restic.NewByteReader(data)
t.SetBytes(int64(length))
t.ResetTimer()
for i := 0; i < t.N; i++ {
if _, err := rd.Seek(0, 0); err != nil {
t.Fatal(err)
}
if err := be.Save(context.TODO(), handle, rd); err != nil {
t.Fatal(err)
}

View File

@ -10,7 +10,6 @@ import (
"os"
"reflect"
"sort"
"strings"
"testing"
"time"
@ -85,7 +84,7 @@ func (s *Suite) TestConfig(t *testing.T) {
t.Fatalf("did not get expected error for non-existing config")
}
err = b.Save(context.TODO(), restic.Handle{Type: restic.ConfigFile}, strings.NewReader(testString))
err = b.Save(context.TODO(), restic.Handle{Type: restic.ConfigFile}, restic.NewByteReader([]byte(testString)))
if err != nil {
t.Fatalf("Save() error: %+v", err)
}
@ -135,7 +134,7 @@ func (s *Suite) TestLoad(t *testing.T) {
id := restic.Hash(data)
handle := restic.Handle{Type: restic.DataFile, Name: id.String()}
err = b.Save(context.TODO(), handle, bytes.NewReader(data))
err = b.Save(context.TODO(), handle, restic.NewByteReader(data))
if err != nil {
t.Fatalf("Save() error: %+v", err)
}
@ -250,7 +249,7 @@ func (s *Suite) TestList(t *testing.T) {
data := test.Random(rand.Int(), rand.Intn(100)+55)
id := restic.Hash(data)
h := restic.Handle{Type: restic.DataFile, Name: id.String()}
err := b.Save(context.TODO(), h, bytes.NewReader(data))
err := b.Save(context.TODO(), h, restic.NewByteReader(data))
if err != nil {
t.Fatal(err)
}
@ -340,7 +339,7 @@ func (s *Suite) TestListCancel(t *testing.T) {
data := []byte(fmt.Sprintf("random test blob %v", i))
id := restic.Hash(data)
h := restic.Handle{Type: restic.DataFile, Name: id.String()}
err := b.Save(context.TODO(), h, bytes.NewReader(data))
err := b.Save(context.TODO(), h, restic.NewByteReader(data))
if err != nil {
t.Fatal(err)
}
@ -443,7 +442,7 @@ func (s *Suite) TestListCancel(t *testing.T) {
}
type errorCloser struct {
io.Reader
io.ReadSeeker
l int
t testing.TB
}
@ -453,10 +452,15 @@ func (ec errorCloser) Close() error {
return errors.New("forbidden method close was called")
}
func (ec errorCloser) Len() int {
func (ec errorCloser) Length() int {
return ec.l
}
func (ec errorCloser) Rewind() error {
_, err := ec.ReadSeeker.Seek(0, io.SeekStart)
return err
}
// TestSave tests saving data in the backend.
func (s *Suite) TestSave(t *testing.T) {
seedRand(t)
@ -480,7 +484,7 @@ func (s *Suite) TestSave(t *testing.T) {
Type: restic.DataFile,
Name: fmt.Sprintf("%s-%d", id, i),
}
err := b.Save(context.TODO(), h, bytes.NewReader(data))
err := b.Save(context.TODO(), h, restic.NewByteReader(data))
test.OK(t, err)
buf, err := backend.LoadAll(context.TODO(), b, h)
@ -532,7 +536,7 @@ func (s *Suite) TestSave(t *testing.T) {
// wrap the tempfile in an errorCloser, so we can detect if the backend
// closes the reader
err = b.Save(context.TODO(), h, errorCloser{t: t, l: length, Reader: tmpfile})
err = b.Save(context.TODO(), h, errorCloser{t: t, l: length, ReadSeeker: tmpfile})
if err != nil {
t.Fatal(err)
}
@ -542,25 +546,10 @@ func (s *Suite) TestSave(t *testing.T) {
t.Fatalf("error removing item: %+v", err)
}
// try again directly with the temp file
if _, err = tmpfile.Seek(588, io.SeekStart); err != nil {
t.Fatal(err)
}
err = b.Save(context.TODO(), h, tmpfile)
if err != nil {
t.Fatal(err)
}
if err = tmpfile.Close(); err != nil {
t.Fatal(err)
}
err = b.Remove(context.TODO(), h)
if err != nil {
t.Fatalf("error removing item: %+v", err)
}
if err = os.Remove(tmpfile.Name()); err != nil {
t.Fatal(err)
}
@ -585,7 +574,7 @@ func (s *Suite) TestSaveFilenames(t *testing.T) {
for i, test := range filenameTests {
h := restic.Handle{Name: test.name, Type: restic.DataFile}
err := b.Save(context.TODO(), h, strings.NewReader(test.data))
err := b.Save(context.TODO(), h, restic.NewByteReader([]byte(test.data)))
if err != nil {
t.Errorf("test %d failed: Save() returned %+v", i, err)
continue
@ -622,7 +611,7 @@ var testStrings = []struct {
func store(t testing.TB, b restic.Backend, tpe restic.FileType, data []byte) restic.Handle {
id := restic.Hash(data)
h := restic.Handle{Name: id.String(), Type: tpe}
err := b.Save(context.TODO(), h, bytes.NewReader(data))
err := b.Save(context.TODO(), h, restic.NewByteReader([]byte(data)))
test.OK(t, err)
return h
}
@ -776,7 +765,7 @@ func (s *Suite) TestBackend(t *testing.T) {
test.Assert(t, !ok, "removed blob still present")
// create blob
err = b.Save(context.TODO(), h, strings.NewReader(ts.data))
err = b.Save(context.TODO(), h, restic.NewByteReader([]byte(ts.data)))
test.OK(t, err)
// list items

View File

@ -24,7 +24,8 @@ func TestLoadAll(t *testing.T) {
data := rtest.Random(23+i, rand.Intn(MiB)+500*KiB)
id := restic.Hash(data)
err := b.Save(context.TODO(), restic.Handle{Name: id.String(), Type: restic.DataFile}, bytes.NewReader(data))
h := restic.Handle{Name: id.String(), Type: restic.DataFile}
err := b.Save(context.TODO(), h, restic.NewByteReader(data))
rtest.OK(t, err)
buf, err := backend.LoadAll(context.TODO(), b, restic.Handle{Type: restic.DataFile, Name: id.String()})
@ -49,7 +50,8 @@ func TestLoadSmallBuffer(t *testing.T) {
data := rtest.Random(23+i, rand.Intn(MiB)+500*KiB)
id := restic.Hash(data)
err := b.Save(context.TODO(), restic.Handle{Name: id.String(), Type: restic.DataFile}, bytes.NewReader(data))
h := restic.Handle{Name: id.String(), Type: restic.DataFile}
err := b.Save(context.TODO(), h, restic.NewByteReader(data))
rtest.OK(t, err)
buf, err := backend.LoadAll(context.TODO(), b, restic.Handle{Type: restic.DataFile, Name: id.String()})
@ -74,7 +76,8 @@ func TestLoadLargeBuffer(t *testing.T) {
data := rtest.Random(23+i, rand.Intn(MiB)+500*KiB)
id := restic.Hash(data)
err := b.Save(context.TODO(), restic.Handle{Name: id.String(), Type: restic.DataFile}, bytes.NewReader(data))
h := restic.Handle{Name: id.String(), Type: restic.DataFile}
err := b.Save(context.TODO(), h, restic.NewByteReader(data))
rtest.OK(t, err)
buf, err := backend.LoadAll(context.TODO(), b, restic.Handle{Type: restic.DataFile, Name: id.String()})

View File

@ -5,7 +5,6 @@ import (
"io"
"sync"
"github.com/pkg/errors"
"github.com/restic/restic/internal/debug"
"github.com/restic/restic/internal/restic"
)
@ -50,35 +49,29 @@ var autoCacheTypes = map[restic.FileType]struct{}{
}
// Save stores a new file in the backend and the cache.
func (b *Backend) Save(ctx context.Context, h restic.Handle, rd io.Reader) (err error) {
func (b *Backend) Save(ctx context.Context, h restic.Handle, rd restic.RewindReader) error {
if _, ok := autoCacheTypes[h.Type]; !ok {
return b.Backend.Save(ctx, h, rd)
}
debug.Log("Save(%v): auto-store in the cache", h)
seeker, ok := rd.(io.Seeker)
if !ok {
return errors.New("reader is not a seeker")
}
pos, err := seeker.Seek(0, io.SeekCurrent)
// make sure the reader is at the start
err := rd.Rewind()
if err != nil {
return errors.Wrapf(err, "Seek")
}
if pos != 0 {
return errors.Errorf("reader is not rewind (pos %d)", pos)
return err
}
// first, save in the backend
err = b.Backend.Save(ctx, h, rd)
if err != nil {
return err
}
_, err = seeker.Seek(pos, io.SeekStart)
// next, save in the cache
err = rd.Rewind()
if err != nil {
return errors.Wrapf(err, "Seek")
return err
}
err = b.Cache.Save(h, rd)

View File

@ -28,7 +28,7 @@ func loadAndCompare(t testing.TB, be restic.Backend, h restic.Handle, data []byt
}
func save(t testing.TB, be restic.Backend, h restic.Handle, data []byte) {
err := be.Save(context.TODO(), h, bytes.NewReader(data))
err := be.Save(context.TODO(), h, restic.NewByteReader(data))
if err != nil {
t.Fatal(err)
}

View File

@ -3,7 +3,9 @@ package checker_test
import (
"context"
"io"
"io/ioutil"
"math/rand"
"os"
"path/filepath"
"sort"
"testing"
@ -195,17 +197,47 @@ func TestModifiedIndex(t *testing.T) {
Type: restic.IndexFile,
Name: "90f838b4ac28735fda8644fe6a08dbc742e57aaf81b30977b4fefa357010eafd",
}
err := repo.Backend().Load(context.TODO(), h, 0, 0, func(rd io.Reader) error {
// save the index again with a modified name so that the hash doesn't match
// the content any more
h2 := restic.Handle{
Type: restic.IndexFile,
Name: "80f838b4ac28735fda8644fe6a08dbc742e57aaf81b30977b4fefa357010eafd",
tmpfile, err := ioutil.TempFile("", "restic-test-mod-index-")
if err != nil {
t.Fatal(err)
}
defer func() {
err := tmpfile.Close()
if err != nil {
t.Fatal(err)
}
return repo.Backend().Save(context.TODO(), h2, rd)
err = os.Remove(tmpfile.Name())
if err != nil {
t.Fatal(err)
}
}()
// read the file from the backend
err = repo.Backend().Load(context.TODO(), h, 0, 0, func(rd io.Reader) error {
_, err := io.Copy(tmpfile, rd)
return err
})
test.OK(t, err)
// save the index again with a modified name so that the hash doesn't match
// the content any more
h2 := restic.Handle{
Type: restic.IndexFile,
Name: "80f838b4ac28735fda8644fe6a08dbc742e57aaf81b30977b4fefa357010eafd",
}
rd, err := restic.NewFileReader(tmpfile)
if err != nil {
t.Fatal(err)
}
err = repo.Backend().Save(context.TODO(), h2, rd)
if err != nil {
t.Fatal(err)
}
chkr := checker.New(repo)
hints, errs := chkr.LoadIndex(context.TODO())
if len(errs) == 0 {

View File

@ -21,8 +21,23 @@ type rateLimitedBackend struct {
limiter Limiter
}
func (r rateLimitedBackend) Save(ctx context.Context, h restic.Handle, rd io.Reader) error {
return r.Backend.Save(ctx, h, r.limiter.Upstream(rd))
func (r rateLimitedBackend) Save(ctx context.Context, h restic.Handle, rd restic.RewindReader) error {
limited := limitedRewindReader{
RewindReader: rd,
limited: r.limiter.Upstream(rd),
}
return r.Backend.Save(ctx, h, limited)
}
type limitedRewindReader struct {
restic.RewindReader
limited io.Reader
}
func (l limitedRewindReader) Read(b []byte) (int, error) {
return l.limited.Read(b)
}
func (r rateLimitedBackend) Load(ctx context.Context, h restic.Handle, length int, offset int64, consumer func(rd io.Reader) error) error {

View File

@ -12,7 +12,7 @@ import (
type Backend struct {
CloseFn func() error
IsNotExistFn func(err error) bool
SaveFn func(ctx context.Context, h restic.Handle, rd io.Reader) error
SaveFn func(ctx context.Context, h restic.Handle, rd restic.RewindReader) error
OpenReaderFn func(ctx context.Context, h restic.Handle, length int, offset int64) (io.ReadCloser, error)
StatFn func(ctx context.Context, h restic.Handle) (restic.FileInfo, error)
ListFn func(ctx context.Context, t restic.FileType, fn func(restic.FileInfo) error) error
@ -56,7 +56,7 @@ func (m *Backend) IsNotExist(err error) bool {
}
// Save data in the backend.
func (m *Backend) Save(ctx context.Context, h restic.Handle, rd io.Reader) error {
func (m *Backend) Save(ctx context.Context, h restic.Handle, rd restic.RewindReader) error {
if m.SaveFn == nil {
return errors.New("not implemented")
}

View File

@ -127,7 +127,7 @@ func TestUnpackReadSeeker(t *testing.T) {
id := restic.Hash(packData)
handle := restic.Handle{Type: restic.DataFile, Name: id.String()}
rtest.OK(t, b.Save(context.TODO(), handle, bytes.NewReader(packData)))
rtest.OK(t, b.Save(context.TODO(), handle, restic.NewByteReader(packData)))
verifyBlobs(t, bufs, k, restic.ReaderAt(b, handle), packSize)
}
@ -140,6 +140,6 @@ func TestShortPack(t *testing.T) {
id := restic.Hash(packData)
handle := restic.Handle{Type: restic.DataFile, Name: id.String()}
rtest.OK(t, b.Save(context.TODO(), handle, bytes.NewReader(packData)))
rtest.OK(t, b.Save(context.TODO(), handle, restic.NewByteReader(packData)))
verifyBlobs(t, bufs, k, restic.ReaderAt(b, handle), packSize)
}

View File

@ -1,7 +1,6 @@
package repository
import (
"bytes"
"context"
"encoding/json"
"fmt"
@ -250,7 +249,7 @@ func AddKey(ctx context.Context, s *Repository, password string, template *crypt
Name: restic.Hash(buf).String(),
}
err = s.be.Save(ctx, h, bytes.NewReader(buf))
err = s.be.Save(ctx, h, restic.NewByteReader(buf))
if err != nil {
return nil, err
}

View File

@ -3,7 +3,6 @@ package repository
import (
"context"
"crypto/sha256"
"io"
"os"
"sync"
@ -19,7 +18,7 @@ import (
// Saver implements saving data in a backend.
type Saver interface {
Save(context.Context, restic.Handle, io.Reader) error
Save(context.Context, restic.Handle, restic.RewindReader) error
}
// Packer holds a pack.Packer together with a hash writer.
@ -96,15 +95,15 @@ func (r *Repository) savePacker(ctx context.Context, t restic.BlobType, p *Packe
return err
}
_, err = p.tmpfile.Seek(0, 0)
if err != nil {
return errors.Wrap(err, "Seek")
}
id := restic.IDFromHash(p.hw.Sum(nil))
h := restic.Handle{Type: restic.DataFile, Name: id.String()}
err = r.be.Save(ctx, h, p.tmpfile)
rd, err := restic.NewFileReader(p.tmpfile)
if err != nil {
return err
}
err = r.be.Save(ctx, h, rd)
if err != nil {
debug.Log("Save(%v) error: %v", h, err)
return err

View File

@ -50,11 +50,17 @@ func randomID(rd io.Reader) restic.ID {
const maxBlobSize = 1 << 20
func saveFile(t testing.TB, be Saver, f *os.File, id restic.ID) {
func saveFile(t testing.TB, be Saver, length int, f *os.File, id restic.ID) {
h := restic.Handle{Type: restic.DataFile, Name: id.String()}
t.Logf("save file %v", h)
if err := be.Save(context.TODO(), h, f); err != nil {
rd, err := restic.NewFileReader(f)
if err != nil {
t.Fatal(err)
}
err = be.Save(context.TODO(), h, rd)
if err != nil {
t.Fatal(err)
}
@ -101,12 +107,8 @@ func fillPacks(t testing.TB, rnd *randReader, be Saver, pm *packerManager, buf [
t.Fatal(err)
}
if _, err = packer.tmpfile.Seek(0, 0); err != nil {
t.Fatal(err)
}
packID := restic.IDFromHash(packer.hw.Sum(nil))
saveFile(t, be, packer.tmpfile, packID)
saveFile(t, be, int(packer.Size()), packer.tmpfile, packID)
}
return bytes
@ -122,7 +124,7 @@ func flushRemainingPacks(t testing.TB, rnd *randReader, be Saver, pm *packerMana
bytes += int(n)
packID := restic.IDFromHash(packer.hw.Sum(nil))
saveFile(t, be, packer.tmpfile, packID)
saveFile(t, be, int(packer.Size()), packer.tmpfile, packID)
}
}
@ -147,7 +149,7 @@ func BenchmarkPackerManager(t *testing.B) {
rnd := newRandReader(rand.NewSource(23))
be := &mock.Backend{
SaveFn: func(context.Context, restic.Handle, io.Reader) error { return nil },
SaveFn: func(context.Context, restic.Handle, restic.RewindReader) error { return nil },
}
blobBuf := make([]byte, maxBlobSize)

View File

@ -282,7 +282,7 @@ func (r *Repository) SaveUnpacked(ctx context.Context, t restic.FileType, p []by
id = restic.Hash(ciphertext)
h := restic.Handle{Type: t, Name: id.String()}
err = r.be.Save(ctx, h, bytes.NewReader(ciphertext))
err = r.be.Save(ctx, h, restic.NewByteReader(ciphertext))
if err != nil {
debug.Log("error saving blob %v: %v", h, err)
return restic.ID{}, err
@ -456,11 +456,7 @@ func (r *Repository) LoadIndex(ctx context.Context) error {
}
}
if err := <-errCh; err != nil {
return err
}
return nil
return <-errCh
}
// LoadIndex loads the index id from backend and returns it.

View File

@ -20,8 +20,8 @@ type Backend interface {
// Close the backend
Close() error
// Save stores the data in the backend under the given handle.
Save(ctx context.Context, h Handle, rd io.Reader) error
// Save stores the data from rd under the given handle.
Save(ctx context.Context, h Handle, rd RewindReader) error
// Load runs fn with a reader that yields the contents of the file at h at the
// given offset. If length is larger than zero, only a portion of the file

View File

@ -0,0 +1,90 @@
package restic
import (
"bytes"
"io"
"github.com/restic/restic/internal/errors"
)
// RewindReader allows resetting the Reader to the beginning of the data.
type RewindReader interface {
io.Reader
// Rewind rewinds the reader so the same data can be read again from the
// start.
Rewind() error
// Length returns the number of bytes that can be read from the Reader
// after calling Rewind.
Length() int
}
// ByteReader implements a RewindReader for a byte slice.
type ByteReader struct {
*bytes.Reader
Len int
}
// Rewind restarts the reader from the beginning of the data.
func (b *ByteReader) Rewind() error {
_, err := b.Reader.Seek(0, io.SeekStart)
return err
}
// Length returns the number of bytes read from the reader after Rewind is
// called.
func (b *ByteReader) Length() int {
return b.Len
}
// statically ensure that *ByteReader implements RewindReader.
var _ RewindReader = &ByteReader{}
// NewByteReader prepares a ByteReader that can then be used to read buf.
func NewByteReader(buf []byte) *ByteReader {
return &ByteReader{
Reader: bytes.NewReader(buf),
Len: len(buf),
}
}
// statically ensure that *FileReader implements RewindReader.
var _ RewindReader = &FileReader{}
// FileReader implements a RewindReader for an open file.
type FileReader struct {
io.ReadSeeker
Len int
}
// Rewind seeks to the beginning of the file.
func (f *FileReader) Rewind() error {
_, err := f.ReadSeeker.Seek(0, io.SeekStart)
return errors.Wrap(err, "Seek")
}
// Length returns the length of the file.
func (f *FileReader) Length() int {
return f.Len
}
// NewFileReader wraps f in a *FileReader.
func NewFileReader(f io.ReadSeeker) (*FileReader, error) {
pos, err := f.Seek(0, io.SeekEnd)
if err != nil {
return nil, errors.Wrap(err, "Seek")
}
fr := &FileReader{
ReadSeeker: f,
Len: int(pos),
}
err = fr.Rewind()
if err != nil {
return nil, err
}
return fr, nil
}

View File

@ -0,0 +1,154 @@
package restic
import (
"bytes"
"io"
"io/ioutil"
"math/rand"
"os"
"path/filepath"
"testing"
"time"
"github.com/restic/restic/internal/test"
)
func TestByteReader(t *testing.T) {
buf := []byte("foobar")
fn := func() RewindReader {
return NewByteReader(buf)
}
testRewindReader(t, fn, buf)
}
func TestFileReader(t *testing.T) {
buf := []byte("foobar")
d, cleanup := test.TempDir(t)
defer cleanup()
filename := filepath.Join(d, "file-reader-test")
err := ioutil.WriteFile(filename, []byte("foobar"), 0600)
if err != nil {
t.Fatal(err)
}
f, err := os.Open(filename)
if err != nil {
t.Fatal(err)
}
defer func() {
err := f.Close()
if err != nil {
t.Fatal(err)
}
}()
fn := func() RewindReader {
rd, err := NewFileReader(f)
if err != nil {
t.Fatal(err)
}
return rd
}
testRewindReader(t, fn, buf)
}
func testRewindReader(t *testing.T, fn func() RewindReader, data []byte) {
seed := time.Now().Unix()
t.Logf("seed is %d", seed)
rnd := rand.New(rand.NewSource(seed))
type ReaderTestFunc func(t testing.TB, r RewindReader, data []byte)
var tests = []ReaderTestFunc{
func(t testing.TB, rd RewindReader, data []byte) {
if rd.Length() != len(data) {
t.Fatalf("wrong length returned, want %d, got %d", len(data), rd.Length())
}
buf := make([]byte, len(data))
_, err := io.ReadFull(rd, buf)
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(buf, data) {
t.Fatalf("wrong data returned")
}
if rd.Length() != len(data) {
t.Fatalf("wrong length returned, want %d, got %d", len(data), rd.Length())
}
err = rd.Rewind()
if err != nil {
t.Fatal(err)
}
if rd.Length() != len(data) {
t.Fatalf("wrong length returned, want %d, got %d", len(data), rd.Length())
}
buf2 := make([]byte, len(data))
_, err = io.ReadFull(rd, buf2)
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(buf2, data) {
t.Fatalf("wrong data returned")
}
if rd.Length() != len(data) {
t.Fatalf("wrong length returned, want %d, got %d", len(data), rd.Length())
}
},
func(t testing.TB, rd RewindReader, data []byte) {
// read first bytes
buf := make([]byte, rnd.Intn(len(data)))
_, err := io.ReadFull(rd, buf)
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(buf, data[:len(buf)]) {
t.Fatalf("wrong data returned")
}
err = rd.Rewind()
if err != nil {
t.Fatal(err)
}
buf2 := make([]byte, rnd.Intn(len(data)))
_, err = io.ReadFull(rd, buf2)
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(buf2, data[:len(buf2)]) {
t.Fatalf("wrong data returned")
}
// read remainder
buf3 := make([]byte, len(data)-len(buf2))
_, err = io.ReadFull(rd, buf3)
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(buf3, data[len(buf2):]) {
t.Fatalf("wrong data returned")
}
},
}
for _, test := range tests {
t.Run("", func(t *testing.T) {
rd := fn()
test(t, rd, data)
})
}
}