From f750aa8dfbfa5451591f5f774feedaf64a4be4bd Mon Sep 17 00:00:00 2001 From: Michael Eischer Date: Sun, 1 Oct 2023 22:45:18 +0200 Subject: [PATCH] backend/limiter: replace juju/ratelimit with x/time/rate --- go.mod | 2 +- go.sum | 4 +- internal/backend/limiter/static_limiter.go | 60 +++++++++++++++---- .../backend/limiter/static_limiter_test.go | 33 ++++++++++ 4 files changed, 85 insertions(+), 14 deletions(-) diff --git a/go.mod b/go.mod index 269fa2b41..9b230ca69 100644 --- a/go.mod +++ b/go.mod @@ -12,7 +12,6 @@ require ( github.com/go-ole/go-ole v1.3.0 github.com/google/go-cmp v0.5.9 github.com/hashicorp/golang-lru/v2 v2.0.7 - github.com/juju/ratelimit v1.0.2 github.com/klauspost/compress v1.16.7 github.com/kurin/blazer v0.5.4-0.20230113224640-3887e1ec64b5 github.com/minio/minio-go/v7 v7.0.61 @@ -33,6 +32,7 @@ require ( golang.org/x/sys v0.12.0 golang.org/x/term v0.12.0 golang.org/x/text v0.13.0 + golang.org/x/time v0.3.0 google.golang.org/api v0.138.0 ) diff --git a/go.sum b/go.sum index 8bd30432a..64fc2af39 100644 --- a/go.sum +++ b/go.sum @@ -120,8 +120,6 @@ github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2 github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= -github.com/juju/ratelimit v1.0.2 h1:sRxmtRiajbvrcLQT7S+JbqU0ntsb9W2yhSdNN8tWfaI= -github.com/juju/ratelimit v1.0.2/go.mod h1:qapgC/Gy+xNh9UxzV13HGGl/6UXNN+ct+vwSgWNm/qk= github.com/klauspost/compress v1.16.7 h1:2mk3MPGNzKyxErAw8YaohYh69+pa4sIQSC0fPGCFR9I= github.com/klauspost/compress v1.16.7/go.mod h1:ntbaceVETuRiXiv4DpjP66DpAtAGkEQskQzEyD//IeE= github.com/klauspost/cpuid/v2 v2.0.1/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= @@ -271,6 +269,8 @@ golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/text v0.3.8/go.mod h1:E6s5w1FMmriuDzIBO73fBruAKo1PCIq6d2Q6DHfQ8WQ= golang.org/x/text v0.13.0 h1:ablQoSUd0tRdKxZewP80B+BaqeKJuVhuRxj/dkrun3k= golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= +golang.org/x/time v0.3.0 h1:rg5rLMjNzMS1RkNLzCG38eapWhnYLFYXDXj2gOlr8j4= +golang.org/x/time v0.3.0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY= diff --git a/internal/backend/limiter/static_limiter.go b/internal/backend/limiter/static_limiter.go index 9fb8fbf24..c09a12de0 100644 --- a/internal/backend/limiter/static_limiter.go +++ b/internal/backend/limiter/static_limiter.go @@ -1,15 +1,16 @@ package limiter import ( + "context" "io" "net/http" - "github.com/juju/ratelimit" + "golang.org/x/time/rate" ) type staticLimiter struct { - upstream *ratelimit.Bucket - downstream *ratelimit.Bucket + upstream *rate.Limiter + downstream *rate.Limiter } // Limits represents static upload and download limits. @@ -23,16 +24,16 @@ type Limits struct { // download rate cap func NewStaticLimiter(l Limits) Limiter { var ( - upstreamBucket *ratelimit.Bucket - downstreamBucket *ratelimit.Bucket + upstreamBucket *rate.Limiter + downstreamBucket *rate.Limiter ) if l.UploadKb > 0 { - upstreamBucket = ratelimit.NewBucketWithRate(toByteRate(l.UploadKb), int64(toByteRate(l.UploadKb))) + upstreamBucket = rate.NewLimiter(rate.Limit(toByteRate(l.UploadKb)), int(toByteRate(l.UploadKb))) } if l.DownloadKb > 0 { - downstreamBucket = ratelimit.NewBucketWithRate(toByteRate(l.DownloadKb), int64(toByteRate(l.DownloadKb))) + downstreamBucket = rate.NewLimiter(rate.Limit(toByteRate(l.DownloadKb)), int(toByteRate(l.DownloadKb))) } return staticLimiter{ @@ -95,18 +96,55 @@ func (l staticLimiter) Transport(rt http.RoundTripper) http.RoundTripper { }) } -func (l staticLimiter) limitReader(r io.Reader, b *ratelimit.Bucket) io.Reader { +func (l staticLimiter) limitReader(r io.Reader, b *rate.Limiter) io.Reader { if b == nil { return r } - return ratelimit.Reader(r, b) + return &rateLimitedReader{r, b} } -func (l staticLimiter) limitWriter(w io.Writer, b *ratelimit.Bucket) io.Writer { +type rateLimitedReader struct { + reader io.Reader + bucket *rate.Limiter +} + +func (r *rateLimitedReader) Read(p []byte) (int, error) { + n, err := r.reader.Read(p) + if err := consumeTokens(n, r.bucket); err != nil { + return n, err + } + return n, err +} + +func (l staticLimiter) limitWriter(w io.Writer, b *rate.Limiter) io.Writer { if b == nil { return w } - return ratelimit.Writer(w, b) + return &rateLimitedWriter{w, b} +} + +type rateLimitedWriter struct { + writer io.Writer + bucket *rate.Limiter +} + +func (w *rateLimitedWriter) Write(buf []byte) (int, error) { + if err := consumeTokens(len(buf), w.bucket); err != nil { + return 0, err + } + return w.writer.Write(buf) +} + +func consumeTokens(tokens int, bucket *rate.Limiter) error { + // bucket allows waiting for at most Burst() tokens at once + maxWait := bucket.Burst() + for tokens > maxWait { + if err := bucket.WaitN(context.Background(), maxWait); err != nil { + return err + } + tokens -= maxWait + } + return bucket.WaitN(context.Background(), tokens) } func toByteRate(val int) float64 { diff --git a/internal/backend/limiter/static_limiter_test.go b/internal/backend/limiter/static_limiter_test.go index 564b6a00a..8a839518f 100644 --- a/internal/backend/limiter/static_limiter_test.go +++ b/internal/backend/limiter/static_limiter_test.go @@ -9,6 +9,7 @@ import ( "testing" "github.com/restic/restic/internal/test" + "golang.org/x/time/rate" ) func TestLimiterWrapping(t *testing.T) { @@ -33,6 +34,38 @@ func TestLimiterWrapping(t *testing.T) { } } +func TestReadLimiter(t *testing.T) { + reader := bytes.NewReader(make([]byte, 300)) + limiter := rate.NewLimiter(rate.Limit(10000), int(100)) + limReader := rateLimitedReader{reader, limiter} + + n, err := limReader.Read([]byte{}) + test.OK(t, err) + test.Equals(t, n, 0) + + n, err = limReader.Read(make([]byte, 300)) + test.OK(t, err) + test.Equals(t, n, 300) + + n, err = limReader.Read([]byte{}) + test.Equals(t, err, io.EOF) + test.Equals(t, n, 0) +} + +func TestWriteLimiter(t *testing.T) { + writer := &bytes.Buffer{} + limiter := rate.NewLimiter(rate.Limit(10000), int(100)) + limReader := rateLimitedWriter{writer, limiter} + + n, err := limReader.Write([]byte{}) + test.OK(t, err) + test.Equals(t, n, 0) + + n, err = limReader.Write(make([]byte, 300)) + test.OK(t, err) + test.Equals(t, n, 300) +} + type tracedReadCloser struct { io.Reader Closed bool