Merge pull request #1 from syncthing/review

Code review
This commit is contained in:
Audrius Butkevicius 2015-07-20 18:43:31 +01:00
commit 77457e91e9
3 changed files with 195 additions and 163 deletions

View File

@ -14,27 +14,6 @@ import (
"github.com/syncthing/relaysrv/protocol" "github.com/syncthing/relaysrv/protocol"
) )
func NewProtocolClient(uri *url.URL, certs []tls.Certificate, invitations chan protocol.SessionInvitation) ProtocolClient {
closeInvitationsOnFinish := false
if invitations == nil {
closeInvitationsOnFinish = true
invitations = make(chan protocol.SessionInvitation)
}
return ProtocolClient{
URI: uri,
Invitations: invitations,
closeInvitationsOnFinish: closeInvitationsOnFinish,
config: configForCerts(certs),
timeout: time.Minute * 2,
stop: make(chan struct{}),
stopped: make(chan struct{}),
}
}
type ProtocolClient struct { type ProtocolClient struct {
URI *url.URL URI *url.URL
Invitations chan protocol.SessionInvitation Invitations chan protocol.SessionInvitation
@ -51,6 +30,129 @@ type ProtocolClient struct {
conn *tls.Conn conn *tls.Conn
} }
func NewProtocolClient(uri *url.URL, certs []tls.Certificate, invitations chan protocol.SessionInvitation) *ProtocolClient {
closeInvitationsOnFinish := false
if invitations == nil {
closeInvitationsOnFinish = true
invitations = make(chan protocol.SessionInvitation)
}
return &ProtocolClient{
URI: uri,
Invitations: invitations,
closeInvitationsOnFinish: closeInvitationsOnFinish,
config: configForCerts(certs),
timeout: time.Minute * 2,
stop: make(chan struct{}),
stopped: make(chan struct{}),
}
}
func (c *ProtocolClient) Serve() {
c.stop = make(chan struct{})
c.stopped = make(chan struct{})
defer close(c.stopped)
if err := c.connect(); err != nil {
l.Infoln("Relay connect:", err)
return
}
if debug {
l.Debugln(c, "connected", c.conn.RemoteAddr())
}
if err := c.join(); err != nil {
c.conn.Close()
l.Infoln("Relay join:", err)
return
}
if err := c.conn.SetDeadline(time.Time{}); err != nil {
l.Infoln("Relay set deadline:", err)
return
}
if debug {
l.Debugln(c, "joined", c.conn.RemoteAddr(), "via", c.conn.LocalAddr())
}
defer c.cleanup()
messages := make(chan interface{})
errors := make(chan error, 1)
go messageReader(c.conn, messages, errors)
timeout := time.NewTimer(c.timeout)
for {
select {
case message := <-messages:
timeout.Reset(c.timeout)
if debug {
log.Printf("%s received message %T", c, message)
}
switch msg := message.(type) {
case protocol.Ping:
if err := protocol.WriteMessage(c.conn, protocol.Pong{}); err != nil {
l.Infoln("Relay write:", err)
return
}
if debug {
l.Debugln(c, "sent pong")
}
case protocol.SessionInvitation:
ip := net.IP(msg.Address)
if len(ip) == 0 || ip.IsUnspecified() {
msg.Address = c.conn.RemoteAddr().(*net.TCPAddr).IP[:]
}
c.Invitations <- msg
default:
l.Infoln("Relay: protocol error: unexpected message %v", msg)
return
}
case <-c.stop:
if debug {
l.Debugln(c, "stopping")
}
return
case err := <-errors:
l.Infoln("Relay received:", err)
return
case <-timeout.C:
if debug {
l.Debugln(c, "timed out")
}
return
}
}
}
func (c *ProtocolClient) Stop() {
if c.stop == nil {
return
}
close(c.stop)
<-c.stopped
}
func (c *ProtocolClient) String() string {
return fmt.Sprintf("ProtocolClient@%p", c)
}
func (c *ProtocolClient) connect() error { func (c *ProtocolClient) connect() error {
if c.URI.Scheme != "relay" { if c.URI.Scheme != "relay" {
return fmt.Errorf("Unsupported relay schema:", c.URI.Scheme) return fmt.Errorf("Unsupported relay schema:", c.URI.Scheme)
@ -61,9 +163,13 @@ func (c *ProtocolClient) connect() error {
return err return err
} }
conn.SetDeadline(time.Now().Add(10 * time.Second)) if err := conn.SetDeadline(time.Now().Add(10 * time.Second)); err != nil {
conn.Close()
return err
}
if err := performHandshakeAndValidation(conn, c.URI); err != nil { if err := performHandshakeAndValidation(conn, c.URI); err != nil {
conn.Close()
return err return err
} }
@ -71,101 +177,6 @@ func (c *ProtocolClient) connect() error {
return nil return nil
} }
func (c *ProtocolClient) Serve() {
if err := c.connect(); err != nil {
panic(err)
}
if debug {
l.Debugln(c, "connected", c.conn.RemoteAddr())
}
if err := c.join(); err != nil {
c.conn.Close()
panic(err)
}
c.conn.SetDeadline(time.Time{})
if debug {
l.Debugln(c, "joined", c.conn.RemoteAddr(), "via", c.conn.LocalAddr())
}
c.stop = make(chan struct{})
c.stopped = make(chan struct{})
defer c.cleanup()
messages := make(chan interface{})
errors := make(chan error, 1)
go func(conn net.Conn, message chan<- interface{}, errors chan<- error) {
for {
msg, err := protocol.ReadMessage(conn)
if err != nil {
errors <- err
return
}
messages <- msg
}
}(c.conn, messages, errors)
timeout := time.NewTimer(c.timeout)
for {
select {
case message := <-messages:
timeout.Reset(c.timeout)
if debug {
log.Printf("%s received message %T", c, message)
}
switch msg := message.(type) {
case protocol.Ping:
if err := protocol.WriteMessage(c.conn, protocol.Pong{}); err != nil {
panic(err)
}
if debug {
l.Debugln(c, "sent pong")
}
case protocol.SessionInvitation:
ip := net.IP(msg.Address)
if len(ip) == 0 || ip.IsUnspecified() {
msg.Address = c.conn.RemoteAddr().(*net.TCPAddr).IP[:]
}
c.Invitations <- msg
default:
panic(fmt.Errorf("protocol error: unexpected message %v", msg))
}
case <-c.stop:
if debug {
l.Debugln(c, "stopping")
}
break
case err := <-errors:
panic(err)
case <-timeout.C:
if debug {
l.Debugln(c, "timed out")
}
return
}
}
c.stopped <- struct{}{}
}
func (c *ProtocolClient) Stop() {
if c.stop == nil {
return
}
c.stop <- struct{}{}
<-c.stopped
}
func (c *ProtocolClient) String() string {
return fmt.Sprintf("ProtocolClient@%p", c)
}
func (c *ProtocolClient) cleanup() { func (c *ProtocolClient) cleanup() {
if c.closeInvitationsOnFinish { if c.closeInvitationsOnFinish {
close(c.Invitations) close(c.Invitations)
@ -176,24 +187,11 @@ func (c *ProtocolClient) cleanup() {
l.Debugln(c, "cleaning up") l.Debugln(c, "cleaning up")
} }
if c.stop != nil { c.conn.Close()
close(c.stop)
c.stop = nil
}
if c.stopped != nil {
close(c.stopped)
c.stopped = nil
}
if c.conn != nil {
c.conn.Close()
}
} }
func (c *ProtocolClient) join() error { func (c *ProtocolClient) join() error {
err := protocol.WriteMessage(c.conn, protocol.JoinRelayRequest{}) if err := protocol.WriteMessage(c.conn, protocol.JoinRelayRequest{}); err != nil {
if err != nil {
return err return err
} }
@ -207,6 +205,7 @@ func (c *ProtocolClient) join() error {
if msg.Code != 0 { if msg.Code != 0 {
return fmt.Errorf("Incorrect response code %d: %s", msg.Code, msg.Message) return fmt.Errorf("Incorrect response code %d: %s", msg.Code, msg.Message)
} }
default: default:
return fmt.Errorf("protocol error: expecting response got %v", msg) return fmt.Errorf("protocol error: expecting response got %v", msg)
} }
@ -215,15 +214,12 @@ func (c *ProtocolClient) join() error {
} }
func performHandshakeAndValidation(conn *tls.Conn, uri *url.URL) error { func performHandshakeAndValidation(conn *tls.Conn, uri *url.URL) error {
err := conn.Handshake() if err := conn.Handshake(); err != nil {
if err != nil {
conn.Close()
return err return err
} }
cs := conn.ConnectionState() cs := conn.ConnectionState()
if !cs.NegotiatedProtocolIsMutual || cs.NegotiatedProtocol != protocol.ProtocolName { if !cs.NegotiatedProtocolIsMutual || cs.NegotiatedProtocol != protocol.ProtocolName {
conn.Close()
return fmt.Errorf("protocol negotiation error") return fmt.Errorf("protocol negotiation error")
} }
@ -232,22 +228,30 @@ func performHandshakeAndValidation(conn *tls.Conn, uri *url.URL) error {
if relayIDs != "" { if relayIDs != "" {
relayID, err := syncthingprotocol.DeviceIDFromString(relayIDs) relayID, err := syncthingprotocol.DeviceIDFromString(relayIDs)
if err != nil { if err != nil {
conn.Close()
return fmt.Errorf("relay address contains invalid verification id: %s", err) return fmt.Errorf("relay address contains invalid verification id: %s", err)
} }
certs := cs.PeerCertificates certs := cs.PeerCertificates
if cl := len(certs); cl != 1 { if cl := len(certs); cl != 1 {
conn.Close()
return fmt.Errorf("unexpected certificate count: %d", cl) return fmt.Errorf("unexpected certificate count: %d", cl)
} }
remoteID := syncthingprotocol.NewDeviceID(certs[0].Raw) remoteID := syncthingprotocol.NewDeviceID(certs[0].Raw)
if remoteID != relayID { if remoteID != relayID {
conn.Close()
return fmt.Errorf("relay id does not match. Expected %v got %v", relayID, remoteID) return fmt.Errorf("relay id does not match. Expected %v got %v", relayID, remoteID)
} }
} }
return nil return nil
} }
func messageReader(conn net.Conn, messages chan<- interface{}, errors chan<- error) {
for {
msg, err := protocol.ReadMessage(conn)
if err != nil {
errors <- err
return
}
messages <- msg
}
}

View File

@ -27,7 +27,6 @@ func protocolListener(addr string, config *tls.Config) {
for { for {
conn, err := listener.Accept() conn, err := listener.Accept()
setTCPOptions(conn)
if err != nil { if err != nil {
if debug { if debug {
log.Println(err) log.Println(err)
@ -35,6 +34,8 @@ func protocolListener(addr string, config *tls.Config) {
continue continue
} }
setTCPOptions(conn)
if debug { if debug {
log.Println("Protocol listener accepted connection from", conn.RemoteAddr()) log.Println("Protocol listener accepted connection from", conn.RemoteAddr())
} }
@ -74,16 +75,12 @@ func protocolConnectionHandler(tcpConn net.Conn, config *tls.Config) {
errors := make(chan error, 1) errors := make(chan error, 1)
outbox := make(chan interface{}) outbox := make(chan interface{})
go func(conn net.Conn, message chan<- interface{}, errors chan<- error) { // Read messages from the connection and send them on the messages
for { // channel. When there is an error, send it on the error channel and
msg, err := protocol.ReadMessage(conn) // return. Applies also when the connection gets closed, so the pattern
if err != nil { // below is to close the connection on error, then wait for the error
errors <- err // signal from messageReader to exit.
return go messageReader(conn, messages, errors)
}
messages <- msg
}
}(conn, messages, errors)
pingTicker := time.NewTicker(pingInterval) pingTicker := time.NewTicker(pingInterval)
timeoutTicker := time.NewTimer(networkTimeout) timeoutTicker := time.NewTimer(networkTimeout)
@ -96,6 +93,7 @@ func protocolConnectionHandler(tcpConn net.Conn, config *tls.Config) {
if debug { if debug {
log.Printf("Message %T from %s", message, id) log.Printf("Message %T from %s", message, id)
} }
switch msg := message.(type) { switch msg := message.(type) {
case protocol.JoinRelayRequest: case protocol.JoinRelayRequest:
outboxesMut.RLock() outboxesMut.RLock()
@ -116,6 +114,7 @@ func protocolConnectionHandler(tcpConn net.Conn, config *tls.Config) {
joined = true joined = true
protocol.WriteMessage(conn, protocol.ResponseSuccess) protocol.WriteMessage(conn, protocol.ResponseSuccess)
case protocol.ConnectRequest: case protocol.ConnectRequest:
requestedPeer := syncthingprotocol.DeviceIDFromBytes(msg.ID) requestedPeer := syncthingprotocol.DeviceIDFromBytes(msg.ID)
outboxesMut.RLock() outboxesMut.RLock()
@ -151,7 +150,10 @@ func protocolConnectionHandler(tcpConn net.Conn, config *tls.Config) {
log.Println("Sent invitation from", id, "to", requestedPeer) log.Println("Sent invitation from", id, "to", requestedPeer)
} }
conn.Close() conn.Close()
case protocol.Pong: case protocol.Pong:
// Nothing
default: default:
if debug { if debug {
log.Printf("Unknown message %s: %T", id, message) log.Printf("Unknown message %s: %T", id, message)
@ -159,21 +161,25 @@ func protocolConnectionHandler(tcpConn net.Conn, config *tls.Config) {
protocol.WriteMessage(conn, protocol.ResponseUnexpectedMessage) protocol.WriteMessage(conn, protocol.ResponseUnexpectedMessage)
conn.Close() conn.Close()
} }
case err := <-errors: case err := <-errors:
if debug { if debug {
log.Printf("Closing connection %s: %s", id, err) log.Printf("Closing connection %s: %s", id, err)
} }
// Potentially closing a second time.
close(outbox) close(outbox)
// Potentially closing a second time.
conn.Close() conn.Close()
// Only delete the outbox if the client join, as it migth be a
// lookup request coming from the same client. // Only delete the outbox if the client is joined, as it might be
// a lookup request coming from the same client.
if joined { if joined {
outboxesMut.Lock() outboxesMut.Lock()
delete(outboxes, id) delete(outboxes, id)
outboxesMut.Unlock() outboxesMut.Unlock()
} }
return return
case <-pingTicker.C: case <-pingTicker.C:
if !joined { if !joined {
if debug { if debug {
@ -189,6 +195,7 @@ func protocolConnectionHandler(tcpConn net.Conn, config *tls.Config) {
} }
conn.Close() conn.Close()
} }
case <-timeoutTicker.C: case <-timeoutTicker.C:
// We should receive a error from the reader loop, which will cause // We should receive a error from the reader loop, which will cause
// us to quit this loop. // us to quit this loop.
@ -196,6 +203,7 @@ func protocolConnectionHandler(tcpConn net.Conn, config *tls.Config) {
log.Printf("%s timed out", id) log.Printf("%s timed out", id)
} }
conn.Close() conn.Close()
case msg := <-outbox: case msg := <-outbox:
if debug { if debug {
log.Printf("Sending message %T to %s", msg, id) log.Printf("Sending message %T to %s", msg, id)
@ -209,3 +217,14 @@ func protocolConnectionHandler(tcpConn net.Conn, config *tls.Config) {
} }
} }
} }
func messageReader(conn net.Conn, messages chan<- interface{}, errors chan<- error) {
for {
msg, err := protocol.ReadMessage(conn)
if err != nil {
errors <- err
return
}
messages <- msg
}
}

View File

@ -18,7 +18,6 @@ func sessionListener(addr string) {
for { for {
conn, err := listener.Accept() conn, err := listener.Accept()
setTCPOptions(conn)
if err != nil { if err != nil {
if debug { if debug {
log.Println(err) log.Println(err)
@ -26,6 +25,8 @@ func sessionListener(addr string) {
continue continue
} }
setTCPOptions(conn)
if debug { if debug {
log.Println("Session listener accepted connection from", conn.RemoteAddr()) log.Println("Session listener accepted connection from", conn.RemoteAddr())
} }
@ -35,10 +36,17 @@ func sessionListener(addr string) {
} }
func sessionConnectionHandler(conn net.Conn) { func sessionConnectionHandler(conn net.Conn) {
conn.SetDeadline(time.Now().Add(messageTimeout)) defer conn.Close()
if err := conn.SetDeadline(time.Now().Add(messageTimeout)); err != nil {
if debug {
log.Println("Weird error setting deadline:", err, "on", conn.RemoteAddr())
}
return
}
message, err := protocol.ReadMessage(conn) message, err := protocol.ReadMessage(conn)
if err != nil { if err != nil {
conn.Close()
return return
} }
@ -51,7 +59,6 @@ func sessionConnectionHandler(conn net.Conn) {
if ses == nil { if ses == nil {
protocol.WriteMessage(conn, protocol.ResponseNotFound) protocol.WriteMessage(conn, protocol.ResponseNotFound)
conn.Close()
return return
} }
@ -60,24 +67,26 @@ func sessionConnectionHandler(conn net.Conn) {
log.Println("Failed to add", conn.RemoteAddr(), "to session", ses) log.Println("Failed to add", conn.RemoteAddr(), "to session", ses)
} }
protocol.WriteMessage(conn, protocol.ResponseAlreadyConnected) protocol.WriteMessage(conn, protocol.ResponseAlreadyConnected)
conn.Close()
return return
} }
err := protocol.WriteMessage(conn, protocol.ResponseSuccess) if err := protocol.WriteMessage(conn, protocol.ResponseSuccess); err != nil {
if err != nil {
if debug { if debug {
log.Println("Failed to send session join response to ", conn.RemoteAddr(), "for", ses) log.Println("Failed to send session join response to ", conn.RemoteAddr(), "for", ses)
} }
conn.Close()
return return
} }
conn.SetDeadline(time.Time{})
if err := conn.SetDeadline(time.Time{}); err != nil {
if debug {
log.Println("Weird error setting deadline:", err, "on", conn.RemoteAddr())
}
return
}
default: default:
if debug { if debug {
log.Println("Unexpected message from", conn.RemoteAddr(), message) log.Println("Unexpected message from", conn.RemoteAddr(), message)
} }
protocol.WriteMessage(conn, protocol.ResponseUnexpectedMessage) protocol.WriteMessage(conn, protocol.ResponseUnexpectedMessage)
conn.Close()
} }
} }