2
2
mirror of https://github.com/octoleo/restic.git synced 2024-12-13 06:37:54 +00:00
restic/vendor/github.com/pkg/sftp/request-server.go
2017-07-23 14:25:38 +02:00

232 lines
5.4 KiB
Go

package sftp
import (
"encoding"
"io"
"os"
"path/filepath"
"strconv"
"sync"
"syscall"
"github.com/pkg/errors"
)
var maxTxPacket uint32 = 1 << 15
type handleHandler func(string) string
// Handlers contains the 4 SFTP server request handlers.
type Handlers struct {
FileGet FileReader
FilePut FileWriter
FileCmd FileCmder
FileInfo FileInfoer
}
// RequestServer abstracts the sftp protocol with an http request-like protocol
type RequestServer struct {
*serverConn
Handlers Handlers
pktMgr packetManager
openRequests map[string]Request
openRequestLock sync.RWMutex
handleCount int
}
// NewRequestServer creates/allocates/returns new RequestServer.
// Normally there there will be one server per user-session.
func NewRequestServer(rwc io.ReadWriteCloser, h Handlers) *RequestServer {
svrConn := &serverConn{
conn: conn{
Reader: rwc,
WriteCloser: rwc,
},
}
return &RequestServer{
serverConn: svrConn,
Handlers: h,
pktMgr: newPktMgr(svrConn),
openRequests: make(map[string]Request),
}
}
func (rs *RequestServer) nextRequest(r Request) string {
rs.openRequestLock.Lock()
defer rs.openRequestLock.Unlock()
rs.handleCount++
handle := strconv.Itoa(rs.handleCount)
rs.openRequests[handle] = r
return handle
}
func (rs *RequestServer) getRequest(handle string) (Request, bool) {
rs.openRequestLock.RLock()
defer rs.openRequestLock.RUnlock()
r, ok := rs.openRequests[handle]
return r, ok
}
func (rs *RequestServer) closeRequest(handle string) {
rs.openRequestLock.Lock()
defer rs.openRequestLock.Unlock()
if r, ok := rs.openRequests[handle]; ok {
r.close()
delete(rs.openRequests, handle)
}
}
// Close the read/write/closer to trigger exiting the main server loop
func (rs *RequestServer) Close() error { return rs.conn.Close() }
// Serve requests for user session
func (rs *RequestServer) Serve() error {
var wg sync.WaitGroup
runWorker := func(ch requestChan) {
wg.Add(1)
go func() {
defer wg.Done()
if err := rs.packetWorker(ch); err != nil {
rs.conn.Close() // shuts down recvPacket
}
}()
}
pktChan := rs.pktMgr.workerChan(runWorker)
var err error
var pkt requestPacket
var pktType uint8
var pktBytes []byte
for {
pktType, pktBytes, err = rs.recvPacket()
if err != nil {
break
}
pkt, err = makePacket(rxPacket{fxp(pktType), pktBytes})
if err != nil {
debug("makePacket err: %v", err)
rs.conn.Close() // shuts down recvPacket
break
}
pktChan <- pkt
}
close(pktChan) // shuts down sftpServerWorkers
wg.Wait() // wait for all workers to exit
return err
}
func (rs *RequestServer) packetWorker(pktChan chan requestPacket) error {
for pkt := range pktChan {
var rpkt responsePacket
switch pkt := pkt.(type) {
case *sshFxInitPacket:
rpkt = sshFxVersionPacket{sftpProtocolVersion, nil}
case *sshFxpClosePacket:
handle := pkt.getHandle()
rs.closeRequest(handle)
rpkt = statusFromError(pkt, nil)
case *sshFxpRealpathPacket:
rpkt = cleanPath(pkt)
case isOpener:
handle := rs.nextRequest(requestFromPacket(pkt))
rpkt = sshFxpHandlePacket{pkt.id(), handle}
case *sshFxpFstatPacket:
handle := pkt.getHandle()
request, ok := rs.getRequest(handle)
if !ok {
rpkt = statusFromError(pkt, syscall.EBADF)
} else {
request = requestFromPacket(
&sshFxpStatPacket{ID: pkt.id(), Path: request.Filepath})
rpkt = rs.handle(request, pkt)
}
case *sshFxpFsetstatPacket:
handle := pkt.getHandle()
request, ok := rs.getRequest(handle)
if !ok {
rpkt = statusFromError(pkt, syscall.EBADF)
} else {
request = requestFromPacket(
&sshFxpSetstatPacket{ID: pkt.id(), Path: request.Filepath,
Flags: pkt.Flags, Attrs: pkt.Attrs,
})
rpkt = rs.handle(request, pkt)
}
case hasHandle:
handle := pkt.getHandle()
request, ok := rs.getRequest(handle)
request.update(pkt)
if !ok {
rpkt = statusFromError(pkt, syscall.EBADF)
} else {
rpkt = rs.handle(request, pkt)
}
case hasPath:
request := requestFromPacket(pkt)
rpkt = rs.handle(request, pkt)
default:
return errors.Errorf("unexpected packet type %T", pkt)
}
err := rs.sendPacket(rpkt)
if err != nil {
return err
}
}
return nil
}
func cleanPath(pkt *sshFxpRealpathPacket) responsePacket {
path := pkt.getPath()
if !filepath.IsAbs(path) {
path = "/" + path
} // all paths are absolute
cleaned_path := filepath.Clean(path)
return &sshFxpNamePacket{
ID: pkt.id(),
NameAttrs: []sshFxpNameAttr{{
Name: cleaned_path,
LongName: cleaned_path,
Attrs: emptyFileStat,
}},
}
}
func (rs *RequestServer) handle(request Request, pkt requestPacket) responsePacket {
// fmt.Println("Request Method: ", request.Method)
rpkt, err := request.handle(rs.Handlers)
if err != nil {
err = errorAdapter(err)
rpkt = statusFromError(pkt, err)
}
return rpkt
}
// Wrap underlying connection methods to use packetManager
func (rs *RequestServer) sendPacket(m encoding.BinaryMarshaler) error {
if pkt, ok := m.(responsePacket); ok {
rs.pktMgr.readyPacket(pkt)
} else {
return errors.Errorf("unexpected packet type %T", m)
}
return nil
}
func (rs *RequestServer) sendError(p ider, err error) error {
return rs.sendPacket(statusFromError(p, err))
}
// os.ErrNotExist should convert to ssh_FX_NO_SUCH_FILE, but is not recognized
// by statusFromError. So we convert to syscall.ENOENT which it does.
func errorAdapter(err error) error {
if err == os.ErrNotExist {
return syscall.ENOENT
}
return err
}