2
2
mirror of https://github.com/octoleo/restic.git synced 2024-11-02 11:46:36 +00:00

sftp: Integrate command

This commit is contained in:
Alexander Neumann 2017-04-03 21:05:42 +02:00
parent d3b6f75848
commit c26dd6b76f
5 changed files with 83 additions and 37 deletions

View File

@ -11,6 +11,7 @@ import (
// Config collects all information required to connect to an sftp server. // Config collects all information required to connect to an sftp server.
type Config struct { type Config struct {
User, Host, Dir string User, Host, Dir string
Command string `option:"command"`
} }
// ParseConfig parses the string s and extracts the sftp config. The // ParseConfig parses the string s and extracts the sftp config. The

View File

@ -37,6 +37,7 @@ type SFTP struct {
var _ restic.Backend = &SFTP{} var _ restic.Backend = &SFTP{}
func startClient(program string, args ...string) (*SFTP, error) { func startClient(program string, args ...string) (*SFTP, error) {
debug.Log("start client %v %v", program, args)
// Connect to a remote host and request the sftp subsystem via the 'ssh' // Connect to a remote host and request the sftp subsystem via the 'ssh'
// command. This assumes that passwordless login is correctly configured. // command. This assumes that passwordless login is correctly configured.
cmd := exec.Command(program, args...) cmd := exec.Command(program, args...)
@ -114,11 +115,11 @@ func (r *SFTP) clientError() error {
return nil return nil
} }
// Open opens an sftp backend. When the command is started via // open opens an sftp backend. When the command is started via
// exec.Command, it is expected to speak sftp on stdin/stdout. The backend // exec.Command, it is expected to speak sftp on stdin/stdout. The backend
// is expected at the given path. `dir` must be delimited by forward slashes // is expected at the given path. `dir` must be delimited by forward slashes
// ("/"), which is required by sftp. // ("/"), which is required by sftp.
func Open(dir string, program string, args ...string) (*SFTP, error) { func open(dir string, program string, args ...string) (*SFTP, error) {
debug.Log("open backend with program %v, %v at %v", program, args, dir) debug.Log("open backend with program %v, %v at %v", program, args, dir)
sftp, err := startClient(program, args...) sftp, err := startClient(program, args...)
if err != nil { if err != nil {
@ -155,15 +156,25 @@ func buildSSHCommand(cfg Config) []string {
// OpenWithConfig opens an sftp backend as described by the config by running // OpenWithConfig opens an sftp backend as described by the config by running
// "ssh" with the appropriate arguments. // "ssh" with the appropriate arguments.
func OpenWithConfig(cfg Config) (*SFTP, error) { func OpenWithConfig(cfg Config) (*SFTP, error) {
debug.Log("open with config %v", cfg) debug.Log("config %#v", cfg)
return Open(cfg.Dir, "ssh", buildSSHCommand(cfg)...)
if cfg.Command == "" {
return open(cfg.Dir, "ssh", buildSSHCommand(cfg)...)
}
cmd, args, err := SplitShellArgs(cfg.Command)
if err != nil {
return nil, err
}
return open(cfg.Dir, cmd, args...)
} }
// Create creates all the necessary files and directories for a new sftp // create creates all the necessary files and directories for a new sftp
// backend at dir. Afterwards a new config blob should be created. `dir` must // backend at dir. Afterwards a new config blob should be created. `dir` must
// be delimited by forward slashes ("/"), which is required by sftp. // be delimited by forward slashes ("/"), which is required by sftp.
func Create(dir string, program string, args ...string) (*SFTP, error) { func create(dir string, program string, args ...string) (*SFTP, error) {
debug.Log("%v %v", program, args) debug.Log("create() %v %v", program, args)
sftp, err := startClient(program, args...) sftp, err := startClient(program, args...)
if err != nil { if err != nil {
return nil, err return nil, err
@ -178,6 +189,7 @@ func Create(dir string, program string, args ...string) (*SFTP, error) {
// create paths for data, refs and temp blobs // create paths for data, refs and temp blobs
for _, d := range paths(dir) { for _, d := range paths(dir) {
err = sftp.mkdirAll(d, backend.Modes.Dir) err = sftp.mkdirAll(d, backend.Modes.Dir)
debug.Log("mkdirAll %v -> %v", d, err)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -189,14 +201,23 @@ func Create(dir string, program string, args ...string) (*SFTP, error) {
} }
// open backend // open backend
return Open(dir, program, args...) return open(dir, program, args...)
} }
// CreateWithConfig creates an sftp backend as described by the config by running // CreateWithConfig creates an sftp backend as described by the config by running
// "ssh" with the appropriate arguments. // "ssh" with the appropriate arguments.
func CreateWithConfig(cfg Config) (*SFTP, error) { func CreateWithConfig(cfg Config) (*SFTP, error) {
debug.Log("config %v", cfg) debug.Log("config %#v", cfg)
return Create(cfg.Dir, "ssh", buildSSHCommand(cfg)...) if cfg.Command == "" {
return create(cfg.Dir, "ssh", buildSSHCommand(cfg)...)
}
cmd, args, err := SplitShellArgs(cfg.Command)
if err != nil {
return nil, err
}
return create(cfg.Dir, cmd, args...)
} }
// Location returns this backend's location (the directory name). // Location returns this backend's location (the directory name).

View File

@ -1,6 +1,7 @@
package sftp_test package sftp_test
import ( import (
"fmt"
"io/ioutil" "io/ioutil"
"os" "os"
"path/filepath" "path/filepath"
@ -50,7 +51,9 @@ func init() {
return return
} }
args := []string{"-e"} cfg := sftp.Config{
Command: fmt.Sprintf("%q -e", sftpserver),
}
test.CreateFn = func() (restic.Backend, error) { test.CreateFn = func() (restic.Backend, error) {
err := createTempdir() err := createTempdir()
@ -58,7 +61,9 @@ func init() {
return nil, err return nil, err
} }
return sftp.Create(tempBackendDir, sftpserver, args...) cfg.Dir = tempBackendDir
return sftp.CreateWithConfig(cfg)
} }
test.OpenFn = func() (restic.Backend, error) { test.OpenFn = func() (restic.Backend, error) {
@ -66,7 +71,10 @@ func init() {
if err != nil { if err != nil {
return nil, err return nil, err
} }
return sftp.Open(tempBackendDir, sftpserver, args...)
cfg.Dir = tempBackendDir
return sftp.OpenWithConfig(cfg)
} }
test.CleanupFn = func() error { test.CleanupFn = func() error {

View File

@ -43,7 +43,7 @@ func (s *shellSplitter) isSplitChar(c rune) bool {
} }
// SplitShellArgs returns the list of arguments from a shell command string. // SplitShellArgs returns the list of arguments from a shell command string.
func SplitShellArgs(data string) (list []string, err error) { func SplitShellArgs(data string) (cmd string, args []string, err error) {
s := &shellSplitter{} s := &shellSplitter{}
// derived from strings.SplitFunc // derived from strings.SplitFunc
@ -51,7 +51,7 @@ func SplitShellArgs(data string) (list []string, err error) {
for i, rune := range data { for i, rune := range data {
if s.isSplitChar(rune) { if s.isSplitChar(rune) {
if fieldStart >= 0 { if fieldStart >= 0 {
list = append(list, data[fieldStart:i]) args = append(args, data[fieldStart:i])
fieldStart = -1 fieldStart = -1
} }
} else if fieldStart == -1 { } else if fieldStart == -1 {
@ -59,15 +59,21 @@ func SplitShellArgs(data string) (list []string, err error) {
} }
} }
if fieldStart >= 0 { // Last field might end at EOF. if fieldStart >= 0 { // Last field might end at EOF.
list = append(list, data[fieldStart:]) args = append(args, data[fieldStart:])
} }
switch s.quote { switch s.quote {
case '\'': case '\'':
return nil, errors.New("single-quoted string not terminated") return "", nil, errors.New("single-quoted string not terminated")
case '"': case '"':
return nil, errors.New("double-quoted string not terminated") return "", nil, errors.New("double-quoted string not terminated")
} }
return list, nil if len(args) == 0 {
return "", nil, errors.New("command string is empty")
}
cmd, args = args[0], args[1:]
return cmd, args, nil
} }

View File

@ -8,56 +8,62 @@ import (
func TestShellSplitter(t *testing.T) { func TestShellSplitter(t *testing.T) {
var tests = []struct { var tests = []struct {
data string data string
want []string cmd string
args []string
}{ }{
{ {
`foo`, `foo`,
[]string{"foo"}, "foo", []string{},
}, },
{ {
`'foo'`, `'foo'`,
[]string{"foo"}, "foo", []string{},
}, },
{ {
`foo bar baz`, `foo bar baz`,
[]string{"foo", "bar", "baz"}, "foo", []string{"bar", "baz"},
}, },
{ {
`foo 'bar' baz`, `foo 'bar' baz`,
[]string{"foo", "bar", "baz"}, "foo", []string{"bar", "baz"},
}, },
{ {
`foo 'bar box' baz`, `'bar box' baz`,
[]string{"foo", "bar box", "baz"}, "bar box", []string{"baz"},
}, },
{ {
`"bar 'box'" baz`, `"bar 'box'" baz`,
[]string{"bar 'box'", "baz"}, "bar 'box'", []string{"baz"},
}, },
{ {
`'bar "box"' baz`, `'bar "box"' baz`,
[]string{`bar "box"`, "baz"}, `bar "box"`, []string{"baz"},
}, },
{ {
`\"bar box baz`, `\"bar box baz`,
[]string{`"bar`, "box", "baz"}, `"bar`, []string{"box", "baz"},
}, },
{ {
`"bar/foo/x" "box baz"`, `"bar/foo/x" "box baz"`,
[]string{"bar/foo/x", "box baz"}, "bar/foo/x", []string{"box baz"},
}, },
} }
for _, test := range tests { for _, test := range tests {
t.Run("", func(t *testing.T) { t.Run("", func(t *testing.T) {
res, err := SplitShellArgs(test.data) cmd, args, err := SplitShellArgs(test.data)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if !reflect.DeepEqual(res, test.want) { if cmd != test.cmd {
t.Fatalf("wrong data returned, want:\n %#v\ngot:\n %#v", t.Fatalf("wrong cmd returned, want:\n %#v\ngot:\n %#v",
test.want, res) test.cmd, cmd)
}
if !reflect.DeepEqual(args, test.args) {
t.Fatalf("wrong args returned, want:\n %#v\ngot:\n %#v",
test.args, args)
} }
}) })
} }
@ -88,7 +94,7 @@ func TestShellSplitterInvalid(t *testing.T) {
for _, test := range tests { for _, test := range tests {
t.Run("", func(t *testing.T) { t.Run("", func(t *testing.T) {
res, err := SplitShellArgs(test.data) cmd, args, err := SplitShellArgs(test.data)
if err == nil { if err == nil {
t.Fatalf("expected error not found: %v", test.err) t.Fatalf("expected error not found: %v", test.err)
} }
@ -97,8 +103,12 @@ func TestShellSplitterInvalid(t *testing.T) {
t.Fatalf("expected error not found, want:\n %q\ngot:\n %q", test.err, err.Error()) t.Fatalf("expected error not found, want:\n %q\ngot:\n %q", test.err, err.Error())
} }
if len(res) > 0 { if cmd != "" {
t.Fatalf("splitter returned fields from invalid data: %v", res) t.Fatalf("splitter returned cmd from invalid data: %v", cmd)
}
if len(args) > 0 {
t.Fatalf("splitter returned fields from invalid data: %v", args)
} }
}) })
} }