2
2
mirror of https://github.com/octoleo/restic.git synced 2024-11-23 05:12:10 +00:00

Merge pull request #1564 from restic/sftp-password-prompt

sftp: Prompt for password, don't terminate on SIGINT
This commit is contained in:
Alexander Neumann 2018-01-20 09:48:17 +01:00
commit 200415e0a1
8 changed files with 113 additions and 32 deletions

View File

@ -22,19 +22,9 @@ var stderr = os.Stderr
func init() { func init() {
cleanupHandlers.ch = make(chan os.Signal) cleanupHandlers.ch = make(chan os.Signal)
go CleanupHandler(cleanupHandlers.ch) go CleanupHandler(cleanupHandlers.ch)
InstallSignalHandler()
}
// InstallSignalHandler listens for SIGINT, and triggers the cleanup handlers.
func InstallSignalHandler() {
signal.Notify(cleanupHandlers.ch, syscall.SIGINT) signal.Notify(cleanupHandlers.ch, syscall.SIGINT)
} }
// SuspendSignalHandler removes the signal handler for SIGINT.
func SuspendSignalHandler() {
signal.Reset(syscall.SIGINT)
}
// AddCleanupHandler adds the function f to the list of cleanup handlers so // AddCleanupHandler adds the function f to the list of cleanup handlers so
// that it is executed when all the cleanup handlers are run, e.g. when SIGINT // that it is executed when all the cleanup handlers are run, e.g. when SIGINT
// is received. // is received.

View File

@ -555,7 +555,7 @@ func open(s string, gopts GlobalOptions, opts options.Options) (restic.Backend,
// wrap the backend in a LimitBackend so that the throughput is limited // wrap the backend in a LimitBackend so that the throughput is limited
be = limiter.LimitBackend(be, limiter.NewStaticLimiter(gopts.LimitUploadKb, gopts.LimitDownloadKb)) be = limiter.LimitBackend(be, limiter.NewStaticLimiter(gopts.LimitUploadKb, gopts.LimitDownloadKb))
case "sftp": case "sftp":
be, err = sftp.Open(cfg.(sftp.Config), SuspendSignalHandler, InstallSignalHandler) be, err = sftp.Open(cfg.(sftp.Config))
// wrap the backend in a LimitBackend so that the throughput is limited // wrap the backend in a LimitBackend so that the throughput is limited
be = limiter.LimitBackend(be, limiter.NewStaticLimiter(gopts.LimitUploadKb, gopts.LimitDownloadKb)) be = limiter.LimitBackend(be, limiter.NewStaticLimiter(gopts.LimitUploadKb, gopts.LimitDownloadKb))
case "s3": case "s3":
@ -614,7 +614,7 @@ func create(s string, opts options.Options) (restic.Backend, error) {
case "local": case "local":
return local.Create(cfg.(local.Config)) return local.Create(cfg.(local.Config))
case "sftp": case "sftp":
return sftp.Create(cfg.(sftp.Config), SuspendSignalHandler, InstallSignalHandler) return sftp.Create(cfg.(sftp.Config))
case "s3": case "s3":
return s3.Create(cfg.(s3.Config), rt) return s3.Create(cfg.(s3.Config), rt)
case "gs": case "gs":

View File

@ -0,0 +1,73 @@
// +build !windows
package sftp
import (
"os"
"os/exec"
"os/signal"
"syscall"
"unsafe"
"github.com/restic/restic/internal/errors"
)
func tcsetpgrp(fd int, pid int) error {
_, _, errno := syscall.RawSyscall(syscall.SYS_IOCTL, uintptr(fd),
uintptr(syscall.TIOCSPGRP), uintptr(unsafe.Pointer(&pid)))
if errno == 0 {
return nil
}
return errno
}
// startForeground runs cmd in the foreground, by temporarily switching to the
// new process group created for cmd. The returned function `bg` switches back
// to the previous process group.
func startForeground(cmd *exec.Cmd) (bg func() error, err error) {
// open the TTY, we need the file descriptor
tty, err := os.OpenFile("/dev/tty", os.O_RDWR, 0)
if err != nil {
return nil, errors.Wrap(err, "open TTY")
}
signal.Ignore(syscall.SIGTTIN)
signal.Ignore(syscall.SIGTTOU)
// run the command in its own process group
cmd.SysProcAttr = &syscall.SysProcAttr{
Setpgid: true,
}
// start the process
err = cmd.Start()
if err != nil {
_ = tty.Close()
return nil, errors.Wrap(err, "cmd.Start")
}
// move the command's process group into the foreground
prev := syscall.Getpgrp()
err = tcsetpgrp(int(tty.Fd()), cmd.Process.Pid)
if err != nil {
_ = tty.Close()
return nil, err
}
bg = func() error {
signal.Reset(syscall.SIGTTIN)
signal.Reset(syscall.SIGTTOU)
// reset the foreground process group
err = tcsetpgrp(int(tty.Fd()), prev)
if err != nil {
_ = tty.Close()
return err
}
return tty.Close()
}
return bg, nil
}

View File

@ -0,0 +1,21 @@
package sftp
import (
"os/exec"
"github.com/restic/restic/internal/errors"
)
// startForeground runs cmd in the foreground, by temporarily switching to the
// new process group created for cmd. The returned function `bg` switches back
// to the previous process group.
func startForeground(cmd *exec.Cmd) (bg func() error, err error) {
// just start the process and hope for the best
err = cmd.Start()
if err != nil {
return nil, errors.Wrap(err, "cmd.Start")
}
bg = func() error { return nil }
return bg, nil
}

View File

@ -46,7 +46,7 @@ func TestLayout(t *testing.T) {
Command: fmt.Sprintf("%q -e", sftpServer), Command: fmt.Sprintf("%q -e", sftpServer),
Path: repo, Path: repo,
Layout: test.layout, Layout: test.layout,
}, nil, nil) })
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View File

@ -36,7 +36,7 @@ var _ restic.Backend = &SFTP{}
const defaultLayout = "default" const defaultLayout = "default"
func startClient(preExec, postExec func(), program string, args ...string) (*SFTP, error) { func startClient(program string, args ...string) (*SFTP, error) {
debug.Log("start client %v %v", program, args) debug.Log("start client %v %v", program, args)
// Connect to a remote host and request the sftp subsystem via the 'ssh' // Connect to a remote host and request the sftp subsystem via the 'ssh'
// command. This assumes that passwordless login is correctly configured. // command. This assumes that passwordless login is correctly configured.
@ -65,19 +65,11 @@ func startClient(preExec, postExec func(), program string, args ...string) (*SFT
return nil, errors.Wrap(err, "cmd.StdoutPipe") return nil, errors.Wrap(err, "cmd.StdoutPipe")
} }
if preExec != nil { bg, err := startForeground(cmd)
preExec() if err != nil {
}
// start the process
if err := cmd.Start(); err != nil {
return nil, errors.Wrap(err, "cmd.Start") return nil, errors.Wrap(err, "cmd.Start")
} }
if postExec != nil {
postExec()
}
// wait in a different goroutine // wait in a different goroutine
ch := make(chan error, 1) ch := make(chan error, 1)
go func() { go func() {
@ -92,6 +84,11 @@ func startClient(preExec, postExec func(), program string, args ...string) (*SFT
return nil, errors.Errorf("unable to start the sftp session, error: %v", err) return nil, errors.Errorf("unable to start the sftp session, error: %v", err)
} }
err = bg()
if err != nil {
return nil, errors.Wrap(err, "bg")
}
return &SFTP{c: client, cmd: cmd, result: ch}, nil return &SFTP{c: client, cmd: cmd, result: ch}, nil
} }
@ -111,7 +108,7 @@ func (r *SFTP) clientError() error {
// Open opens an sftp backend as described by the config by running // Open opens an sftp backend as described by the config by running
// "ssh" with the appropriate arguments (or cfg.Command, if set). The function // "ssh" with the appropriate arguments (or cfg.Command, if set). The function
// preExec is run just before, postExec just after starting a program. // preExec is run just before, postExec just after starting a program.
func Open(cfg Config, preExec, postExec func()) (*SFTP, error) { func Open(cfg Config) (*SFTP, error) {
debug.Log("open backend with config %#v", cfg) debug.Log("open backend with config %#v", cfg)
cmd, args, err := buildSSHCommand(cfg) cmd, args, err := buildSSHCommand(cfg)
@ -119,7 +116,7 @@ func Open(cfg Config, preExec, postExec func()) (*SFTP, error) {
return nil, err return nil, err
} }
sftp, err := startClient(preExec, postExec, cmd, args...) sftp, err := startClient(cmd, args...)
if err != nil { if err != nil {
debug.Log("unable to start program: %v", err) debug.Log("unable to start program: %v", err)
return nil, err return nil, err
@ -204,13 +201,13 @@ func buildSSHCommand(cfg Config) (cmd string, args []string, err error) {
// Create creates an sftp backend as described by the config by running "ssh" // Create creates an sftp backend as described by the config by running "ssh"
// with the appropriate arguments (or cfg.Command, if set). The function // with the appropriate arguments (or cfg.Command, if set). The function
// preExec is run just before, postExec just after starting a program. // preExec is run just before, postExec just after starting a program.
func Create(cfg Config, preExec, postExec func()) (*SFTP, error) { func Create(cfg Config) (*SFTP, error) {
cmd, args, err := buildSSHCommand(cfg) cmd, args, err := buildSSHCommand(cfg)
if err != nil { if err != nil {
return nil, err return nil, err
} }
sftp, err := startClient(preExec, postExec, cmd, args...) sftp, err := startClient(cmd, args...)
if err != nil { if err != nil {
debug.Log("unable to start program: %v", err) debug.Log("unable to start program: %v", err)
return nil, err return nil, err
@ -238,7 +235,7 @@ func Create(cfg Config, preExec, postExec func()) (*SFTP, error) {
} }
// open backend // open backend
return Open(cfg, preExec, postExec) return Open(cfg)
} }
// Location returns this backend's location (the directory name). // Location returns this backend's location (the directory name).

View File

@ -50,13 +50,13 @@ func newTestSuite(t testing.TB) *test.Suite {
// CreateFn is a function that creates a temporary repository for the tests. // CreateFn is a function that creates a temporary repository for the tests.
Create: func(config interface{}) (restic.Backend, error) { Create: func(config interface{}) (restic.Backend, error) {
cfg := config.(sftp.Config) cfg := config.(sftp.Config)
return sftp.Create(cfg, nil, nil) return sftp.Create(cfg)
}, },
// OpenFn is a function that opens a previously created temporary repository. // OpenFn is a function that opens a previously created temporary repository.
Open: func(config interface{}) (restic.Backend, error) { Open: func(config interface{}) (restic.Backend, error) {
cfg := config.(sftp.Config) cfg := config.(sftp.Config)
return sftp.Open(cfg, nil, nil) return sftp.Open(cfg)
}, },
// CleanupFn removes data created during the tests. // CleanupFn removes data created during the tests.

View File

@ -101,7 +101,7 @@ func TestNodeFromFileInfo(t *testing.T) {
t.Run("", func(t *testing.T) { t.Run("", func(t *testing.T) {
fi, found := stat(t, test.filename) fi, found := stat(t, test.filename)
if !found && test.canSkip { if !found && test.canSkip {
t.Skipf("%v not found in filesystem") t.Skipf("%v not found in filesystem", test.filename)
return return
} }