mirror of
https://github.com/octoleo/restic.git
synced 2024-11-22 21:05:10 +00:00
Improve limiting HTTP based backends
This commit is contained in:
parent
c686dd0448
commit
18ecd9df30
@ -323,16 +323,11 @@ func OpenRepository(opts GlobalOptions) (*repository.Repository, error) {
|
|||||||
return nil, errors.Fatal("Please specify repository location (-r)")
|
return nil, errors.Fatal("Please specify repository location (-r)")
|
||||||
}
|
}
|
||||||
|
|
||||||
be, err := open(opts.Repo, opts.extended)
|
be, err := open(opts.Repo, opts, opts.extended)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if opts.LimitUploadKb > 0 || opts.LimitDownloadKb > 0 {
|
|
||||||
debug.Log("rate limiting backend to %d KiB/s upload and %d KiB/s download", opts.LimitUploadKb, opts.LimitDownloadKb)
|
|
||||||
be = limiter.LimitBackend(be, limiter.NewStaticLimiter(opts.LimitUploadKb, opts.LimitDownloadKb))
|
|
||||||
}
|
|
||||||
|
|
||||||
be = backend.NewRetryBackend(be, 10, func(msg string, err error, d time.Duration) {
|
be = backend.NewRetryBackend(be, 10, func(msg string, err error, d time.Duration) {
|
||||||
Warnf("%v returned error, retrying after %v: %v\n", msg, d, err)
|
Warnf("%v returned error, retrying after %v: %v\n", msg, d, err)
|
||||||
})
|
})
|
||||||
@ -532,7 +527,7 @@ func parseConfig(loc location.Location, opts options.Options) (interface{}, erro
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Open the backend specified by a location config.
|
// Open the backend specified by a location config.
|
||||||
func open(s string, opts options.Options) (restic.Backend, error) {
|
func open(s string, gopts GlobalOptions, opts options.Options) (restic.Backend, error) {
|
||||||
debug.Log("parsing location %v", s)
|
debug.Log("parsing location %v", s)
|
||||||
loc, err := location.Parse(s)
|
loc, err := location.Parse(s)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -551,11 +546,18 @@ func open(s string, opts options.Options) (restic.Backend, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// wrap the transport so that the throughput via HTTP is limited
|
||||||
|
rt = limiter.NewStaticLimiter(gopts.LimitUploadKb, gopts.LimitDownloadKb).Transport(rt)
|
||||||
|
|
||||||
switch loc.Scheme {
|
switch loc.Scheme {
|
||||||
case "local":
|
case "local":
|
||||||
be, err = local.Open(cfg.(local.Config))
|
be, err = local.Open(cfg.(local.Config))
|
||||||
|
// wrap the backend in a LimitBackend so that the throughput is limited
|
||||||
|
be = limiter.LimitBackend(be, limiter.NewStaticLimiter(gopts.LimitUploadKb, gopts.LimitDownloadKb))
|
||||||
case "sftp":
|
case "sftp":
|
||||||
be, err = sftp.Open(cfg.(sftp.Config), SuspendSignalHandler, InstallSignalHandler)
|
be, err = sftp.Open(cfg.(sftp.Config), SuspendSignalHandler, InstallSignalHandler)
|
||||||
|
// wrap the backend in a LimitBackend so that the throughput is limited
|
||||||
|
be = limiter.LimitBackend(be, limiter.NewStaticLimiter(gopts.LimitUploadKb, gopts.LimitDownloadKb))
|
||||||
case "s3":
|
case "s3":
|
||||||
be, err = s3.Open(cfg.(s3.Config), rt)
|
be, err = s3.Open(cfg.(s3.Config), rt)
|
||||||
case "gs":
|
case "gs":
|
||||||
|
@ -2,6 +2,7 @@ package limiter
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"io"
|
"io"
|
||||||
|
"net/http"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Limiter defines an interface that implementors can use to rate limit I/O
|
// Limiter defines an interface that implementors can use to rate limit I/O
|
||||||
@ -14,4 +15,7 @@ type Limiter interface {
|
|||||||
// Downstream returns a rate limited reader that is intended to be used
|
// Downstream returns a rate limited reader that is intended to be used
|
||||||
// for downloads.
|
// for downloads.
|
||||||
Downstream(r io.Reader) io.Reader
|
Downstream(r io.Reader) io.Reader
|
||||||
|
|
||||||
|
// Transport returns an http.RoundTripper limited with the limiter.
|
||||||
|
Transport(http.RoundTripper) http.RoundTripper
|
||||||
}
|
}
|
||||||
|
@ -2,6 +2,7 @@ package limiter
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"io"
|
"io"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
"github.com/juju/ratelimit"
|
"github.com/juju/ratelimit"
|
||||||
)
|
)
|
||||||
@ -41,6 +42,39 @@ func (l staticLimiter) Downstream(r io.Reader) io.Reader {
|
|||||||
return l.limit(r, l.downstream)
|
return l.limit(r, l.downstream)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type roundTripper func(*http.Request) (*http.Response, error)
|
||||||
|
|
||||||
|
func (rt roundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||||
|
return rt(req)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l staticLimiter) roundTripper(rt http.RoundTripper, req *http.Request) (*http.Response, error) {
|
||||||
|
if req.Body != nil {
|
||||||
|
req.Body = limitedReadCloser{
|
||||||
|
limited: l.Upstream(req.Body),
|
||||||
|
original: req.Body,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
res, err := rt.RoundTrip(req)
|
||||||
|
|
||||||
|
if res != nil && res.Body != nil {
|
||||||
|
res.Body = limitedReadCloser{
|
||||||
|
limited: l.Downstream(res.Body),
|
||||||
|
original: res.Body,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return res, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Transport returns an HTTP transport limited with the limiter l.
|
||||||
|
func (l staticLimiter) Transport(rt http.RoundTripper) http.RoundTripper {
|
||||||
|
return roundTripper(func(req *http.Request) (*http.Response, error) {
|
||||||
|
return l.roundTripper(rt, req)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func (l staticLimiter) limit(r io.Reader, b *ratelimit.Bucket) io.Reader {
|
func (l staticLimiter) limit(r io.Reader, b *ratelimit.Bucket) io.Reader {
|
||||||
if b == nil {
|
if b == nil {
|
||||||
return r
|
return r
|
||||||
|
Loading…
Reference in New Issue
Block a user