From ae170e2b38d3de7ee8b8c65e363dbd9184a8ff07 Mon Sep 17 00:00:00 2001 From: greatroar <@> Date: Sat, 24 Apr 2021 11:46:06 +0200 Subject: [PATCH] Simplify internal/limiter --- internal/limiter/limiter_backend.go | 37 +++++++++-------------------- internal/limiter/static_limiter.go | 17 ++++++++----- 2 files changed, 22 insertions(+), 32 deletions(-) diff --git a/internal/limiter/limiter_backend.go b/internal/limiter/limiter_backend.go index d074a5a0e..f1b508327 100644 --- a/internal/limiter/limiter_backend.go +++ b/internal/limiter/limiter_backend.go @@ -42,44 +42,29 @@ func (l limitedRewindReader) Read(b []byte) (int, error) { func (r rateLimitedBackend) Load(ctx context.Context, h restic.Handle, length int, offset int64, consumer func(rd io.Reader) error) error { return r.Backend.Load(ctx, h, length, offset, func(rd io.Reader) error { - return consumer(newDownstreamLimitedReadCloser(rd, r.limiter, nil)) + return consumer(newDownstreamLimitedReader(rd, r.limiter)) }) } -type limitedReadCloser struct { +type limitedReader struct { io.Reader - original io.ReadCloser -} - -type limitedReadWriteToCloser struct { - limitedReadCloser writerTo io.WriterTo limiter Limiter } -func newDownstreamLimitedReadCloser(rd io.Reader, limiter Limiter, original io.ReadCloser) io.ReadCloser { - lrd := limitedReadCloser{ - Reader: limiter.Downstream(rd), - original: original, - } - if _, ok := rd.(io.WriterTo); ok { - return &limitedReadWriteToCloser{ - limitedReadCloser: lrd, - writerTo: rd.(io.WriterTo), - limiter: limiter, +func newDownstreamLimitedReader(rd io.Reader, limiter Limiter) io.Reader { + lrd := limiter.Downstream(rd) + if wt, ok := rd.(io.WriterTo); ok { + lrd = &limitedReader{ + Reader: lrd, + writerTo: wt, + limiter: limiter, } } - return &lrd + return lrd } -func (l limitedReadCloser) Close() error { - if l.original == nil { - return nil - } - return l.original.Close() -} - -func (l limitedReadWriteToCloser) WriteTo(w io.Writer) (int64, error) { +func (l *limitedReader) WriteTo(w io.Writer) (int64, error) { return l.writerTo.WriteTo(l.limiter.DownstreamWriter(w)) } diff --git a/internal/limiter/static_limiter.go b/internal/limiter/static_limiter.go index e9b2b8285..04a9b9d34 100644 --- a/internal/limiter/static_limiter.go +++ b/internal/limiter/static_limiter.go @@ -57,19 +57,24 @@ func (rt roundTripper) RoundTrip(req *http.Request) (*http.Response, error) { } func (l staticLimiter) roundTripper(rt http.RoundTripper, req *http.Request) (*http.Response, error) { + type readCloser struct { + io.Reader + io.Closer + } + if req.Body != nil { - req.Body = limitedReadCloser{ - Reader: l.Upstream(req.Body), - original: req.Body, + req.Body = &readCloser{ + Reader: l.Upstream(req.Body), + Closer: req.Body, } } res, err := rt.RoundTrip(req) if res != nil && res.Body != nil { - res.Body = limitedReadCloser{ - Reader: l.Downstream(res.Body), - original: res.Body, + res.Body = &readCloser{ + Reader: l.Downstream(res.Body), + Closer: res.Body, } }