diff --git a/Gopkg.lock b/Gopkg.lock index a7b7f672f..b1f76dd31 100644 --- a/Gopkg.lock +++ b/Gopkg.lock @@ -88,8 +88,8 @@ [[projects]] name = "github.com/kurin/blazer" packages = ["b2","base","internal/b2types","internal/blog"] - revision = "cad56a04490fe20c43548d70a5a9af2be53ff14e" - version = "v0.2.0" + revision = "e269a1a17bb6aec278c06a57cb7e8f8d0d333e04" + version = "v0.2.1" [[projects]] branch = "master" @@ -214,6 +214,6 @@ [solve-meta] analyzer-name = "dep" analyzer-version = 1 - inputs-digest = "ea711bd1a9bfc8902b973a4de3a840f42536b9091fd8558980f44d6ca1622227" + inputs-digest = "abc33af201086afac21e33a2a7987a473daa6a229c3699ca13761f4d4fd7f52e" solver-name = "gps-cdcl" solver-version = 1 diff --git a/Gopkg.toml b/Gopkg.toml index ab0f5c1bd..1f58bcd4d 100644 --- a/Gopkg.toml +++ b/Gopkg.toml @@ -72,3 +72,7 @@ [[constraint]] branch = "master" name = "golang.org/x/sys" + +[[constraint]] + name = "github.com/kurin/blazer" + branch = "master" diff --git a/vendor/github.com/kurin/blazer/.travis.yml b/vendor/github.com/kurin/blazer/.travis.yml index 47102cb0d..5ca626e7f 100644 --- a/vendor/github.com/kurin/blazer/.travis.yml +++ b/vendor/github.com/kurin/blazer/.travis.yml @@ -8,4 +8,6 @@ branches: - master before_script: go run internal/bin/cleanup/cleanup.go -script: go test -v ./base ./b2 ./x/... +script: + - go test -v ./base ./b2 ./x/... + - go vet -v ./base ./b2 ./x/... diff --git a/vendor/github.com/kurin/blazer/b2/integration_test.go b/vendor/github.com/kurin/blazer/b2/integration_test.go index 07ca3c68b..2321e3e90 100644 --- a/vendor/github.com/kurin/blazer/b2/integration_test.go +++ b/vendor/github.com/kurin/blazer/b2/integration_test.go @@ -27,6 +27,8 @@ import ( "sync/atomic" "testing" "time" + + "github.com/kurin/blazer/x/transport" ) const ( @@ -744,10 +746,10 @@ func (rt *rtCounter) RoundTrip(r *http.Request) (*http.Response, error) { } func TestAttrsNoRoundtrip(t *testing.T) { - rt := &rtCounter{rt: transport} - transport = rt + rt := &rtCounter{rt: defaultTransport} + defaultTransport = rt defer func() { - transport = rt.rt + defaultTransport = rt.rt }() ctx := context.Background() @@ -767,7 +769,7 @@ func TestAttrsNoRoundtrip(t *testing.T) { t.Fatal(err) } if len(objs) != 1 { - t.Fatal("unexpected objects: got %d, want 1", len(objs)) + t.Fatalf("unexpected objects: got %d, want 1", len(objs)) } trips := rt.trips @@ -842,7 +844,7 @@ func listObjects(ctx context.Context, f func(context.Context, int, *Cursor) ([]* return ch } -var transport = http.DefaultTransport +var defaultTransport = http.DefaultTransport type eofTripper struct { rt http.RoundTripper @@ -919,9 +921,10 @@ func startLiveTest(ctx context.Context, t *testing.T) (*Bucket, func()) { t.Skipf("B2_ACCOUNT_ID or B2_SECRET_KEY unset; skipping integration tests") return nil, nil } - ccport := &ccTripper{rt: transport, t: t} + ccport := &ccTripper{rt: defaultTransport, t: t} tport := eofTripper{rt: ccport, t: t} - client, err := NewClient(ctx, id, key, FailSomeUploads(), ExpireSomeAuthTokens(), Transport(tport), UserAgent("b2-test"), UserAgent("integration-test")) + errport := transport.WithFailures(tport, transport.FailureRate(.25), transport.MatchPathSubstring("/b2_get_upload_url"), transport.Response(503)) + client, err := NewClient(ctx, id, key, FailSomeUploads(), ExpireSomeAuthTokens(), Transport(errport), UserAgent("b2-test"), UserAgent("integration-test")) if err != nil { t.Fatal(err) return nil, nil diff --git a/vendor/github.com/kurin/blazer/base/base.go b/vendor/github.com/kurin/blazer/base/base.go index 431e618bd..726ea8e37 100644 --- a/vendor/github.com/kurin/blazer/base/base.go +++ b/vendor/github.com/kurin/blazer/base/base.go @@ -43,7 +43,7 @@ import ( const ( APIBase = "https://api.backblazeb2.com" - DefaultUserAgent = "blazer/0.1.1" + DefaultUserAgent = "blazer/0.2.1" ) type b2err struct { @@ -131,25 +131,32 @@ const ( func mkErr(resp *http.Response) error { data, err := ioutil.ReadAll(resp.Body) + var msgBody string if err != nil { - return err + msgBody = fmt.Sprintf("couldn't read message body: %v", err) } logResponse(resp, data) msg := &b2types.ErrorMessage{} if err := json.Unmarshal(data, msg); err != nil { - return err + if msgBody != "" { + msgBody = fmt.Sprintf("couldn't read message body: %v", err) + } + } + if msgBody == "" { + msgBody = msg.Msg } var retryAfter int retry := resp.Header.Get("Retry-After") if retry != "" { r, err := strconv.ParseInt(retry, 10, 64) if err != nil { - return err + r = 0 + blog.V(1).Infof("couldn't parse retry-after header %q: %v", retry, err) } retryAfter = int(r) } return b2err{ - msg: msg.Msg, + msg: msgBody, retry: retryAfter, code: resp.StatusCode, method: resp.Request.Header.Get("X-Blazer-Method"), @@ -222,6 +229,19 @@ type b2Options struct { userAgent string } +func (o *b2Options) addHeaders(req *http.Request) { + if o.failSomeUploads { + req.Header.Add("X-Bz-Test-Mode", "fail_some_uploads") + } + if o.expireTokens { + req.Header.Add("X-Bz-Test-Mode", "expire_some_account_authorization_tokens") + } + if o.capExceeded { + req.Header.Add("X-Bz-Test-Mode", "force_cap_exceeded") + } + req.Header.Set("User-Agent", o.getUserAgent()) +} + func (o *b2Options) getAPIBase() string { if o.apiBase != "" { return o.apiBase @@ -268,14 +288,22 @@ type httpReply struct { err error } -func makeNetRequest(req *http.Request, rt http.RoundTripper) <-chan httpReply { - ch := make(chan httpReply) - go func() { - resp, err := rt.RoundTrip(req) - ch <- httpReply{resp, err} - close(ch) - }() - return ch +func makeNetRequest(ctx context.Context, req *http.Request, rt http.RoundTripper) (*http.Response, error) { + req = req.WithContext(ctx) + resp, err := rt.RoundTrip(req) + switch err { + case nil: + return resp, nil + case context.Canceled, context.DeadlineExceeded: + return nil, err + default: + method := req.Header.Get("X-Blazer-Method") + blog.V(2).Infof(">> %s uri: %v err: %v", method, req.URL, err) + return nil, b2err{ + msg: err.Error(), + retry: 1, + } + } } type requestBody struct { @@ -351,38 +379,14 @@ func (o *b2Options) makeRequest(ctx context.Context, method, verb, uri string, b } req.Header.Set(k, v) } - req.Header.Set("User-Agent", o.getUserAgent()) req.Header.Set("X-Blazer-Request-ID", fmt.Sprintf("%d", atomic.AddInt64(&reqID, 1))) req.Header.Set("X-Blazer-Method", method) - if o.failSomeUploads { - req.Header.Add("X-Bz-Test-Mode", "fail_some_uploads") - } - if o.expireTokens { - req.Header.Add("X-Bz-Test-Mode", "expire_some_account_authorization_tokens") - } - if o.capExceeded { - req.Header.Add("X-Bz-Test-Mode", "force_cap_exceeded") - } - cancel := make(chan struct{}) - req.Cancel = cancel + o.addHeaders(req) logRequest(req, args) - ch := makeNetRequest(req, o.getTransport()) - var reply httpReply - select { - case reply = <-ch: - case <-ctx.Done(): - close(cancel) - return ctx.Err() + resp, err := makeNetRequest(ctx, req, o.getTransport()) + if err != nil { + return err } - if reply.err != nil { - // Connection errors are retryable. - blog.V(2).Infof(">> %s uri: %v err: %v", method, req.URL, reply.err) - return b2err{ - msg: reply.err.Error(), - retry: 1, - } - } - resp := reply.resp defer resp.Body.Close() if resp.StatusCode != 200 { return mkErr(resp) @@ -397,10 +401,11 @@ func (o *b2Options) makeRequest(ctx context.Context, method, verb, uri string, b } replyArgs = rbuf.Bytes() } else { - replyArgs, err = ioutil.ReadAll(resp.Body) + ra, err := ioutil.ReadAll(resp.Body) if err != nil { - return err + blog.V(1).Infof("%s: couldn't read response: %v", method, err) } + replyArgs = ra } logResponse(resp, replyArgs) return nil @@ -1038,7 +1043,7 @@ func mkRange(offset, size int64) string { // DownloadFileByName wraps b2_download_file_by_name. func (b *Bucket) DownloadFileByName(ctx context.Context, name string, offset, size int64) (*FileReader, error) { - uri := fmt.Sprintf("%s/file/%s/%s", b.b2.downloadURI, b.Name, name) + uri := fmt.Sprintf("%s/file/%s/%s", b.b2.downloadURI, b.Name, escape(name)) req, err := http.NewRequest("GET", uri, nil) if err != nil { return nil, err @@ -1046,25 +1051,16 @@ func (b *Bucket) DownloadFileByName(ctx context.Context, name string, offset, si req.Header.Set("Authorization", b.b2.authToken) req.Header.Set("X-Blazer-Request-ID", fmt.Sprintf("%d", atomic.AddInt64(&reqID, 1))) req.Header.Set("X-Blazer-Method", "b2_download_file_by_name") + b.b2.opts.addHeaders(req) rng := mkRange(offset, size) if rng != "" { req.Header.Set("Range", rng) } - cancel := make(chan struct{}) - req.Cancel = cancel logRequest(req, nil) - ch := makeNetRequest(req, b.b2.opts.getTransport()) - var reply httpReply - select { - case reply = <-ch: - case <-ctx.Done(): - close(cancel) - return nil, ctx.Err() + resp, err := makeNetRequest(ctx, req, b.b2.opts.getTransport()) + if err != nil { + return nil, err } - if reply.err != nil { - return nil, reply.err - } - resp := reply.resp logResponse(resp, nil) if resp.StatusCode != 200 && resp.StatusCode != 206 { defer resp.Body.Close() diff --git a/vendor/github.com/kurin/blazer/base/integration_test.go b/vendor/github.com/kurin/blazer/base/integration_test.go index 92af5a6d8..ad648f1c4 100644 --- a/vendor/github.com/kurin/blazer/base/integration_test.go +++ b/vendor/github.com/kurin/blazer/base/integration_test.go @@ -20,14 +20,14 @@ import ( "encoding/json" "fmt" "io" - "net" - "net/http" "os" "reflect" "strings" "testing" "time" + "github.com/kurin/blazer/x/transport" + "context" ) @@ -270,48 +270,7 @@ func TestStorage(t *testing.T) { } } -// This slow motion train wreck of a type exists to axe a net connection after -// N bytes have been written. Because of the specific bug it's built to test, -// it can't just *close* the connection, so it just sleeps forever. -type wonkyNetConn struct { - net.Conn - ctx context.Context // implode once cancelled - die *bool // only implode once - n int // bytes to allow before imploding, roughly - i int // bytes written -} - -func (w *wonkyNetConn) Write(b []byte) (int, error) { - if w.i > w.n && w.ctx.Err() != nil && *w.die { - *w.die = false - select {} - } - n, err := w.Conn.Write(b) - w.i += n - return n, err -} - -func newWonkyNetConn(ctx context.Context, die *bool, n int, netw, addr string) (net.Conn, error) { - conn, err := net.Dial(netw, addr) - if err != nil { - return nil, err - } - return &wonkyNetConn{ - Conn: conn, - ctx: ctx, - n: n, - die: die, - }, nil -} - -func makeBadDialContext(ctx context.Context) func(context.Context, string, string) (net.Conn, error) { - die := true - return func(noCtx context.Context, network, addr string) (net.Conn, error) { - return newWonkyNetConn(ctx, &die, 10000, network, addr) - } -} - -func TestBadUpload(t *testing.T) { +func TestUploadAuthAfterConnectionHang(t *testing.T) { id := os.Getenv(apiID) key := os.Getenv(apiKey) if id == "" || key == "" { @@ -319,19 +278,16 @@ func TestBadUpload(t *testing.T) { } ctx := context.Background() - octx, ocancel := context.WithCancel(ctx) - defer ocancel() + hung := make(chan struct{}) - badTransport := &http.Transport{ - Proxy: http.ProxyFromEnvironment, - DialContext: makeBadDialContext(octx), - MaxIdleConns: 100, - IdleConnTimeout: 90 * time.Second, - TLSHandshakeTimeout: 10 * time.Second, - ExpectContinueTimeout: 1 * time.Second, + // An http.RoundTripper that dies after sending ~10k bytes. + hang := func() { + close(hung) + select {} } + tport := transport.WithFailures(nil, transport.AfterNBytes(10000, hang)) - b2, err := AuthorizeAccount(ctx, id, key, Transport(badTransport)) + b2, err := AuthorizeAccount(ctx, id, key, Transport(tport)) if err != nil { t.Fatal(err) } @@ -358,13 +314,13 @@ func TestBadUpload(t *testing.T) { t.Error(err) } smallSHA1 := fmt.Sprintf("%x", hash.Sum(nil)) - ocancel() + go func() { ue.UploadFile(ctx, buf, buf.Len(), smallFileName, "application/octet-stream", smallSHA1, nil) t.Fatal("this ought not to be reachable") }() - time.Sleep(time.Second) // give this a chance to hang + <-hung // Do the whole thing again with the same upload auth, before the remote end // notices we're gone. @@ -381,7 +337,97 @@ func TestBadUpload(t *testing.T) { } } if Action(err) != AttemptNewUpload { - t.Error("Action(%v): got %v, want AttemptNewUpload", err, Action(err)) + t.Errorf("Action(%v): got %v, want AttemptNewUpload", err, Action(err)) + } +} + +func TestCancelledContextCancelsHTTPRequest(t *testing.T) { + id := os.Getenv(apiID) + key := os.Getenv(apiKey) + if id == "" || key == "" { + t.Skipf("B2_ACCOUNT_ID or B2_SECRET_KEY unset; skipping integration tests") + } + ctx := context.Background() + + tport := transport.WithFailures(nil, transport.MatchPathSubstring("b2_upload_file"), transport.FailureRate(1), transport.Stall(2*time.Second)) + + b2, err := AuthorizeAccount(ctx, id, key, Transport(tport)) + if err != nil { + t.Fatal(err) + } + bname := id + "-" + bucketName + bucket, err := b2.CreateBucket(ctx, bname, "", nil, nil) + if err != nil { + t.Fatal(err) + } + defer func() { + if err := bucket.DeleteBucket(ctx); err != nil { + t.Error(err) + } + }() + ue, err := bucket.GetUploadURL(ctx) + if err != nil { + t.Fatal(err) + } + + smallFile := io.LimitReader(zReader{}, 1024*50) // 50k + hash := sha1.New() + buf := &bytes.Buffer{} + w := io.MultiWriter(hash, buf) + if _, err := io.Copy(w, smallFile); err != nil { + t.Error(err) + } + smallSHA1 := fmt.Sprintf("%x", hash.Sum(nil)) + cctx, cancel := context.WithCancel(ctx) + go func() { + time.Sleep(1) + cancel() + }() + if _, err := ue.UploadFile(cctx, buf, buf.Len(), smallFileName, "application/octet-stream", smallSHA1, nil); err != context.Canceled { + t.Errorf("expected canceled context, but got %v", err) + } +} + +func TestDeadlineExceededContextCancelsHTTPRequest(t *testing.T) { + id := os.Getenv(apiID) + key := os.Getenv(apiKey) + if id == "" || key == "" { + t.Skipf("B2_ACCOUNT_ID or B2_SECRET_KEY unset; skipping integration tests") + } + ctx := context.Background() + + tport := transport.WithFailures(nil, transport.MatchPathSubstring("b2_upload_file"), transport.FailureRate(1), transport.Stall(2*time.Second)) + b2, err := AuthorizeAccount(ctx, id, key, Transport(tport)) + if err != nil { + t.Fatal(err) + } + bname := id + "-" + bucketName + bucket, err := b2.CreateBucket(ctx, bname, "", nil, nil) + if err != nil { + t.Fatal(err) + } + defer func() { + if err := bucket.DeleteBucket(ctx); err != nil { + t.Error(err) + } + }() + ue, err := bucket.GetUploadURL(ctx) + if err != nil { + t.Fatal(err) + } + + smallFile := io.LimitReader(zReader{}, 1024*50) // 50k + hash := sha1.New() + buf := &bytes.Buffer{} + w := io.MultiWriter(hash, buf) + if _, err := io.Copy(w, smallFile); err != nil { + t.Error(err) + } + smallSHA1 := fmt.Sprintf("%x", hash.Sum(nil)) + cctx, cancel := context.WithTimeout(ctx, time.Second) + defer cancel() + if _, err := ue.UploadFile(cctx, buf, buf.Len(), smallFileName, "application/octet-stream", smallSHA1, nil); err != context.DeadlineExceeded { + t.Errorf("expected deadline exceeded error, but got %v", err) } } @@ -533,3 +579,72 @@ func TestEscapes(t *testing.T) { } } } + +func TestUploadDownloadFilenameEscaping(t *testing.T) { + filename := "file%foo.txt" + + id := os.Getenv(apiID) + key := os.Getenv(apiKey) + + if id == "" || key == "" { + t.Skipf("B2_ACCOUNT_ID or B2_SECRET_KEY unset; skipping integration tests") + } + ctx := context.Background() + + // b2_authorize_account + b2, err := AuthorizeAccount(ctx, id, key, UserAgent("blazer-base-test")) + if err != nil { + t.Fatal(err) + } + + // b2_create_bucket + bname := id + "-" + bucketName + bucket, err := b2.CreateBucket(ctx, bname, "", nil, nil) + if err != nil { + t.Fatal(err) + } + + defer func() { + // b2_delete_bucket + if err := bucket.DeleteBucket(ctx); err != nil { + t.Error(err) + } + }() + + // b2_get_upload_url + ue, err := bucket.GetUploadURL(ctx) + if err != nil { + t.Fatal(err) + } + + // b2_upload_file + smallFile := io.LimitReader(zReader{}, 128) + hash := sha1.New() + buf := &bytes.Buffer{} + w := io.MultiWriter(hash, buf) + if _, err := io.Copy(w, smallFile); err != nil { + t.Error(err) + } + smallSHA1 := fmt.Sprintf("%x", hash.Sum(nil)) + file, err := ue.UploadFile(ctx, buf, buf.Len(), filename, "application/octet-stream", smallSHA1, nil) + if err != nil { + t.Fatal(err) + } + + defer func() { + // b2_delete_file_version + if err := file.DeleteFileVersion(ctx); err != nil { + t.Error(err) + } + }() + + // b2_download_file_by_name + fr, err := bucket.DownloadFileByName(ctx, filename, 0, 0) + if err != nil { + t.Fatal(err) + } + lbuf := &bytes.Buffer{} + if _, err := io.Copy(lbuf, fr); err != nil { + t.Fatal(err) + } +} diff --git a/vendor/github.com/kurin/blazer/base/strings.go b/vendor/github.com/kurin/blazer/base/strings.go index 88e615f3e..9ad08dc36 100644 --- a/vendor/github.com/kurin/blazer/base/strings.go +++ b/vendor/github.com/kurin/blazer/base/strings.go @@ -15,67 +15,14 @@ package base import ( - "bytes" - "errors" - "fmt" + "net/url" + "strings" ) -func noEscape(c byte) bool { - switch c { - case '.', '_', '-', '/', '~', '!', '$', '\'', '(', ')', '*', ';', '=', ':', '@': - return true - } - return false -} - func escape(s string) string { - // cribbed from url.go, kinda - b := &bytes.Buffer{} - for i := 0; i < len(s); i++ { - switch c := s[i]; { - case c == '/': - b.WriteByte(c) - case 'a' <= c && c <= 'z' || 'A' <= c && c <= 'Z' || '0' <= c && c <= '9': - b.WriteByte(c) - case noEscape(c): - b.WriteByte(c) - default: - fmt.Fprintf(b, "%%%X", c) - } - } - return b.String() + return strings.Replace(url.QueryEscape(s), "%2F", "/", -1) } func unescape(s string) (string, error) { - b := &bytes.Buffer{} - for i := 0; i < len(s); i++ { - c := s[i] - switch c { - case '/': - b.WriteString("/") - case '+': - b.WriteString(" ") - case '%': - if len(s)-i < 3 { - return "", errors.New("unescape: bad encoding") - } - b.WriteByte(unhex(s[i+1])<<4 | unhex(s[i+2])) - i += 2 - default: - b.WriteByte(c) - } - } - return b.String(), nil -} - -func unhex(c byte) byte { - switch { - case '0' <= c && c <= '9': - return c - '0' - case 'a' <= c && c <= 'f': - return c - 'a' + 10 - case 'A' <= c && c <= 'F': - return c - 'A' + 10 - } - return 0 + return url.QueryUnescape(s) } diff --git a/vendor/github.com/kurin/blazer/base/strings_test.go b/vendor/github.com/kurin/blazer/base/strings_test.go new file mode 100644 index 000000000..c37629fa1 --- /dev/null +++ b/vendor/github.com/kurin/blazer/base/strings_test.go @@ -0,0 +1,52 @@ +package base + +import ( + "fmt" + "testing" +) + +func TestEncodeDecode(t *testing.T) { + // crashes identified by go-fuzz + origs := []string{ + "&\x020000", + "&\x020000\x9c", + "&\x020\x9c0", + "&\x0230j", + "&\x02\x98000", + "&\x02\x983\xc8j00", + "00\x000", + "00\x0000", + "00\x0000000000000", + "\x11\x030", + } + + for _, orig := range origs { + escaped := escape(orig) + unescaped, err := unescape(escaped) + if err != nil { + t.Errorf("%s: orig: %#v, escaped: %#v, unescaped: %#v\n", err.Error(), orig, escaped, unescaped) + continue + } + + if unescaped != orig { + t.Errorf("expected: %#v, got: %#v", orig, unescaped) + } + } +} + +// hook for go-fuzz: https://github.com/dvyukov/go-fuzz +func Fuzz(data []byte) int { + orig := string(data) + escaped := escape(orig) + + unescaped, err := unescape(escaped) + if err != nil { + return 0 + } + + if unescaped != orig { + panic(fmt.Sprintf("unescaped: \"%#v\", != orig: \"%#v\"", unescaped, orig)) + } + + return 1 +} diff --git a/vendor/github.com/kurin/blazer/x/transport/transport.go b/vendor/github.com/kurin/blazer/x/transport/transport.go new file mode 100644 index 000000000..03a878f63 --- /dev/null +++ b/vendor/github.com/kurin/blazer/x/transport/transport.go @@ -0,0 +1,207 @@ +// Copyright 2017, Google +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package transport provides http.RoundTrippers that may be useful to clients +// of Blazer. +package transport + +import ( + "context" + "fmt" + "io" + "io/ioutil" + "math/rand" + "net/http" + "strings" + "sync/atomic" + "time" +) + +// WithFailures returns an http.RoundTripper that wraps an existing +// RoundTripper, causing failures according to the options given. If rt is +// nil, the http.DefaultTransport is wrapped. +func WithFailures(rt http.RoundTripper, opts ...FailureOption) http.RoundTripper { + if rt == nil { + rt = http.DefaultTransport + } + o := &options{ + rt: rt, + } + for _, opt := range opts { + opt(o) + } + return o +} + +type options struct { + pathSubstrings []string + failureRate float64 + status int + stall time.Duration + rt http.RoundTripper + msg string + trg *triggerReaderGroup +} + +func (o *options) doRequest(req *http.Request) (*http.Response, error) { + if o.trg != nil && req.Body != nil { + req.Body = o.trg.new(req.Body) + } + resp, err := o.rt.RoundTrip(req) + if resp != nil && o.trg != nil { + resp.Body = o.trg.new(resp.Body) + } + return resp, err +} + +func (o *options) RoundTrip(req *http.Request) (*http.Response, error) { + // TODO: fix triggering conditions + if rand.Float64() > o.failureRate { + return o.doRequest(req) + } + + var match bool + if len(o.pathSubstrings) == 0 { + match = true + } + for _, ss := range o.pathSubstrings { + if strings.Contains(req.URL.Path, ss) { + match = true + break + } + } + if !match { + return o.doRequest(req) + } + + if o.status > 0 { + resp := &http.Response{ + Status: fmt.Sprintf("%d %s", o.status, http.StatusText(o.status)), + StatusCode: o.status, + Body: ioutil.NopCloser(strings.NewReader(o.msg)), + Request: req, + } + return resp, nil + } + + if o.stall > 0 { + ctx := req.Context() + select { + case <-time.After(o.stall): + case <-ctx.Done(): + } + } + return o.doRequest(req) +} + +// A FailureOption specifies the kind of failure that the RoundTripper should +// display. +type FailureOption func(*options) + +// MatchPathSubstring restricts the RoundTripper to URLs whose paths contain +// the given string. The default behavior is to match all paths. +func MatchPathSubstring(s string) FailureOption { + return func(o *options) { + o.pathSubstrings = append(o.pathSubstrings, s) + } +} + +// FailureRate causes the RoundTripper to fail a certain percentage of the +// time. rate should be a number between 0 and 1, where 0 will never fail and +// 1 will always fail. The default is never to fail. +func FailureRate(rate float64) FailureOption { + return func(o *options) { + o.failureRate = rate + } +} + +// Response simulates a given status code. The returned http.Response will +// have its Status, StatusCode, and Body (with any predefined message) set. +func Response(status int) FailureOption { + return func(o *options) { + o.status = status + } +} + +// Stall simulates a network connection failure by stalling for the given +// duration. +func Stall(dur time.Duration) FailureOption { + return func(o *options) { + o.stall = dur + } +} + +// If a specific Response is requested, the body will have the given message +// set. +func Body(msg string) FailureOption { + return func(o *options) { + o.msg = msg + } +} + +// Trigger will raise the RoundTripper's failure rate to 100% when the given +// context is closed. +func Trigger(ctx context.Context) FailureOption { + return func(o *options) { + go func() { + <-ctx.Done() + o.failureRate = 1 + }() + } +} + +// AfterNBytes will call effect once (roughly) n bytes have gone over the wire. +// Both sent and received bytes are counted against the total. Only bytes in +// the body of an HTTP request are currently counted; this may change in the +// future. effect will only be called once, and it will block (allowing +// callers to simulate connection hangs). +func AfterNBytes(n int, effect func()) FailureOption { + return func(o *options) { + o.trg = &triggerReaderGroup{ + bytes: int64(n), + trigger: effect, + } + } +} + +type triggerReaderGroup struct { + bytes int64 + trigger func() + triggered int64 +} + +func (rg *triggerReaderGroup) new(rc io.ReadCloser) io.ReadCloser { + return &triggerReader{ + ReadCloser: rc, + bytes: &rg.bytes, + trigger: rg.trigger, + triggered: &rg.triggered, + } +} + +type triggerReader struct { + io.ReadCloser + bytes *int64 + trigger func() + triggered *int64 +} + +func (r *triggerReader) Read(p []byte) (int, error) { + n, err := r.ReadCloser.Read(p) + if atomic.AddInt64(r.bytes, -int64(n)) < 0 && atomic.CompareAndSwapInt64(r.triggered, 0, 1) { + // Can't use sync.Once because it blocks for *all* callers until Do returns. + r.trigger() + } + return n, err +}