Update vendored library github.com/pkg/sftp

This commit is contained in:
Alexander Neumann 2018-03-30 12:37:16 +02:00
parent d9ba9279e0
commit 19035e977b
27 changed files with 494 additions and 240 deletions

4
Gopkg.lock generated
View File

@ -130,8 +130,8 @@
[[projects]]
name = "github.com/pkg/sftp"
packages = ["."]
revision = "f6a9258a0f570c3a76681b897b6ded57cb0dfa88"
version = "1.2.0"
revision = "49488377fa2f14143ba3067cf7555f60f6c7b550"
version = "1.5.0"
[[projects]]
name = "github.com/pkg/xattr"

View File

@ -4,8 +4,8 @@ go_import_path: github.com/pkg/sftp
# current and previous stable releases, plus tip
# remember to exclude previous and tip for macs below
go:
- 1.8.x
- 1.9.x
- 1.10.x
- tip
os:
@ -15,7 +15,7 @@ os:
matrix:
exclude:
- os: osx
go: 1.8.x
go: 1.9.x
- os: osx
go: tip

View File

@ -118,7 +118,7 @@ func getFileStat(flags uint32, b []byte) (*FileStat, []byte) {
if flags&ssh_FILEXFER_ATTR_EXTENDED == ssh_FILEXFER_ATTR_EXTENDED {
var count uint32
count, b = unmarshalUint32(b)
ext := make([]StatExtended, count, count)
ext := make([]StatExtended, count)
for i := uint32(0); i < count; i++ {
var typ string
var data string

77
vendor/github.com/pkg/sftp/client.go generated vendored
View File

@ -24,29 +24,57 @@ var InternalInconsistency = errors.New("internal inconsistency")
// A ClientOption is a function which applies configuration to a Client.
type ClientOption func(*Client) error
// This is based on Openssh's max accepted size of 1<<18 - overhead
const maxMaxPacket = (1 << 18) - 1024
// MaxPacket sets the maximum size of the payload. The size param must be
// between 32768 (1<<15) and 261120 ((1 << 18) - 1024). The minimum size is
// given by the RFC, while the maximum size is a de-facto standard based on
// Openssh's SFTP server which won't accept packets much larger than that.
// MaxPacketChecked sets the maximum size of the payload, measured in bytes.
// This option only accepts sizes servers should support, ie. <= 32768 bytes.
//
// Note if you aren't using Openssh's sftp server and get the error "failed to
// send packet header: EOF" when copying a large file try lowering this number.
func MaxPacket(size int) ClientOption {
// If you get the error "failed to send packet header: EOF" when copying a
// large file, try lowering this number.
//
// The default packet size is 32768 bytes.
func MaxPacketChecked(size int) ClientOption {
return func(c *Client) error {
if size < 1<<15 {
return errors.Errorf("size must be greater or equal to 32k")
if size < 1 {
return errors.Errorf("size must be greater or equal to 1")
}
if size > maxMaxPacket {
return errors.Errorf("max packet size is too large (see docs)")
if size > 32768 {
return errors.Errorf("sizes larger than 32KB might not work with all servers")
}
c.maxPacket = size
return nil
}
}
// MaxPacketUnchecked sets the maximum size of the payload, measured in bytes.
// It accepts sizes larger than the 32768 bytes all servers should support.
// Only use a setting higher than 32768 if your application always connects to
// the same server or after sufficiently broad testing.
//
// If you get the error "failed to send packet header: EOF" when copying a
// large file, try lowering this number.
//
// The default packet size is 32768 bytes.
func MaxPacketUnchecked(size int) ClientOption {
return func(c *Client) error {
if size < 1 {
return errors.Errorf("size must be greater or equal to 1")
}
c.maxPacket = size
return nil
}
}
// MaxPacket sets the maximum size of the payload, measured in bytes.
// This option only accepts sizes servers should support, ie. <= 32768 bytes.
// This is a synonym for MaxPacketChecked that provides backward compatibility.
//
// If you get the error "failed to send packet header: EOF" when copying a
// large file, try lowering this number.
//
// The default packet size is 32768 bytes.
func MaxPacket(size int) ClientOption {
return MaxPacketChecked(size)
}
// NewClient creates a new SFTP client on conn, using zero or more option
// functions.
func NewClient(conn *ssh.Client, opts ...ClientOption) (*Client, error) {
@ -112,9 +140,10 @@ type Client struct {
nextid uint32
}
// Create creates the named file mode 0666 (before umask), truncating it if
// it already exists. If successful, methods on the returned File can be
// used for I/O; the associated file descriptor has mode O_RDWR.
// Create creates the named file mode 0666 (before umask), truncating it if it
// already exists. If successful, methods on the returned File can be used for
// I/O; the associated file descriptor has mode O_RDWR. If you need more
// control over the flags/mode used to open the file see client.OpenFile.
func (c *Client) Create(path string) (*File, error) {
return c.open(path, flags(os.O_RDWR|os.O_CREATE|os.O_TRUNC))
}
@ -697,7 +726,7 @@ func (f *File) Read(b []byte) (int, error) {
offset := f.offset
// maxConcurrentRequests buffer to deal with broadcastErr() floods
// also must have a buffer of max value of (desiredInFlight - inFlight)
ch := make(chan result, maxConcurrentRequests)
ch := make(chan result, maxConcurrentRequests+1)
type inflightRead struct {
b []byte
offset uint64
@ -793,7 +822,7 @@ func (f *File) WriteTo(w io.Writer) (int64, error) {
writeOffset := offset
fileSize := uint64(fi.Size())
// see comment on same line in Read() above
ch := make(chan result, maxConcurrentRequests)
ch := make(chan result, maxConcurrentRequests+1)
type inflightRead struct {
b []byte
offset uint64
@ -936,7 +965,7 @@ func (f *File) Write(b []byte) (int, error) {
desiredInFlight := 1
offset := f.offset
// see comment on same line in Read() above
ch := make(chan result, maxConcurrentRequests)
ch := make(chan result, maxConcurrentRequests+1)
var firstErr error
written := len(b)
for len(b) > 0 || inFlight > 0 {
@ -997,7 +1026,7 @@ func (f *File) ReadFrom(r io.Reader) (int64, error) {
desiredInFlight := 1
offset := f.offset
// see comment on same line in Read() above
ch := make(chan result, maxConcurrentRequests)
ch := make(chan result, maxConcurrentRequests+1)
var firstErr error
read := int64(0)
b := make([]byte, f.c.maxPacket)
@ -1061,11 +1090,11 @@ func (f *File) ReadFrom(r io.Reader) (int64, error) {
// the file is undefined. Seeking relative to the end calls Stat.
func (f *File) Seek(offset int64, whence int) (int64, error) {
switch whence {
case os.SEEK_SET:
case io.SeekStart:
f.offset = uint64(offset)
case os.SEEK_CUR:
case io.SeekCurrent:
f.offset = uint64(int64(f.offset) + offset)
case os.SEEK_END:
case io.SeekEnd:
fi, err := f.Stat()
if err != nil {
return int64(f.offset), err

View File

@ -62,7 +62,7 @@ func newDelayedWriter(w io.WriteCloser, delay time.Duration) io.WriteCloser {
closed := make(chan struct{})
go func() {
for writeMsg := range ch {
time.Sleep(writeMsg.t.Add(delay).Sub(time.Now()))
time.Sleep(time.Until(writeMsg.t.Add(delay)))
n, err := w.Write(writeMsg.b)
if err != nil {
panic("write error")
@ -313,7 +313,7 @@ func (s seek) Generate(r *rand.Rand, _ int) reflect.Value {
}
func (s seek) set(t *testing.T, r io.ReadSeeker) {
if _, err := r.Seek(s.offset, os.SEEK_SET); err != nil {
if _, err := r.Seek(s.offset, io.SeekStart); err != nil {
t.Fatalf("error while seeking with %+v: %v", s, err)
}
}
@ -326,16 +326,16 @@ func (s seek) current(t *testing.T, r io.ReadSeeker) {
skip = -skip
}
if _, err := r.Seek(mid, os.SEEK_SET); err != nil {
if _, err := r.Seek(mid, io.SeekStart); err != nil {
t.Fatalf("error seeking to midpoint with %+v: %v", s, err)
}
if _, err := r.Seek(skip, os.SEEK_CUR); err != nil {
if _, err := r.Seek(skip, io.SeekCurrent); err != nil {
t.Fatalf("error seeking from %d with %+v: %v", mid, s, err)
}
}
func (s seek) end(t *testing.T, r io.ReadSeeker) {
if _, err := r.Seek(-s.offset, os.SEEK_END); err != nil {
if _, err := r.Seek(-s.offset, io.SeekEnd); err != nil {
t.Fatalf("error seeking from end with %+v: %v", s, err)
}
}
@ -1761,6 +1761,68 @@ func TestServerRoughDisconnect2(t *testing.T) {
}
}
// sftp/issue/234 - abrupt shutdown during ReadFrom hangs client
func TestServerRoughDisconnect3(t *testing.T) {
if *testServerImpl {
t.Skipf("skipping with -testserver")
}
sftp, cmd := testClient(t, READWRITE, NO_DELAY)
defer cmd.Wait()
defer sftp.Close()
rf, err := sftp.OpenFile("/dev/null", os.O_RDWR)
if err != nil {
t.Fatal(err)
}
defer rf.Close()
lf, err := os.Open("/dev/zero")
if err != nil {
t.Fatal(err)
}
defer lf.Close()
go func() {
time.Sleep(10 * time.Millisecond)
cmd.Process.Kill()
}()
io.Copy(rf, lf)
}
// sftp/issue/234 - also affected Write
func TestServerRoughDisconnect4(t *testing.T) {
if *testServerImpl {
t.Skipf("skipping with -testserver")
}
sftp, cmd := testClient(t, READWRITE, NO_DELAY)
defer cmd.Wait()
defer sftp.Close()
rf, err := sftp.OpenFile("/dev/null", os.O_RDWR)
if err != nil {
t.Fatal(err)
}
defer rf.Close()
lf, err := os.Open("/dev/zero")
if err != nil {
t.Fatal(err)
}
defer lf.Close()
go func() {
time.Sleep(10 * time.Millisecond)
cmd.Process.Kill()
}()
b := make([]byte, 32768*200)
lf.Read(b)
for {
_, err = rf.Write(b)
if err != nil {
break
}
}
io.Copy(rf, lf)
}
// sftp/issue/26 writing to a read only file caused client to loop.
func TestClientWriteToROFile(t *testing.T) {
sftp, cmd := testClient(t, READWRITE, NO_DELAY)

View File

@ -145,3 +145,52 @@ func TestUnmarshalStatus(t *testing.T) {
}
}
}
type packetSizeTest struct {
size int
valid bool
}
var maxPacketCheckedTests = []packetSizeTest{
{size: 0, valid: false},
{size: 1, valid: true},
{size: 32768, valid: true},
{size: 32769, valid: false},
}
var maxPacketUncheckedTests = []packetSizeTest{
{size: 0, valid: false},
{size: 1, valid: true},
{size: 32768, valid: true},
{size: 32769, valid: true},
}
func TestMaxPacketChecked(t *testing.T) {
for _, tt := range maxPacketCheckedTests {
testMaxPacketOption(t, MaxPacketChecked(tt.size), tt)
}
}
func TestMaxPacketUnchecked(t *testing.T) {
for _, tt := range maxPacketUncheckedTests {
testMaxPacketOption(t, MaxPacketUnchecked(tt.size), tt)
}
}
func TestMaxPacket(t *testing.T) {
for _, tt := range maxPacketCheckedTests {
testMaxPacketOption(t, MaxPacket(tt.size), tt)
}
}
func testMaxPacketOption(t *testing.T, o ClientOption, tt packetSizeTest) {
var c Client
err := o(&c)
if (err == nil) != tt.valid {
t.Errorf("MaxPacketChecked(%v)\n- want: %v\n- got: %v", tt.size, tt.valid, err == nil)
}
if c.maxPacket != tt.size && tt.valid {
t.Errorf("MaxPacketChecked(%v)\n- want: %v\n- got: %v", tt.size, tt.size, c.maxPacket)
}
}

2
vendor/github.com/pkg/sftp/conn.go generated vendored
View File

@ -93,7 +93,7 @@ type idmarshaler interface {
}
func (c *clientConn) sendPacket(p idmarshaler) (byte, []byte, error) {
ch := make(chan result, 1)
ch := make(chan result, 2)
c.dispatchRequest(ch, p)
s := <-ch
return s.typ, s.data, s.err

View File

@ -1,7 +1,9 @@
package sftp_test
import (
"bufio"
"fmt"
"io"
"log"
"os"
"os/exec"
@ -107,13 +109,13 @@ func ExampleClient_Mkdir_parents() {
sshFxFailure := uint32(4)
mkdirParents := func(client *sftp.Client, dir string) (err error) {
var parents string
if path.IsAbs(dir) {
// Otherwise, an absolute path given below would be turned in to a relative one
// by splitting on "/"
parents = "/"
}
for _, name := range strings.Split(dir, "/") {
if name == "" {
// Paths with double-/ in them should just move along
@ -145,3 +147,18 @@ func ExampleClient_Mkdir_parents() {
log.Fatal(err)
}
}
func ExampleFile_ReadFrom_bufio() {
// Using Bufio to buffer writes going to an sftp.File won't buffer as it
// skips buffering if the underlying writer support ReadFrom. The
// workaround is to wrap your writer in a struct that only implements
// io.Writer.
//
// For background see github.com/pkg/sftp/issues/125
var data_source io.Reader
var f *sftp.File
type writerOnly struct{ io.Writer }
bw := bufio.NewWriter(writerOnly{f}) // no ReadFrom()
bw.ReadFrom(data_source)
}

View File

@ -2,6 +2,7 @@ package sftp
import (
"encoding"
"sort"
"sync"
)
@ -37,6 +38,22 @@ func newPktMgr(sender packetSender) *packetManager {
return s
}
type responsePackets []responsePacket
func (r responsePackets) Sort() {
sort.Slice(r, func(i, j int) bool {
return r[i].id() < r[j].id()
})
}
type requestPacketIDs []uint32
func (r requestPacketIDs) Sort() {
sort.Slice(r, func(i, j int) bool {
return r[i] < r[j]
})
}
// register incoming packets to be handled
// send id of 0 for packets without id
func (s *packetManager) incomingPacket(pkt requestPacket) {

View File

@ -1,21 +0,0 @@
// +build go1.8
package sftp
import "sort"
type responsePackets []responsePacket
func (r responsePackets) Sort() {
sort.Slice(r, func(i, j int) bool {
return r[i].id() < r[j].id()
})
}
type requestPacketIDs []uint32
func (r requestPacketIDs) Sort() {
sort.Slice(r, func(i, j int) bool {
return r[i] < r[j]
})
}

View File

@ -1,21 +0,0 @@
// +build !go1.8
package sftp
import "sort"
// for sorting/ordering outgoing
type responsePackets []responsePacket
func (r responsePackets) Len() int { return len(r) }
func (r responsePackets) Swap(i, j int) { r[i], r[j] = r[j], r[i] }
func (r responsePackets) Less(i, j int) bool { return r[i].id() < r[j].id() }
func (r responsePackets) Sort() { sort.Sort(r) }
// for sorting/ordering incoming
type requestPacketIDs []uint32
func (r requestPacketIDs) Len() int { return len(r) }
func (r requestPacketIDs) Swap(i, j int) { r[i], r[j] = r[j], r[i] }
func (r requestPacketIDs) Less(i, j int) bool { return r[i] < r[j] }
func (r requestPacketIDs) Sort() { sort.Sort(r) }

View File

@ -3,9 +3,7 @@ package sftp
import (
"encoding"
"fmt"
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
@ -102,53 +100,3 @@ func (p sshFxpWritePacket) String() string {
func (p sshFxpClosePacket) String() string {
return fmt.Sprintf("ClPct:%d", p.ID)
}
// Test what happens when the pool processes a close packet on a file that it
// is still reading from.
func TestCloseOutOfOrder(t *testing.T) {
packets := []requestPacket{
&sshFxpRemovePacket{ID: 0, Filename: "foo"},
&sshFxpOpenPacket{ID: 1},
&sshFxpWritePacket{ID: 2, Handle: "foo"},
&sshFxpWritePacket{ID: 3, Handle: "foo"},
&sshFxpWritePacket{ID: 4, Handle: "foo"},
&sshFxpWritePacket{ID: 5, Handle: "foo"},
&sshFxpClosePacket{ID: 6, Handle: "foo"},
&sshFxpRemovePacket{ID: 7, Filename: "foo"},
}
recvChan := make(chan requestPacket, len(packets)+1)
sender := newTestSender()
pktMgr := newPktMgr(sender)
wg := sync.WaitGroup{}
wg.Add(len(packets))
runWorker := func(ch requestChan) {
go func() {
for pkt := range ch {
if _, ok := pkt.(*sshFxpWritePacket); ok {
// sleep to cause writes to come after close/remove
time.Sleep(time.Millisecond)
}
pktMgr.working.Done()
recvChan <- pkt
wg.Done()
}
}()
}
pktChan := pktMgr.workerChan(runWorker)
for _, p := range packets {
pktChan <- p
}
wg.Wait()
close(recvChan)
received := []requestPacket{}
for p := range recvChan {
received = append(received, p)
}
if received[len(received)-2].id() != packets[len(packets)-2].id() {
t.Fatal("Packets processed out of order1:", received, packets)
}
if received[len(received)-1].id() != packets[len(packets)-1].id() {
t.Fatal("Packets processed out of order2:", received, packets)
}
}

View File

@ -58,6 +58,7 @@ func (p sshFxpFsetstatPacket) getHandle() string { return p.Handle }
func (p sshFxpReadPacket) getHandle() string { return p.Handle }
func (p sshFxpWritePacket) getHandle() string { return p.Handle }
func (p sshFxpReaddirPacket) getHandle() string { return p.Handle }
func (p sshFxpClosePacket) getHandle() string { return p.Handle }
// notReadOnly
func (p sshFxpWritePacket) notReadOnly() {}
@ -70,9 +71,6 @@ func (p sshFxpRenamePacket) notReadOnly() {}
func (p sshFxpSymlinkPacket) notReadOnly() {}
func (p sshFxpExtendedPacketPosixRename) notReadOnly() {}
// this has a handle, but is only used for close
func (p sshFxpClosePacket) getHandle() string { return p.Handle }
// some packets with ID are missing id()
func (p sshFxpDataPacket) id() uint32 { return p.ID }
func (p sshFxpStatusPacket) id() uint32 { return p.ID }

18
vendor/github.com/pkg/sftp/packet.go generated vendored
View File

@ -255,7 +255,7 @@ func unmarshalIDString(b []byte, id *uint32, str *string) error {
if err != nil {
return err
}
*str, b, err = unmarshalStringSafe(b)
*str, _, err = unmarshalStringSafe(b)
return err
}
@ -406,7 +406,7 @@ func (p *sshFxpSymlinkPacket) UnmarshalBinary(b []byte) error {
return err
} else if p.Targetpath, b, err = unmarshalStringSafe(b); err != nil {
return err
} else if p.Linkpath, b, err = unmarshalStringSafe(b); err != nil {
} else if p.Linkpath, _, err = unmarshalStringSafe(b); err != nil {
return err
}
return nil
@ -510,7 +510,7 @@ func (p *sshFxpOpenPacket) UnmarshalBinary(b []byte) error {
return err
} else if p.Pflags, b, err = unmarshalUint32Safe(b); err != nil {
return err
} else if p.Flags, b, err = unmarshalUint32Safe(b); err != nil {
} else if p.Flags, _, err = unmarshalUint32Safe(b); err != nil {
return err
}
return nil
@ -547,7 +547,7 @@ func (p *sshFxpReadPacket) UnmarshalBinary(b []byte) error {
return err
} else if p.Offset, b, err = unmarshalUint64Safe(b); err != nil {
return err
} else if p.Len, b, err = unmarshalUint32Safe(b); err != nil {
} else if p.Len, _, err = unmarshalUint32Safe(b); err != nil {
return err
}
return nil
@ -580,7 +580,7 @@ func (p *sshFxpRenamePacket) UnmarshalBinary(b []byte) error {
return err
} else if p.Oldpath, b, err = unmarshalStringSafe(b); err != nil {
return err
} else if p.Newpath, b, err = unmarshalStringSafe(b); err != nil {
} else if p.Newpath, _, err = unmarshalStringSafe(b); err != nil {
return err
}
return nil
@ -681,7 +681,7 @@ func (p *sshFxpMkdirPacket) UnmarshalBinary(b []byte) error {
return err
} else if p.Path, b, err = unmarshalStringSafe(b); err != nil {
return err
} else if p.Flags, b, err = unmarshalUint32Safe(b); err != nil {
} else if p.Flags, _, err = unmarshalUint32Safe(b); err != nil {
return err
}
return nil
@ -886,7 +886,7 @@ func (p *sshFxpExtendedPacket) UnmarshalBinary(b []byte) error {
bOrig := b
if p.ID, b, err = unmarshalUint32Safe(b); err != nil {
return err
} else if p.ExtendedRequest, b, err = unmarshalStringSafe(b); err != nil {
} else if p.ExtendedRequest, _, err = unmarshalStringSafe(b); err != nil {
return err
}
@ -917,7 +917,7 @@ func (p *sshFxpExtendedPacketStatVFS) UnmarshalBinary(b []byte) error {
return err
} else if p.ExtendedRequest, b, err = unmarshalStringSafe(b); err != nil {
return err
} else if p.Path, b, err = unmarshalStringSafe(b); err != nil {
} else if p.Path, _, err = unmarshalStringSafe(b); err != nil {
return err
}
return nil
@ -940,7 +940,7 @@ func (p *sshFxpExtendedPacketPosixRename) UnmarshalBinary(b []byte) error {
return err
} else if p.Oldpath, b, err = unmarshalStringSafe(b); err != nil {
return err
} else if p.Newpath, b, err = unmarshalStringSafe(b); err != nil {
} else if p.Newpath, _, err = unmarshalStringSafe(b); err != nil {
return err
}
return nil

42
vendor/github.com/pkg/sftp/request-errors.go generated vendored Normal file
View File

@ -0,0 +1,42 @@
package sftp
// Error types that match the SFTP's SSH_FXP_STATUS codes. Gives you more
// direct control of the errors being sent vs. letting the library work them
// out from the standard os/io errors.
type fxerr uint32
const (
ErrSshFxOk = fxerr(ssh_FX_OK)
ErrSshFxEof = fxerr(ssh_FX_EOF)
ErrSshFxNoSuchFile = fxerr(ssh_FX_NO_SUCH_FILE)
ErrSshFxPermissionDenied = fxerr(ssh_FX_PERMISSION_DENIED)
ErrSshFxFailure = fxerr(ssh_FX_FAILURE)
ErrSshFxBadMessage = fxerr(ssh_FX_BAD_MESSAGE)
ErrSshFxNoConnection = fxerr(ssh_FX_NO_CONNECTION)
ErrSshFxConnectionLost = fxerr(ssh_FX_CONNECTION_LOST)
ErrSshFxOpUnsupported = fxerr(ssh_FX_OP_UNSUPPORTED)
)
func (e fxerr) Error() string {
switch e {
case ErrSshFxOk:
return "OK"
case ErrSshFxEof:
return "EOF"
case ErrSshFxNoSuchFile:
return "No Such File"
case ErrSshFxPermissionDenied:
return "Permission Denied"
case ErrSshFxBadMessage:
return "Bad Message"
case ErrSshFxNoConnection:
return "No Connection"
case ErrSshFxConnectionLost:
return "Connection Lost"
case ErrSshFxOpUnsupported:
return "Operation Unsupported"
default:
return "Failure"
}
}

View File

@ -24,18 +24,10 @@ func InMemHandler() Handlers {
return Handlers{root, root, root, root}
}
// So I can test Handlers returning errors
var (
readErr error = nil
writeErr error = nil
cmdErr error = nil
listErr error = nil
)
// Handlers
func (fs *root) Fileread(r *Request) (io.ReaderAt, error) {
if readErr != nil {
return nil, readErr
if fs.mockErr != nil {
return nil, fs.mockErr
}
fs.filesLock.Lock()
defer fs.filesLock.Unlock()
@ -53,8 +45,8 @@ func (fs *root) Fileread(r *Request) (io.ReaderAt, error) {
}
func (fs *root) Filewrite(r *Request) (io.WriterAt, error) {
if writeErr != nil {
return nil, writeErr
if fs.mockErr != nil {
return nil, fs.mockErr
}
fs.filesLock.Lock()
defer fs.filesLock.Unlock()
@ -74,8 +66,8 @@ func (fs *root) Filewrite(r *Request) (io.WriterAt, error) {
}
func (fs *root) Filecmd(r *Request) error {
if cmdErr != nil {
return cmdErr
if fs.mockErr != nil {
return fs.mockErr
}
fs.filesLock.Lock()
defer fs.filesLock.Unlock()
@ -133,8 +125,8 @@ func (f listerat) ListAt(ls []os.FileInfo, offset int64) (int, error) {
}
func (fs *root) Filelist(r *Request) (ListerAt, error) {
if listErr != nil {
return nil, listErr
if fs.mockErr != nil {
return nil, fs.mockErr
}
fs.filesLock.Lock()
defer fs.filesLock.Unlock()
@ -147,7 +139,7 @@ func (fs *root) Filelist(r *Request) (ListerAt, error) {
ordered_names = append(ordered_names, fn)
}
}
sort.Sort(sort.StringSlice(ordered_names))
sort.Strings(ordered_names)
list := make([]os.FileInfo, len(ordered_names))
for i, fn := range ordered_names {
list[i] = fs.files[fn]
@ -180,6 +172,13 @@ type root struct {
*memFile
files map[string]*memFile
filesLock sync.Mutex
mockErr error
}
// Set a mocked error that the next handler call will return.
// Set to nil to reset for no error.
func (fs *root) returnErr(err error) {
fs.mockErr = err
}
func (fs *root) fetch(path string) (*memFile, error) {

View File

@ -8,22 +8,29 @@ import (
// Interfaces are differentiated based on required returned values.
// All input arguments are to be pulled from Request (the only arg).
// FileReader should return an io.Reader for the filepath
// FileReader should return an io.ReaderAt for the filepath
// Note in cases of an error, the error text will be sent to the client.
type FileReader interface {
Fileread(*Request) (io.ReaderAt, error)
}
// FileWriter should return an io.Writer for the filepath
// FileWriter should return an io.WriterAt for the filepath.
//
// The request server code will call Close() on the returned io.WriterAt
// ojbect if an io.Closer type assertion succeeds.
// Note in cases of an error, the error text will be sent to the client.
type FileWriter interface {
Filewrite(*Request) (io.WriterAt, error)
}
// FileCmder should return an error (rename, remove, setstate, etc.)
// FileCmder should return an error
// Note in cases of an error, the error text will be sent to the client.
type FileCmder interface {
Filecmd(*Request) error
}
// FileLister should return file info interface and errors (readdir, stat)
// FileLister should return an object that fulfils the ListerAt interface
// Note in cases of an error, the error text will be sent to the client.
type FileLister interface {
Filelist(*Request) (ListerAt, error)
}
@ -33,6 +40,7 @@ type FileLister interface {
// error if at end of list. This is testable by comparing how many you
// copied to how many could be copied (eg. n < len(ls) below).
// The copy() builtin is best for the copying.
// Note in cases of an error, the error text will be sent to the client.
type ListerAt interface {
ListAt([]os.FileInfo, int64) (int, error)
}

View File

@ -1,6 +1,7 @@
package sftp
import (
"context"
"encoding"
"io"
"path"
@ -14,8 +15,6 @@ import (
var maxTxPacket uint32 = 1 << 15
type handleHandler func(string) string
// Handlers contains the 4 SFTP server request handlers.
type Handlers struct {
FileGet FileReader
@ -82,7 +81,8 @@ func (rs *RequestServer) getRequest(handle, method string) (*Request, bool) {
if !ok || r.Method == method { // re-check needed b/c lock race
return r, ok
}
r = &Request{Method: method, Filepath: r.Filepath, state: r.state}
r = r.copy()
r.Method = method
rs.openRequests[handle] = r
return r, ok
}
@ -102,12 +102,14 @@ func (rs *RequestServer) Close() error { return rs.conn.Close() }
// Serve requests for user session
func (rs *RequestServer) Serve() error {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
var wg sync.WaitGroup
runWorker := func(ch requestChan) {
wg.Add(1)
go func() {
defer wg.Done()
if err := rs.packetWorker(ch); err != nil {
if err := rs.packetWorker(ctx, ch); err != nil {
rs.conn.Close() // shuts down recvPacket
}
}()
@ -137,10 +139,19 @@ func (rs *RequestServer) Serve() error {
close(pktChan) // shuts down sftpServerWorkers
wg.Wait() // wait for all workers to exit
// make sure all open requests are properly closed
// (eg. possible on dropped connections, client crashes, etc.)
for handle, req := range rs.openRequests {
delete(rs.openRequests, handle)
req.close()
}
return err
}
func (rs *RequestServer) packetWorker(pktChan chan requestPacket) error {
func (rs *RequestServer) packetWorker(
ctx context.Context, pktChan chan requestPacket,
) error {
for pkt := range pktChan {
var rpkt responsePacket
switch pkt := pkt.(type) {
@ -152,15 +163,15 @@ func (rs *RequestServer) packetWorker(pktChan chan requestPacket) error {
case *sshFxpRealpathPacket:
rpkt = cleanPacketPath(pkt)
case *sshFxpOpendirPacket:
request := requestFromPacket(pkt)
request := requestFromPacket(ctx, pkt)
handle := rs.nextRequest(request)
rpkt = sshFxpHandlePacket{pkt.id(), handle}
case *sshFxpOpenPacket:
request := requestFromPacket(pkt)
request := requestFromPacket(ctx, pkt)
handle := rs.nextRequest(request)
rpkt = sshFxpHandlePacket{pkt.id(), handle}
if pkt.hasPflags(ssh_FXF_CREAT) {
if p := request.call(rs.Handlers, pkt); !isOk(p) {
if p := request.call(rs.Handlers, pkt); !statusOk(p) {
rpkt = p // if error in write, return it
}
}
@ -173,8 +184,9 @@ func (rs *RequestServer) packetWorker(pktChan chan requestPacket) error {
rpkt = request.call(rs.Handlers, pkt)
}
case hasPath:
request := requestFromPacket(pkt)
request := requestFromPacket(ctx, pkt)
rpkt = request.call(rs.Handlers, pkt)
request.close()
default:
return errors.Errorf("unexpected packet type %T", pkt)
}
@ -188,7 +200,7 @@ func (rs *RequestServer) packetWorker(pktChan chan requestPacket) error {
}
// True is responsePacket is an OK status packet
func isOk(rpkt responsePacket) bool {
func statusOk(rpkt responsePacket) bool {
p, ok := rpkt.(sshFxpStatusPacket)
return ok && p.StatusError.Code == ssh_FX_OK
}
@ -224,7 +236,3 @@ func (rs *RequestServer) sendPacket(m encoding.BinaryMarshaler) error {
}
return nil
}
func (rs *RequestServer) sendError(p ider, err error) error {
return rs.sendPacket(statusFromError(p, err))
}

View File

@ -1,6 +1,7 @@
package sftp
import (
"context"
"fmt"
"io"
"net"
@ -78,16 +79,25 @@ func TestRequestCache(t *testing.T) {
p := clientRequestServerPair(t)
defer p.Close()
foo := NewRequest("", "foo")
foo.ctx, foo.cancelCtx = context.WithCancel(context.Background())
bar := NewRequest("", "bar")
fh := p.svr.nextRequest(foo)
bh := p.svr.nextRequest(bar)
assert.Len(t, p.svr.openRequests, 2)
_foo, ok := p.svr.getRequest(fh, "")
assert.Equal(t, foo, _foo)
assert.Equal(t, foo.Method, _foo.Method)
assert.Equal(t, foo.Filepath, _foo.Filepath)
assert.Equal(t, foo.Target, _foo.Target)
assert.Equal(t, foo.Flags, _foo.Flags)
assert.Equal(t, foo.Attrs, _foo.Attrs)
assert.Equal(t, foo.state, _foo.state)
assert.NotNil(t, _foo.ctx)
assert.Equal(t, _foo.Context().Err(), nil, "context is still valid")
assert.True(t, ok)
_, ok = p.svr.getRequest("zed", "")
assert.False(t, ok)
p.svr.closeRequest(fh)
assert.Equal(t, _foo.Context().Err(), context.Canceled, "context is now canceled")
p.svr.closeRequest(bh)
assert.Len(t, p.svr.openRequests, 0)
}
@ -140,10 +150,11 @@ func TestRequestWriteEmpty(t *testing.T) {
assert.Equal(t, f.content, []byte(""))
}
// lets test with an error
writeErr = os.ErrInvalid
r.returnErr(os.ErrInvalid)
n, err = putTestFile(p.cli, "/bar", "")
assert.Error(t, err)
writeErr = nil
r.returnErr(nil)
assert.Equal(t, 0, n)
}
func TestRequestFilename(t *testing.T) {
@ -155,7 +166,7 @@ func TestRequestFilename(t *testing.T) {
f, err := r.fetch("/foo")
assert.NoError(t, err)
assert.Equal(t, f.Name(), "foo")
f, err = r.fetch("/bar")
_, err = r.fetch("/bar")
assert.Error(t, err)
}
@ -258,6 +269,7 @@ func TestRequestStat(t *testing.T) {
assert.Equal(t, fi.Size(), int64(5))
assert.Equal(t, fi.Mode(), os.FileMode(0644))
assert.NoError(t, testOsSys(fi.Sys()))
assert.NoError(t, err)
}
// NOTE: Setstat is a noop in the request server tests, but we want to test

View File

@ -1,6 +1,7 @@
package sftp
import (
"context"
"io"
"os"
"path"
@ -24,11 +25,14 @@ type Request struct {
Attrs []byte // convert to sub-struct
Target string // for renames and sym-links
// reader/writer/readdir from handlers
stateLock sync.RWMutex
state state
state state
// context lasts duration of request
ctx context.Context
cancelCtx context.CancelFunc
}
type state struct {
*sync.RWMutex
writerAt io.WriterAt
readerAt io.ReaderAt
listerAt ListerAt
@ -36,9 +40,11 @@ type state struct {
}
// New Request initialized based on packet data
func requestFromPacket(pkt hasPath) *Request {
func requestFromPacket(ctx context.Context, pkt hasPath) *Request {
method := requestMethod(pkt)
request := NewRequest(method, pkt.getPath())
request.ctx, request.cancelCtx = context.WithCancel(ctx)
switch p := pkt.(type) {
case *sshFxpOpenPacket:
request.Flags = p.Pflags
@ -55,60 +61,100 @@ func requestFromPacket(pkt hasPath) *Request {
// NewRequest creates a new Request object.
func NewRequest(method, path string) *Request {
return &Request{Method: method, Filepath: cleanPath(path)}
return &Request{Method: method, Filepath: cleanPath(path),
state: state{RWMutex: new(sync.RWMutex)}}
}
// shallow copy of existing request
func (r *Request) copy() *Request {
r.state.Lock()
defer r.state.Unlock()
r2 := new(Request)
*r2 = *r
return r2
}
// Context returns the request's context. To change the context,
// use WithContext.
//
// The returned context is always non-nil; it defaults to the
// background context.
//
// For incoming server requests, the context is canceled when the
// request is complete or the client's connection closes.
func (r *Request) Context() context.Context {
if r.ctx != nil {
return r.ctx
}
return context.Background()
}
// WithContext returns a copy of r with its context changed to ctx.
// The provided ctx must be non-nil.
func (r *Request) WithContext(ctx context.Context) *Request {
if ctx == nil {
panic("nil context")
}
r2 := r.copy()
r2.ctx = ctx
r2.cancelCtx = nil
return r2
}
// Returns current offset for file list
func (r *Request) lsNext() int64 {
r.stateLock.RLock()
defer r.stateLock.RUnlock()
r.state.RLock()
defer r.state.RUnlock()
return r.state.lsoffset
}
// Increases next offset
func (r *Request) lsInc(offset int64) {
r.stateLock.Lock()
defer r.stateLock.Unlock()
r.state.Lock()
defer r.state.Unlock()
r.state.lsoffset = r.state.lsoffset + offset
}
// manage file read/write state
func (r *Request) setWriterState(wa io.WriterAt) {
r.stateLock.Lock()
defer r.stateLock.Unlock()
r.state.Lock()
defer r.state.Unlock()
r.state.writerAt = wa
}
func (r *Request) setReaderState(ra io.ReaderAt) {
r.stateLock.Lock()
defer r.stateLock.Unlock()
r.state.Lock()
defer r.state.Unlock()
r.state.readerAt = ra
}
func (r *Request) setListerState(la ListerAt) {
r.stateLock.Lock()
defer r.stateLock.Unlock()
r.state.Lock()
defer r.state.Unlock()
r.state.listerAt = la
}
func (r *Request) getWriter() io.WriterAt {
r.stateLock.RLock()
defer r.stateLock.RUnlock()
r.state.RLock()
defer r.state.RUnlock()
return r.state.writerAt
}
func (r *Request) getReader() io.ReaderAt {
r.stateLock.RLock()
defer r.stateLock.RUnlock()
r.state.RLock()
defer r.state.RUnlock()
return r.state.readerAt
}
func (r *Request) getLister() ListerAt {
r.stateLock.RLock()
defer r.stateLock.RUnlock()
r.state.RLock()
defer r.state.RUnlock()
return r.state.listerAt
}
// Close reader/writer if possible
func (r *Request) close() error {
if r.cancelCtx != nil {
r.cancelCtx()
}
rd := r.getReader()
if c, ok := rd.(io.Closer); ok {
return c.Close()

View File

@ -1,6 +1,8 @@
package sftp
import (
"sync"
"github.com/stretchr/testify/assert"
"bytes"
@ -60,6 +62,7 @@ func testRequest(method string) *Request {
Method: method,
Attrs: []byte("foo"),
Target: "foo",
state: state{RWMutex: new(sync.RWMutex)},
}
return request
}
@ -96,15 +99,19 @@ func (h Handlers) getOutString() string {
var errTest = errors.New("test error")
func (h *Handlers) returnError() {
func (h *Handlers) returnError(err error) {
handler := h.FilePut.(*testHandler)
handler.err = errTest
handler.err = err
}
func statusOk(t *testing.T, p interface{}) {
if pkt, ok := p.(*sshFxpStatusPacket); ok {
assert.Equal(t, pkt.StatusError.Code, uint32(ssh_FX_OK))
}
func getStatusMsg(p interface{}) string {
pkt := p.(sshFxpStatusPacket)
return pkt.StatusError.msg
}
func checkOkStatus(t *testing.T, p interface{}) {
pkt := p.(sshFxpStatusPacket)
assert.Equal(t, pkt.StatusError.Code, uint32(ssh_FX_OK),
"sshFxpStatusPacket not OK\n", pkt.StatusError.msg)
}
// fake/test packet
@ -135,15 +142,25 @@ func TestRequestGet(t *testing.T) {
}
}
func TestRequestCustomError(t *testing.T) {
handlers := newTestHandlers()
request := testRequest("Stat")
pkt := fakePacket{myid: 1}
cmdErr := errors.New("stat not supported")
handlers.returnError(cmdErr)
rpkt := request.call(handlers, pkt)
assert.Equal(t, rpkt, statusFromError(rpkt, cmdErr))
}
func TestRequestPut(t *testing.T) {
handlers := newTestHandlers()
request := testRequest("Put")
pkt := &sshFxpWritePacket{0, "a", 0, 5, []byte("file-")}
rpkt := request.call(handlers, pkt)
statusOk(t, rpkt)
checkOkStatus(t, rpkt)
pkt = &sshFxpWritePacket{1, "a", 5, 5, []byte("data.")}
rpkt = request.call(handlers, pkt)
statusOk(t, rpkt)
checkOkStatus(t, rpkt)
assert.Equal(t, "file-data.", handlers.getOutString())
}
@ -152,9 +169,9 @@ func TestRequestCmdr(t *testing.T) {
request := testRequest("Mkdir")
pkt := fakePacket{myid: 1}
rpkt := request.call(handlers, pkt)
statusOk(t, rpkt)
checkOkStatus(t, rpkt)
handlers.returnError()
handlers.returnError(errTest)
rpkt = request.call(handlers, pkt)
assert.Equal(t, rpkt, statusFromError(rpkt, errTest))
}

42
vendor/github.com/pkg/sftp/server.go generated vendored
View File

@ -462,7 +462,7 @@ func (p sshFxpSetstatPacket) respond(svr *Server) error {
var uid uint32
var gid uint32
if uid, b, err = unmarshalUint32Safe(b); err != nil {
} else if gid, b, err = unmarshalUint32Safe(b); err != nil {
} else if gid, _, err = unmarshalUint32Safe(b); err != nil {
} else {
err = os.Chown(p.Path, int(uid), int(gid))
}
@ -509,7 +509,7 @@ func (p sshFxpFsetstatPacket) respond(svr *Server) error {
var uid uint32
var gid uint32
if uid, b, err = unmarshalUint32Safe(b); err != nil {
} else if gid, b, err = unmarshalUint32Safe(b); err != nil {
} else if gid, _, err = unmarshalUint32Safe(b); err != nil {
} else {
err = f.Chown(int(uid), int(gid))
}
@ -548,23 +548,33 @@ func statusFromError(p ider, err error) sshFxpStatusPacket {
Code: ssh_FX_OK,
},
}
if err != nil {
debug("statusFromError: error is %T %#v", err, err)
ret.StatusError.Code = ssh_FX_FAILURE
ret.StatusError.msg = err.Error()
if err == io.EOF {
ret.StatusError.Code = ssh_FX_EOF
} else if err == os.ErrNotExist {
ret.StatusError.Code = ssh_FX_NO_SUCH_FILE
} else if errno, ok := err.(syscall.Errno); ok {
if err == nil {
return ret
}
debug("statusFromError: error is %T %#v", err, err)
ret.StatusError.Code = ssh_FX_FAILURE
ret.StatusError.msg = err.Error()
switch e := err.(type) {
case syscall.Errno:
ret.StatusError.Code = translateErrno(e)
case *os.PathError:
debug("statusFromError,pathError: error is %T %#v", e.Err, e.Err)
if errno, ok := e.Err.(syscall.Errno); ok {
ret.StatusError.Code = translateErrno(errno)
} else if pathError, ok := err.(*os.PathError); ok {
debug("statusFromError: error is %T %#v", pathError.Err, pathError.Err)
if errno, ok := pathError.Err.(syscall.Errno); ok {
ret.StatusError.Code = translateErrno(errno)
}
}
case fxerr:
ret.StatusError.Code = uint32(e)
default:
switch e {
case io.EOF:
ret.StatusError.Code = ssh_FX_EOF
case os.ErrNotExist:
ret.StatusError.Code = ssh_FX_NO_SUCH_FILE
}
}
return ret
}

View File

@ -426,7 +426,7 @@ func runSftpClient(t *testing.T, script string, path string, host string, port i
return "", err
}
err = cmd.Wait()
return string(stdout.Bytes()), err
return stdout.String(), err
}
func TestServerCompareSubsystems(t *testing.T) {
@ -468,15 +468,17 @@ ls -l /usr/bin/
if goLine != opLine {
goWords := spaceRegex.Split(goLine, -1)
opWords := spaceRegex.Split(opLine, -1)
// allow words[2] and [3] to be different as these are users & groups
// also allow words[1] to differ as the link count for directories like
// proc is unstable during testing as processes are created/destroyed.
// some fields are allowed to be different..
// words[2] and [3] as these are users & groups
// words[1] as the link count for directories like proc is unstable
// during testing as processes are created/destroyed.
// words[7] as timestamp on dirs can very for things like /tmp
for j, goWord := range goWords {
if j > len(opWords) {
bad = true
}
opWord := opWords[j]
if goWord != opWord && j != 1 && j != 2 && j != 3 {
if goWord != opWord && j != 1 && j != 2 && j != 3 && j != 7 {
bad = true
}
}

View File

@ -1,4 +1,4 @@
// +build darwin linux,!gccgo
// +build darwin linux
// fill in statvfs structure with OS specific values
// Statfs_t is different per-kernel, and only exists on some unixes (not Solaris for instance)

View File

@ -1,4 +1,4 @@
// +build !gccgo,linux
// +build linux
package sftp

View File

@ -1,4 +1,4 @@
// +build !darwin,!linux gccgo
// +build !darwin,!linux
package sftp

View File

@ -1,12 +1,16 @@
package sftp
import (
"testing"
"io"
"os"
"regexp"
"time"
"io"
"sync"
"syscall"
"testing"
"time"
"github.com/pkg/errors"
"github.com/stretchr/testify/assert"
)
const (
@ -61,7 +65,7 @@ func TestRunLsWithLicensesFile(t *testing.T) {
where `id' is the request identifier, and `attrs' is the returned
file attributes as described in Section ``File Attributes''.
*/
*/
func runLsTestHelper(t *testing.T, result, expectedType, path string) {
// using regular expressions to make tests work on all systems
// a virtual file system (like afero) would be needed to mock valid filesystem checks
@ -241,3 +245,31 @@ func TestConcurrentRequests(t *testing.T) {
}
wg.Wait()
}
// Test error conversion
func TestStatusFromError(t *testing.T) {
type test struct {
err error
pkt sshFxpStatusPacket
}
tpkt := func(id, code uint32) sshFxpStatusPacket {
return sshFxpStatusPacket{
ID: id,
StatusError: StatusError{Code: code},
}
}
test_cases := []test{
test{syscall.ENOENT, tpkt(1, ssh_FX_NO_SUCH_FILE)},
test{&os.PathError{Err: syscall.ENOENT},
tpkt(2, ssh_FX_NO_SUCH_FILE)},
test{&os.PathError{Err: errors.New("foo")}, tpkt(3, ssh_FX_FAILURE)},
test{ErrSshFxEof, tpkt(4, ssh_FX_EOF)},
test{ErrSshFxOpUnsupported, tpkt(5, ssh_FX_OP_UNSUPPORTED)},
test{io.EOF, tpkt(6, ssh_FX_EOF)},
test{os.ErrNotExist, tpkt(7, ssh_FX_NO_SUCH_FILE)},
}
for _, tc := range test_cases {
tc.pkt.StatusError.msg = tc.err.Error()
assert.Equal(t, tc.pkt, statusFromError(tc.pkt, tc.err))
}
}