From 6b5b29dbee334d825dcc2d89c3de8b5661d2e9e0 Mon Sep 17 00:00:00 2001 From: Michael Eischer Date: Sun, 3 Jan 2021 01:23:22 +0100 Subject: [PATCH] limiter: add unit tests --- internal/limiter/limiter_backend_test.go | 109 +++++++++++++++++++++++ internal/limiter/static_limiter_test.go | 108 ++++++++++++++++++++++ 2 files changed, 217 insertions(+) create mode 100644 internal/limiter/limiter_backend_test.go create mode 100644 internal/limiter/static_limiter_test.go diff --git a/internal/limiter/limiter_backend_test.go b/internal/limiter/limiter_backend_test.go new file mode 100644 index 000000000..9bac9c70a --- /dev/null +++ b/internal/limiter/limiter_backend_test.go @@ -0,0 +1,109 @@ +package limiter + +import ( + "bytes" + "context" + "crypto/rand" + "fmt" + "io" + "testing" + + "github.com/restic/restic/internal/mock" + "github.com/restic/restic/internal/restic" + rtest "github.com/restic/restic/internal/test" +) + +func randomBytes(t *testing.T, size int) []byte { + data := make([]byte, size) + _, err := io.ReadFull(rand.Reader, data) + rtest.OK(t, err) + return data +} + +func TestLimitBackendSave(t *testing.T) { + testHandle := restic.Handle{Type: restic.PackFile, Name: "test"} + data := randomBytes(t, 1234) + + be := mock.NewBackend() + be.SaveFn = func(ctx context.Context, h restic.Handle, rd restic.RewindReader) error { + buf := new(bytes.Buffer) + _, err := io.Copy(buf, rd) + if err != nil { + return nil + } + if !bytes.Equal(data, buf.Bytes()) { + return fmt.Errorf("data mismatch") + } + return nil + } + limiter := NewStaticLimiter(42*1024, 42*1024) + limbe := LimitBackend(be, limiter) + + rd := restic.NewByteReader(data) + err := limbe.Save(context.TODO(), testHandle, rd) + rtest.OK(t, err) +} + +type tracedReadWriteToCloser struct { + io.Reader + io.WriterTo + Traced bool +} + +func newTracedReadWriteToCloser(rd *bytes.Reader) *tracedReadWriteToCloser { + return &tracedReadWriteToCloser{Reader: rd, WriterTo: rd} +} + +func (r *tracedReadWriteToCloser) WriteTo(w io.Writer) (n int64, err error) { + r.Traced = true + return r.WriterTo.WriteTo(w) +} + +func (r *tracedReadWriteToCloser) Close() error { + return nil +} + +func TestLimitBackendLoad(t *testing.T) { + testHandle := restic.Handle{Type: restic.PackFile, Name: "test"} + data := randomBytes(t, 1234) + + for _, test := range []struct { + innerWriteTo, outerWriteTo bool + }{{false, false}, {false, true}, {true, false}, {true, true}} { + be := mock.NewBackend() + src := newTracedReadWriteToCloser(bytes.NewReader(data)) + be.OpenReaderFn = func(ctx context.Context, h restic.Handle, length int, offset int64) (io.ReadCloser, error) { + if length != 0 || offset != 0 { + return nil, fmt.Errorf("Not supported") + } + // test both code paths in WriteTo of limitedReadCloser + if test.innerWriteTo { + return src, nil + } + return newTracedReadCloser(src), nil + } + limiter := NewStaticLimiter(42*1024, 42*1024) + limbe := LimitBackend(be, limiter) + + err := limbe.Load(context.TODO(), testHandle, 0, 0, func(rd io.Reader) error { + dataRead := new(bytes.Buffer) + // test both Read and WriteTo + if !test.outerWriteTo { + rd = newTracedReadCloser(rd) + } + _, err := io.Copy(dataRead, rd) + if err != nil { + return err + } + if !bytes.Equal(data, dataRead.Bytes()) { + return fmt.Errorf("read broken data") + } + + return nil + }) + rtest.OK(t, err) + rtest.Assert(t, src.Traced == (test.innerWriteTo && test.outerWriteTo), + "unexpected/missing writeTo call innerWriteTo %v outerWriteTo %v", + test.innerWriteTo, test.outerWriteTo) + } +} diff --git a/internal/limiter/static_limiter_test.go b/internal/limiter/static_limiter_test.go new file mode 100644 index 000000000..bd3c62ccb --- /dev/null +++ b/internal/limiter/static_limiter_test.go @@ -0,0 +1,108 @@ +package limiter + +import ( + "bytes" + "crypto/rand" + "fmt" + "io" + "net/http" + "testing" + + "github.com/restic/restic/internal/test" +) + +func TestLimiterWrapping(t *testing.T) { + reader := bytes.NewReader([]byte{}) + writer := new(bytes.Buffer) + + for _, limits := range []struct { + upstream int + downstream int + }{ + {0, 0}, + {42, 0}, + {0, 42}, + {42, 42}, + } { + limiter := NewStaticLimiter(limits.upstream*1024, limits.downstream*1024) + + mustWrapUpstream := limits.upstream > 0 + test.Equals(t, limiter.Upstream(reader) != reader, mustWrapUpstream) + test.Equals(t, limiter.UpstreamWriter(writer) != writer, mustWrapUpstream) + + mustWrapDownstream := limits.downstream > 0 + test.Equals(t, limiter.Downstream(reader) != reader, mustWrapDownstream) + test.Equals(t, limiter.DownstreamWriter(writer) != writer, mustWrapDownstream) + } +} + +type tracedReadCloser struct { + io.Reader + Closed bool +} + +func newTracedReadCloser(rd io.Reader) *tracedReadCloser { + return &tracedReadCloser{Reader: rd} +} + +func (r *tracedReadCloser) Close() error { + r.Closed = true + return nil +} + +func TestRoundTripperReader(t *testing.T) { + limiter := NewStaticLimiter(42*1024, 42*1024) + data := make([]byte, 1234) + _, err := io.ReadFull(rand.Reader, data) + test.OK(t, err) + + var send *tracedReadCloser = newTracedReadCloser(bytes.NewReader(data)) + var recv *tracedReadCloser + + rt := limiter.Transport(roundTripper(func(req *http.Request) (*http.Response, error) { + buf := new(bytes.Buffer) + _, err := io.Copy(buf, req.Body) + if err != nil { + return nil, err + } + err = req.Body.Close() + if err != nil { + return nil, err + } + + recv = newTracedReadCloser(bytes.NewReader(buf.Bytes())) + return &http.Response{Body: recv}, nil + })) + + res, err := rt.RoundTrip(&http.Request{Body: send}) + test.OK(t, err) + + out := new(bytes.Buffer) + n, err := io.Copy(out, res.Body) + test.OK(t, err) + test.Equals(t, int64(len(data)), n) + test.OK(t, res.Body.Close()) + + test.Assert(t, send.Closed, "request body not closed") + test.Assert(t, recv.Closed, "result body not closed") + test.Assert(t, bytes.Equal(data, out.Bytes()), "data ping-pong failed") +} + +func TestRoundTripperCornerCases(t *testing.T) { + limiter := NewStaticLimiter(42*1024, 42*1024) + + rt := limiter.Transport(roundTripper(func(req *http.Request) (*http.Response, error) { + return &http.Response{}, nil + })) + + res, err := rt.RoundTrip(&http.Request{}) + test.OK(t, err) + test.Assert(t, res != nil, "round tripper returned no response") + + rt = limiter.Transport(roundTripper(func(req *http.Request) (*http.Response, error) { + return nil, fmt.Errorf("error") + })) + + _, err = rt.RoundTrip(&http.Request{}) + test.Assert(t, err != nil, "round tripper lost an error") +}