diff --git a/repository/index_test.go b/repository/index_test.go index 66b7af41c..a796e35e3 100644 --- a/repository/index_test.go +++ b/repository/index_test.go @@ -3,7 +3,9 @@ package repository_test import ( "bytes" "crypto/rand" + "fmt" "io" + "path/filepath" "testing" "github.com/restic/restic/backend" diff --git a/repository/repository.go b/repository/repository.go index 8c3d81fa9..c30cce0b7 100644 --- a/repository/repository.go +++ b/repository/repository.go @@ -456,44 +456,82 @@ func (r *Repository) SetIndex(i *Index) { r.idx = i } -// SaveIndex saves all new packs in the index in the backend, returned is the -// storage ID. -func (r *Repository) SaveIndex() (backend.ID, error) { - debug.Log("Repo.SaveIndex", "Saving index") +// BlobWriter encrypts and saves the data written to it in a backend. After +// Close() was called, ID() returns the backend.ID. +type BlobWriter struct { + id backend.ID + blob backend.Blob + hw *backend.HashingWriter + ewr io.WriteCloser + t backend.Type + closed bool +} - // create blob +// CreateEncryptedBlob returns a BlobWriter that encrypts and saves the data +// written to it in the backend. After Close() was called, ID() returns the +// backend.ID. +func (r *Repository) CreateEncryptedBlob(t backend.Type) (*BlobWriter, error) { blob, err := r.be.Create() if err != nil { - return backend.ID{}, err + return nil, err } - debug.Log("Repo.SaveIndex", "create new pack %p", blob) - // hash hw := backend.NewHashingWriter(blob, sha256.New()) // encrypt blob ewr := crypto.EncryptTo(r.key, hw) - err = r.idx.Encode(ewr) + return &BlobWriter{t: t, blob: blob, hw: hw, ewr: ewr}, nil +} + +func (bw *BlobWriter) Write(buf []byte) (int, error) { + return bw.ewr.Write(buf) +} + +// Close finalizes the blob in the backend, afterwards ID() can be used to retrieve the ID. +func (bw *BlobWriter) Close() error { + if bw.closed { + return errors.New("BlobWriter already closed") + } + bw.closed = true + + err := bw.ewr.Close() + if err != nil { + return err + } + + copy(bw.id[:], bw.hw.Sum(nil)) + return bw.blob.Finalize(bw.t, bw.id.String()) +} + +// ID returns the Id the blob has been written to after Close() was called. +func (bw *BlobWriter) ID() backend.ID { + return bw.id +} + +// SaveIndex saves all new packs in the index in the backend, returned is the +// storage ID. +func (r *Repository) SaveIndex() (backend.ID, error) { + debug.Log("Repo.SaveIndex", "Saving index") + + blob, err := r.CreateEncryptedBlob(backend.Index) if err != nil { return backend.ID{}, err } - err = ewr.Close() + err = r.idx.Encode(blob) if err != nil { return backend.ID{}, err } - // finalize blob in the backend - sid := backend.ID{} - copy(sid[:], hw.Sum(nil)) - - err = blob.Finalize(backend.Index, sid.String()) + err = blob.Close() if err != nil { return backend.ID{}, err } + sid := blob.ID() + debug.Log("Repo.SaveIndex", "Saved index as %v", sid.Str()) return sid, nil @@ -554,28 +592,61 @@ func LoadIndex(repo *Repository, id string) (*Index, error) { return nil, err } +// decryptReadCloser couples an underlying reader with a DecryptReader and +// implements io.ReadCloser. On Close(), both readers are closed. +type decryptReadCloser struct { + r io.ReadCloser + dr io.ReadCloser +} + +func newDecryptReadCloser(key *crypto.Key, rd io.ReadCloser) (io.ReadCloser, error) { + dr, err := crypto.DecryptFrom(key, rd) + if err != nil { + return nil, err + } + + return &decryptReadCloser{r: rd, dr: dr}, nil +} + +func (dr *decryptReadCloser) Read(buf []byte) (int, error) { + return dr.dr.Read(buf) +} + +func (dr *decryptReadCloser) Close() error { + err := dr.dr.Close() + if err != nil { + return err + } + + return dr.r.Close() +} + +// GetDecryptReader opens the file id stored in the backend and returns a +// reader that yields the decrypted content. The reader must be closed. +func (r *Repository) GetDecryptReader(t backend.Type, id string) (io.ReadCloser, error) { + rd, err := r.be.Get(t, id) + if err != nil { + return nil, err + } + + return newDecryptReadCloser(r.key, rd) +} + func loadIndex(repo *Repository, id string, oldFormat bool) (*Index, error) { debug.Log("loadIndex", "Loading index %v", id[:8]) - rd, err := repo.be.Get(backend.Index, id) + rd, err := repo.GetDecryptReader(backend.Index, id) + if err != nil { + return nil, err + } defer rd.Close() - if err != nil { - return nil, err - } - - // decrypt - decryptRd, err := crypto.DecryptFrom(repo.key, rd) - defer decryptRd.Close() - if err != nil { - return nil, err - } var idx *Index if !oldFormat { - idx, _, err = DecodeIndex(decryptRd) + idx, _, err = DecodeIndex(rd) } else { - idx, _, err = DecodeOldIndex(decryptRd) + idx, _, err = DecodeOldIndex(rd) } if err != nil {