diff --git a/cmd/restic/global.go b/cmd/restic/global.go index 1563ad4c9..600ee48aa 100644 --- a/cmd/restic/global.go +++ b/cmd/restic/global.go @@ -60,14 +60,11 @@ type GlobalOptions struct { JSON bool CacheDir string NoCache bool - CACerts []string - InsecureTLS bool - TLSClientCert string CleanupCache bool Compression repository.CompressionMode - LimitUploadKb int - LimitDownloadKb int + backend.TransportOptions + limiter.Limits ctx context.Context password string @@ -117,13 +114,13 @@ func init() { f.BoolVarP(&globalOptions.JSON, "json", "", false, "set output mode to JSON for commands that support it") f.StringVar(&globalOptions.CacheDir, "cache-dir", "", "set the cache `directory`. (default: use system default cache directory)") f.BoolVar(&globalOptions.NoCache, "no-cache", false, "do not use a local cache") - f.StringSliceVar(&globalOptions.CACerts, "cacert", nil, "`file` to load root certificates from (default: use system certificates)") - f.StringVar(&globalOptions.TLSClientCert, "tls-client-cert", "", "path to a `file` containing PEM encoded TLS client certificate and private key") + f.StringSliceVar(&globalOptions.RootCertFilenames, "cacert", nil, "`file` to load root certificates from (default: use system certificates)") + f.StringVar(&globalOptions.TLSClientCertKeyFilename, "tls-client-cert", "", "path to a `file` containing PEM encoded TLS client certificate and private key") f.BoolVar(&globalOptions.InsecureTLS, "insecure-tls", false, "skip TLS certificate verification when connecting to the repo (insecure)") f.BoolVar(&globalOptions.CleanupCache, "cleanup-cache", false, "auto remove old cache directories") f.Var(&globalOptions.Compression, "compression", "compression mode (only available for repo format version 2), one of (auto|off|max)") - f.IntVar(&globalOptions.LimitUploadKb, "limit-upload", 0, "limits uploads to a maximum rate in KiB/s. (default: unlimited)") - f.IntVar(&globalOptions.LimitDownloadKb, "limit-download", 0, "limits downloads to a maximum rate in KiB/s. (default: unlimited)") + f.IntVar(&globalOptions.Limits.UploadKb, "limit-upload", 0, "limits uploads to a maximum rate in KiB/s. (default: unlimited)") + f.IntVar(&globalOptions.Limits.DownloadKb, "limit-download", 0, "limits downloads to a maximum rate in KiB/s. (default: unlimited)") f.StringSliceVarP(&globalOptions.Options, "option", "o", []string{}, "set extended option (`key=value`, can be specified multiple times)") // Use our "generate" command instead of the cobra provided "completion" command cmdRoot.CompletionOptions.DisableDefaultCmd = true @@ -681,18 +678,13 @@ func open(s string, gopts GlobalOptions, opts options.Options) (restic.Backend, return nil, err } - tropts := backend.TransportOptions{ - RootCertFilenames: globalOptions.CACerts, - TLSClientCertKeyFilename: globalOptions.TLSClientCert, - InsecureTLS: globalOptions.InsecureTLS, - } - rt, err := backend.Transport(tropts) + rt, err := backend.Transport(globalOptions.TransportOptions) if err != nil { return nil, err } // wrap the transport so that the throughput via HTTP is limited - lim := limiter.NewStaticLimiter(gopts.LimitUploadKb, gopts.LimitDownloadKb) + lim := limiter.NewStaticLimiter(gopts.Limits) rt = lim.Transport(rt) switch loc.Scheme { @@ -762,12 +754,7 @@ func create(s string, opts options.Options) (restic.Backend, error) { return nil, err } - tropts := backend.TransportOptions{ - RootCertFilenames: globalOptions.CACerts, - TLSClientCertKeyFilename: globalOptions.TLSClientCert, - InsecureTLS: globalOptions.InsecureTLS, - } - rt, err := backend.Transport(tropts) + rt, err := backend.Transport(globalOptions.TransportOptions) if err != nil { return nil, err } diff --git a/internal/limiter/limiter_backend_test.go b/internal/limiter/limiter_backend_test.go index e8f0ae17d..10e216be9 100644 --- a/internal/limiter/limiter_backend_test.go +++ b/internal/limiter/limiter_backend_test.go @@ -36,7 +36,7 @@ func TestLimitBackendSave(t *testing.T) { } return nil } - limiter := NewStaticLimiter(42*1024, 42*1024) + limiter := NewStaticLimiter(Limits{42 * 1024, 42 * 1024}) limbe := LimitBackend(be, limiter) rd := restic.NewByteReader(data, nil) @@ -82,7 +82,7 @@ func TestLimitBackendLoad(t *testing.T) { } return newTracedReadCloser(src), nil } - limiter := NewStaticLimiter(42*1024, 42*1024) + limiter := NewStaticLimiter(Limits{42 * 1024, 42 * 1024}) limbe := LimitBackend(be, limiter) err := limbe.Load(context.TODO(), testHandle, 0, 0, func(rd io.Reader) error { diff --git a/internal/limiter/static_limiter.go b/internal/limiter/static_limiter.go index 04a9b9d34..9fb8fbf24 100644 --- a/internal/limiter/static_limiter.go +++ b/internal/limiter/static_limiter.go @@ -12,20 +12,27 @@ type staticLimiter struct { downstream *ratelimit.Bucket } +// Limits represents static upload and download limits. +// For both, zero means unlimited. +type Limits struct { + UploadKb int + DownloadKb int +} + // NewStaticLimiter constructs a Limiter with a fixed (static) upload and // download rate cap -func NewStaticLimiter(uploadKb, downloadKb int) Limiter { +func NewStaticLimiter(l Limits) Limiter { var ( upstreamBucket *ratelimit.Bucket downstreamBucket *ratelimit.Bucket ) - if uploadKb > 0 { - upstreamBucket = ratelimit.NewBucketWithRate(toByteRate(uploadKb), int64(toByteRate(uploadKb))) + if l.UploadKb > 0 { + upstreamBucket = ratelimit.NewBucketWithRate(toByteRate(l.UploadKb), int64(toByteRate(l.UploadKb))) } - if downloadKb > 0 { - downstreamBucket = ratelimit.NewBucketWithRate(toByteRate(downloadKb), int64(toByteRate(downloadKb))) + if l.DownloadKb > 0 { + downstreamBucket = ratelimit.NewBucketWithRate(toByteRate(l.DownloadKb), int64(toByteRate(l.DownloadKb))) } return staticLimiter{ diff --git a/internal/limiter/static_limiter_test.go b/internal/limiter/static_limiter_test.go index 69712c932..564b6a00a 100644 --- a/internal/limiter/static_limiter_test.go +++ b/internal/limiter/static_limiter_test.go @@ -15,22 +15,19 @@ func TestLimiterWrapping(t *testing.T) { reader := bytes.NewReader([]byte{}) writer := new(bytes.Buffer) - for _, limits := range []struct { - upstream int - downstream int - }{ + for _, limits := range []Limits{ {0, 0}, {42, 0}, {0, 42}, {42, 42}, } { - limiter := NewStaticLimiter(limits.upstream*1024, limits.downstream*1024) + limiter := NewStaticLimiter(limits) - mustWrapUpstream := limits.upstream > 0 + mustWrapUpstream := limits.UploadKb > 0 test.Equals(t, limiter.Upstream(reader) != reader, mustWrapUpstream) test.Equals(t, limiter.UpstreamWriter(writer) != writer, mustWrapUpstream) - mustWrapDownstream := limits.downstream > 0 + mustWrapDownstream := limits.DownloadKb > 0 test.Equals(t, limiter.Downstream(reader) != reader, mustWrapDownstream) test.Equals(t, limiter.DownstreamWriter(writer) != writer, mustWrapDownstream) } @@ -51,7 +48,7 @@ func (r *tracedReadCloser) Close() error { } func TestRoundTripperReader(t *testing.T) { - limiter := NewStaticLimiter(42*1024, 42*1024) + limiter := NewStaticLimiter(Limits{42 * 1024, 42 * 1024}) data := make([]byte, 1234) _, err := io.ReadFull(rand.Reader, data) test.OK(t, err) @@ -89,7 +86,7 @@ func TestRoundTripperReader(t *testing.T) { } func TestRoundTripperCornerCases(t *testing.T) { - limiter := NewStaticLimiter(42*1024, 42*1024) + limiter := NewStaticLimiter(Limits{42 * 1024, 42 * 1024}) rt := limiter.Transport(roundTripper(func(req *http.Request) (*http.Response, error) { return &http.Response{}, nil