diff --git a/internal/backend/rclone/backend.go b/internal/backend/rclone/backend.go index 40f434b3e..67bd1eab7 100644 --- a/internal/backend/rclone/backend.go +++ b/internal/backend/rclone/backend.go @@ -10,6 +10,7 @@ import ( "net/url" "os" "os/exec" + "sync" "time" "github.com/restic/restic/internal/backend" @@ -27,18 +28,25 @@ type Backend struct { cmd *exec.Cmd waitCh <-chan struct{} waitResult error + wg *sync.WaitGroup + conn *StdioConn } // run starts command with args and initializes the StdioConn. -func run(command string, args ...string) (*StdioConn, *exec.Cmd, func() error, error) { +func run(command string, args ...string) (*StdioConn, *exec.Cmd, *sync.WaitGroup, func() error, error) { cmd := exec.Command(command, args...) + p, err := cmd.StderrPipe() if err != nil { - return nil, nil, nil, err + return nil, nil, nil, nil, err } + var wg sync.WaitGroup + // start goroutine to add a prefix to all messages printed by to stderr by rclone + wg.Add(1) go func() { + defer wg.Done() sc := bufio.NewScanner(p) for sc.Scan() { fmt.Fprintf(os.Stderr, "rclone: %v\n", sc.Text()) @@ -47,12 +55,12 @@ func run(command string, args ...string) (*StdioConn, *exec.Cmd, func() error, e r, stdin, err := os.Pipe() if err != nil { - return nil, nil, nil, err + return nil, nil, nil, nil, err } stdout, w, err := os.Pipe() if err != nil { - return nil, nil, nil, err + return nil, nil, nil, nil, err } cmd.Stdin = r @@ -60,7 +68,7 @@ func run(command string, args ...string) (*StdioConn, *exec.Cmd, func() error, e bg, err := backend.StartForeground(cmd) if err != nil { - return nil, nil, nil, err + return nil, nil, nil, nil, err } c := &StdioConn{ @@ -69,7 +77,7 @@ func run(command string, args ...string) (*StdioConn, *exec.Cmd, func() error, e cmd: cmd, } - return c, cmd, bg, nil + return c, cmd, &wg, bg, nil } // New initializes a Backend and starts the process. @@ -107,15 +115,20 @@ func New(cfg Config) (*Backend, error) { arg0, args := args[0], args[1:] debug.Log("running command: %v %v", arg0, args) - conn, cmd, bg, err := run(arg0, args...) + conn, cmd, wg, bg, err := run(arg0, args...) if err != nil { return nil, err } + dialCount := 0 tr := &http2.Transport{ AllowHTTP: true, // this is not really HTTP, just stdin/stdout DialTLS: func(network, address string, cfg *tls.Config) (net.Conn, error) { debug.Log("new connection requested, %v %v", network, address) + if dialCount > 0 { + panic("dial count > 0") + } + dialCount++ return conn, nil }, } @@ -125,9 +138,13 @@ func New(cfg Config) (*Backend, error) { tr: tr, cmd: cmd, waitCh: waitCh, + conn: conn, + wg: wg, } + wg.Add(1) go func() { + defer wg.Done() debug.Log("waiting for error result") err := cmd.Wait() debug.Log("Wait returned %v", err) @@ -138,7 +155,9 @@ func New(cfg Config) (*Backend, error) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() + wg.Add(1) go func() { + defer wg.Done() debug.Log("monitoring command to cancel first HTTP request context") select { case <-ctx.Done(): @@ -160,6 +179,7 @@ func New(cfg Config) (*Backend, error) { return nil, err } req.Header.Set("Accept", rest.ContentTypeV2) + req.Cancel = ctx.Done() res, err := ctxhttp.Do(ctx, client, req) if err != nil { @@ -229,11 +249,25 @@ func Create(cfg Config) (*Backend, error) { return be, nil } +const waitForExit = 5 * time.Second + // Close terminates the backend. func (be *Backend) Close() error { debug.Log("exiting rclone") be.tr.CloseIdleConnections() - <-be.waitCh + + select { + case <-be.waitCh: + debug.Log("rclone exited") + case <-time.After(waitForExit): + debug.Log("timeout, closing file descriptors") + err := be.conn.Close() + if err != nil { + return err + } + } + + be.wg.Wait() debug.Log("wait for rclone returned: %v", be.waitResult) return be.waitResult } diff --git a/internal/backend/rclone/stdio_conn.go b/internal/backend/rclone/stdio_conn.go index 4472300ec..bb4928176 100644 --- a/internal/backend/rclone/stdio_conn.go +++ b/internal/backend/rclone/stdio_conn.go @@ -4,6 +4,7 @@ import ( "net" "os" "os/exec" + "sync" "github.com/restic/restic/internal/debug" ) @@ -14,6 +15,7 @@ type StdioConn struct { stdout *os.File bytesWritten, bytesRead int cmd *exec.Cmd + close sync.Once } func (s *StdioConn) Read(p []byte) (int, error) { @@ -29,21 +31,24 @@ func (s *StdioConn) Write(p []byte) (int, error) { } // Close closes both streams. -func (s *StdioConn) Close() error { - debug.Log("close server instance") - var errs []error +func (s *StdioConn) Close() (err error) { + s.close.Do(func() { + debug.Log("close stdio connection") + var errs []error - for _, f := range []func() error{s.stdin.Close, s.stdout.Close} { - err := f() - if err != nil { - errs = append(errs, err) + for _, f := range []func() error{s.stdin.Close, s.stdout.Close} { + err := f() + if err != nil { + errs = append(errs, err) + } } - } - if len(errs) > 0 { - return errs[0] - } - return nil + if len(errs) > 0 { + err = errs[0] + } + }) + + return err } // LocalAddr returns nil.