2
2
mirror of https://github.com/octoleo/restic.git synced 2025-01-10 01:37:48 +00:00
restic/internal/limiter/static_limiter.go

108 lines
2.3 KiB
Go
Raw Normal View History

package limiter
import (
"io"
2017-12-29 11:43:49 +00:00
"net/http"
"github.com/juju/ratelimit"
)
type staticLimiter struct {
upstream *ratelimit.Bucket
downstream *ratelimit.Bucket
}
// NewStaticLimiter constructs a Limiter with a fixed (static) upload and
// download rate cap
func NewStaticLimiter(uploadKb, downloadKb int) Limiter {
var (
upstreamBucket *ratelimit.Bucket
downstreamBucket *ratelimit.Bucket
)
if uploadKb > 0 {
upstreamBucket = ratelimit.NewBucketWithRate(toByteRate(uploadKb), int64(toByteRate(uploadKb)))
}
if downloadKb > 0 {
downstreamBucket = ratelimit.NewBucketWithRate(toByteRate(downloadKb), int64(toByteRate(downloadKb)))
}
return staticLimiter{
upstream: upstreamBucket,
downstream: downstreamBucket,
}
}
func (l staticLimiter) Upstream(r io.Reader) io.Reader {
2018-05-22 18:48:17 +00:00
return l.limitReader(r, l.upstream)
}
func (l staticLimiter) UpstreamWriter(w io.Writer) io.Writer {
return l.limitWriter(w, l.upstream)
}
func (l staticLimiter) Downstream(r io.Reader) io.Reader {
2018-05-22 18:48:17 +00:00
return l.limitReader(r, l.downstream)
}
func (l staticLimiter) DownstreamWriter(w io.Writer) io.Writer {
return l.limitWriter(w, l.downstream)
}
2017-12-29 11:43:49 +00:00
type roundTripper func(*http.Request) (*http.Response, error)
func (rt roundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
return rt(req)
}
func (l staticLimiter) roundTripper(rt http.RoundTripper, req *http.Request) (*http.Response, error) {
2021-04-24 09:46:06 +00:00
type readCloser struct {
io.Reader
io.Closer
}
2017-12-29 11:43:49 +00:00
if req.Body != nil {
2021-04-24 09:46:06 +00:00
req.Body = &readCloser{
Reader: l.Upstream(req.Body),
Closer: req.Body,
2017-12-29 11:43:49 +00:00
}
}
res, err := rt.RoundTrip(req)
if res != nil && res.Body != nil {
2021-04-24 09:46:06 +00:00
res.Body = &readCloser{
Reader: l.Downstream(res.Body),
Closer: res.Body,
2017-12-29 11:43:49 +00:00
}
}
return res, err
}
// Transport returns an HTTP transport limited with the limiter l.
func (l staticLimiter) Transport(rt http.RoundTripper) http.RoundTripper {
return roundTripper(func(req *http.Request) (*http.Response, error) {
return l.roundTripper(rt, req)
})
}
2018-05-22 18:48:17 +00:00
func (l staticLimiter) limitReader(r io.Reader, b *ratelimit.Bucket) io.Reader {
if b == nil {
return r
}
return ratelimit.Reader(r, b)
}
2018-05-22 18:48:17 +00:00
func (l staticLimiter) limitWriter(w io.Writer, b *ratelimit.Bucket) io.Writer {
if b == nil {
return w
}
return ratelimit.Writer(w, b)
}
func toByteRate(val int) float64 {
return float64(val) * 1024.
}