package limiter import ( "bytes" "crypto/rand" "fmt" "io" "net/http" "testing" "github.com/restic/restic/internal/test" ) func TestLimiterWrapping(t *testing.T) { reader := bytes.NewReader([]byte{}) writer := new(bytes.Buffer) for _, limits := range []struct { upstream int downstream int }{ {0, 0}, {42, 0}, {0, 42}, {42, 42}, } { limiter := NewStaticLimiter(limits.upstream*1024, limits.downstream*1024) mustWrapUpstream := limits.upstream > 0 test.Equals(t, limiter.Upstream(reader) != reader, mustWrapUpstream) test.Equals(t, limiter.UpstreamWriter(writer) != writer, mustWrapUpstream) mustWrapDownstream := limits.downstream > 0 test.Equals(t, limiter.Downstream(reader) != reader, mustWrapDownstream) test.Equals(t, limiter.DownstreamWriter(writer) != writer, mustWrapDownstream) } } type tracedReadCloser struct { io.Reader Closed bool } func newTracedReadCloser(rd io.Reader) *tracedReadCloser { return &tracedReadCloser{Reader: rd} } func (r *tracedReadCloser) Close() error { r.Closed = true return nil } func TestRoundTripperReader(t *testing.T) { limiter := NewStaticLimiter(42*1024, 42*1024) data := make([]byte, 1234) _, err := io.ReadFull(rand.Reader, data) test.OK(t, err) send := newTracedReadCloser(bytes.NewReader(data)) var recv *tracedReadCloser rt := limiter.Transport(roundTripper(func(req *http.Request) (*http.Response, error) { buf := new(bytes.Buffer) _, err := io.Copy(buf, req.Body) if err != nil { return nil, err } err = req.Body.Close() if err != nil { return nil, err } recv = newTracedReadCloser(bytes.NewReader(buf.Bytes())) return &http.Response{Body: recv}, nil })) res, err := rt.RoundTrip(&http.Request{Body: send}) test.OK(t, err) out := new(bytes.Buffer) n, err := io.Copy(out, res.Body) test.OK(t, err) test.Equals(t, int64(len(data)), n) test.OK(t, res.Body.Close()) test.Assert(t, send.Closed, "request body not closed") test.Assert(t, recv.Closed, "result body not closed") test.Assert(t, bytes.Equal(data, out.Bytes()), "data ping-pong failed") } func TestRoundTripperCornerCases(t *testing.T) { limiter := NewStaticLimiter(42*1024, 42*1024) rt := limiter.Transport(roundTripper(func(req *http.Request) (*http.Response, error) { return &http.Response{}, nil })) res, err := rt.RoundTrip(&http.Request{}) test.OK(t, err) test.Assert(t, res != nil, "round tripper returned no response") rt = limiter.Transport(roundTripper(func(req *http.Request) (*http.Response, error) { return nil, fmt.Errorf("error") })) _, err = rt.RoundTrip(&http.Request{}) test.Assert(t, err != nil, "round tripper lost an error") }