diff --git a/internal/backend/azure/azure.go b/internal/backend/azure/azure.go index 5a2e7d240..82bad1e6b 100644 --- a/internal/backend/azure/azure.go +++ b/internal/backend/azure/azure.go @@ -3,6 +3,7 @@ package azure import ( "context" "encoding/base64" + "hash" "io" "net/http" "os" @@ -112,6 +113,11 @@ func (be *Backend) Location() string { return be.Join(be.container.Name, be.prefix) } +// Hasher may return a hash function for calculating a content hash for the backend +func (be *Backend) Hasher() hash.Hash { + return nil +} + // Path returns the path in the bucket that is used for this backend. func (be *Backend) Path() string { return be.prefix diff --git a/internal/backend/azure/azure_test.go b/internal/backend/azure/azure_test.go index 8d6284cf5..a9ed94cd2 100644 --- a/internal/backend/azure/azure_test.go +++ b/internal/backend/azure/azure_test.go @@ -172,7 +172,7 @@ func TestUploadLargeFile(t *testing.T) { t.Logf("hash of %d bytes: %v", len(data), id) - err = be.Save(ctx, h, restic.NewByteReader(data)) + err = be.Save(ctx, h, restic.NewByteReader(data, be.Hasher())) if err != nil { t.Fatal(err) } diff --git a/internal/backend/b2/b2.go b/internal/backend/b2/b2.go index 704a85980..8d4116452 100644 --- a/internal/backend/b2/b2.go +++ b/internal/backend/b2/b2.go @@ -2,6 +2,7 @@ package b2 import ( "context" + "hash" "io" "net/http" "path" @@ -137,6 +138,11 @@ func (be *b2Backend) Location() string { return be.cfg.Bucket } +// Hasher may return a hash function for calculating a content hash for the backend +func (be *b2Backend) Hasher() hash.Hash { + return nil +} + // IsNotExist returns true if the error is caused by a non-existing file. func (be *b2Backend) IsNotExist(err error) bool { return b2.IsNotExist(errors.Cause(err)) diff --git a/internal/backend/backend_retry_test.go b/internal/backend/backend_retry_test.go index a746032c7..4013f4ea5 100644 --- a/internal/backend/backend_retry_test.go +++ b/internal/backend/backend_retry_test.go @@ -36,7 +36,7 @@ func TestBackendSaveRetry(t *testing.T) { retryBackend := NewRetryBackend(be, 10, nil) data := test.Random(23, 5*1024*1024+11241) - err := retryBackend.Save(context.TODO(), restic.Handle{}, restic.NewByteReader(data)) + err := retryBackend.Save(context.TODO(), restic.Handle{}, restic.NewByteReader(data, be.Hasher())) if err != nil { t.Fatal(err) } @@ -256,7 +256,7 @@ func TestBackendCanceledContext(t *testing.T) { _, err = retryBackend.Stat(ctx, h) assertIsCanceled(t, err) - err = retryBackend.Save(ctx, h, restic.NewByteReader([]byte{})) + err = retryBackend.Save(ctx, h, restic.NewByteReader([]byte{}, nil)) assertIsCanceled(t, err) err = retryBackend.Remove(ctx, h) assertIsCanceled(t, err) diff --git a/internal/backend/dryrun/dry_backend.go b/internal/backend/dryrun/dry_backend.go index 2b0735d66..8412bd26a 100644 --- a/internal/backend/dryrun/dry_backend.go +++ b/internal/backend/dryrun/dry_backend.go @@ -2,6 +2,7 @@ package dryrun import ( "context" + "hash" "io" "github.com/restic/restic/internal/debug" @@ -58,6 +59,10 @@ func (be *Backend) Close() error { return be.b.Close() } +func (be *Backend) Hasher() hash.Hash { + return be.b.Hasher() +} + func (be *Backend) IsNotExist(err error) bool { return be.b.IsNotExist(err) } diff --git a/internal/backend/dryrun/dry_backend_test.go b/internal/backend/dryrun/dry_backend_test.go index c3cabf801..1b512ad20 100644 --- a/internal/backend/dryrun/dry_backend_test.go +++ b/internal/backend/dryrun/dry_backend_test.go @@ -71,7 +71,7 @@ func TestDry(t *testing.T) { handle := restic.Handle{Type: restic.PackFile, Name: step.fname} switch step.op { case "save": - err = step.be.Save(ctx, handle, restic.NewByteReader([]byte(step.content))) + err = step.be.Save(ctx, handle, restic.NewByteReader([]byte(step.content), step.be.Hasher())) case "test": boolRes, err = step.be.Test(ctx, handle) if boolRes != (step.content != "") { diff --git a/internal/backend/gs/gs.go b/internal/backend/gs/gs.go index 70af32d25..30bb0b4ee 100644 --- a/internal/backend/gs/gs.go +++ b/internal/backend/gs/gs.go @@ -3,6 +3,7 @@ package gs import ( "context" + "hash" "io" "net/http" "os" @@ -188,6 +189,11 @@ func (be *Backend) Location() string { return be.Join(be.bucketName, be.prefix) } +// Hasher may return a hash function for calculating a content hash for the backend +func (be *Backend) Hasher() hash.Hash { + return nil +} + // Path returns the path in the bucket that is used for this backend. func (be *Backend) Path() string { return be.prefix diff --git a/internal/backend/local/local.go b/internal/backend/local/local.go index 33f81f182..7b1e35310 100644 --- a/internal/backend/local/local.go +++ b/internal/backend/local/local.go @@ -2,6 +2,7 @@ package local import ( "context" + "hash" "io" "io/ioutil" "os" @@ -77,6 +78,11 @@ func (b *Local) Location() string { return b.Path } +// Hasher may return a hash function for calculating a content hash for the backend +func (b *Local) Hasher() hash.Hash { + return nil +} + // IsNotExist returns true if the error is caused by a non existing file. func (b *Local) IsNotExist(err error) bool { return errors.Is(err, os.ErrNotExist) diff --git a/internal/backend/mem/mem_backend.go b/internal/backend/mem/mem_backend.go index 719ca46a5..0227ee6ed 100644 --- a/internal/backend/mem/mem_backend.go +++ b/internal/backend/mem/mem_backend.go @@ -3,6 +3,7 @@ package mem import ( "bytes" "context" + "hash" "io" "io/ioutil" "sync" @@ -214,6 +215,11 @@ func (be *MemoryBackend) Location() string { return "RAM" } +// Hasher may return a hash function for calculating a content hash for the backend +func (be *MemoryBackend) Hasher() hash.Hash { + return nil +} + // Delete removes all data in the backend. func (be *MemoryBackend) Delete(ctx context.Context) error { be.m.Lock() diff --git a/internal/backend/rest/rest.go b/internal/backend/rest/rest.go index 55732e871..c7675cba1 100644 --- a/internal/backend/rest/rest.go +++ b/internal/backend/rest/rest.go @@ -4,6 +4,7 @@ import ( "context" "encoding/json" "fmt" + "hash" "io" "io/ioutil" "net/http" @@ -109,6 +110,11 @@ func (b *Backend) Location() string { return b.url.String() } +// Hasher may return a hash function for calculating a content hash for the backend +func (b *Backend) Hasher() hash.Hash { + return nil +} + // Save stores data in the backend at the handle. func (b *Backend) Save(ctx context.Context, h restic.Handle, rd restic.RewindReader) error { if err := h.Valid(); err != nil { diff --git a/internal/backend/s3/s3.go b/internal/backend/s3/s3.go index 539ff5a4e..d94e7be84 100644 --- a/internal/backend/s3/s3.go +++ b/internal/backend/s3/s3.go @@ -3,6 +3,7 @@ package s3 import ( "context" "fmt" + "hash" "io" "io/ioutil" "net/http" @@ -250,6 +251,11 @@ func (be *Backend) Location() string { return be.Join(be.cfg.Bucket, be.cfg.Prefix) } +// Hasher may return a hash function for calculating a content hash for the backend +func (be *Backend) Hasher() hash.Hash { + return nil +} + // Path returns the path in the bucket that is used for this backend. func (be *Backend) Path() string { return be.cfg.Prefix diff --git a/internal/backend/sftp/sftp.go b/internal/backend/sftp/sftp.go index 535cb3bfb..110239277 100644 --- a/internal/backend/sftp/sftp.go +++ b/internal/backend/sftp/sftp.go @@ -4,6 +4,7 @@ import ( "bufio" "context" "fmt" + "hash" "io" "os" "os/exec" @@ -240,6 +241,11 @@ func (r *SFTP) Location() string { return r.p } +// Hasher may return a hash function for calculating a content hash for the backend +func (r *SFTP) Hasher() hash.Hash { + return nil +} + // Join joins the given paths and cleans them afterwards. This always uses // forward slashes, which is required by sftp. func Join(parts ...string) string { diff --git a/internal/backend/swift/swift.go b/internal/backend/swift/swift.go index 92b6567e3..02a246203 100644 --- a/internal/backend/swift/swift.go +++ b/internal/backend/swift/swift.go @@ -3,6 +3,7 @@ package swift import ( "context" "fmt" + "hash" "io" "net/http" "path" @@ -115,6 +116,11 @@ func (be *beSwift) Location() string { return be.container } +// Hasher may return a hash function for calculating a content hash for the backend +func (be *beSwift) Hasher() hash.Hash { + return nil +} + // Load runs fn with a reader that yields the contents of the file at h at the // given offset. func (be *beSwift) Load(ctx context.Context, h restic.Handle, length int, offset int64, fn func(rd io.Reader) error) error { diff --git a/internal/backend/test/benchmarks.go b/internal/backend/test/benchmarks.go index db8c5e750..b977eb682 100644 --- a/internal/backend/test/benchmarks.go +++ b/internal/backend/test/benchmarks.go @@ -14,7 +14,7 @@ func saveRandomFile(t testing.TB, be restic.Backend, length int) ([]byte, restic data := test.Random(23, length) id := restic.Hash(data) handle := restic.Handle{Type: restic.PackFile, Name: id.String()} - err := be.Save(context.TODO(), handle, restic.NewByteReader(data)) + err := be.Save(context.TODO(), handle, restic.NewByteReader(data, be.Hasher())) if err != nil { t.Fatalf("Save() error: %+v", err) } @@ -148,7 +148,7 @@ func (s *Suite) BenchmarkSave(t *testing.B) { id := restic.Hash(data) handle := restic.Handle{Type: restic.PackFile, Name: id.String()} - rd := restic.NewByteReader(data) + rd := restic.NewByteReader(data, be.Hasher()) t.SetBytes(int64(length)) t.ResetTimer() diff --git a/internal/backend/test/tests.go b/internal/backend/test/tests.go index ae9b58677..ebb209a3e 100644 --- a/internal/backend/test/tests.go +++ b/internal/backend/test/tests.go @@ -84,7 +84,7 @@ func (s *Suite) TestConfig(t *testing.T) { t.Fatalf("did not get expected error for non-existing config") } - err = b.Save(context.TODO(), restic.Handle{Type: restic.ConfigFile}, restic.NewByteReader([]byte(testString))) + err = b.Save(context.TODO(), restic.Handle{Type: restic.ConfigFile}, restic.NewByteReader([]byte(testString), b.Hasher())) if err != nil { t.Fatalf("Save() error: %+v", err) } @@ -134,7 +134,7 @@ func (s *Suite) TestLoad(t *testing.T) { id := restic.Hash(data) handle := restic.Handle{Type: restic.PackFile, Name: id.String()} - err = b.Save(context.TODO(), handle, restic.NewByteReader(data)) + err = b.Save(context.TODO(), handle, restic.NewByteReader(data, b.Hasher())) if err != nil { t.Fatalf("Save() error: %+v", err) } @@ -253,7 +253,7 @@ func (s *Suite) TestList(t *testing.T) { data := test.Random(rand.Int(), rand.Intn(100)+55) id := restic.Hash(data) h := restic.Handle{Type: restic.PackFile, Name: id.String()} - err := b.Save(context.TODO(), h, restic.NewByteReader(data)) + err := b.Save(context.TODO(), h, restic.NewByteReader(data, b.Hasher())) if err != nil { t.Fatal(err) } @@ -343,7 +343,7 @@ func (s *Suite) TestListCancel(t *testing.T) { data := []byte(fmt.Sprintf("random test blob %v", i)) id := restic.Hash(data) h := restic.Handle{Type: restic.PackFile, Name: id.String()} - err := b.Save(context.TODO(), h, restic.NewByteReader(data)) + err := b.Save(context.TODO(), h, restic.NewByteReader(data, b.Hasher())) if err != nil { t.Fatal(err) } @@ -447,6 +447,7 @@ type errorCloser struct { io.ReadSeeker l int64 t testing.TB + h []byte } func (ec errorCloser) Close() error { @@ -458,6 +459,10 @@ func (ec errorCloser) Length() int64 { return ec.l } +func (ec errorCloser) Hash() []byte { + return ec.h +} + func (ec errorCloser) Rewind() error { _, err := ec.ReadSeeker.Seek(0, io.SeekStart) return err @@ -486,7 +491,7 @@ func (s *Suite) TestSave(t *testing.T) { Type: restic.PackFile, Name: fmt.Sprintf("%s-%d", id, i), } - err := b.Save(context.TODO(), h, restic.NewByteReader(data)) + err := b.Save(context.TODO(), h, restic.NewByteReader(data, b.Hasher())) test.OK(t, err) buf, err := backend.LoadAll(context.TODO(), nil, b, h) @@ -538,7 +543,19 @@ func (s *Suite) TestSave(t *testing.T) { // wrap the tempfile in an errorCloser, so we can detect if the backend // closes the reader - err = b.Save(context.TODO(), h, errorCloser{t: t, l: int64(length), ReadSeeker: tmpfile}) + var beHash []byte + if b.Hasher() != nil { + beHasher := b.Hasher() + // must never fail according to interface + _, _ = beHasher.Write(data) + beHash = beHasher.Sum(nil) + } + err = b.Save(context.TODO(), h, errorCloser{ + t: t, + l: int64(length), + ReadSeeker: tmpfile, + h: beHash, + }) if err != nil { t.Fatal(err) } @@ -583,7 +600,7 @@ func (s *Suite) TestSaveError(t *testing.T) { // test that incomplete uploads fail h := restic.Handle{Type: restic.PackFile, Name: id.String()} - err := b.Save(context.TODO(), h, &incompleteByteReader{ByteReader: *restic.NewByteReader(data)}) + err := b.Save(context.TODO(), h, &incompleteByteReader{ByteReader: *restic.NewByteReader(data, b.Hasher())}) // try to delete possible leftovers _ = s.delayedRemove(t, b, h) if err == nil { @@ -610,7 +627,7 @@ func (s *Suite) TestSaveFilenames(t *testing.T) { for i, test := range filenameTests { h := restic.Handle{Name: test.name, Type: restic.PackFile} - err := b.Save(context.TODO(), h, restic.NewByteReader([]byte(test.data))) + err := b.Save(context.TODO(), h, restic.NewByteReader([]byte(test.data), b.Hasher())) if err != nil { t.Errorf("test %d failed: Save() returned %+v", i, err) continue @@ -647,7 +664,7 @@ var testStrings = []struct { func store(t testing.TB, b restic.Backend, tpe restic.FileType, data []byte) restic.Handle { id := restic.Hash(data) h := restic.Handle{Name: id.String(), Type: tpe} - err := b.Save(context.TODO(), h, restic.NewByteReader([]byte(data))) + err := b.Save(context.TODO(), h, restic.NewByteReader([]byte(data), b.Hasher())) test.OK(t, err) return h } @@ -801,7 +818,7 @@ func (s *Suite) TestBackend(t *testing.T) { test.Assert(t, !ok, "removed blob still present") // create blob - err = b.Save(context.TODO(), h, restic.NewByteReader([]byte(ts.data))) + err = b.Save(context.TODO(), h, restic.NewByteReader([]byte(ts.data), b.Hasher())) test.OK(t, err) // list items diff --git a/internal/backend/utils_test.go b/internal/backend/utils_test.go index e6bcfa4dc..1030537bc 100644 --- a/internal/backend/utils_test.go +++ b/internal/backend/utils_test.go @@ -26,7 +26,7 @@ func TestLoadAll(t *testing.T) { id := restic.Hash(data) h := restic.Handle{Name: id.String(), Type: restic.PackFile} - err := b.Save(context.TODO(), h, restic.NewByteReader(data)) + err := b.Save(context.TODO(), h, restic.NewByteReader(data, b.Hasher())) rtest.OK(t, err) buf, err := backend.LoadAll(context.TODO(), buf, b, restic.Handle{Type: restic.PackFile, Name: id.String()}) @@ -47,7 +47,7 @@ func TestLoadAll(t *testing.T) { func save(t testing.TB, be restic.Backend, buf []byte) restic.Handle { id := restic.Hash(buf) h := restic.Handle{Name: id.String(), Type: restic.PackFile} - err := be.Save(context.TODO(), h, restic.NewByteReader(buf)) + err := be.Save(context.TODO(), h, restic.NewByteReader(buf, be.Hasher())) if err != nil { t.Fatal(err) } diff --git a/internal/cache/backend_test.go b/internal/cache/backend_test.go index 872ddbde1..79b838eb2 100644 --- a/internal/cache/backend_test.go +++ b/internal/cache/backend_test.go @@ -32,7 +32,7 @@ func loadAndCompare(t testing.TB, be restic.Backend, h restic.Handle, data []byt } func save(t testing.TB, be restic.Backend, h restic.Handle, data []byte) { - err := be.Save(context.TODO(), h, restic.NewByteReader(data)) + err := be.Save(context.TODO(), h, restic.NewByteReader(data, be.Hasher())) if err != nil { t.Fatal(err) } diff --git a/internal/checker/checker_test.go b/internal/checker/checker_test.go index ade2813c7..6d17a4593 100644 --- a/internal/checker/checker_test.go +++ b/internal/checker/checker_test.go @@ -16,6 +16,7 @@ import ( "github.com/restic/restic/internal/archiver" "github.com/restic/restic/internal/checker" "github.com/restic/restic/internal/errors" + "github.com/restic/restic/internal/hashing" "github.com/restic/restic/internal/repository" "github.com/restic/restic/internal/restic" "github.com/restic/restic/internal/test" @@ -218,10 +219,16 @@ func TestModifiedIndex(t *testing.T) { t.Fatal(err) } }() + wr := io.Writer(tmpfile) + var hw *hashing.Writer + if repo.Backend().Hasher() != nil { + hw = hashing.NewWriter(wr, repo.Backend().Hasher()) + wr = hw + } // read the file from the backend err = repo.Backend().Load(context.TODO(), h, 0, 0, func(rd io.Reader) error { - _, err := io.Copy(tmpfile, rd) + _, err := io.Copy(wr, rd) return err }) test.OK(t, err) @@ -233,7 +240,11 @@ func TestModifiedIndex(t *testing.T) { Name: "80f838b4ac28735fda8644fe6a08dbc742e57aaf81b30977b4fefa357010eafd", } - rd, err := restic.NewFileReader(tmpfile) + var hash []byte + if hw != nil { + hash = hw.Sum(nil) + } + rd, err := restic.NewFileReader(tmpfile, hash) if err != nil { t.Fatal(err) } diff --git a/internal/limiter/limiter_backend_test.go b/internal/limiter/limiter_backend_test.go index 9bac9c70a..e8f0ae17d 100644 --- a/internal/limiter/limiter_backend_test.go +++ b/internal/limiter/limiter_backend_test.go @@ -39,7 +39,7 @@ func TestLimitBackendSave(t *testing.T) { limiter := NewStaticLimiter(42*1024, 42*1024) limbe := LimitBackend(be, limiter) - rd := restic.NewByteReader(data) + rd := restic.NewByteReader(data, nil) err := limbe.Save(context.TODO(), testHandle, rd) rtest.OK(t, err) } diff --git a/internal/mock/backend.go b/internal/mock/backend.go index e3759acbf..9f6036fdb 100644 --- a/internal/mock/backend.go +++ b/internal/mock/backend.go @@ -2,6 +2,7 @@ package mock import ( "context" + "hash" "io" "github.com/restic/restic/internal/errors" @@ -20,6 +21,7 @@ type Backend struct { TestFn func(ctx context.Context, h restic.Handle) (bool, error) DeleteFn func(ctx context.Context) error LocationFn func() string + HasherFn func() hash.Hash } // NewBackend returns new mock Backend instance @@ -46,6 +48,15 @@ func (m *Backend) Location() string { return m.LocationFn() } +// Hasher may return a hash function for calculating a content hash for the backend +func (m *Backend) Hasher() hash.Hash { + if m.HasherFn == nil { + return nil + } + + return m.HasherFn() +} + // IsNotExist returns true if the error is caused by a missing file. func (m *Backend) IsNotExist(err error) bool { if m.IsNotExistFn == nil { diff --git a/internal/pack/pack_test.go b/internal/pack/pack_test.go index 02413247f..c789e472b 100644 --- a/internal/pack/pack_test.go +++ b/internal/pack/pack_test.go @@ -127,7 +127,7 @@ func TestUnpackReadSeeker(t *testing.T) { id := restic.Hash(packData) handle := restic.Handle{Type: restic.PackFile, Name: id.String()} - rtest.OK(t, b.Save(context.TODO(), handle, restic.NewByteReader(packData))) + rtest.OK(t, b.Save(context.TODO(), handle, restic.NewByteReader(packData, b.Hasher()))) verifyBlobs(t, bufs, k, restic.ReaderAt(context.TODO(), b, handle), packSize) } @@ -140,6 +140,6 @@ func TestShortPack(t *testing.T) { id := restic.Hash(packData) handle := restic.Handle{Type: restic.PackFile, Name: id.String()} - rtest.OK(t, b.Save(context.TODO(), handle, restic.NewByteReader(packData))) + rtest.OK(t, b.Save(context.TODO(), handle, restic.NewByteReader(packData, b.Hasher()))) verifyBlobs(t, bufs, k, restic.ReaderAt(context.TODO(), b, handle), packSize) } diff --git a/internal/repository/key.go b/internal/repository/key.go index 2c60f59e8..5de154195 100644 --- a/internal/repository/key.go +++ b/internal/repository/key.go @@ -279,7 +279,7 @@ func AddKey(ctx context.Context, s *Repository, password, username, hostname str Name: restic.Hash(buf).String(), } - err = s.be.Save(ctx, h, restic.NewByteReader(buf)) + err = s.be.Save(ctx, h, restic.NewByteReader(buf, s.be.Hasher())) if err != nil { return nil, err } diff --git a/internal/repository/packer_manager.go b/internal/repository/packer_manager.go index 491f888bc..163d5b254 100644 --- a/internal/repository/packer_manager.go +++ b/internal/repository/packer_manager.go @@ -2,6 +2,8 @@ package repository import ( "context" + "hash" + "io" "os" "sync" @@ -20,12 +22,14 @@ import ( // Saver implements saving data in a backend. type Saver interface { Save(context.Context, restic.Handle, restic.RewindReader) error + Hasher() hash.Hash } // Packer holds a pack.Packer together with a hash writer. type Packer struct { *pack.Packer hw *hashing.Writer + beHw *hashing.Writer tmpfile *os.File } @@ -71,10 +75,19 @@ func (r *packerManager) findPacker() (packer *Packer, err error) { return nil, errors.Wrap(err, "fs.TempFile") } - hw := hashing.NewWriter(tmpfile, sha256.New()) + w := io.Writer(tmpfile) + beHasher := r.be.Hasher() + var beHw *hashing.Writer + if beHasher != nil { + beHw = hashing.NewWriter(w, beHasher) + w = beHw + } + + hw := hashing.NewWriter(w, sha256.New()) p := pack.NewPacker(r.key, hw) packer = &Packer{ Packer: p, + beHw: beHw, hw: hw, tmpfile: tmpfile, } @@ -101,8 +114,11 @@ func (r *Repository) savePacker(ctx context.Context, t restic.BlobType, p *Packe id := restic.IDFromHash(p.hw.Sum(nil)) h := restic.Handle{Type: restic.PackFile, Name: id.String()} - - rd, err := restic.NewFileReader(p.tmpfile) + var beHash []byte + if p.beHw != nil { + beHash = p.beHw.Sum(nil) + } + rd, err := restic.NewFileReader(p.tmpfile, beHash) if err != nil { return err } diff --git a/internal/repository/packer_manager_test.go b/internal/repository/packer_manager_test.go index 93cbf74b0..1a810ab61 100644 --- a/internal/repository/packer_manager_test.go +++ b/internal/repository/packer_manager_test.go @@ -33,11 +33,11 @@ func min(a, b int) int { return b } -func saveFile(t testing.TB, be Saver, length int, f *os.File, id restic.ID) { +func saveFile(t testing.TB, be Saver, length int, f *os.File, id restic.ID, hash []byte) { h := restic.Handle{Type: restic.PackFile, Name: id.String()} t.Logf("save file %v", h) - rd, err := restic.NewFileReader(f) + rd, err := restic.NewFileReader(f, hash) if err != nil { t.Fatal(err) } @@ -90,7 +90,11 @@ func fillPacks(t testing.TB, rnd *rand.Rand, be Saver, pm *packerManager, buf [] } packID := restic.IDFromHash(packer.hw.Sum(nil)) - saveFile(t, be, int(packer.Size()), packer.tmpfile, packID) + var beHash []byte + if packer.beHw != nil { + beHash = packer.beHw.Sum(nil) + } + saveFile(t, be, int(packer.Size()), packer.tmpfile, packID, beHash) } return bytes @@ -106,7 +110,11 @@ func flushRemainingPacks(t testing.TB, be Saver, pm *packerManager) (bytes int) bytes += int(n) packID := restic.IDFromHash(packer.hw.Sum(nil)) - saveFile(t, be, int(packer.Size()), packer.tmpfile, packID) + var beHash []byte + if packer.beHw != nil { + beHash = packer.beHw.Sum(nil) + } + saveFile(t, be, int(packer.Size()), packer.tmpfile, packID, beHash) } } diff --git a/internal/repository/repository.go b/internal/repository/repository.go index 4344fff2d..2901e768a 100644 --- a/internal/repository/repository.go +++ b/internal/repository/repository.go @@ -322,7 +322,7 @@ func (r *Repository) SaveUnpacked(ctx context.Context, t restic.FileType, p []by } h := restic.Handle{Type: t, Name: id.String()} - err = r.be.Save(ctx, h, restic.NewByteReader(ciphertext)) + err = r.be.Save(ctx, h, restic.NewByteReader(ciphertext, r.be.Hasher())) if err != nil { debug.Log("error saving blob %v: %v", h, err) return restic.ID{}, err diff --git a/internal/restic/backend.go b/internal/restic/backend.go index cda5c30c7..41292470a 100644 --- a/internal/restic/backend.go +++ b/internal/restic/backend.go @@ -2,6 +2,7 @@ package restic import ( "context" + "hash" "io" ) @@ -17,6 +18,9 @@ type Backend interface { // repository. Location() string + // Hasher may return a hash function for calculating a content hash for the backend + Hasher() hash.Hash + // Test a boolean value whether a File with the name and type exists. Test(ctx context.Context, h Handle) (bool, error) diff --git a/internal/restic/rewind_reader.go b/internal/restic/rewind_reader.go index acbb29678..20339f222 100644 --- a/internal/restic/rewind_reader.go +++ b/internal/restic/rewind_reader.go @@ -2,6 +2,7 @@ package restic import ( "bytes" + "hash" "io" "github.com/restic/restic/internal/errors" @@ -18,12 +19,16 @@ type RewindReader interface { // Length returns the number of bytes that can be read from the Reader // after calling Rewind. Length() int64 + + // Hash return a hash of the data if requested by the backed. + Hash() []byte } // ByteReader implements a RewindReader for a byte slice. type ByteReader struct { *bytes.Reader - Len int64 + Len int64 + hash []byte } // Rewind restarts the reader from the beginning of the data. @@ -38,14 +43,26 @@ func (b *ByteReader) Length() int64 { return b.Len } +// Hash return a hash of the data if requested by the backed. +func (b *ByteReader) Hash() []byte { + return b.hash +} + // statically ensure that *ByteReader implements RewindReader. var _ RewindReader = &ByteReader{} // NewByteReader prepares a ByteReader that can then be used to read buf. -func NewByteReader(buf []byte) *ByteReader { +func NewByteReader(buf []byte, hasher hash.Hash) *ByteReader { + var hash []byte + if hasher != nil { + // must never fail according to interface + _, _ = hasher.Write(buf) + hash = hasher.Sum(nil) + } return &ByteReader{ Reader: bytes.NewReader(buf), Len: int64(len(buf)), + hash: hash, } } @@ -55,7 +72,8 @@ var _ RewindReader = &FileReader{} // FileReader implements a RewindReader for an open file. type FileReader struct { io.ReadSeeker - Len int64 + Len int64 + hash []byte } // Rewind seeks to the beginning of the file. @@ -69,8 +87,13 @@ func (f *FileReader) Length() int64 { return f.Len } +// Hash return a hash of the data if requested by the backed. +func (f *FileReader) Hash() []byte { + return f.hash +} + // NewFileReader wraps f in a *FileReader. -func NewFileReader(f io.ReadSeeker) (*FileReader, error) { +func NewFileReader(f io.ReadSeeker, hash []byte) (*FileReader, error) { pos, err := f.Seek(0, io.SeekEnd) if err != nil { return nil, errors.Wrap(err, "Seek") @@ -79,6 +102,7 @@ func NewFileReader(f io.ReadSeeker) (*FileReader, error) { fr := &FileReader{ ReadSeeker: f, Len: pos, + hash: hash, } err = fr.Rewind() diff --git a/internal/restic/rewind_reader_test.go b/internal/restic/rewind_reader_test.go index 53f0a4424..0e15ee686 100644 --- a/internal/restic/rewind_reader_test.go +++ b/internal/restic/rewind_reader_test.go @@ -2,6 +2,8 @@ package restic import ( "bytes" + "crypto/md5" + "hash" "io" "io/ioutil" "math/rand" @@ -15,10 +17,12 @@ import ( func TestByteReader(t *testing.T) { buf := []byte("foobar") - fn := func() RewindReader { - return NewByteReader(buf) + for _, hasher := range []hash.Hash{nil, md5.New()} { + fn := func() RewindReader { + return NewByteReader(buf, hasher) + } + testRewindReader(t, fn, buf) } - testRewindReader(t, fn, buf) } func TestFileReader(t *testing.T) { @@ -28,7 +32,7 @@ func TestFileReader(t *testing.T) { defer cleanup() filename := filepath.Join(d, "file-reader-test") - err := ioutil.WriteFile(filename, []byte("foobar"), 0600) + err := ioutil.WriteFile(filename, buf, 0600) if err != nil { t.Fatal(err) } @@ -45,15 +49,23 @@ func TestFileReader(t *testing.T) { } }() - fn := func() RewindReader { - rd, err := NewFileReader(f) - if err != nil { - t.Fatal(err) + for _, hasher := range []hash.Hash{nil, md5.New()} { + fn := func() RewindReader { + var hash []byte + if hasher != nil { + // must never fail according to interface + _, _ = hasher.Write(buf) + hash = hasher.Sum(nil) + } + rd, err := NewFileReader(f, hash) + if err != nil { + t.Fatal(err) + } + return rd } - return rd - } - testRewindReader(t, fn, buf) + testRewindReader(t, fn, buf) + } } func testRewindReader(t *testing.T, fn func() RewindReader, data []byte) { @@ -104,6 +116,15 @@ func testRewindReader(t *testing.T, fn func() RewindReader, data []byte) { if rd.Length() != int64(len(data)) { t.Fatalf("wrong length returned, want %d, got %d", int64(len(data)), rd.Length()) } + + if rd.Hash() != nil { + hasher := md5.New() + // must never fail according to interface + _, _ = hasher.Write(buf2) + if !bytes.Equal(rd.Hash(), hasher.Sum(nil)) { + t.Fatal("hash does not match data") + } + } }, func(t testing.TB, rd RewindReader, data []byte) { // read first bytes