diff --git a/cmd/restic/global.go b/cmd/restic/global.go index 3a66323dc..96d1f9e36 100644 --- a/cmd/restic/global.go +++ b/cmd/restic/global.go @@ -561,17 +561,18 @@ func open(s string, gopts GlobalOptions, opts options.Options) (restic.Backend, } // wrap the transport so that the throughput via HTTP is limited - rt = limiter.NewStaticLimiter(gopts.LimitUploadKb, gopts.LimitDownloadKb).Transport(rt) + lim := limiter.NewStaticLimiter(gopts.LimitUploadKb, gopts.LimitDownloadKb) + rt = lim.Transport(rt) switch loc.Scheme { case "local": be, err = local.Open(cfg.(local.Config)) // wrap the backend in a LimitBackend so that the throughput is limited - be = limiter.LimitBackend(be, limiter.NewStaticLimiter(gopts.LimitUploadKb, gopts.LimitDownloadKb)) + be = limiter.LimitBackend(be, lim) case "sftp": be, err = sftp.Open(cfg.(sftp.Config)) // wrap the backend in a LimitBackend so that the throughput is limited - be = limiter.LimitBackend(be, limiter.NewStaticLimiter(gopts.LimitUploadKb, gopts.LimitDownloadKb)) + be = limiter.LimitBackend(be, lim) case "s3": be, err = s3.Open(cfg.(s3.Config), rt) case "gs": @@ -585,7 +586,7 @@ func open(s string, gopts GlobalOptions, opts options.Options) (restic.Backend, case "rest": be, err = rest.Open(cfg.(rest.Config), rt) case "rclone": - be, err = rclone.Open(cfg.(rclone.Config)) + be, err = rclone.Open(cfg.(rclone.Config), lim) default: return nil, errors.Fatalf("invalid backend: %q", loc.Scheme) @@ -648,7 +649,7 @@ func create(s string, opts options.Options) (restic.Backend, error) { case "rest": return rest.Create(cfg.(rest.Config), rt) case "rclone": - return rclone.Open(cfg.(rclone.Config)) + return rclone.Open(cfg.(rclone.Config), nil) } debug.Log("invalid repository scheme: %v", s) diff --git a/internal/backend/rclone/backend.go b/internal/backend/rclone/backend.go index e19cfa6c8..31426718e 100644 --- a/internal/backend/rclone/backend.go +++ b/internal/backend/rclone/backend.go @@ -5,6 +5,7 @@ import ( "context" "crypto/tls" "fmt" + "io" "math/rand" "net" "net/http" @@ -18,6 +19,7 @@ import ( "github.com/restic/restic/internal/backend/rest" "github.com/restic/restic/internal/debug" "github.com/restic/restic/internal/errors" + "github.com/restic/restic/internal/limiter" "golang.org/x/net/context/ctxhttp" "golang.org/x/net/http2" ) @@ -81,8 +83,38 @@ func run(command string, args ...string) (*StdioConn, *exec.Cmd, *sync.WaitGroup return c, cmd, &wg, bg, nil } +// wrappedConn adds bandwidth limiting capabilities to the StdioConn by +// wrapping the Read/Write methods. +type wrappedConn struct { + *StdioConn + io.Reader + io.Writer +} + +func (c wrappedConn) Read(p []byte) (int, error) { + return c.Reader.Read(p) +} + +func (c wrappedConn) Write(p []byte) (int, error) { + return c.Writer.Write(p) +} + +func wrapConn(c *StdioConn, lim limiter.Limiter) wrappedConn { + wc := wrappedConn{ + StdioConn: c, + Reader: c, + Writer: c, + } + if lim != nil { + wc.Reader = lim.Downstream(c) + wc.Writer = lim.UpstreamWriter(c) + } + + return wc +} + // New initializes a Backend and starts the process. -func New(cfg Config) (*Backend, error) { +func New(cfg Config, lim limiter.Limiter) (*Backend, error) { var ( args []string err error @@ -118,11 +150,16 @@ func New(cfg Config) (*Backend, error) { arg0, args := args[0], args[1:] debug.Log("running command: %v %v", arg0, args) - conn, cmd, wg, bg, err := run(arg0, args...) + stdioConn, cmd, wg, bg, err := run(arg0, args...) if err != nil { return nil, err } + var conn net.Conn = stdioConn + if lim != nil { + conn = wrapConn(stdioConn, lim) + } + dialCount := 0 tr := &http2.Transport{ AllowHTTP: true, // this is not really HTTP, just stdin/stdout @@ -141,7 +178,7 @@ func New(cfg Config) (*Backend, error) { tr: tr, cmd: cmd, waitCh: waitCh, - conn: conn, + conn: stdioConn, wg: wg, } @@ -202,8 +239,8 @@ func New(cfg Config) (*Backend, error) { } // Open starts an rclone process with the given config. -func Open(cfg Config) (*Backend, error) { - be, err := New(cfg) +func Open(cfg Config, lim limiter.Limiter) (*Backend, error) { + be, err := New(cfg, lim) if err != nil { return nil, err } @@ -229,7 +266,7 @@ func Open(cfg Config) (*Backend, error) { // Create initializes a new restic repo with clone. func Create(cfg Config) (*Backend, error) { - be, err := New(cfg) + be, err := New(cfg, nil) if err != nil { return nil, err } diff --git a/internal/backend/rclone/backend_test.go b/internal/backend/rclone/backend_test.go index 16281035d..b7f3cebb4 100644 --- a/internal/backend/rclone/backend_test.go +++ b/internal/backend/rclone/backend_test.go @@ -39,7 +39,7 @@ func newTestSuite(t testing.TB) *test.Suite { Open: func(config interface{}) (restic.Backend, error) { t.Logf("Open()") cfg := config.(rclone.Config) - return rclone.Open(cfg) + return rclone.Open(cfg, nil) }, // CleanupFn removes data created during the tests. diff --git a/internal/limiter/limiter.go b/internal/limiter/limiter.go index abdbeaf75..410bc7f64 100644 --- a/internal/limiter/limiter.go +++ b/internal/limiter/limiter.go @@ -12,6 +12,10 @@ type Limiter interface { // uploads. Upstream(r io.Reader) io.Reader + // UpstreamWriter returns a rate limited writer that is intended to be used + // in uploads. + UpstreamWriter(w io.Writer) io.Writer + // Downstream returns a rate limited reader that is intended to be used // for downloads. Downstream(r io.Reader) io.Reader diff --git a/internal/limiter/static_limiter.go b/internal/limiter/static_limiter.go index c2ff96ce4..5df7a84da 100644 --- a/internal/limiter/static_limiter.go +++ b/internal/limiter/static_limiter.go @@ -35,11 +35,15 @@ func NewStaticLimiter(uploadKb, downloadKb int) Limiter { } func (l staticLimiter) Upstream(r io.Reader) io.Reader { - return l.limit(r, l.upstream) + return l.limitReader(r, l.upstream) +} + +func (l staticLimiter) UpstreamWriter(w io.Writer) io.Writer { + return l.limitWriter(w, l.upstream) } func (l staticLimiter) Downstream(r io.Reader) io.Reader { - return l.limit(r, l.downstream) + return l.limitReader(r, l.downstream) } type roundTripper func(*http.Request) (*http.Response, error) @@ -75,13 +79,20 @@ func (l staticLimiter) Transport(rt http.RoundTripper) http.RoundTripper { }) } -func (l staticLimiter) limit(r io.Reader, b *ratelimit.Bucket) io.Reader { +func (l staticLimiter) limitReader(r io.Reader, b *ratelimit.Bucket) io.Reader { if b == nil { return r } return ratelimit.Reader(r, b) } +func (l staticLimiter) limitWriter(w io.Writer, b *ratelimit.Bucket) io.Writer { + if b == nil { + return w + } + return ratelimit.Writer(w, b) +} + func toByteRate(val int) float64 { return float64(val) * 1024. }