2
2
mirror of https://github.com/octoleo/restic.git synced 2025-01-22 14:48:24 +00:00

Update sftp library

This commit is contained in:
Alexander Neumann 2016-08-28 12:15:37 +02:00
parent ed09887d9e
commit a7e64afc0d
19 changed files with 822 additions and 480 deletions

4
vendor/manifest vendored
View File

@ -40,8 +40,8 @@
{
"importpath": "github.com/pkg/sftp",
"repository": "https://github.com/pkg/sftp",
"revision": "e84cc8c755ca39b7b64f510fe1fffc1b51f210a5",
"branch": "HEAD"
"revision": "a71e8f580e3b622ebff585309160b1cc549ef4d2",
"branch": "master"
},
{
"importpath": "github.com/restic/chunker",

View File

@ -17,9 +17,7 @@ The Walker interface for directory traversal is heavily inspired by Keith Rarick
roadmap
-------
* Currently all traffic with the server is serialized, this can be improved by allowing overlapping requests/responses.
* There is way too much duplication in the Client methods. If there was an unmarshal(interface{}) method this would reduce a heap of the duplication.
* Implement integration tests by talking directly to a real opensftp-server process. This shouldn't be too difficult to implement with a small refactoring to the sftp.NewClient method. These tests should be gated on an -sftp.integration test flag. _in progress_
contributing
------------

View File

@ -2,19 +2,15 @@ package sftp
import (
"bytes"
"encoding"
"encoding/binary"
"errors"
"fmt"
"io"
"os"
"path"
"sync"
"sync/atomic"
"time"
"github.com/kr/fs"
"github.com/pkg/errors"
"golang.org/x/crypto/ssh"
)
@ -22,7 +18,7 @@ import (
func MaxPacket(size int) func(*Client) error {
return func(c *Client) error {
if size < 1<<15 {
return fmt.Errorf("size must be greater or equal to 32k")
return errors.Errorf("size must be greater or equal to 32k")
}
c.maxPacket = size
return nil
@ -56,11 +52,14 @@ func NewClient(conn *ssh.Client, opts ...func(*Client) error) (*Client, error) {
// the system's ssh client program (e.g. via exec.Command).
func NewClientPipe(rd io.Reader, wr io.WriteCloser, opts ...func(*Client) error) (*Client, error) {
sftp := &Client{
w: wr,
r: rd,
maxPacket: 1 << 15,
inflight: make(map[uint32]chan<- result),
recvClosed: make(chan struct{}),
clientConn: clientConn{
conn: conn{
Reader: rd,
WriteCloser: wr,
},
inflight: make(map[uint32]chan<- result),
},
maxPacket: 1 << 15,
}
if err := sftp.applyOptions(opts...); err != nil {
wr.Close()
@ -74,7 +73,8 @@ func NewClientPipe(rd io.Reader, wr io.WriteCloser, opts ...func(*Client) error)
wr.Close()
return nil, err
}
go sftp.recv()
sftp.clientConn.wg.Add(1)
go sftp.loop()
return sftp, nil
}
@ -84,22 +84,10 @@ func NewClientPipe(rd io.Reader, wr io.WriteCloser, opts ...func(*Client) error)
//
// Client implements the github.com/kr/fs.FileSystem interface.
type Client struct {
w io.WriteCloser
r io.Reader
clientConn
maxPacket int // max packet size read or written.
nextid uint32
mu sync.Mutex // ensures only on request is in flight to the server at once
inflight map[uint32]chan<- result // outstanding requests
recvClosed chan struct{} // remote end has closed the connection
}
// Close closes the SFTP session.
func (c *Client) Close() error {
err := c.w.Close()
<-c.recvClosed
return err
}
// Create creates the named file mode 0666 (before umask), truncating it if
@ -112,7 +100,7 @@ func (c *Client) Create(path string) (*File, error) {
const sftpProtocolVersion = 3 // http://tools.ietf.org/html/draft-ietf-secsh-filexfer-02
func (c *Client) sendInit() error {
return sendPacket(c.w, sshFxInitPacket{
return c.clientConn.conn.sendPacket(sshFxInitPacket{
Version: sftpProtocolVersion, // http://tools.ietf.org/html/draft-ietf-secsh-filexfer-02
})
}
@ -123,7 +111,7 @@ func (c *Client) nextID() uint32 {
}
func (c *Client) recvVersion() error {
typ, data, err := recvPacket(c.r)
typ, data, err := c.recvPacket()
if err != nil {
return err
}
@ -139,46 +127,6 @@ func (c *Client) recvVersion() error {
return nil
}
// broadcastErr sends an error to all goroutines waiting for a response.
func (c *Client) broadcastErr(err error) {
c.mu.Lock()
listeners := make([]chan<- result, 0, len(c.inflight))
for _, ch := range c.inflight {
listeners = append(listeners, ch)
}
c.mu.Unlock()
for _, ch := range listeners {
ch <- result{err: err}
}
}
// recv continuously reads from the server and forwards responses to the
// appropriate channel.
func (c *Client) recv() {
defer close(c.recvClosed)
for {
typ, data, err := recvPacket(c.r)
if err != nil {
// Return the error to all listeners.
c.broadcastErr(err)
return
}
sid, _ := unmarshalUint32(data)
c.mu.Lock()
ch, ok := c.inflight[sid]
delete(c.inflight, sid)
c.mu.Unlock()
if !ok {
// This is an unexpected occurrence. Send the error
// back to all listeners so that they terminate
// gracefully.
c.broadcastErr(fmt.Errorf("sid: %v not fond", sid))
return
}
ch <- result{typ: typ, data: data}
}
}
// Walk returns a new Walker rooted at root.
func (c *Client) Walk(root string) *fs.Walker {
return fs.WalkFS(root, c)
@ -196,7 +144,7 @@ func (c *Client) ReadDir(p string) ([]os.FileInfo, error) {
var done = false
for !done {
id := c.nextID()
typ, data, err1 := c.sendRequest(sshFxpReaddirPacket{
typ, data, err1 := c.sendPacket(sshFxpReaddirPacket{
ID: id,
Handle: handle,
})
@ -239,7 +187,7 @@ func (c *Client) ReadDir(p string) ([]os.FileInfo, error) {
func (c *Client) opendir(path string) (string, error) {
id := c.nextID()
typ, data, err := c.sendRequest(sshFxpOpendirPacket{
typ, data, err := c.sendPacket(sshFxpOpendirPacket{
ID: id,
Path: path,
})
@ -265,7 +213,7 @@ func (c *Client) opendir(path string) (string, error) {
// If 'p' is a symbolic link, the returned FileInfo structure describes the referent file.
func (c *Client) Stat(p string) (os.FileInfo, error) {
id := c.nextID()
typ, data, err := c.sendRequest(sshFxpStatPacket{
typ, data, err := c.sendPacket(sshFxpStatPacket{
ID: id,
Path: p,
})
@ -291,7 +239,7 @@ func (c *Client) Stat(p string) (os.FileInfo, error) {
// If 'p' is a symbolic link, the returned FileInfo structure describes the symbolic link.
func (c *Client) Lstat(p string) (os.FileInfo, error) {
id := c.nextID()
typ, data, err := c.sendRequest(sshFxpLstatPacket{
typ, data, err := c.sendPacket(sshFxpLstatPacket{
ID: id,
Path: p,
})
@ -316,7 +264,7 @@ func (c *Client) Lstat(p string) (os.FileInfo, error) {
// ReadLink reads the target of a symbolic link.
func (c *Client) ReadLink(p string) (string, error) {
id := c.nextID()
typ, data, err := c.sendRequest(sshFxpReadlinkPacket{
typ, data, err := c.sendPacket(sshFxpReadlinkPacket{
ID: id,
Path: p,
})
@ -345,7 +293,7 @@ func (c *Client) ReadLink(p string) (string, error) {
// Symlink creates a symbolic link at 'newname', pointing at target 'oldname'
func (c *Client) Symlink(oldname, newname string) error {
id := c.nextID()
typ, data, err := c.sendRequest(sshFxpSymlinkPacket{
typ, data, err := c.sendPacket(sshFxpSymlinkPacket{
ID: id,
Linkpath: newname,
Targetpath: oldname,
@ -364,7 +312,7 @@ func (c *Client) Symlink(oldname, newname string) error {
// setstat is a convience wrapper to allow for changing of various parts of the file descriptor.
func (c *Client) setstat(path string, flags uint32, attrs interface{}) error {
id := c.nextID()
typ, data, err := c.sendRequest(sshFxpSetstatPacket{
typ, data, err := c.sendPacket(sshFxpSetstatPacket{
ID: id,
Path: path,
Flags: flags,
@ -430,7 +378,7 @@ func (c *Client) OpenFile(path string, f int) (*File, error) {
func (c *Client) open(path string, pflags uint32) (*File, error) {
id := c.nextID()
typ, data, err := c.sendRequest(sshFxpOpenPacket{
typ, data, err := c.sendPacket(sshFxpOpenPacket{
ID: id,
Path: path,
Pflags: pflags,
@ -458,7 +406,7 @@ func (c *Client) open(path string, pflags uint32) (*File, error) {
// immediately after this request has been sent.
func (c *Client) close(handle string) error {
id := c.nextID()
typ, data, err := c.sendRequest(sshFxpClosePacket{
typ, data, err := c.sendPacket(sshFxpClosePacket{
ID: id,
Handle: handle,
})
@ -475,7 +423,7 @@ func (c *Client) close(handle string) error {
func (c *Client) fstat(handle string) (*FileStat, error) {
id := c.nextID()
typ, data, err := c.sendRequest(sshFxpFstatPacket{
typ, data, err := c.sendPacket(sshFxpFstatPacket{
ID: id,
Handle: handle,
})
@ -504,7 +452,7 @@ func (c *Client) fstat(handle string) (*FileStat, error) {
func (c *Client) StatVFS(path string) (*StatVFS, error) {
// send the StatVFS packet to the server
id := c.nextID()
typ, data, err := c.sendRequest(sshFxpStatvfsPacket{
typ, data, err := c.sendPacket(sshFxpStatvfsPacket{
ID: id,
Path: path,
})
@ -560,7 +508,7 @@ func (c *Client) Remove(path string) error {
func (c *Client) removeFile(path string) error {
id := c.nextID()
typ, data, err := c.sendRequest(sshFxpRemovePacket{
typ, data, err := c.sendPacket(sshFxpRemovePacket{
ID: id,
Filename: path,
})
@ -577,7 +525,7 @@ func (c *Client) removeFile(path string) error {
func (c *Client) removeDirectory(path string) error {
id := c.nextID()
typ, data, err := c.sendRequest(sshFxpRmdirPacket{
typ, data, err := c.sendPacket(sshFxpRmdirPacket{
ID: id,
Path: path,
})
@ -595,7 +543,7 @@ func (c *Client) removeDirectory(path string) error {
// Rename renames a file.
func (c *Client) Rename(oldname, newname string) error {
id := c.nextID()
typ, data, err := c.sendRequest(sshFxpRenamePacket{
typ, data, err := c.sendPacket(sshFxpRenamePacket{
ID: id,
Oldpath: oldname,
Newpath: newname,
@ -613,7 +561,7 @@ func (c *Client) Rename(oldname, newname string) error {
func (c *Client) realpath(path string) (string, error) {
id := c.nextID()
typ, data, err := c.sendRequest(sshFxpRealpathPacket{
typ, data, err := c.sendPacket(sshFxpRealpathPacket{
ID: id,
Path: path,
})
@ -645,41 +593,12 @@ func (c *Client) Getwd() (string, error) {
return c.realpath(".")
}
// result captures the result of receiving the a packet from the server
type result struct {
typ byte
data []byte
err error
}
type idmarshaler interface {
id() uint32
encoding.BinaryMarshaler
}
func (c *Client) sendRequest(p idmarshaler) (byte, []byte, error) {
ch := make(chan result, 1)
c.dispatchRequest(ch, p)
s := <-ch
return s.typ, s.data, s.err
}
func (c *Client) dispatchRequest(ch chan<- result, p idmarshaler) {
c.mu.Lock()
c.inflight[p.id()] = ch
if err := sendPacket(c.w, p); err != nil {
delete(c.inflight, p.id())
ch <- result{err: err}
}
c.mu.Unlock()
}
// Mkdir creates the specified directory. An error will be returned if a file or
// directory with the specified path already exists, or if the directory's
// parent folder does not exist (the method cannot create complete paths).
func (c *Client) Mkdir(path string) error {
id := c.nextID()
typ, data, err := c.sendRequest(sshFxpMkdirPacket{
typ, data, err := c.sendPacket(sshFxpMkdirPacket{
ID: id,
Path: path,
})
@ -784,7 +703,7 @@ func (f *File) Read(b []byte) (int, error) {
reqID, data := unmarshalUint32(res.data)
req, ok := reqs[reqID]
if !ok {
firstErr = offsetErr{offset: 0, err: fmt.Errorf("sid: %v not found", reqID)}
firstErr = offsetErr{offset: 0, err: errors.Errorf("sid: %v not found", reqID)}
break
}
delete(reqs, reqID)
@ -885,7 +804,7 @@ func (f *File) WriteTo(w io.Writer) (int64, error) {
reqID, data := unmarshalUint32(res.data)
req, ok := reqs[reqID]
if !ok {
firstErr = offsetErr{offset: 0, err: fmt.Errorf("sid: %v not found", reqID)}
firstErr = offsetErr{offset: 0, err: errors.Errorf("sid: %v not found", reqID)}
break
}
delete(reqs, reqID)
@ -1166,7 +1085,10 @@ func unmarshalStatus(id uint32, data []byte) error {
return &unexpectedIDErr{id, sid}
}
code, data := unmarshalUint32(data)
msg, data := unmarshalString(data)
msg, data, err := unmarshalStringSafe(data)
if err != nil {
return err
}
lang, _, _ := unmarshalStringSafe(data)
return &StatusError{
Code: code,

View File

@ -39,8 +39,4 @@ func TestClientStatVFS(t *testing.T) {
if vfs.Favail != uint64(s.Ffree) {
t.Fatal("f_namemax does not match")
}
if vfs.Bavail != s.Bavail {
t.Fatal("f_bavail does not match")
}
}

View File

@ -29,18 +29,14 @@ func TestClientStatVFS(t *testing.T) {
// check some stats
if vfs.Frsize != uint64(s.Frsize) {
t.Fatal("fr_size does not match")
t.Fatal("fr_size does not match, expected: %v, got: %v", s.Frsize, vfs.Frsize)
}
if vfs.Bsize != uint64(s.Bsize) {
t.Fatal("f_bsize does not match")
t.Fatal("f_bsize does not match, expected: %v, got: %v", s.Bsize, vfs.Bsize)
}
if vfs.Namemax != uint64(s.Namelen) {
t.Fatal("f_namemax does not match")
}
if vfs.Bavail != s.Bavail {
t.Fatal("f_bavail does not match")
t.Fatal("f_namemax does not match, expected: %v, got: %v", s.Namelen, vfs.Namemax)
}
}

View File

@ -9,6 +9,7 @@ import (
"io"
"io/ioutil"
"math/rand"
"net"
"os"
"os/exec"
"os/user"
@ -84,37 +85,61 @@ func (w delayedWriter) Close() error {
return nil
}
// netPipe provides a pair of io.ReadWriteClosers connected to each other.
// The functions is identical to os.Pipe with the exception that netPipe
// provides the Read/Close guarentees that os.File derrived pipes do not.
func netPipe(t testing.TB) (io.ReadWriteCloser, io.ReadWriteCloser) {
type result struct {
net.Conn
error
}
l, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatal(err)
}
ch := make(chan result, 1)
go func() {
conn, err := l.Accept()
ch <- result{conn, err}
err = l.Close()
if err != nil {
t.Error(err)
}
}()
c1, err := net.Dial("tcp", l.Addr().String())
if err != nil {
l.Close() // might cause another in the listening goroutine, but too bad
t.Fatal(err)
}
r := <-ch
if r.error != nil {
t.Fatal(err)
}
return c1, r.Conn
}
func testClientGoSvr(t testing.TB, readonly bool, delay time.Duration) (*Client, *exec.Cmd) {
txPipeRd, txPipeWr, err := os.Pipe()
if err != nil {
t.Fatal(err)
}
rxPipeRd, rxPipeWr, err := os.Pipe()
if err != nil {
t.Fatal(err)
}
c1, c2 := netPipe(t)
options := []ServerOption{WithDebug(os.Stderr)}
if readonly {
options = append(options, ReadOnly())
}
server, err := NewServer(
txPipeRd,
rxPipeWr,
options...,
)
server, err := NewServer(c1, options...)
if err != nil {
t.Fatal(err)
}
go server.Serve()
var ctx io.WriteCloser = txPipeWr
var ctx io.WriteCloser = c2
if delay > NO_DELAY {
ctx = newDelayedWriter(ctx, delay)
}
client, err := NewClientPipe(rxPipeRd, ctx)
client, err := NewClientPipe(c2, ctx)
if err != nil {
t.Fatal(err)
}
@ -465,6 +490,66 @@ func TestClientFileStat(t *testing.T) {
}
}
func TestClientStatLink(t *testing.T) {
sftp, cmd := testClient(t, READONLY, NO_DELAY)
defer cmd.Wait()
defer sftp.Close()
f, err := ioutil.TempFile("", "sftptest")
if err != nil {
t.Fatal(err)
}
defer os.Remove(f.Name())
realName := f.Name()
linkName := f.Name() + ".softlink"
// create a symlink that points at sftptest
if err := os.Symlink(realName, linkName); err != nil {
t.Fatal(err)
}
defer os.Remove(linkName)
// compare Lstat of links
wantLstat, err := os.Lstat(linkName)
if err != nil {
t.Fatal(err)
}
wantStat, err := os.Stat(linkName)
if err != nil {
t.Fatal(err)
}
gotLstat, err := sftp.Lstat(linkName)
if err != nil {
t.Fatal(err)
}
gotStat, err := sftp.Stat(linkName)
if err != nil {
t.Fatal(err)
}
// check that stat is not lstat from os package
if sameFile(wantLstat, wantStat) {
t.Fatalf("Lstat / Stat(%q): both %#v %#v", f.Name(), wantLstat, wantStat)
}
// compare Lstat of links
if !sameFile(wantLstat, gotLstat) {
t.Fatalf("Lstat(%q): want %#v, got %#v", f.Name(), wantLstat, gotLstat)
}
// compare Stat of links
if !sameFile(wantStat, gotStat) {
t.Fatalf("Stat(%q): want %#v, got %#v", f.Name(), wantStat, gotStat)
}
// check that stat is not lstat
if sameFile(gotLstat, gotStat) {
t.Fatalf("Lstat / Stat(%q): both %#v %#v", f.Name(), gotLstat, gotStat)
}
}
func TestClientRemove(t *testing.T) {
sftp, cmd := testClient(t, READWRITE, NO_DELAY)
defer cmd.Wait()
@ -1076,10 +1161,6 @@ func TestClientWrite(t *testing.T) {
// taken from github.com/kr/fs/walk_test.go
type PathTest struct {
path, result string
}
type Node struct {
name string
entries []*Node // nil if the entry is a file

View File

@ -4,6 +4,7 @@ import (
"errors"
"io"
"os"
"reflect"
"testing"
"github.com/kr/fs"
@ -87,13 +88,58 @@ func TestFlags(t *testing.T) {
}
}
func TestMissingLangTag(t *testing.T) {
defer func() {
if r := recover(); r != nil {
t.Fail()
func TestUnmarshalStatus(t *testing.T) {
requestID := uint32(1)
id := marshalUint32([]byte{}, requestID)
idCode := marshalUint32(id, ssh_FX_FAILURE)
idCodeMsg := marshalString(idCode, "err msg")
idCodeMsgLang := marshalString(idCodeMsg, "lang tag")
var tests = []struct {
desc string
reqID uint32
status []byte
want error
}{
{
desc: "well-formed status",
reqID: 1,
status: idCodeMsgLang,
want: &StatusError{
Code: ssh_FX_FAILURE,
msg: "err msg",
lang: "lang tag",
},
},
{
desc: "missing error message and language tag",
reqID: 1,
status: idCode,
want: errShortPacket,
},
{
desc: "missing language tag",
reqID: 1,
status: idCodeMsg,
want: &StatusError{
Code: ssh_FX_FAILURE,
msg: "err msg",
},
},
{
desc: "request identifier mismatch",
reqID: 2,
status: idCodeMsgLang,
want: &unexpectedIDErr{2, requestID},
},
}
for _, tt := range tests {
got := unmarshalStatus(tt.reqID, tt.status)
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("unmarshalStatus(%v, %v), test %q\n- want: %#v\n- got: %#v",
requestID, tt.status, tt.desc, tt.want, got)
}
}()
buf := marshalUint32([]byte{}, 0)
buf = marshalStatus(buf, StatusError{})
_ = unmarshalStatus(0, buf[:len(buf)-4])
}
}

122
vendor/src/github.com/pkg/sftp/conn.go vendored Normal file
View File

@ -0,0 +1,122 @@
package sftp
import (
"encoding"
"io"
"sync"
"github.com/pkg/errors"
)
// conn implements a bidirectional channel on which client and server
// connections are multiplexed.
type conn struct {
io.Reader
io.WriteCloser
sync.Mutex // used to serialise writes to sendPacket
}
func (c *conn) recvPacket() (uint8, []byte, error) {
return recvPacket(c)
}
func (c *conn) sendPacket(m encoding.BinaryMarshaler) error {
c.Lock()
defer c.Unlock()
return sendPacket(c, m)
}
type clientConn struct {
conn
wg sync.WaitGroup
sync.Mutex // protects inflight
inflight map[uint32]chan<- result // outstanding requests
}
// Close closes the SFTP session.
func (c *clientConn) Close() error {
defer c.wg.Wait()
return c.conn.Close()
}
func (c *clientConn) loop() {
defer c.wg.Done()
err := c.recv()
if err != nil {
c.broadcastErr(err)
}
}
// recv continuously reads from the server and forwards responses to the
// appropriate channel.
func (c *clientConn) recv() error {
defer c.conn.Close()
for {
typ, data, err := c.recvPacket()
if err != nil {
return err
}
sid, _ := unmarshalUint32(data)
c.Lock()
ch, ok := c.inflight[sid]
delete(c.inflight, sid)
c.Unlock()
if !ok {
// This is an unexpected occurrence. Send the error
// back to all listeners so that they terminate
// gracefully.
return errors.Errorf("sid: %v not fond", sid)
}
ch <- result{typ: typ, data: data}
}
}
// result captures the result of receiving the a packet from the server
type result struct {
typ byte
data []byte
err error
}
type idmarshaler interface {
id() uint32
encoding.BinaryMarshaler
}
func (c *clientConn) sendPacket(p idmarshaler) (byte, []byte, error) {
ch := make(chan result, 1)
c.dispatchRequest(ch, p)
s := <-ch
return s.typ, s.data, s.err
}
func (c *clientConn) dispatchRequest(ch chan<- result, p idmarshaler) {
c.Lock()
c.inflight[p.id()] = ch
if err := c.conn.sendPacket(p); err != nil {
delete(c.inflight, p.id())
ch <- result{err: err}
}
c.Unlock()
}
// broadcastErr sends an error to all goroutines waiting for a response.
func (c *clientConn) broadcastErr(err error) {
c.Lock()
listeners := make([]chan<- result, 0, len(c.inflight))
for _, ch := range c.inflight {
listeners = append(listeners, ch)
}
c.Unlock()
for _, ch := range listeners {
ch <- result{err: err}
}
}
type serverConn struct {
conn
}
func (s *serverConn) sendError(p id, err error) error {
return s.sendPacket(statusFromError(p, err))
}

View File

@ -118,11 +118,20 @@ func main() {
}
}(requests)
serverOptions := []sftp.ServerOption{
sftp.WithDebug(debugStream),
}
if readOnly {
serverOptions = append(serverOptions, sftp.ReadOnly())
fmt.Fprintf(debugStream, "Read-only server\n")
} else {
fmt.Fprintf(debugStream, "Read write server\n")
}
server, err := sftp.NewServer(
channel,
channel,
sftp.WithDebug(debugStream),
sftp.ReadOnly(),
serverOptions...,
)
if err != nil {
log.Fatal(err)

View File

@ -1,16 +1,20 @@
package sftp
import (
"bytes"
"encoding"
"errors"
"encoding/binary"
"fmt"
"io"
"os"
"reflect"
"github.com/pkg/errors"
)
var (
errShortPacket = errors.New("packet too short")
errShortPacket = errors.New("packet too short")
errUnknownExtendedPacket = errors.New("unknown extended packet")
)
const (
@ -114,7 +118,7 @@ func unmarshalStringSafe(b []byte) (string, []byte, error) {
func sendPacket(w io.Writer, m encoding.BinaryMarshaler) error {
bb, err := m.MarshalBinary()
if err != nil {
return fmt.Errorf("marshal2(%#v): binary marshaller failed", err)
return errors.Errorf("binary marshaller failed: %v", err)
}
if debugDumpTxPacketBytes {
debug("send packet: %s %d bytes %x", fxp(bb[0]), len(bb), bb[1:])
@ -125,17 +129,13 @@ func sendPacket(w io.Writer, m encoding.BinaryMarshaler) error {
hdr := []byte{byte(l >> 24), byte(l >> 16), byte(l >> 8), byte(l)}
_, err = w.Write(hdr)
if err != nil {
return err
return errors.Errorf("failed to send packet header: %v", err)
}
_, err = w.Write(bb)
return err
}
func (svr *Server) sendPacket(m encoding.BinaryMarshaler) error {
// any responder can call sendPacket(); actual socket access must be serialized
svr.outMutex.Lock()
defer svr.outMutex.Unlock()
return sendPacket(svr.out, m)
if err != nil {
return errors.Errorf("failed to send packet body: %v", err)
}
return nil
}
func recvPacket(r io.Reader) (uint8, []byte, error) {
@ -259,10 +259,7 @@ func unmarshalIDString(b []byte, id *uint32, str *string) error {
return err
}
*str, b, err = unmarshalStringSafe(b)
if err != nil {
return err
}
return nil
return err
}
type sshFxpReaddirPacket struct {
@ -318,7 +315,7 @@ type sshFxpStatPacket struct {
func (p sshFxpStatPacket) id() uint32 { return p.ID }
func (p sshFxpStatPacket) MarshalBinary() ([]byte, error) {
return marshalIDString(ssh_FXP_LSTAT, p.ID, p.Path)
return marshalIDString(ssh_FXP_STAT, p.ID, p.Path)
}
func (p *sshFxpStatPacket) UnmarshalBinary(b []byte) error {
@ -838,3 +835,67 @@ func (p *StatVFS) TotalSpace() uint64 {
func (p *StatVFS) FreeSpace() uint64 {
return p.Frsize * p.Bfree
}
// Convert to ssh_FXP_EXTENDED_REPLY packet binary format
func (p *StatVFS) MarshalBinary() ([]byte, error) {
var buf bytes.Buffer
buf.Write([]byte{ssh_FXP_EXTENDED_REPLY})
err := binary.Write(&buf, binary.BigEndian, p)
return buf.Bytes(), err
}
type sshFxpExtendedPacket struct {
ID uint32
ExtendedRequest string
SpecificPacket interface {
serverRespondablePacket
readonly() bool
}
}
func (p sshFxpExtendedPacket) id() uint32 { return p.ID }
func (p sshFxpExtendedPacket) readonly() bool { return p.SpecificPacket.readonly() }
func (p sshFxpExtendedPacket) respond(svr *Server) error {
return p.SpecificPacket.respond(svr)
}
func (p *sshFxpExtendedPacket) UnmarshalBinary(b []byte) error {
var err error
bOrig := b
if p.ID, b, err = unmarshalUint32Safe(b); err != nil {
return err
} else if p.ExtendedRequest, b, err = unmarshalStringSafe(b); err != nil {
return err
}
// specific unmarshalling
switch p.ExtendedRequest {
case "statvfs@openssh.com":
p.SpecificPacket = &sshFxpExtendedPacketStatVFS{}
default:
return errUnknownExtendedPacket
}
return p.SpecificPacket.UnmarshalBinary(bOrig)
}
type sshFxpExtendedPacketStatVFS struct {
ID uint32
ExtendedRequest string
Path string
}
func (p sshFxpExtendedPacketStatVFS) id() uint32 { return p.ID }
func (p sshFxpExtendedPacketStatVFS) readonly() bool { return true }
func (p *sshFxpExtendedPacketStatVFS) UnmarshalBinary(b []byte) error {
var err error
if p.ID, b, err = unmarshalUint32Safe(b); err != nil {
return err
} else if p.ExtendedRequest, b, err = unmarshalStringSafe(b); err != nil {
return err
} else if p.Path, b, err = unmarshalStringSafe(b); err != nil {
return err
}
return nil
}

View File

@ -13,6 +13,8 @@ import (
"sync"
"syscall"
"time"
"github.com/pkg/errors"
)
const (
@ -24,18 +26,14 @@ const (
// This implementation currently supports most of sftp server protocol version 3,
// as specified at http://tools.ietf.org/html/draft-ietf-secsh-filexfer-02
type Server struct {
in io.Reader
out io.WriteCloser
outMutex *sync.Mutex
serverConn
debugStream io.Writer
readOnly bool
lastID uint32
pktChan chan rxPacket
openFiles map[string]*os.File
openFilesLock *sync.RWMutex
openFilesLock sync.RWMutex
handleCount int
maxTxPacket uint32
workerCount int
}
func (svr *Server) nextHandle(f *os.File) string {
@ -69,7 +67,6 @@ type serverRespondablePacket interface {
encoding.BinaryUnmarshaler
id() uint32
respond(svr *Server) error
readonly() bool
}
// NewServer creates a new Server instance around the provided streams, serving
@ -77,17 +74,18 @@ type serverRespondablePacket interface {
// functions may be specified to further configure the Server.
//
// A subsequent call to Serve() is required to begin serving files over SFTP.
func NewServer(in io.Reader, out io.WriteCloser, options ...ServerOption) (*Server, error) {
func NewServer(rwc io.ReadWriteCloser, options ...ServerOption) (*Server, error) {
s := &Server{
in: in,
out: out,
outMutex: &sync.Mutex{},
debugStream: ioutil.Discard,
pktChan: make(chan rxPacket, sftpServerWorkerCount),
openFiles: map[string]*os.File{},
openFilesLock: &sync.RWMutex{},
maxTxPacket: 1 << 15,
workerCount: sftpServerWorkerCount,
serverConn: serverConn{
conn: conn{
Reader: rwc,
WriteCloser: rwc,
},
},
debugStream: ioutil.Discard,
pktChan: make(chan rxPacket, sftpServerWorkerCount),
openFiles: make(map[string]*os.File),
maxTxPacket: 1 << 15,
}
for _, o := range options {
@ -123,123 +121,261 @@ type rxPacket struct {
pktBytes []byte
}
// Unmarshal a single logical packet from the secure channel
func (svr *Server) rxPackets() error {
defer close(svr.pktChan)
for {
pktType, pktBytes, err := recvPacket(svr.in)
switch err {
case nil:
svr.pktChan <- rxPacket{fxp(pktType), pktBytes}
case io.EOF:
return nil
// Up to N parallel servers
func (svr *Server) sftpServerWorker() error {
for p := range svr.pktChan {
var pkt interface {
encoding.BinaryUnmarshaler
id() uint32
}
var readonly = true
switch p.pktType {
case ssh_FXP_INIT:
pkt = &sshFxInitPacket{}
case ssh_FXP_LSTAT:
pkt = &sshFxpLstatPacket{}
case ssh_FXP_OPEN:
pkt = &sshFxpOpenPacket{}
// readonly handled specially below
case ssh_FXP_CLOSE:
pkt = &sshFxpClosePacket{}
case ssh_FXP_READ:
pkt = &sshFxpReadPacket{}
case ssh_FXP_WRITE:
pkt = &sshFxpWritePacket{}
readonly = false
case ssh_FXP_FSTAT:
pkt = &sshFxpFstatPacket{}
case ssh_FXP_SETSTAT:
pkt = &sshFxpSetstatPacket{}
readonly = false
case ssh_FXP_FSETSTAT:
pkt = &sshFxpFsetstatPacket{}
readonly = false
case ssh_FXP_OPENDIR:
pkt = &sshFxpOpendirPacket{}
case ssh_FXP_READDIR:
pkt = &sshFxpReaddirPacket{}
case ssh_FXP_REMOVE:
pkt = &sshFxpRemovePacket{}
readonly = false
case ssh_FXP_MKDIR:
pkt = &sshFxpMkdirPacket{}
readonly = false
case ssh_FXP_RMDIR:
pkt = &sshFxpRmdirPacket{}
readonly = false
case ssh_FXP_REALPATH:
pkt = &sshFxpRealpathPacket{}
case ssh_FXP_STAT:
pkt = &sshFxpStatPacket{}
case ssh_FXP_RENAME:
pkt = &sshFxpRenamePacket{}
readonly = false
case ssh_FXP_READLINK:
pkt = &sshFxpReadlinkPacket{}
case ssh_FXP_SYMLINK:
pkt = &sshFxpSymlinkPacket{}
readonly = false
case ssh_FXP_EXTENDED:
pkt = &sshFxpExtendedPacket{}
default:
fmt.Fprintf(svr.debugStream, "recvPacket error: %v\n", err)
return errors.Errorf("unhandled packet type: %s", p.pktType)
}
if err := pkt.UnmarshalBinary(p.pktBytes); err != nil {
return err
}
}
}
// Up to N parallel servers
func (svr *Server) sftpServerWorker(doneChan chan error) {
for pkt := range svr.pktChan {
dPkt, err := svr.decodePacket(pkt.pktType, pkt.pktBytes)
if err != nil {
fmt.Fprintf(svr.debugStream, "decodePacket error: %v\n", err)
doneChan <- err
return
// handle FXP_OPENDIR specially
switch pkt := pkt.(type) {
case *sshFxpOpenPacket:
readonly = pkt.readonly()
case *sshFxpExtendedPacket:
readonly = pkt.SpecificPacket.readonly()
}
// If server is operating read-only and a write operation is requested,
// return permission denied
if !dPkt.readonly() && svr.readOnly {
_ = svr.sendPacket(statusFromError(dPkt.id(), syscall.EPERM))
if !readonly && svr.readOnly {
if err := svr.sendError(pkt, syscall.EPERM); err != nil {
return errors.Wrap(err, "failed to send read only packet response")
}
continue
}
_ = dPkt.respond(svr)
if err := handlePacket(svr, pkt); err != nil {
return err
}
}
return nil
}
func handlePacket(s *Server, p interface{}) error {
switch p := p.(type) {
case *sshFxInitPacket:
return s.sendPacket(sshFxVersionPacket{sftpProtocolVersion, nil})
case *sshFxpStatPacket:
// stat the requested file
info, err := os.Stat(p.Path)
if err != nil {
return s.sendError(p, err)
}
return s.sendPacket(sshFxpStatResponse{
ID: p.ID,
info: info,
})
case *sshFxpLstatPacket:
// stat the requested file
info, err := os.Lstat(p.Path)
if err != nil {
return s.sendError(p, err)
}
return s.sendPacket(sshFxpStatResponse{
ID: p.ID,
info: info,
})
case *sshFxpFstatPacket:
f, ok := s.getHandle(p.Handle)
if !ok {
return s.sendError(p, syscall.EBADF)
}
info, err := f.Stat()
if err != nil {
return s.sendError(p, err)
}
return s.sendPacket(sshFxpStatResponse{
ID: p.ID,
info: info,
})
case *sshFxpMkdirPacket:
// TODO FIXME: ignore flags field
err := os.Mkdir(p.Path, 0755)
return s.sendError(p, err)
case *sshFxpRmdirPacket:
err := os.Remove(p.Path)
return s.sendError(p, err)
case *sshFxpRemovePacket:
err := os.Remove(p.Filename)
return s.sendError(p, err)
case *sshFxpRenamePacket:
err := os.Rename(p.Oldpath, p.Newpath)
return s.sendError(p, err)
case *sshFxpSymlinkPacket:
err := os.Symlink(p.Targetpath, p.Linkpath)
return s.sendError(p, err)
case *sshFxpClosePacket:
return s.sendError(p, s.closeHandle(p.Handle))
case *sshFxpReadlinkPacket:
f, err := os.Readlink(p.Path)
if err != nil {
return s.sendError(p, err)
}
return s.sendPacket(sshFxpNamePacket{
ID: p.ID,
NameAttrs: []sshFxpNameAttr{{
Name: f,
LongName: f,
Attrs: emptyFileStat,
}},
})
case *sshFxpRealpathPacket:
f, err := filepath.Abs(p.Path)
if err != nil {
return s.sendError(p, err)
}
f = filepath.Clean(f)
return s.sendPacket(sshFxpNamePacket{
ID: p.ID,
NameAttrs: []sshFxpNameAttr{{
Name: f,
LongName: f,
Attrs: emptyFileStat,
}},
})
case *sshFxpOpendirPacket:
return sshFxpOpenPacket{
ID: p.ID,
Path: p.Path,
Pflags: ssh_FXF_READ,
}.respond(s)
case *sshFxpReadPacket:
f, ok := s.getHandle(p.Handle)
if !ok {
return s.sendError(p, syscall.EBADF)
}
data := make([]byte, clamp(p.Len, s.maxTxPacket))
n, err := f.ReadAt(data, int64(p.Offset))
if err != nil && (err != io.EOF || n == 0) {
return s.sendError(p, err)
}
return s.sendPacket(sshFxpDataPacket{
ID: p.ID,
Length: uint32(n),
Data: data[:n],
})
case *sshFxpWritePacket:
f, ok := s.getHandle(p.Handle)
if !ok {
return s.sendError(p, syscall.EBADF)
}
_, err := f.WriteAt(p.Data, int64(p.Offset))
return s.sendError(p, err)
case serverRespondablePacket:
err := p.respond(s)
return errors.Wrap(err, "pkt.respond failed")
default:
return errors.Errorf("unexpected packet type %T", p)
}
doneChan <- nil
}
// Serve serves SFTP connections until the streams stop or the SFTP subsystem
// is stopped.
func (svr *Server) Serve() error {
go svr.rxPackets()
doneChan := make(chan error)
for i := 0; i < svr.workerCount; i++ {
go svr.sftpServerWorker(doneChan)
var wg sync.WaitGroup
wg.Add(sftpServerWorkerCount)
for i := 0; i < sftpServerWorkerCount; i++ {
go func() {
defer wg.Done()
if err := svr.sftpServerWorker(); err != nil {
svr.conn.Close() // shuts down recvPacket
}
}()
}
for i := 0; i < svr.workerCount; i++ {
if err := <-doneChan; err != nil {
// abort early and shut down the session on un-decodable packets
var err error
var pktType uint8
var pktBytes []byte
for {
pktType, pktBytes, err = svr.recvPacket()
if err != nil {
break
}
svr.pktChan <- rxPacket{fxp(pktType), pktBytes}
}
close(svr.pktChan) // shuts down sftpServerWorkers
wg.Wait() // wait for all workers to exit
// close any still-open files
for handle, file := range svr.openFiles {
fmt.Fprintf(svr.debugStream, "sftp server file with handle '%v' left open: %v\n", handle, file.Name())
fmt.Fprintf(svr.debugStream, "sftp server file with handle %q left open: %v\n", handle, file.Name())
file.Close()
}
return svr.out.Close()
return err // error from recvPacket
}
func (svr *Server) decodePacket(pktType fxp, pktBytes []byte) (serverRespondablePacket, error) {
var pkt serverRespondablePacket
switch pktType {
case ssh_FXP_INIT:
pkt = &sshFxInitPacket{}
case ssh_FXP_LSTAT:
pkt = &sshFxpLstatPacket{}
case ssh_FXP_OPEN:
pkt = &sshFxpOpenPacket{}
case ssh_FXP_CLOSE:
pkt = &sshFxpClosePacket{}
case ssh_FXP_READ:
pkt = &sshFxpReadPacket{}
case ssh_FXP_WRITE:
pkt = &sshFxpWritePacket{}
case ssh_FXP_FSTAT:
pkt = &sshFxpFstatPacket{}
case ssh_FXP_SETSTAT:
pkt = &sshFxpSetstatPacket{}
case ssh_FXP_FSETSTAT:
pkt = &sshFxpFsetstatPacket{}
case ssh_FXP_OPENDIR:
pkt = &sshFxpOpendirPacket{}
case ssh_FXP_READDIR:
pkt = &sshFxpReaddirPacket{}
case ssh_FXP_REMOVE:
pkt = &sshFxpRemovePacket{}
case ssh_FXP_MKDIR:
pkt = &sshFxpMkdirPacket{}
case ssh_FXP_RMDIR:
pkt = &sshFxpRmdirPacket{}
case ssh_FXP_REALPATH:
pkt = &sshFxpRealpathPacket{}
case ssh_FXP_STAT:
pkt = &sshFxpStatPacket{}
case ssh_FXP_RENAME:
pkt = &sshFxpRenamePacket{}
case ssh_FXP_READLINK:
pkt = &sshFxpReadlinkPacket{}
case ssh_FXP_SYMLINK:
pkt = &sshFxpSymlinkPacket{}
default:
return nil, fmt.Errorf("unhandled packet type: %s", pktType)
}
err := pkt.UnmarshalBinary(pktBytes)
return pkt, err
}
func (p sshFxInitPacket) respond(svr *Server) error {
return svr.sendPacket(sshFxVersionPacket{sftpProtocolVersion, nil})
type id interface {
id() uint32
}
// The init packet has no ID, so we just return a zero-value ID
func (p sshFxInitPacket) id() uint32 { return 0 }
func (p sshFxInitPacket) readonly() bool { return true }
func (p sshFxInitPacket) id() uint32 { return 0 }
type sshFxpStatResponse struct {
ID uint32
@ -253,141 +389,8 @@ func (p sshFxpStatResponse) MarshalBinary() ([]byte, error) {
return b, nil
}
func (p sshFxpLstatPacket) readonly() bool { return true }
func (p sshFxpLstatPacket) respond(svr *Server) error {
// stat the requested file
info, err := os.Lstat(p.Path)
if err != nil {
return svr.sendPacket(statusFromError(p.ID, err))
}
return svr.sendPacket(sshFxpStatResponse{
ID: p.ID,
info: info,
})
}
func (p sshFxpStatPacket) readonly() bool { return true }
func (p sshFxpStatPacket) respond(svr *Server) error {
// stat the requested file
info, err := os.Stat(p.Path)
if err != nil {
return svr.sendPacket(statusFromError(p.ID, err))
}
return svr.sendPacket(sshFxpStatResponse{
ID: p.ID,
info: info,
})
}
func (p sshFxpFstatPacket) readonly() bool { return true }
func (p sshFxpFstatPacket) respond(svr *Server) error {
f, ok := svr.getHandle(p.Handle)
if !ok {
return svr.sendPacket(statusFromError(p.ID, syscall.EBADF))
}
info, err := f.Stat()
if err != nil {
return svr.sendPacket(statusFromError(p.ID, err))
}
return svr.sendPacket(sshFxpStatResponse{
ID: p.ID,
info: info,
})
}
func (p sshFxpMkdirPacket) readonly() bool { return false }
func (p sshFxpMkdirPacket) respond(svr *Server) error {
// TODO FIXME: ignore flags field
err := os.Mkdir(p.Path, 0755)
return svr.sendPacket(statusFromError(p.ID, err))
}
func (p sshFxpRmdirPacket) readonly() bool { return false }
func (p sshFxpRmdirPacket) respond(svr *Server) error {
err := os.Remove(p.Path)
return svr.sendPacket(statusFromError(p.ID, err))
}
func (p sshFxpRemovePacket) readonly() bool { return false }
func (p sshFxpRemovePacket) respond(svr *Server) error {
err := os.Remove(p.Filename)
return svr.sendPacket(statusFromError(p.ID, err))
}
func (p sshFxpRenamePacket) readonly() bool { return false }
func (p sshFxpRenamePacket) respond(svr *Server) error {
err := os.Rename(p.Oldpath, p.Newpath)
return svr.sendPacket(statusFromError(p.ID, err))
}
func (p sshFxpSymlinkPacket) readonly() bool { return false }
func (p sshFxpSymlinkPacket) respond(svr *Server) error {
err := os.Symlink(p.Targetpath, p.Linkpath)
return svr.sendPacket(statusFromError(p.ID, err))
}
var emptyFileStat = []interface{}{uint32(0)}
func (p sshFxpReadlinkPacket) readonly() bool { return true }
func (p sshFxpReadlinkPacket) respond(svr *Server) error {
f, err := os.Readlink(p.Path)
if err != nil {
return svr.sendPacket(statusFromError(p.ID, err))
}
return svr.sendPacket(sshFxpNamePacket{
ID: p.ID,
NameAttrs: []sshFxpNameAttr{{
Name: f,
LongName: f,
Attrs: emptyFileStat,
}},
})
}
func (p sshFxpRealpathPacket) readonly() bool { return true }
func (p sshFxpRealpathPacket) respond(svr *Server) error {
f, err := filepath.Abs(p.Path)
if err != nil {
return svr.sendPacket(statusFromError(p.ID, err))
}
f = filepath.Clean(f)
return svr.sendPacket(sshFxpNamePacket{
ID: p.ID,
NameAttrs: []sshFxpNameAttr{{
Name: f,
LongName: f,
Attrs: emptyFileStat,
}},
})
}
func (p sshFxpOpendirPacket) readonly() bool { return true }
func (p sshFxpOpendirPacket) respond(svr *Server) error {
return sshFxpOpenPacket{
ID: p.ID,
Path: p.Path,
Pflags: ssh_FXF_READ,
}.respond(svr)
}
func (p sshFxpOpenPacket) readonly() bool {
return !p.hasPflags(ssh_FXF_WRITE)
}
@ -398,7 +401,6 @@ func (p sshFxpOpenPacket) hasPflags(flags ...uint32) bool {
return false
}
}
return true
}
@ -412,7 +414,7 @@ func (p sshFxpOpenPacket) respond(svr *Server) error {
osFlags |= os.O_RDONLY
} else {
// how are they opening?
return svr.sendPacket(statusFromError(p.ID, syscall.EINVAL))
return svr.sendError(p, syscall.EINVAL)
}
if p.hasPflags(ssh_FXF_APPEND) {
@ -430,69 +432,23 @@ func (p sshFxpOpenPacket) respond(svr *Server) error {
f, err := os.OpenFile(p.Path, osFlags, 0644)
if err != nil {
return svr.sendPacket(statusFromError(p.ID, err))
return svr.sendError(p, err)
}
handle := svr.nextHandle(f)
return svr.sendPacket(sshFxpHandlePacket{p.ID, handle})
}
func (p sshFxpClosePacket) readonly() bool { return true }
func (p sshFxpClosePacket) respond(svr *Server) error {
return svr.sendPacket(statusFromError(p.ID, svr.closeHandle(p.Handle)))
}
func (p sshFxpReadPacket) readonly() bool { return true }
func (p sshFxpReadPacket) respond(svr *Server) error {
f, ok := svr.getHandle(p.Handle)
if !ok {
return svr.sendPacket(statusFromError(p.ID, syscall.EBADF))
}
if p.Len > svr.maxTxPacket {
p.Len = svr.maxTxPacket
}
ret := sshFxpDataPacket{
ID: p.ID,
Length: p.Len,
Data: make([]byte, p.Len),
}
n, err := f.ReadAt(ret.Data, int64(p.Offset))
if err != nil && (err != io.EOF || n == 0) {
return svr.sendPacket(statusFromError(p.ID, err))
}
ret.Length = uint32(n)
return svr.sendPacket(ret)
}
func (p sshFxpWritePacket) readonly() bool { return false }
func (p sshFxpWritePacket) respond(svr *Server) error {
f, ok := svr.getHandle(p.Handle)
if !ok {
return svr.sendPacket(statusFromError(p.ID, syscall.EBADF))
}
_, err := f.WriteAt(p.Data, int64(p.Offset))
return svr.sendPacket(statusFromError(p.ID, err))
}
func (p sshFxpReaddirPacket) readonly() bool { return true }
func (p sshFxpReaddirPacket) respond(svr *Server) error {
f, ok := svr.getHandle(p.Handle)
if !ok {
return svr.sendPacket(statusFromError(p.ID, syscall.EBADF))
return svr.sendError(p, syscall.EBADF)
}
dirname := f.Name()
dirents, err := f.Readdir(128)
if err != nil {
return svr.sendPacket(statusFromError(p.ID, err))
return svr.sendError(p, err)
}
ret := sshFxpNamePacket{ID: p.ID}
@ -506,8 +462,6 @@ func (p sshFxpReaddirPacket) respond(svr *Server) error {
return svr.sendPacket(ret)
}
func (p sshFxpSetstatPacket) readonly() bool { return false }
func (p sshFxpSetstatPacket) respond(svr *Server) error {
// additional unmarshalling is required for each possibility here
b := p.Attrs.([]byte)
@ -547,15 +501,13 @@ func (p sshFxpSetstatPacket) respond(svr *Server) error {
}
}
return svr.sendPacket(statusFromError(p.ID, err))
return svr.sendError(p, err)
}
func (p sshFxpFsetstatPacket) readonly() bool { return false }
func (p sshFxpFsetstatPacket) respond(svr *Server) error {
f, ok := svr.getHandle(p.Handle)
if !ok {
return svr.sendPacket(statusFromError(p.ID, syscall.EBADF))
return svr.sendError(p, syscall.EBADF)
}
// additional unmarshalling is required for each possibility here
@ -596,7 +548,7 @@ func (p sshFxpFsetstatPacket) respond(svr *Server) error {
}
}
return svr.sendPacket(statusFromError(p.ID, err))
return svr.sendError(p, err)
}
// translateErrno translates a syscall error number to a SFTP error code.
@ -613,9 +565,9 @@ func translateErrno(errno syscall.Errno) uint32 {
return ssh_FX_FAILURE
}
func statusFromError(id uint32, err error) sshFxpStatusPacket {
func statusFromError(p id, err error) sshFxpStatusPacket {
ret := sshFxpStatusPacket{
ID: id,
ID: p.id(),
StatusError: StatusError{
// ssh_FX_OK = 0
// ssh_FX_EOF = 1
@ -646,3 +598,10 @@ func statusFromError(id uint32, err error) sshFxpStatusPacket {
}
return ret
}
func clamp(v, max uint32) uint32 {
if v > max {
return max
}
return v
}

View File

@ -295,7 +295,6 @@ func (chsvr *sshSessionChannelServer) handleSubsystem(req *ssh.Request) error {
}
sftpServer, err := NewServer(
chsvr.ch,
chsvr.ch,
WithDebug(sftpServerDebugStream),
)

View File

@ -6,6 +6,7 @@ package main
import (
"flag"
"fmt"
"io"
"io/ioutil"
"os"
@ -16,10 +17,12 @@ func main() {
var (
readOnly bool
debugStderr bool
debugLevel string
)
flag.BoolVar(&readOnly, "R", false, "read-only server")
flag.BoolVar(&debugStderr, "e", false, "debug to stderr")
flag.StringVar(&debugLevel, "l", "none", "debug level (ignored)")
flag.Parse()
debugStream := ioutil.Discard
@ -28,8 +31,12 @@ func main() {
}
svr, _ := sftp.NewServer(
os.Stdin,
os.Stdout,
struct {
io.Reader
io.WriteCloser
}{os.Stdin,
os.Stdout,
},
sftp.WithDebug(debugStream),
sftp.ReadOnly(),
)

View File

@ -0,0 +1,21 @@
package sftp
import (
"syscall"
)
func statvfsFromStatfst(stat *syscall.Statfs_t) (*StatVFS, error) {
return &StatVFS{
Bsize: uint64(stat.Bsize),
Frsize: uint64(stat.Bsize), // fragment size is a linux thing; use block size here
Blocks: stat.Blocks,
Bfree: stat.Bfree,
Bavail: stat.Bavail,
Files: stat.Files,
Ffree: stat.Ffree,
Favail: stat.Ffree, // not sure how to calculate Favail
Fsid: uint64(uint64(stat.Fsid.Val[1])<<32 | uint64(stat.Fsid.Val[0])), // endianness?
Flag: uint64(stat.Flags), // assuming POSIX?
Namemax: 1024, // man 2 statfs shows: #define MAXPATHLEN 1024
}, nil
}

View File

@ -0,0 +1,25 @@
// +build darwin linux,!gccgo
// fill in statvfs structure with OS specific values
// Statfs_t is different per-kernel, and only exists on some unixes (not Solaris for instance)
package sftp
import (
"syscall"
)
func (p sshFxpExtendedPacketStatVFS) respond(svr *Server) error {
stat := &syscall.Statfs_t{}
if err := syscall.Statfs(p.Path, stat); err != nil {
return svr.sendPacket(statusFromError(p, err))
}
retPkt, err := statvfsFromStatfst(stat)
if err != nil {
return svr.sendPacket(statusFromError(p, err))
}
retPkt.ID = p.ID
return svr.sendPacket(retPkt)
}

View File

@ -0,0 +1,23 @@
// +build !gccgo,linux
package sftp
import (
"syscall"
)
func statvfsFromStatfst(stat *syscall.Statfs_t) (*StatVFS, error) {
return &StatVFS{
Bsize: uint64(stat.Bsize),
Frsize: uint64(stat.Frsize),
Blocks: stat.Blocks,
Bfree: stat.Bfree,
Bavail: stat.Bavail,
Files: stat.Files,
Ffree: stat.Ffree,
Favail: stat.Ffree, // not sure how to calculate Favail
Fsid: uint64(uint64(stat.Fsid.X__val[1])<<32 | uint64(stat.Fsid.X__val[0])), // endianness?
Flag: uint64(stat.Flags), // assuming POSIX?
Namemax: uint64(stat.Namelen),
}, nil
}

View File

@ -0,0 +1,11 @@
// +build !darwin,!linux gccgo
package sftp
import (
"syscall"
)
func (p sshFxpExtendedPacketStatVFS) respond(svr *Server) error {
return syscall.ENOTSUP
}

View File

@ -0,0 +1,64 @@
package sftp
import (
"io"
"testing"
)
func clientServerPair(t *testing.T) (*Client, *Server) {
cr, sw := io.Pipe()
sr, cw := io.Pipe()
server, err := NewServer(struct {
io.Reader
io.WriteCloser
}{sr, sw})
if err != nil {
t.Fatal(err)
}
go server.Serve()
client, err := NewClientPipe(cr, cw)
if err != nil {
t.Fatalf("%+v\n", err)
}
return client, server
}
type sshFxpTestBadExtendedPacket struct {
ID uint32
Extension string
Data string
}
func (p sshFxpTestBadExtendedPacket) id() uint32 { return p.ID }
func (p sshFxpTestBadExtendedPacket) MarshalBinary() ([]byte, error) {
l := 1 + 4 + 4 + // type(byte) + uint32 + uint32
len(p.Extension) +
len(p.Data)
b := make([]byte, 0, l)
b = append(b, ssh_FXP_EXTENDED)
b = marshalUint32(b, p.ID)
b = marshalString(b, p.Extension)
b = marshalString(b, p.Data)
return b, nil
}
// test that errors are sent back when we request an invalid extended packet operation
func TestInvalidExtendedPacket(t *testing.T) {
client, _ := clientServerPair(t)
defer client.Close()
badPacket := sshFxpTestBadExtendedPacket{client.nextID(), "thisDoesn'tExist", "foobar"}
_, _, err := client.clientConn.sendPacket(badPacket)
if err == nil {
t.Fatal("expected error from bad packet")
}
// try to stat a file; the client should have shut down.
filePath := "/etc/passwd"
_, err = client.Stat(filePath)
if err == nil {
t.Fatal("expected error from closed connection")
}
}

View File

@ -4,6 +4,8 @@ package sftp
import (
"fmt"
"github.com/pkg/errors"
)
const (
@ -182,7 +184,7 @@ func (u *unexpectedPacketErr) Error() string {
}
func unimplementedPacketErr(u uint8) error {
return fmt.Errorf("sftp: unimplemented packet type: got %v", fxp(u))
return errors.Errorf("sftp: unimplemented packet type: got %v", fxp(u))
}
type unexpectedIDErr struct{ want, got uint32 }
@ -192,11 +194,11 @@ func (u *unexpectedIDErr) Error() string {
}
func unimplementedSeekWhence(whence int) error {
return fmt.Errorf("sftp: unimplemented seek whence %v", whence)
return errors.Errorf("sftp: unimplemented seek whence %v", whence)
}
func unexpectedCount(want, got uint32) error {
return fmt.Errorf("sftp: unexpected count: want %v, got %v", want, got)
return errors.Errorf("sftp: unexpected count: want %v, got %v", want, got)
}
type unexpectedVersionErr struct{ want, got uint32 }