diff --git a/cmd/restic/cmd_cat.go b/cmd/restic/cmd_cat.go index e735daf88..0dffc0508 100644 --- a/cmd/restic/cmd_cat.go +++ b/cmd/restic/cmd_cat.go @@ -74,7 +74,7 @@ func runCat(gopts GlobalOptions, args []string) error { fmt.Println(string(buf)) return nil case "index": - buf, err := repo.LoadAndDecrypt(gopts.ctx, restic.IndexFile, id) + buf, err := repo.LoadAndDecrypt(gopts.ctx, nil, restic.IndexFile, id) if err != nil { return err } @@ -99,7 +99,7 @@ func runCat(gopts GlobalOptions, args []string) error { return nil case "key": h := restic.Handle{Type: restic.KeyFile, Name: id.String()} - buf, err := backend.LoadAll(gopts.ctx, repo.Backend(), h) + buf, err := backend.LoadAll(gopts.ctx, nil, repo.Backend(), h) if err != nil { return err } @@ -150,7 +150,7 @@ func runCat(gopts GlobalOptions, args []string) error { switch tpe { case "pack": h := restic.Handle{Type: restic.DataFile, Name: id.String()} - buf, err := backend.LoadAll(gopts.ctx, repo.Backend(), h) + buf, err := backend.LoadAll(gopts.ctx, nil, repo.Backend(), h) if err != nil { return err } diff --git a/internal/backend/test/tests.go b/internal/backend/test/tests.go index dec1e0bee..7e9f7f5ab 100644 --- a/internal/backend/test/tests.go +++ b/internal/backend/test/tests.go @@ -79,7 +79,7 @@ func (s *Suite) TestConfig(t *testing.T) { var testString = "Config" // create config and read it back - _, err := backend.LoadAll(context.TODO(), b, restic.Handle{Type: restic.ConfigFile}) + _, err := backend.LoadAll(context.TODO(), nil, b, restic.Handle{Type: restic.ConfigFile}) if err == nil { t.Fatalf("did not get expected error for non-existing config") } @@ -93,7 +93,7 @@ func (s *Suite) TestConfig(t *testing.T) { // same config for _, name := range []string{"", "foo", "bar", "0000000000000000000000000000000000000000000000000000000000000000"} { h := restic.Handle{Type: restic.ConfigFile, Name: name} - buf, err := backend.LoadAll(context.TODO(), b, h) + buf, err := backend.LoadAll(context.TODO(), nil, b, h) if err != nil { t.Fatalf("unable to read config with name %q: %+v", name, err) } @@ -491,7 +491,7 @@ func (s *Suite) TestSave(t *testing.T) { err := b.Save(context.TODO(), h, restic.NewByteReader(data)) test.OK(t, err) - buf, err := backend.LoadAll(context.TODO(), b, h) + buf, err := backend.LoadAll(context.TODO(), nil, b, h) test.OK(t, err) if len(buf) != len(data) { t.Fatalf("number of bytes does not match, want %v, got %v", len(data), len(buf)) @@ -584,7 +584,7 @@ func (s *Suite) TestSaveFilenames(t *testing.T) { continue } - buf, err := backend.LoadAll(context.TODO(), b, h) + buf, err := backend.LoadAll(context.TODO(), nil, b, h) if err != nil { t.Errorf("test %d failed: Load() returned %+v", i, err) continue @@ -734,7 +734,7 @@ func (s *Suite) TestBackend(t *testing.T) { // test Load() h := restic.Handle{Type: tpe, Name: ts.id} - buf, err := backend.LoadAll(context.TODO(), b, h) + buf, err := backend.LoadAll(context.TODO(), nil, b, h) test.OK(t, err) test.Equals(t, ts.data, string(buf)) diff --git a/internal/backend/utils.go b/internal/backend/utils.go index 222f210e5..1665aedc6 100644 --- a/internal/backend/utils.go +++ b/internal/backend/utils.go @@ -1,20 +1,33 @@ package backend import ( + "bytes" "context" "io" - "io/ioutil" "github.com/restic/restic/internal/restic" ) -// LoadAll reads all data stored in the backend for the handle. -func LoadAll(ctx context.Context, be restic.Backend, h restic.Handle) (buf []byte, err error) { - err = be.Load(ctx, h, 0, 0, func(rd io.Reader) (ierr error) { - buf, ierr = ioutil.ReadAll(rd) - return ierr +// LoadAll reads all data stored in the backend for the handle into the given +// buffer, which is truncated. If the buffer is not large enough or nil, a new +// one is allocated. +func LoadAll(ctx context.Context, buf []byte, be restic.Backend, h restic.Handle) ([]byte, error) { + err := be.Load(ctx, h, 0, 0, func(rd io.Reader) error { + // make sure this is idempotent, in case an error occurs this function may be called multiple times! + wr := bytes.NewBuffer(buf[:0]) + _, cerr := io.Copy(wr, rd) + if cerr != nil { + return cerr + } + buf = wr.Bytes() + return nil }) - return buf, err + + if err != nil { + return nil, err + } + + return buf, nil } // LimitedReadCloser wraps io.LimitedReader and exposes the Close() method. diff --git a/internal/backend/utils_test.go b/internal/backend/utils_test.go index 74929fd0b..a29add676 100644 --- a/internal/backend/utils_test.go +++ b/internal/backend/utils_test.go @@ -19,6 +19,7 @@ const MiB = 1 << 20 func TestLoadAll(t *testing.T) { b := mem.New() + var buf []byte for i := 0; i < 20; i++ { data := rtest.Random(23+i, rand.Intn(MiB)+500*KiB) @@ -28,7 +29,7 @@ func TestLoadAll(t *testing.T) { err := b.Save(context.TODO(), h, restic.NewByteReader(data)) rtest.OK(t, err) - buf, err := backend.LoadAll(context.TODO(), b, restic.Handle{Type: restic.DataFile, Name: id.String()}) + buf, err := backend.LoadAll(context.TODO(), buf, b, restic.Handle{Type: restic.DataFile, Name: id.String()}) rtest.OK(t, err) if len(buf) != len(data) { @@ -43,55 +44,66 @@ func TestLoadAll(t *testing.T) { } } -func TestLoadSmallBuffer(t *testing.T) { - b := mem.New() - - for i := 0; i < 20; i++ { - data := rtest.Random(23+i, rand.Intn(MiB)+500*KiB) - - id := restic.Hash(data) - h := restic.Handle{Name: id.String(), Type: restic.DataFile} - err := b.Save(context.TODO(), h, restic.NewByteReader(data)) - rtest.OK(t, err) - - buf, err := backend.LoadAll(context.TODO(), b, restic.Handle{Type: restic.DataFile, Name: id.String()}) - rtest.OK(t, err) - - if len(buf) != len(data) { - t.Errorf("length of returned buffer does not match, want %d, got %d", len(data), len(buf)) - continue - } - - if !bytes.Equal(buf, data) { - t.Errorf("wrong data returned") - continue - } +func save(t testing.TB, be restic.Backend, buf []byte) restic.Handle { + id := restic.Hash(buf) + h := restic.Handle{Name: id.String(), Type: restic.DataFile} + err := be.Save(context.TODO(), h, restic.NewByteReader(buf)) + if err != nil { + t.Fatal(err) } + return h } -func TestLoadLargeBuffer(t *testing.T) { +func TestLoadAllAppend(t *testing.T) { b := mem.New() - for i := 0; i < 20; i++ { - data := rtest.Random(23+i, rand.Intn(MiB)+500*KiB) + h1 := save(t, b, []byte("foobar test string")) + randomData := rtest.Random(23, rand.Intn(MiB)+500*KiB) + h2 := save(t, b, randomData) - id := restic.Hash(data) - h := restic.Handle{Name: id.String(), Type: restic.DataFile} - err := b.Save(context.TODO(), h, restic.NewByteReader(data)) - rtest.OK(t, err) + var tests = []struct { + handle restic.Handle + buf []byte + want []byte + }{ + { + handle: h1, + buf: nil, + want: []byte("foobar test string"), + }, + { + handle: h1, + buf: []byte("xxx"), + want: []byte("foobar test string"), + }, + { + handle: h2, + buf: nil, + want: randomData, + }, + { + handle: h2, + buf: make([]byte, 0, 200), + want: randomData, + }, + { + handle: h2, + buf: []byte("foobarbaz"), + want: randomData, + }, + } - buf, err := backend.LoadAll(context.TODO(), b, restic.Handle{Type: restic.DataFile, Name: id.String()}) - rtest.OK(t, err) + for _, test := range tests { + t.Run("", func(t *testing.T) { + buf, err := backend.LoadAll(context.TODO(), test.buf, b, test.handle) + if err != nil { + t.Fatal(err) + } - if len(buf) != len(data) { - t.Errorf("length of returned buffer does not match, want %d, got %d", len(data), len(buf)) - continue - } - - if !bytes.Equal(buf, data) { - t.Errorf("wrong data returned") - continue - } + if !bytes.Equal(buf, test.want) { + t.Errorf("wrong data returned, want %q, got %q", test.want, buf) + } + }) } } diff --git a/internal/cache/backend_test.go b/internal/cache/backend_test.go index b4cc431ac..872ddbde1 100644 --- a/internal/cache/backend_test.go +++ b/internal/cache/backend_test.go @@ -17,7 +17,7 @@ import ( ) func loadAndCompare(t testing.TB, be restic.Backend, h restic.Handle, data []byte) { - buf, err := backend.LoadAll(context.TODO(), be, h) + buf, err := backend.LoadAll(context.TODO(), nil, be, h) if err != nil { t.Fatal(err) } @@ -147,7 +147,7 @@ func TestErrorBackend(t *testing.T) { loadTest := func(wg *sync.WaitGroup, be restic.Backend) { defer wg.Done() - buf, err := backend.LoadAll(context.TODO(), be, h) + buf, err := backend.LoadAll(context.TODO(), nil, be, h) if err == testErr { return } diff --git a/internal/repository/index.go b/internal/repository/index.go index 8d6d64c3e..ef6661dfc 100644 --- a/internal/repository/index.go +++ b/internal/repository/index.go @@ -552,7 +552,7 @@ func DecodeOldIndex(buf []byte) (idx *Index, err error) { func LoadIndexWithDecoder(ctx context.Context, repo restic.Repository, id restic.ID, fn func([]byte) (*Index, error)) (idx *Index, err error) { debug.Log("Loading index %v", id) - buf, err := repo.LoadAndDecrypt(ctx, restic.IndexFile, id) + buf, err := repo.LoadAndDecrypt(ctx, nil, restic.IndexFile, id) if err != nil { return nil, err } diff --git a/internal/repository/key.go b/internal/repository/key.go index 46e3b912f..62558c0b3 100644 --- a/internal/repository/key.go +++ b/internal/repository/key.go @@ -184,7 +184,7 @@ func SearchKey(ctx context.Context, s *Repository, password string, maxKeys int, // LoadKey loads a key from the backend. func LoadKey(ctx context.Context, s *Repository, name string) (k *Key, err error) { h := restic.Handle{Type: restic.KeyFile, Name: name} - data, err := backend.LoadAll(ctx, s.be, h) + data, err := backend.LoadAll(ctx, nil, s.be, h) if err != nil { return nil, err } diff --git a/internal/repository/repository.go b/internal/repository/repository.go index 880006cf5..6ab8ca595 100644 --- a/internal/repository/repository.go +++ b/internal/repository/repository.go @@ -9,7 +9,6 @@ import ( "io" "os" - "github.com/restic/restic/internal/backend" "github.com/restic/restic/internal/cache" "github.com/restic/restic/internal/crypto" "github.com/restic/restic/internal/debug" @@ -67,15 +66,29 @@ func (r *Repository) PrefixLength(t restic.FileType) (int, error) { return restic.PrefixLength(r.be, t) } -// LoadAndDecrypt loads and decrypts data identified by t and id from the -// backend. -func (r *Repository) LoadAndDecrypt(ctx context.Context, t restic.FileType, id restic.ID) (buf []byte, err error) { +// LoadAndDecrypt loads and decrypts the file with the given type and ID, using +// the supplied buffer (which must be empty). If the buffer is nil, a new +// buffer will be allocated and returned. +func (r *Repository) LoadAndDecrypt(ctx context.Context, buf []byte, t restic.FileType, id restic.ID) ([]byte, error) { + if len(buf) != 0 { + panic("buf is not empty") + } + debug.Log("load %v with id %v", t, id) h := restic.Handle{Type: t, Name: id.String()} - buf, err = backend.LoadAll(ctx, r.be, h) + err := r.be.Load(ctx, h, 0, 0, func(rd io.Reader) error { + // make sure this call is idempotent, in case an error occurs + wr := bytes.NewBuffer(buf[:0]) + _, cerr := io.Copy(wr, rd) + if cerr != nil { + return cerr + } + buf = wr.Bytes() + return nil + }) + if err != nil { - debug.Log("error loading %v: %v", h, err) return nil, err } @@ -188,7 +201,7 @@ func (r *Repository) loadBlob(ctx context.Context, id restic.ID, t restic.BlobTy // LoadJSONUnpacked decrypts the data and afterwards calls json.Unmarshal on // the item. func (r *Repository) LoadJSONUnpacked(ctx context.Context, t restic.FileType, id restic.ID, item interface{}) (err error) { - buf, err := r.LoadAndDecrypt(ctx, t, id) + buf, err := r.LoadAndDecrypt(ctx, nil, t, id) if err != nil { return err } diff --git a/internal/repository/repository_test.go b/internal/repository/repository_test.go index 8ea203d59..43d70c533 100644 --- a/internal/repository/repository_test.go +++ b/internal/repository/repository_test.go @@ -244,7 +244,7 @@ func BenchmarkLoadAndDecrypt(b *testing.B) { b.SetBytes(int64(length)) for i := 0; i < b.N; i++ { - data, err := repo.LoadAndDecrypt(context.TODO(), restic.DataFile, storageID) + data, err := repo.LoadAndDecrypt(context.TODO(), nil, restic.DataFile, storageID) rtest.OK(b, err) if len(data) != length { b.Errorf("wanted %d bytes, got %d", length, len(data)) diff --git a/internal/restic/repository.go b/internal/restic/repository.go index ff8f38034..46d7379db 100644 --- a/internal/restic/repository.go +++ b/internal/restic/repository.go @@ -39,8 +39,11 @@ type Repository interface { SaveUnpacked(context.Context, FileType, []byte) (ID, error) SaveJSONUnpacked(context.Context, FileType, interface{}) (ID, error) - LoadJSONUnpacked(context.Context, FileType, ID, interface{}) error - LoadAndDecrypt(context.Context, FileType, ID) ([]byte, error) + LoadJSONUnpacked(ctx context.Context, t FileType, id ID, dest interface{}) error + // LoadAndDecrypt loads and decrypts the file with the given type and ID, + // using the supplied buffer (which must be empty). If the buffer is nil, a + // new buffer will be allocated and returned. + LoadAndDecrypt(ctx context.Context, buf []byte, t FileType, id ID) (data []byte, err error) LoadBlob(context.Context, BlobType, ID, []byte) (int, error) SaveBlob(context.Context, BlobType, []byte, ID) (ID, error)