2
2
mirror of https://github.com/octoleo/restic.git synced 2024-06-17 00:02:49 +00:00

backend/rclone: Make sure rclone terminates

This commit is contained in:
Alexander Neumann 2018-03-15 21:22:14 +01:00
parent fc0295016a
commit 4d5c7a8749
2 changed files with 59 additions and 20 deletions

View File

@ -10,6 +10,7 @@ import (
"net/url" "net/url"
"os" "os"
"os/exec" "os/exec"
"sync"
"time" "time"
"github.com/restic/restic/internal/backend" "github.com/restic/restic/internal/backend"
@ -27,18 +28,25 @@ type Backend struct {
cmd *exec.Cmd cmd *exec.Cmd
waitCh <-chan struct{} waitCh <-chan struct{}
waitResult error waitResult error
wg *sync.WaitGroup
conn *StdioConn
} }
// run starts command with args and initializes the StdioConn. // run starts command with args and initializes the StdioConn.
func run(command string, args ...string) (*StdioConn, *exec.Cmd, func() error, error) { func run(command string, args ...string) (*StdioConn, *exec.Cmd, *sync.WaitGroup, func() error, error) {
cmd := exec.Command(command, args...) cmd := exec.Command(command, args...)
p, err := cmd.StderrPipe() p, err := cmd.StderrPipe()
if err != nil { if err != nil {
return nil, nil, nil, err return nil, nil, nil, nil, err
} }
var wg sync.WaitGroup
// start goroutine to add a prefix to all messages printed by to stderr by rclone // start goroutine to add a prefix to all messages printed by to stderr by rclone
wg.Add(1)
go func() { go func() {
defer wg.Done()
sc := bufio.NewScanner(p) sc := bufio.NewScanner(p)
for sc.Scan() { for sc.Scan() {
fmt.Fprintf(os.Stderr, "rclone: %v\n", sc.Text()) fmt.Fprintf(os.Stderr, "rclone: %v\n", sc.Text())
@ -47,12 +55,12 @@ func run(command string, args ...string) (*StdioConn, *exec.Cmd, func() error, e
r, stdin, err := os.Pipe() r, stdin, err := os.Pipe()
if err != nil { if err != nil {
return nil, nil, nil, err return nil, nil, nil, nil, err
} }
stdout, w, err := os.Pipe() stdout, w, err := os.Pipe()
if err != nil { if err != nil {
return nil, nil, nil, err return nil, nil, nil, nil, err
} }
cmd.Stdin = r cmd.Stdin = r
@ -60,7 +68,7 @@ func run(command string, args ...string) (*StdioConn, *exec.Cmd, func() error, e
bg, err := backend.StartForeground(cmd) bg, err := backend.StartForeground(cmd)
if err != nil { if err != nil {
return nil, nil, nil, err return nil, nil, nil, nil, err
} }
c := &StdioConn{ c := &StdioConn{
@ -69,7 +77,7 @@ func run(command string, args ...string) (*StdioConn, *exec.Cmd, func() error, e
cmd: cmd, cmd: cmd,
} }
return c, cmd, bg, nil return c, cmd, &wg, bg, nil
} }
// New initializes a Backend and starts the process. // New initializes a Backend and starts the process.
@ -107,15 +115,20 @@ func New(cfg Config) (*Backend, error) {
arg0, args := args[0], args[1:] arg0, args := args[0], args[1:]
debug.Log("running command: %v %v", arg0, args) debug.Log("running command: %v %v", arg0, args)
conn, cmd, bg, err := run(arg0, args...) conn, cmd, wg, bg, err := run(arg0, args...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
dialCount := 0
tr := &http2.Transport{ tr := &http2.Transport{
AllowHTTP: true, // this is not really HTTP, just stdin/stdout AllowHTTP: true, // this is not really HTTP, just stdin/stdout
DialTLS: func(network, address string, cfg *tls.Config) (net.Conn, error) { DialTLS: func(network, address string, cfg *tls.Config) (net.Conn, error) {
debug.Log("new connection requested, %v %v", network, address) debug.Log("new connection requested, %v %v", network, address)
if dialCount > 0 {
panic("dial count > 0")
}
dialCount++
return conn, nil return conn, nil
}, },
} }
@ -125,9 +138,13 @@ func New(cfg Config) (*Backend, error) {
tr: tr, tr: tr,
cmd: cmd, cmd: cmd,
waitCh: waitCh, waitCh: waitCh,
conn: conn,
wg: wg,
} }
wg.Add(1)
go func() { go func() {
defer wg.Done()
debug.Log("waiting for error result") debug.Log("waiting for error result")
err := cmd.Wait() err := cmd.Wait()
debug.Log("Wait returned %v", err) debug.Log("Wait returned %v", err)
@ -138,7 +155,9 @@ func New(cfg Config) (*Backend, error) {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
wg.Add(1)
go func() { go func() {
defer wg.Done()
debug.Log("monitoring command to cancel first HTTP request context") debug.Log("monitoring command to cancel first HTTP request context")
select { select {
case <-ctx.Done(): case <-ctx.Done():
@ -160,6 +179,7 @@ func New(cfg Config) (*Backend, error) {
return nil, err return nil, err
} }
req.Header.Set("Accept", rest.ContentTypeV2) req.Header.Set("Accept", rest.ContentTypeV2)
req.Cancel = ctx.Done()
res, err := ctxhttp.Do(ctx, client, req) res, err := ctxhttp.Do(ctx, client, req)
if err != nil { if err != nil {
@ -229,11 +249,25 @@ func Create(cfg Config) (*Backend, error) {
return be, nil return be, nil
} }
const waitForExit = 5 * time.Second
// Close terminates the backend. // Close terminates the backend.
func (be *Backend) Close() error { func (be *Backend) Close() error {
debug.Log("exiting rclone") debug.Log("exiting rclone")
be.tr.CloseIdleConnections() be.tr.CloseIdleConnections()
<-be.waitCh
select {
case <-be.waitCh:
debug.Log("rclone exited")
case <-time.After(waitForExit):
debug.Log("timeout, closing file descriptors")
err := be.conn.Close()
if err != nil {
return err
}
}
be.wg.Wait()
debug.Log("wait for rclone returned: %v", be.waitResult) debug.Log("wait for rclone returned: %v", be.waitResult)
return be.waitResult return be.waitResult
} }

View File

@ -4,6 +4,7 @@ import (
"net" "net"
"os" "os"
"os/exec" "os/exec"
"sync"
"github.com/restic/restic/internal/debug" "github.com/restic/restic/internal/debug"
) )
@ -14,6 +15,7 @@ type StdioConn struct {
stdout *os.File stdout *os.File
bytesWritten, bytesRead int bytesWritten, bytesRead int
cmd *exec.Cmd cmd *exec.Cmd
close sync.Once
} }
func (s *StdioConn) Read(p []byte) (int, error) { func (s *StdioConn) Read(p []byte) (int, error) {
@ -29,21 +31,24 @@ func (s *StdioConn) Write(p []byte) (int, error) {
} }
// Close closes both streams. // Close closes both streams.
func (s *StdioConn) Close() error { func (s *StdioConn) Close() (err error) {
debug.Log("close server instance") s.close.Do(func() {
var errs []error debug.Log("close stdio connection")
var errs []error
for _, f := range []func() error{s.stdin.Close, s.stdout.Close} { for _, f := range []func() error{s.stdin.Close, s.stdout.Close} {
err := f() err := f()
if err != nil { if err != nil {
errs = append(errs, err) errs = append(errs, err)
}
} }
}
if len(errs) > 0 { if len(errs) > 0 {
return errs[0] err = errs[0]
} }
return nil })
return err
} }
// LocalAddr returns nil. // LocalAddr returns nil.