mirror of
https://github.com/octoleo/restic.git
synced 2024-11-18 02:55:18 +00:00
b9f0f031b6
Closes #2129
220 lines
5.5 KiB
Go
220 lines
5.5 KiB
Go
package sftp
|
|
|
|
import (
|
|
"context"
|
|
"io"
|
|
"path"
|
|
"path/filepath"
|
|
"strconv"
|
|
"sync"
|
|
"syscall"
|
|
|
|
"github.com/pkg/errors"
|
|
)
|
|
|
|
var maxTxPacket uint32 = 1 << 15
|
|
|
|
// Handlers contains the 4 SFTP server request handlers.
|
|
type Handlers struct {
|
|
FileGet FileReader
|
|
FilePut FileWriter
|
|
FileCmd FileCmder
|
|
FileList FileLister
|
|
}
|
|
|
|
// RequestServer abstracts the sftp protocol with an http request-like protocol
|
|
type RequestServer struct {
|
|
*serverConn
|
|
Handlers Handlers
|
|
pktMgr *packetManager
|
|
openRequests map[string]*Request
|
|
openRequestLock sync.RWMutex
|
|
handleCount int
|
|
}
|
|
|
|
// NewRequestServer creates/allocates/returns new RequestServer.
|
|
// Normally there there will be one server per user-session.
|
|
func NewRequestServer(rwc io.ReadWriteCloser, h Handlers) *RequestServer {
|
|
svrConn := &serverConn{
|
|
conn: conn{
|
|
Reader: rwc,
|
|
WriteCloser: rwc,
|
|
},
|
|
}
|
|
return &RequestServer{
|
|
serverConn: svrConn,
|
|
Handlers: h,
|
|
pktMgr: newPktMgr(svrConn),
|
|
openRequests: make(map[string]*Request),
|
|
}
|
|
}
|
|
|
|
// New Open packet/Request
|
|
func (rs *RequestServer) nextRequest(r *Request) string {
|
|
rs.openRequestLock.Lock()
|
|
defer rs.openRequestLock.Unlock()
|
|
rs.handleCount++
|
|
handle := strconv.Itoa(rs.handleCount)
|
|
r.handle = handle
|
|
rs.openRequests[handle] = r
|
|
return handle
|
|
}
|
|
|
|
// Returns Request from openRequests, bool is false if it is missing.
|
|
//
|
|
// The Requests in openRequests work essentially as open file descriptors that
|
|
// you can do different things with. What you are doing with it are denoted by
|
|
// the first packet of that type (read/write/etc).
|
|
func (rs *RequestServer) getRequest(handle string) (*Request, bool) {
|
|
rs.openRequestLock.RLock()
|
|
defer rs.openRequestLock.RUnlock()
|
|
r, ok := rs.openRequests[handle]
|
|
return r, ok
|
|
}
|
|
|
|
// Close the Request and clear from openRequests map
|
|
func (rs *RequestServer) closeRequest(handle string) error {
|
|
rs.openRequestLock.Lock()
|
|
defer rs.openRequestLock.Unlock()
|
|
if r, ok := rs.openRequests[handle]; ok {
|
|
delete(rs.openRequests, handle)
|
|
return r.close()
|
|
}
|
|
return syscall.EBADF
|
|
}
|
|
|
|
// Close the read/write/closer to trigger exiting the main server loop
|
|
func (rs *RequestServer) Close() error { return rs.conn.Close() }
|
|
|
|
// Serve requests for user session
|
|
func (rs *RequestServer) Serve() error {
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
defer cancel()
|
|
var wg sync.WaitGroup
|
|
runWorker := func(ch chan orderedRequest) {
|
|
wg.Add(1)
|
|
go func() {
|
|
defer wg.Done()
|
|
if err := rs.packetWorker(ctx, ch); err != nil {
|
|
rs.conn.Close() // shuts down recvPacket
|
|
}
|
|
}()
|
|
}
|
|
pktChan := rs.pktMgr.workerChan(runWorker)
|
|
|
|
var err error
|
|
var pkt requestPacket
|
|
var pktType uint8
|
|
var pktBytes []byte
|
|
for {
|
|
pktType, pktBytes, err = rs.recvPacket()
|
|
if err != nil {
|
|
break
|
|
}
|
|
|
|
pkt, err = makePacket(rxPacket{fxp(pktType), pktBytes})
|
|
if err != nil {
|
|
switch errors.Cause(err) {
|
|
case errUnknownExtendedPacket:
|
|
if err := rs.serverConn.sendError(pkt, ErrSshFxOpUnsupported); err != nil {
|
|
debug("failed to send err packet: %v", err)
|
|
rs.conn.Close() // shuts down recvPacket
|
|
break
|
|
}
|
|
default:
|
|
debug("makePacket err: %v", err)
|
|
rs.conn.Close() // shuts down recvPacket
|
|
break
|
|
}
|
|
}
|
|
|
|
pktChan <- rs.pktMgr.newOrderedRequest(pkt)
|
|
}
|
|
|
|
close(pktChan) // shuts down sftpServerWorkers
|
|
wg.Wait() // wait for all workers to exit
|
|
|
|
// make sure all open requests are properly closed
|
|
// (eg. possible on dropped connections, client crashes, etc.)
|
|
for handle, req := range rs.openRequests {
|
|
delete(rs.openRequests, handle)
|
|
req.close()
|
|
}
|
|
|
|
return err
|
|
}
|
|
|
|
func (rs *RequestServer) packetWorker(
|
|
ctx context.Context, pktChan chan orderedRequest,
|
|
) error {
|
|
for pkt := range pktChan {
|
|
var rpkt responsePacket
|
|
switch pkt := pkt.requestPacket.(type) {
|
|
case *sshFxInitPacket:
|
|
rpkt = sshFxVersionPacket{Version: sftpProtocolVersion}
|
|
case *sshFxpClosePacket:
|
|
handle := pkt.getHandle()
|
|
rpkt = statusFromError(pkt, rs.closeRequest(handle))
|
|
case *sshFxpRealpathPacket:
|
|
rpkt = cleanPacketPath(pkt)
|
|
case *sshFxpOpendirPacket:
|
|
request := requestFromPacket(ctx, pkt)
|
|
rs.nextRequest(request)
|
|
rpkt = request.opendir(rs.Handlers, pkt)
|
|
case *sshFxpOpenPacket:
|
|
request := requestFromPacket(ctx, pkt)
|
|
rs.nextRequest(request)
|
|
rpkt = request.open(rs.Handlers, pkt)
|
|
case *sshFxpFstatPacket:
|
|
handle := pkt.getHandle()
|
|
request, ok := rs.getRequest(handle)
|
|
if !ok {
|
|
rpkt = statusFromError(pkt, syscall.EBADF)
|
|
} else {
|
|
request = NewRequest("Stat", request.Filepath)
|
|
rpkt = request.call(rs.Handlers, pkt)
|
|
}
|
|
case hasHandle:
|
|
handle := pkt.getHandle()
|
|
request, ok := rs.getRequest(handle)
|
|
if !ok {
|
|
rpkt = statusFromError(pkt, syscall.EBADF)
|
|
} else {
|
|
rpkt = request.call(rs.Handlers, pkt)
|
|
}
|
|
case hasPath:
|
|
request := requestFromPacket(ctx, pkt)
|
|
rpkt = request.call(rs.Handlers, pkt)
|
|
request.close()
|
|
default:
|
|
return errors.Errorf("unexpected packet type %T", pkt)
|
|
}
|
|
|
|
rs.pktMgr.readyPacket(
|
|
rs.pktMgr.newOrderedResponse(rpkt, pkt.orderId()))
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// clean and return name packet for file
|
|
func cleanPacketPath(pkt *sshFxpRealpathPacket) responsePacket {
|
|
path := cleanPath(pkt.getPath())
|
|
return &sshFxpNamePacket{
|
|
ID: pkt.id(),
|
|
NameAttrs: []sshFxpNameAttr{{
|
|
Name: path,
|
|
LongName: path,
|
|
Attrs: emptyFileStat,
|
|
}},
|
|
}
|
|
}
|
|
|
|
// Makes sure we have a clean POSIX (/) absolute path to work with
|
|
func cleanPath(p string) string {
|
|
p = filepath.ToSlash(p)
|
|
if !filepath.IsAbs(p) {
|
|
p = "/" + p
|
|
}
|
|
return path.Clean(p)
|
|
}
|