package limiter import ( "context" "io" "net/http" "golang.org/x/time/rate" ) type staticLimiter struct { upstream *rate.Limiter downstream *rate.Limiter } // Limits represents static upload and download limits. // For both, zero means unlimited. type Limits struct { UploadKb int DownloadKb int } // NewStaticLimiter constructs a Limiter with a fixed (static) upload and // download rate cap func NewStaticLimiter(l Limits) Limiter { var ( upstreamBucket *rate.Limiter downstreamBucket *rate.Limiter ) if l.UploadKb > 0 { upstreamBucket = rate.NewLimiter(rate.Limit(toByteRate(l.UploadKb)), int(toByteRate(l.UploadKb))) } if l.DownloadKb > 0 { downstreamBucket = rate.NewLimiter(rate.Limit(toByteRate(l.DownloadKb)), int(toByteRate(l.DownloadKb))) } return staticLimiter{ upstream: upstreamBucket, downstream: downstreamBucket, } } func (l staticLimiter) Upstream(r io.Reader) io.Reader { return l.limitReader(r, l.upstream) } func (l staticLimiter) UpstreamWriter(w io.Writer) io.Writer { return l.limitWriter(w, l.upstream) } func (l staticLimiter) Downstream(r io.Reader) io.Reader { return l.limitReader(r, l.downstream) } func (l staticLimiter) DownstreamWriter(w io.Writer) io.Writer { return l.limitWriter(w, l.downstream) } type roundTripper func(*http.Request) (*http.Response, error) func (rt roundTripper) RoundTrip(req *http.Request) (*http.Response, error) { return rt(req) } func (l staticLimiter) roundTripper(rt http.RoundTripper, req *http.Request) (*http.Response, error) { type readCloser struct { io.Reader io.Closer } if req.Body != nil { req.Body = &readCloser{ Reader: l.Upstream(req.Body), Closer: req.Body, } } res, err := rt.RoundTrip(req) if res != nil && res.Body != nil { res.Body = &readCloser{ Reader: l.Downstream(res.Body), Closer: res.Body, } } return res, err } // Transport returns an HTTP transport limited with the limiter l. func (l staticLimiter) Transport(rt http.RoundTripper) http.RoundTripper { return roundTripper(func(req *http.Request) (*http.Response, error) { return l.roundTripper(rt, req) }) } func (l staticLimiter) limitReader(r io.Reader, b *rate.Limiter) io.Reader { if b == nil { return r } return &rateLimitedReader{r, b} } type rateLimitedReader struct { reader io.Reader bucket *rate.Limiter } func (r *rateLimitedReader) Read(p []byte) (int, error) { n, err := r.reader.Read(p) if err := consumeTokens(n, r.bucket); err != nil { return n, err } return n, err } func (l staticLimiter) limitWriter(w io.Writer, b *rate.Limiter) io.Writer { if b == nil { return w } return &rateLimitedWriter{w, b} } type rateLimitedWriter struct { writer io.Writer bucket *rate.Limiter } func (w *rateLimitedWriter) Write(buf []byte) (int, error) { if err := consumeTokens(len(buf), w.bucket); err != nil { return 0, err } return w.writer.Write(buf) } func consumeTokens(tokens int, bucket *rate.Limiter) error { // bucket allows waiting for at most Burst() tokens at once maxWait := bucket.Burst() for tokens > maxWait { if err := bucket.WaitN(context.Background(), maxWait); err != nil { return err } tokens -= maxWait } return bucket.WaitN(context.Background(), tokens) } func toByteRate(val int) float64 { return float64(val) * 1024. }