syncthing/cmd/strelaysrv/listener.go

347 lines
8.2 KiB
Go
Raw Normal View History

2016-06-02 14:16:02 +02:00
// Copyright (C) 2015 Audrius Butkevicius and Contributors.
2015-06-24 12:39:46 +01:00
package main
import (
"crypto/tls"
"encoding/hex"
2015-06-24 12:39:46 +01:00
"log"
"net"
2015-06-28 01:52:01 +01:00
"sync"
2015-08-20 14:02:52 +02:00
"sync/atomic"
2015-06-24 12:39:46 +01:00
"time"
syncthingprotocol "github.com/syncthing/syncthing/lib/protocol"
2015-09-02 21:35:52 +01:00
"github.com/syncthing/syncthing/lib/tlsutil"
2015-06-24 12:39:46 +01:00
"github.com/syncthing/syncthing/lib/relay/protocol"
2015-06-24 12:39:46 +01:00
)
2015-06-28 01:52:01 +01:00
var (
2015-08-20 14:02:52 +02:00
outboxesMut = sync.RWMutex{}
outboxes = make(map[syncthingprotocol.DeviceID]chan interface{})
numConnections int64
2015-06-28 01:52:01 +01:00
)
2015-06-24 12:39:46 +01:00
func listener(proto, addr string, config *tls.Config) {
2015-09-02 21:35:52 +01:00
tcpListener, err := net.Listen("tcp", addr)
2015-06-24 12:39:46 +01:00
if err != nil {
log.Fatalln(err)
}
listener := tlsutil.DowngradingListener{
Listener: tcpListener,
}
2015-09-02 21:35:52 +01:00
2015-06-24 12:39:46 +01:00
for {
2015-09-02 22:02:17 +01:00
conn, isTLS, err := listener.AcceptNoWrapTLS()
2015-06-24 12:39:46 +01:00
if err != nil {
if debug {
log.Println("Listener failed to accept connection from", conn.RemoteAddr(), ". Possibly a TCP Ping.")
2015-06-24 12:39:46 +01:00
}
continue
}
2015-07-20 11:38:00 +02:00
setTCPOptions(conn)
2015-06-24 12:39:46 +01:00
if debug {
2015-09-02 21:35:52 +01:00
log.Println("Listener accepted connection from", conn.RemoteAddr(), "tls", isTLS)
}
if isTLS {
go protocolConnectionHandler(conn, config)
} else {
go sessionConnectionHandler(conn)
2015-06-24 12:39:46 +01:00
}
}
}
func protocolConnectionHandler(tcpConn net.Conn, config *tls.Config) {
conn := tls.Server(tcpConn, config)
2015-06-28 01:52:01 +01:00
err := conn.Handshake()
2015-06-24 12:39:46 +01:00
if err != nil {
2015-06-28 01:52:01 +01:00
if debug {
log.Println("Protocol connection TLS handshake:", conn.RemoteAddr(), err)
}
2015-06-24 12:39:46 +01:00
conn.Close()
return
}
state := conn.ConnectionState()
if (!state.NegotiatedProtocolIsMutual || state.NegotiatedProtocol != protocol.ProtocolName) && debug {
log.Println("Protocol negotiation error")
}
certs := state.PeerCertificates
if len(certs) != 1 {
2015-06-28 01:52:01 +01:00
if debug {
log.Println("Certificate list error")
}
2015-06-24 12:39:46 +01:00
conn.Close()
return
}
2015-06-28 01:52:01 +01:00
id := syncthingprotocol.NewDeviceID(certs[0].Raw)
2015-06-24 12:39:46 +01:00
2015-06-28 01:52:01 +01:00
messages := make(chan interface{})
errors := make(chan error, 1)
outbox := make(chan interface{})
2015-06-24 12:39:46 +01:00
2015-07-20 11:38:00 +02:00
// Read messages from the connection and send them on the messages
// channel. When there is an error, send it on the error channel and
// return. Applies also when the connection gets closed, so the pattern
// below is to close the connection on error, then wait for the error
// signal from messageReader to exit.
go messageReader(conn, messages, errors)
2015-06-24 12:39:46 +01:00
pingTicker := time.NewTicker(pingInterval)
2015-06-28 01:52:01 +01:00
timeoutTicker := time.NewTimer(networkTimeout)
2015-06-24 12:39:46 +01:00
joined := false
for {
select {
2015-06-28 01:52:01 +01:00
case message := <-messages:
timeoutTicker.Reset(networkTimeout)
if debug {
log.Printf("Message %T from %s", message, id)
}
2015-07-20 11:38:00 +02:00
2015-06-28 01:52:01 +01:00
switch msg := message.(type) {
case protocol.JoinRelayRequest:
if atomic.LoadInt32(&overLimit) > 0 {
protocol.WriteMessage(conn, protocol.RelayFull{})
if debug {
log.Println("Refusing join request from", id, "due to being over limits")
}
conn.Close()
limitCheckTimer.Reset(time.Second)
continue
}
2015-06-28 01:52:01 +01:00
outboxesMut.RLock()
_, ok := outboxes[id]
outboxesMut.RUnlock()
if ok {
protocol.WriteMessage(conn, protocol.ResponseAlreadyConnected)
if debug {
log.Println("Already have a peer with the same ID", id, conn.RemoteAddr())
}
2015-06-24 12:39:46 +01:00
conn.Close()
continue
}
2015-06-28 01:52:01 +01:00
outboxesMut.Lock()
outboxes[id] = outbox
outboxesMut.Unlock()
joined = true
protocol.WriteMessage(conn, protocol.ResponseSuccess)
2015-07-20 11:38:00 +02:00
2015-06-28 01:52:01 +01:00
case protocol.ConnectRequest:
requestedPeer := syncthingprotocol.DeviceIDFromBytes(msg.ID)
outboxesMut.RLock()
peerOutbox, ok := outboxes[requestedPeer]
outboxesMut.RUnlock()
2015-06-24 12:39:46 +01:00
if !ok {
if debug {
2015-08-20 14:02:52 +02:00
log.Println(id, "is looking for", requestedPeer, "which does not exist")
2015-06-24 12:39:46 +01:00
}
2015-06-28 01:52:01 +01:00
protocol.WriteMessage(conn, protocol.ResponseNotFound)
2015-06-24 12:39:46 +01:00
conn.Close()
continue
}
// requestedPeer is the server, id is the client
ses := newSession(requestedPeer, id, sessionLimiter, globalLimiter)
2015-06-24 12:39:46 +01:00
2015-06-28 01:52:01 +01:00
go ses.Serve()
clientInvitation := ses.GetClientInvitationMessage()
serverInvitation := ses.GetServerInvitationMessage()
2015-06-28 01:52:01 +01:00
if err := protocol.WriteMessage(conn, clientInvitation); err != nil {
if debug {
log.Printf("Error sending invitation from %s to client: %s", id, err)
}
2015-06-24 12:39:46 +01:00
conn.Close()
continue
}
select {
case peerOutbox <- serverInvitation:
if debug {
log.Println("Sent invitation from", id, "to", requestedPeer)
}
case <-time.After(time.Second):
if debug {
log.Println("Could not send invitation from", id, "to", requestedPeer, "as peer disconnected")
}
2015-06-24 12:39:46 +01:00
}
conn.Close()
2015-07-20 11:38:00 +02:00
2015-09-14 13:44:47 +02:00
case protocol.Ping:
if err := protocol.WriteMessage(conn, protocol.Pong{}); err != nil {
if debug {
log.Println("Error writing pong:", err)
}
conn.Close()
continue
}
2015-06-28 01:52:01 +01:00
case protocol.Pong:
2015-07-20 11:38:00 +02:00
// Nothing
2015-06-28 01:52:01 +01:00
default:
if debug {
log.Printf("Unknown message %s: %T", id, message)
}
protocol.WriteMessage(conn, protocol.ResponseUnexpectedMessage)
conn.Close()
}
2015-07-20 11:38:00 +02:00
2015-06-28 01:52:01 +01:00
case err := <-errors:
if debug {
log.Printf("Closing connection %s: %s", id, err)
2015-06-24 12:39:46 +01:00
}
2015-07-20 11:38:00 +02:00
// Potentially closing a second time.
2015-06-28 01:52:01 +01:00
conn.Close()
2015-07-20 11:38:00 +02:00
2015-07-17 22:04:02 +01:00
if joined {
// Only delete the outbox if the client is joined, as it might be
// a lookup request coming from the same client.
2015-07-17 22:04:02 +01:00
outboxesMut.Lock()
delete(outboxes, id)
outboxesMut.Unlock()
// Also, kill all sessions related to this node, as it probably
// went offline. This is for the other end to realize the client
// is no longer there faster. This also helps resolve
// 'already connected' errors when one of the sides is
// restarting, and connecting to the other peer before the other
// peer even realised that the node has gone away.
dropSessions(id)
2015-07-17 22:04:02 +01:00
}
2015-06-24 12:39:46 +01:00
return
2015-07-20 11:38:00 +02:00
2015-06-24 12:39:46 +01:00
case <-pingTicker.C:
if !joined {
2015-06-28 01:52:01 +01:00
if debug {
log.Println(id, "didn't join within", pingInterval)
}
2015-06-24 12:39:46 +01:00
conn.Close()
continue
}
2015-06-28 01:52:01 +01:00
if err := protocol.WriteMessage(conn, protocol.Ping{}); err != nil {
if debug {
log.Println(id, err)
}
2015-06-24 12:39:46 +01:00
conn.Close()
}
2015-07-20 11:38:00 +02:00
if atomic.LoadInt32(&overLimit) > 0 && !hasSessions(id) {
if debug {
log.Println("Dropping", id, "as it has no sessions and we are over our limits")
}
protocol.WriteMessage(conn, protocol.RelayFull{})
conn.Close()
limitCheckTimer.Reset(time.Second)
}
2015-06-24 12:39:46 +01:00
case <-timeoutTicker.C:
2015-06-28 01:52:01 +01:00
// We should receive a error from the reader loop, which will cause
// us to quit this loop.
if debug {
log.Printf("%s timed out", id)
}
2015-06-24 12:39:46 +01:00
conn.Close()
2015-07-20 11:38:00 +02:00
2015-06-28 01:52:01 +01:00
case msg := <-outbox:
2015-06-24 12:39:46 +01:00
if debug {
2015-06-28 01:52:01 +01:00
log.Printf("Sending message %T to %s", msg, id)
2015-06-24 12:39:46 +01:00
}
2015-06-28 01:52:01 +01:00
if err := protocol.WriteMessage(conn, msg); err != nil {
if debug {
log.Println(id, err)
}
2015-06-24 12:39:46 +01:00
conn.Close()
}
}
}
}
2015-07-20 11:38:00 +02:00
2015-09-02 21:35:52 +01:00
func sessionConnectionHandler(conn net.Conn) {
if err := conn.SetDeadline(time.Now().Add(messageTimeout)); err != nil {
if debug {
log.Println("Weird error setting deadline:", err, "on", conn.RemoteAddr())
}
return
}
message, err := protocol.ReadMessage(conn)
if err != nil {
return
}
switch msg := message.(type) {
case protocol.JoinSessionRequest:
ses := findSession(string(msg.Key))
if debug {
log.Println(conn.RemoteAddr(), "session lookup", ses, hex.EncodeToString(msg.Key)[:5])
2015-09-02 21:35:52 +01:00
}
if ses == nil {
protocol.WriteMessage(conn, protocol.ResponseNotFound)
conn.Close()
return
}
if !ses.AddConnection(conn) {
if debug {
log.Println("Failed to add", conn.RemoteAddr(), "to session", ses)
}
protocol.WriteMessage(conn, protocol.ResponseAlreadyConnected)
conn.Close()
return
}
if err := protocol.WriteMessage(conn, protocol.ResponseSuccess); err != nil {
if debug {
log.Println("Failed to send session join response to ", conn.RemoteAddr(), "for", ses)
}
return
}
if err := conn.SetDeadline(time.Time{}); err != nil {
if debug {
log.Println("Weird error setting deadline:", err, "on", conn.RemoteAddr())
}
conn.Close()
return
}
default:
if debug {
log.Println("Unexpected message from", conn.RemoteAddr(), message)
}
protocol.WriteMessage(conn, protocol.ResponseUnexpectedMessage)
conn.Close()
}
}
2015-07-20 11:38:00 +02:00
func messageReader(conn net.Conn, messages chan<- interface{}, errors chan<- error) {
2015-08-20 14:02:52 +02:00
atomic.AddInt64(&numConnections, 1)
defer atomic.AddInt64(&numConnections, -1)
2015-07-20 11:38:00 +02:00
for {
msg, err := protocol.ReadMessage(conn)
if err != nil {
errors <- err
return
}
messages <- msg
}
}