2
2
mirror of https://github.com/octoleo/restic.git synced 2024-06-03 09:30:50 +00:00
restic/vendor/github.com/kurin/blazer/x/transport/transport.go
2017-10-21 10:30:39 +02:00

208 lines
5.1 KiB
Go

// Copyright 2017, Google
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package transport provides http.RoundTrippers that may be useful to clients
// of Blazer.
package transport
import (
"context"
"fmt"
"io"
"io/ioutil"
"math/rand"
"net/http"
"strings"
"sync/atomic"
"time"
)
// WithFailures returns an http.RoundTripper that wraps an existing
// RoundTripper, causing failures according to the options given. If rt is
// nil, the http.DefaultTransport is wrapped.
func WithFailures(rt http.RoundTripper, opts ...FailureOption) http.RoundTripper {
if rt == nil {
rt = http.DefaultTransport
}
o := &options{
rt: rt,
}
for _, opt := range opts {
opt(o)
}
return o
}
type options struct {
pathSubstrings []string
failureRate float64
status int
stall time.Duration
rt http.RoundTripper
msg string
trg *triggerReaderGroup
}
func (o *options) doRequest(req *http.Request) (*http.Response, error) {
if o.trg != nil && req.Body != nil {
req.Body = o.trg.new(req.Body)
}
resp, err := o.rt.RoundTrip(req)
if resp != nil && o.trg != nil {
resp.Body = o.trg.new(resp.Body)
}
return resp, err
}
func (o *options) RoundTrip(req *http.Request) (*http.Response, error) {
// TODO: fix triggering conditions
if rand.Float64() > o.failureRate {
return o.doRequest(req)
}
var match bool
if len(o.pathSubstrings) == 0 {
match = true
}
for _, ss := range o.pathSubstrings {
if strings.Contains(req.URL.Path, ss) {
match = true
break
}
}
if !match {
return o.doRequest(req)
}
if o.status > 0 {
resp := &http.Response{
Status: fmt.Sprintf("%d %s", o.status, http.StatusText(o.status)),
StatusCode: o.status,
Body: ioutil.NopCloser(strings.NewReader(o.msg)),
Request: req,
}
return resp, nil
}
if o.stall > 0 {
ctx := req.Context()
select {
case <-time.After(o.stall):
case <-ctx.Done():
}
}
return o.doRequest(req)
}
// A FailureOption specifies the kind of failure that the RoundTripper should
// display.
type FailureOption func(*options)
// MatchPathSubstring restricts the RoundTripper to URLs whose paths contain
// the given string. The default behavior is to match all paths.
func MatchPathSubstring(s string) FailureOption {
return func(o *options) {
o.pathSubstrings = append(o.pathSubstrings, s)
}
}
// FailureRate causes the RoundTripper to fail a certain percentage of the
// time. rate should be a number between 0 and 1, where 0 will never fail and
// 1 will always fail. The default is never to fail.
func FailureRate(rate float64) FailureOption {
return func(o *options) {
o.failureRate = rate
}
}
// Response simulates a given status code. The returned http.Response will
// have its Status, StatusCode, and Body (with any predefined message) set.
func Response(status int) FailureOption {
return func(o *options) {
o.status = status
}
}
// Stall simulates a network connection failure by stalling for the given
// duration.
func Stall(dur time.Duration) FailureOption {
return func(o *options) {
o.stall = dur
}
}
// If a specific Response is requested, the body will have the given message
// set.
func Body(msg string) FailureOption {
return func(o *options) {
o.msg = msg
}
}
// Trigger will raise the RoundTripper's failure rate to 100% when the given
// context is closed.
func Trigger(ctx context.Context) FailureOption {
return func(o *options) {
go func() {
<-ctx.Done()
o.failureRate = 1
}()
}
}
// AfterNBytes will call effect once (roughly) n bytes have gone over the wire.
// Both sent and received bytes are counted against the total. Only bytes in
// the body of an HTTP request are currently counted; this may change in the
// future. effect will only be called once, and it will block (allowing
// callers to simulate connection hangs).
func AfterNBytes(n int, effect func()) FailureOption {
return func(o *options) {
o.trg = &triggerReaderGroup{
bytes: int64(n),
trigger: effect,
}
}
}
type triggerReaderGroup struct {
bytes int64
trigger func()
triggered int64
}
func (rg *triggerReaderGroup) new(rc io.ReadCloser) io.ReadCloser {
return &triggerReader{
ReadCloser: rc,
bytes: &rg.bytes,
trigger: rg.trigger,
triggered: &rg.triggered,
}
}
type triggerReader struct {
io.ReadCloser
bytes *int64
trigger func()
triggered *int64
}
func (r *triggerReader) Read(p []byte) (int, error) {
n, err := r.ReadCloser.Read(p)
if atomic.AddInt64(r.bytes, -int64(n)) < 0 && atomic.CompareAndSwapInt64(r.triggered, 0, 1) {
// Can't use sync.Once because it blocks for *all* callers until Do returns.
r.trigger()
}
return n, err
}