2
2
mirror of https://github.com/octoleo/restic.git synced 2025-01-26 08:38:27 +00:00
restic/vendor/github.com/pkg/sftp/request-server.go

220 lines
5.5 KiB
Go
Raw Normal View History

2017-07-23 14:24:45 +02:00
package sftp
import (
"context"
2017-07-23 14:24:45 +02:00
"io"
"path"
2017-07-23 14:24:45 +02:00
"path/filepath"
"strconv"
"sync"
"syscall"
"github.com/pkg/errors"
)
var maxTxPacket uint32 = 1 << 15
// Handlers contains the 4 SFTP server request handlers.
type Handlers struct {
FileGet FileReader
FilePut FileWriter
FileCmd FileCmder
2017-09-13 14:09:48 +02:00
FileList FileLister
2017-07-23 14:24:45 +02:00
}
// RequestServer abstracts the sftp protocol with an http request-like protocol
type RequestServer struct {
*serverConn
Handlers Handlers
2017-09-13 14:09:48 +02:00
pktMgr *packetManager
openRequests map[string]*Request
2017-07-23 14:24:45 +02:00
openRequestLock sync.RWMutex
handleCount int
}
// NewRequestServer creates/allocates/returns new RequestServer.
// Normally there there will be one server per user-session.
func NewRequestServer(rwc io.ReadWriteCloser, h Handlers) *RequestServer {
svrConn := &serverConn{
conn: conn{
Reader: rwc,
WriteCloser: rwc,
},
}
return &RequestServer{
serverConn: svrConn,
Handlers: h,
pktMgr: newPktMgr(svrConn),
openRequests: make(map[string]*Request),
2017-07-23 14:24:45 +02:00
}
}
// New Open packet/Request
2017-09-13 14:09:48 +02:00
func (rs *RequestServer) nextRequest(r *Request) string {
2017-07-23 14:24:45 +02:00
rs.openRequestLock.Lock()
defer rs.openRequestLock.Unlock()
rs.handleCount++
handle := strconv.Itoa(rs.handleCount)
2019-01-27 21:07:57 +01:00
r.handle = handle
rs.openRequests[handle] = r
2017-07-23 14:24:45 +02:00
return handle
}
2019-01-27 21:07:57 +01:00
// Returns Request from openRequests, bool is false if it is missing.
//
// The Requests in openRequests work essentially as open file descriptors that
// you can do different things with. What you are doing with it are denoted by
2019-01-27 21:07:57 +01:00
// the first packet of that type (read/write/etc).
func (rs *RequestServer) getRequest(handle string) (*Request, bool) {
2017-07-23 14:24:45 +02:00
rs.openRequestLock.RLock()
2019-01-27 21:07:57 +01:00
defer rs.openRequestLock.RUnlock()
2017-07-23 14:24:45 +02:00
r, ok := rs.openRequests[handle]
return r, ok
2017-07-23 14:24:45 +02:00
}
2019-01-27 21:07:57 +01:00
// Close the Request and clear from openRequests map
func (rs *RequestServer) closeRequest(handle string) error {
2017-07-23 14:24:45 +02:00
rs.openRequestLock.Lock()
defer rs.openRequestLock.Unlock()
if r, ok := rs.openRequests[handle]; ok {
delete(rs.openRequests, handle)
return r.close()
2017-07-23 14:24:45 +02:00
}
return syscall.EBADF
2017-07-23 14:24:45 +02:00
}
// Close the read/write/closer to trigger exiting the main server loop
func (rs *RequestServer) Close() error { return rs.conn.Close() }
// Serve requests for user session
func (rs *RequestServer) Serve() error {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
2017-07-23 14:24:45 +02:00
var wg sync.WaitGroup
2018-09-03 20:23:56 +02:00
runWorker := func(ch chan orderedRequest) {
2017-07-23 14:24:45 +02:00
wg.Add(1)
go func() {
defer wg.Done()
if err := rs.packetWorker(ctx, ch); err != nil {
2017-07-23 14:24:45 +02:00
rs.conn.Close() // shuts down recvPacket
}
}()
}
pktChan := rs.pktMgr.workerChan(runWorker)
var err error
var pkt requestPacket
var pktType uint8
var pktBytes []byte
for {
pktType, pktBytes, err = rs.recvPacket()
if err != nil {
break
}
pkt, err = makePacket(rxPacket{fxp(pktType), pktBytes})
if err != nil {
switch errors.Cause(err) {
case errUnknownExtendedPacket:
if err := rs.serverConn.sendError(pkt, ErrSshFxOpUnsupported); err != nil {
debug("failed to send err packet: %v", err)
rs.conn.Close() // shuts down recvPacket
break
}
default:
debug("makePacket err: %v", err)
rs.conn.Close() // shuts down recvPacket
break
}
2017-07-23 14:24:45 +02:00
}
2018-09-03 20:23:56 +02:00
pktChan <- rs.pktMgr.newOrderedRequest(pkt)
2017-07-23 14:24:45 +02:00
}
close(pktChan) // shuts down sftpServerWorkers
wg.Wait() // wait for all workers to exit
// make sure all open requests are properly closed
// (eg. possible on dropped connections, client crashes, etc.)
for handle, req := range rs.openRequests {
delete(rs.openRequests, handle)
req.close()
}
2017-07-23 14:24:45 +02:00
return err
}
func (rs *RequestServer) packetWorker(
2018-09-03 20:23:56 +02:00
ctx context.Context, pktChan chan orderedRequest,
) error {
2017-07-23 14:24:45 +02:00
for pkt := range pktChan {
var rpkt responsePacket
2018-09-03 20:23:56 +02:00
switch pkt := pkt.requestPacket.(type) {
2017-07-23 14:24:45 +02:00
case *sshFxInitPacket:
2018-09-03 20:23:56 +02:00
rpkt = sshFxVersionPacket{Version: sftpProtocolVersion}
2017-07-23 14:24:45 +02:00
case *sshFxpClosePacket:
handle := pkt.getHandle()
rpkt = statusFromError(pkt, rs.closeRequest(handle))
2017-07-23 14:24:45 +02:00
case *sshFxpRealpathPacket:
2017-09-13 14:09:48 +02:00
rpkt = cleanPacketPath(pkt)
case *sshFxpOpendirPacket:
request := requestFromPacket(ctx, pkt)
2019-01-27 21:07:57 +01:00
rs.nextRequest(request)
rpkt = request.opendir(rs.Handlers, pkt)
case *sshFxpOpenPacket:
request := requestFromPacket(ctx, pkt)
2019-01-27 21:07:57 +01:00
rs.nextRequest(request)
rpkt = request.open(rs.Handlers, pkt)
case *sshFxpFstatPacket:
handle := pkt.getHandle()
request, ok := rs.getRequest(handle)
if !ok {
rpkt = statusFromError(pkt, syscall.EBADF)
} else {
request = NewRequest("Stat", request.Filepath)
rpkt = request.call(rs.Handlers, pkt)
2017-07-23 14:24:45 +02:00
}
case hasHandle:
handle := pkt.getHandle()
2019-01-27 21:07:57 +01:00
request, ok := rs.getRequest(handle)
if !ok {
2017-07-23 14:24:45 +02:00
rpkt = statusFromError(pkt, syscall.EBADF)
} else {
2017-09-13 14:09:48 +02:00
rpkt = request.call(rs.Handlers, pkt)
2017-07-23 14:24:45 +02:00
}
case hasPath:
request := requestFromPacket(ctx, pkt)
2017-09-13 14:09:48 +02:00
rpkt = request.call(rs.Handlers, pkt)
request.close()
2017-07-23 14:24:45 +02:00
default:
return errors.Errorf("unexpected packet type %T", pkt)
}
2018-09-03 20:23:56 +02:00
rs.pktMgr.readyPacket(
rs.pktMgr.newOrderedResponse(rpkt, pkt.orderId()))
2017-07-23 14:24:45 +02:00
}
return nil
}
// clean and return name packet for file
2017-09-13 14:09:48 +02:00
func cleanPacketPath(pkt *sshFxpRealpathPacket) responsePacket {
path := cleanPath(pkt.getPath())
2017-07-23 14:24:45 +02:00
return &sshFxpNamePacket{
ID: pkt.id(),
NameAttrs: []sshFxpNameAttr{{
2017-09-13 14:09:48 +02:00
Name: path,
LongName: path,
2017-07-23 14:24:45 +02:00
Attrs: emptyFileStat,
}},
}
}
// Makes sure we have a clean POSIX (/) absolute path to work with
func cleanPath(p string) string {
p = filepath.ToSlash(p)
if !filepath.IsAbs(p) {
p = "/" + p
2017-07-23 14:24:45 +02:00
}
return path.Clean(p)
2017-07-23 14:24:45 +02:00
}