From bcb852a8d0a674c5131668baae0826578f2e588a Mon Sep 17 00:00:00 2001 From: Michael Eischer Date: Fri, 1 Jan 2021 12:42:33 +0100 Subject: [PATCH] hashing: support WriteTo in the reader --- internal/hashing/reader.go | 46 +++++++++++++++++------ internal/hashing/reader_test.go | 66 ++++++++++++++++++++++++++------- 2 files changed, 87 insertions(+), 25 deletions(-) diff --git a/internal/hashing/reader.go b/internal/hashing/reader.go index a499f4a63..ea45dcd24 100644 --- a/internal/hashing/reader.go +++ b/internal/hashing/reader.go @@ -5,25 +5,47 @@ import ( "io" ) -// Reader hashes all data read from the underlying reader. -type Reader struct { - r io.Reader +// ReadSumer hashes all data read from the underlying reader. +type ReadSumer interface { + io.Reader + // Sum returns the hash of the data read so far. + Sum(d []byte) []byte +} + +type reader struct { + io.Reader h hash.Hash } -// NewReader returns a new Reader that uses the hash h. -func NewReader(r io.Reader, h hash.Hash) *Reader { - return &Reader{ - h: h, - r: io.TeeReader(r, h), - } +type readWriterTo struct { + reader + writerTo io.WriterTo } -func (h *Reader) Read(p []byte) (int, error) { - return h.r.Read(p) +// NewReader returns a new ReadSummer that uses the hash h. If the underlying +// reader supports WriteTo then the returned reader will do so too. +func NewReader(r io.Reader, h hash.Hash) ReadSumer { + rs := reader{ + Reader: io.TeeReader(r, h), + h: h, + } + + if _, ok := r.(io.WriterTo); ok { + return &readWriterTo{ + reader: rs, + writerTo: r.(io.WriterTo), + } + } + + return &rs } // Sum returns the hash of the data read so far. -func (h *Reader) Sum(d []byte) []byte { +func (h *reader) Sum(d []byte) []byte { return h.h.Sum(d) } + +// WriteTo reads all data into the passed writer +func (h *readWriterTo) WriteTo(w io.Writer) (int64, error) { + return h.writerTo.WriteTo(NewWriter(w, h.h)) +} diff --git a/internal/hashing/reader_test.go b/internal/hashing/reader_test.go index d17f264de..d7bdc2e02 100644 --- a/internal/hashing/reader_test.go +++ b/internal/hashing/reader_test.go @@ -7,8 +7,26 @@ import ( "io" "io/ioutil" "testing" + + rtest "github.com/restic/restic/internal/test" ) +// only expose Read method +type onlyReader struct { + io.Reader +} + +type traceWriterTo struct { + io.Reader + writerTo io.WriterTo + Traced bool +} + +func (r *traceWriterTo) WriteTo(w io.Writer) (n int64, err error) { + r.Traced = true + return r.writerTo.WriteTo(w) +} + func TestReader(t *testing.T) { tests := []int{5, 23, 2<<18 + 23, 1 << 20} @@ -21,22 +39,44 @@ func TestReader(t *testing.T) { expectedHash := sha256.Sum256(data) - rd := NewReader(bytes.NewReader(data), sha256.New()) - n, err := io.Copy(ioutil.Discard, rd) - if err != nil { - t.Fatal(err) - } + for _, test := range []struct { + innerWriteTo, outerWriteTo bool + }{{false, false}, {false, true}, {true, false}, {true, true}} { + // test both code paths in WriteTo + src := bytes.NewReader(data) + rawSrc := &traceWriterTo{Reader: src, writerTo: src} + innerSrc := io.Reader(rawSrc) + if !test.innerWriteTo { + innerSrc = &onlyReader{Reader: rawSrc} + } - if n != int64(size) { - t.Errorf("Reader: invalid number of bytes written: got %d, expected %d", - n, size) - } + rd := NewReader(innerSrc, sha256.New()) + // test both Read and WriteTo + outerSrc := io.Reader(rd) + if !test.outerWriteTo { + outerSrc = &onlyReader{Reader: outerSrc} + } - resultingHash := rd.Sum(nil) + n, err := io.Copy(ioutil.Discard, outerSrc) + if err != nil { + t.Fatal(err) + } - if !bytes.Equal(expectedHash[:], resultingHash) { - t.Errorf("Reader: hashes do not match: expected %02x, got %02x", - expectedHash, resultingHash) + if n != int64(size) { + t.Errorf("Reader: invalid number of bytes written: got %d, expected %d", + n, size) + } + + resultingHash := rd.Sum(nil) + + if !bytes.Equal(expectedHash[:], resultingHash) { + t.Errorf("Reader: hashes do not match: expected %02x, got %02x", + expectedHash, resultingHash) + } + + rtest.Assert(t, rawSrc.Traced == (test.innerWriteTo && test.outerWriteTo), + "unexpected/missing writeTo call innerWriteTo %v outerWriteTo %v", + test.innerWriteTo, test.outerWriteTo) } } }