diff --git a/internal/backend/b2/b2.go b/internal/backend/b2/b2.go index b115662e9..7e570c6eb 100644 --- a/internal/backend/b2/b2.go +++ b/internal/backend/b2/b2.go @@ -137,31 +137,6 @@ func (be *b2Backend) Location() string { return be.cfg.Bucket } -// wrapReader wraps an io.ReadCloser to run an additional function on Close. -type wrapReader struct { - io.ReadCloser - eofSeen bool - f func() -} - -func (wr *wrapReader) Read(p []byte) (int, error) { - if wr.eofSeen { - return 0, io.EOF - } - - n, err := wr.ReadCloser.Read(p) - if err == io.EOF { - wr.eofSeen = true - } - return n, err -} - -func (wr *wrapReader) Close() error { - err := wr.ReadCloser.Close() - wr.f() - return err -} - // IsNotExist returns true if the error is caused by a non-existing file. func (be *b2Backend) IsNotExist(err error) bool { return b2.IsNotExist(errors.Cause(err)) @@ -192,14 +167,7 @@ func (be *b2Backend) Load(ctx context.Context, h restic.Handle, length int, offs if offset == 0 && length == 0 { rd := obj.NewReader(ctx) - wrapper := &wrapReader{ - ReadCloser: rd, - f: func() { - cancel() - be.sem.ReleaseToken() - }, - } - return wrapper, nil + return be.sem.ReleaseTokenOnClose(rd, cancel), nil } // pass a negative length to NewRangeReader so that the remainder of the @@ -209,14 +177,7 @@ func (be *b2Backend) Load(ctx context.Context, h restic.Handle, length int, offs } rd := obj.NewRangeReader(ctx, offset, int64(length)) - wrapper := &wrapReader{ - ReadCloser: rd, - f: func() { - cancel() - be.sem.ReleaseToken() - }, - } - return wrapper, nil + return be.sem.ReleaseTokenOnClose(rd, cancel), nil } // Save stores data in the backend at the handle. diff --git a/internal/backend/semaphore.go b/internal/backend/semaphore.go index e83191c46..2146db2f3 100644 --- a/internal/backend/semaphore.go +++ b/internal/backend/semaphore.go @@ -1,6 +1,10 @@ package backend -import "github.com/restic/restic/internal/errors" +import ( + "context" + "github.com/restic/restic/internal/errors" + "io" +) // Semaphore limits access to a restricted resource. type Semaphore struct { @@ -26,3 +30,39 @@ func (s *Semaphore) GetToken() { func (s *Semaphore) ReleaseToken() { <-s.ch } + +// ReleaseTokenOnClose wraps an io.ReadCloser to return a token on Close. Before returning the token, +// cancel, if provided, will be run to free up context resources. +func (s *Semaphore) ReleaseTokenOnClose(rc io.ReadCloser, cancel context.CancelFunc) io.ReadCloser { + return &wrapReader{rc, false, func() { + if cancel != nil { + cancel() + } + s.ReleaseToken() + }} +} + +// wrapReader wraps an io.ReadCloser to run an additional function on Close. +type wrapReader struct { + io.ReadCloser + eofSeen bool + f func() +} + +func (wr *wrapReader) Read(p []byte) (int, error) { + if wr.eofSeen { + return 0, io.EOF + } + + n, err := wr.ReadCloser.Read(p) + if err == io.EOF { + wr.eofSeen = true + } + return n, err +} + +func (wr *wrapReader) Close() error { + err := wr.ReadCloser.Close() + wr.f() + return err +} diff --git a/internal/backend/swift/swift.go b/internal/backend/swift/swift.go index 8c69caffa..3497fec1d 100644 --- a/internal/backend/swift/swift.go +++ b/internal/backend/swift/swift.go @@ -129,11 +129,6 @@ func (be *beSwift) Load(ctx context.Context, h restic.Handle, length int, offset objName := be.Filename(h) - be.sem.GetToken() - defer func() { - be.sem.ReleaseToken() - }() - headers := swift.Headers{} if offset > 0 { headers["Range"] = fmt.Sprintf("bytes=%d-", offset) @@ -147,13 +142,15 @@ func (be *beSwift) Load(ctx context.Context, h restic.Handle, length int, offset debug.Log("Load(%v) send range %v", h, headers["Range"]) } + be.sem.GetToken() obj, _, err := be.conn.ObjectOpen(be.container, objName, false, headers) if err != nil { debug.Log(" err %v", err) + be.sem.ReleaseToken() return nil, errors.Wrap(err, "conn.ObjectOpen") } - return obj, nil + return be.sem.ReleaseTokenOnClose(obj, nil), nil } // Save stores data in the backend at the handle. @@ -243,6 +240,9 @@ func (be *beSwift) List(ctx context.Context, t restic.FileType) <-chan string { go func() { defer close(ch) + // NB: unfortunately we can't protect this with be.sem.GetToken() here. + // Doing so would enable a deadlock situation (PR: gh-1399), as ObjectsWalk() + // starts its own goroutine and returns results via a channel. err := be.conn.ObjectsWalk(be.container, &swift.ObjectsOpts{Prefix: prefix}, func(opts *swift.ObjectsOpts) (interface{}, error) { newObjects, err := be.conn.ObjectNames(be.container, opts)