diff --git a/cmd/relaysrv/protocol_listener.go b/cmd/relaysrv/protocol_listener.go index 8825af827..a7243ff69 100644 --- a/cmd/relaysrv/protocol_listener.go +++ b/cmd/relaysrv/protocol_listener.go @@ -27,7 +27,6 @@ func protocolListener(addr string, config *tls.Config) { for { conn, err := listener.Accept() - setTCPOptions(conn) if err != nil { if debug { log.Println(err) @@ -35,6 +34,8 @@ func protocolListener(addr string, config *tls.Config) { continue } + setTCPOptions(conn) + if debug { log.Println("Protocol listener accepted connection from", conn.RemoteAddr()) } @@ -74,16 +75,12 @@ func protocolConnectionHandler(tcpConn net.Conn, config *tls.Config) { errors := make(chan error, 1) outbox := make(chan interface{}) - go func(conn net.Conn, message chan<- interface{}, errors chan<- error) { - for { - msg, err := protocol.ReadMessage(conn) - if err != nil { - errors <- err - return - } - messages <- msg - } - }(conn, messages, errors) + // Read messages from the connection and send them on the messages + // channel. When there is an error, send it on the error channel and + // return. Applies also when the connection gets closed, so the pattern + // below is to close the connection on error, then wait for the error + // signal from messageReader to exit. + go messageReader(conn, messages, errors) pingTicker := time.NewTicker(pingInterval) timeoutTicker := time.NewTimer(networkTimeout) @@ -96,6 +93,7 @@ func protocolConnectionHandler(tcpConn net.Conn, config *tls.Config) { if debug { log.Printf("Message %T from %s", message, id) } + switch msg := message.(type) { case protocol.JoinRelayRequest: outboxesMut.RLock() @@ -116,6 +114,7 @@ func protocolConnectionHandler(tcpConn net.Conn, config *tls.Config) { joined = true protocol.WriteMessage(conn, protocol.ResponseSuccess) + case protocol.ConnectRequest: requestedPeer := syncthingprotocol.DeviceIDFromBytes(msg.ID) outboxesMut.RLock() @@ -151,7 +150,10 @@ func protocolConnectionHandler(tcpConn net.Conn, config *tls.Config) { log.Println("Sent invitation from", id, "to", requestedPeer) } conn.Close() + case protocol.Pong: + // Nothing + default: if debug { log.Printf("Unknown message %s: %T", id, message) @@ -159,21 +161,25 @@ func protocolConnectionHandler(tcpConn net.Conn, config *tls.Config) { protocol.WriteMessage(conn, protocol.ResponseUnexpectedMessage) conn.Close() } + case err := <-errors: if debug { log.Printf("Closing connection %s: %s", id, err) } - // Potentially closing a second time. close(outbox) + + // Potentially closing a second time. conn.Close() - // Only delete the outbox if the client join, as it migth be a - // lookup request coming from the same client. + + // Only delete the outbox if the client is joined, as it might be + // a lookup request coming from the same client. if joined { outboxesMut.Lock() delete(outboxes, id) outboxesMut.Unlock() } return + case <-pingTicker.C: if !joined { if debug { @@ -189,6 +195,7 @@ func protocolConnectionHandler(tcpConn net.Conn, config *tls.Config) { } conn.Close() } + case <-timeoutTicker.C: // We should receive a error from the reader loop, which will cause // us to quit this loop. @@ -196,6 +203,7 @@ func protocolConnectionHandler(tcpConn net.Conn, config *tls.Config) { log.Printf("%s timed out", id) } conn.Close() + case msg := <-outbox: if debug { log.Printf("Sending message %T to %s", msg, id) @@ -209,3 +217,14 @@ func protocolConnectionHandler(tcpConn net.Conn, config *tls.Config) { } } } + +func messageReader(conn net.Conn, messages chan<- interface{}, errors chan<- error) { + for { + msg, err := protocol.ReadMessage(conn) + if err != nil { + errors <- err + return + } + messages <- msg + } +} diff --git a/cmd/relaysrv/session_listener.go b/cmd/relaysrv/session_listener.go index 6159ceef5..2f6bae9ab 100644 --- a/cmd/relaysrv/session_listener.go +++ b/cmd/relaysrv/session_listener.go @@ -18,7 +18,6 @@ func sessionListener(addr string) { for { conn, err := listener.Accept() - setTCPOptions(conn) if err != nil { if debug { log.Println(err) @@ -26,6 +25,8 @@ func sessionListener(addr string) { continue } + setTCPOptions(conn) + if debug { log.Println("Session listener accepted connection from", conn.RemoteAddr()) } @@ -35,10 +36,17 @@ func sessionListener(addr string) { } func sessionConnectionHandler(conn net.Conn) { - conn.SetDeadline(time.Now().Add(messageTimeout)) + defer conn.Close() + + if err := conn.SetDeadline(time.Now().Add(messageTimeout)); err != nil { + if debug { + log.Println("Weird error setting deadline:", err, "on", conn.RemoteAddr()) + } + return + } + message, err := protocol.ReadMessage(conn) if err != nil { - conn.Close() return } @@ -51,7 +59,6 @@ func sessionConnectionHandler(conn net.Conn) { if ses == nil { protocol.WriteMessage(conn, protocol.ResponseNotFound) - conn.Close() return } @@ -60,24 +67,26 @@ func sessionConnectionHandler(conn net.Conn) { log.Println("Failed to add", conn.RemoteAddr(), "to session", ses) } protocol.WriteMessage(conn, protocol.ResponseAlreadyConnected) - conn.Close() return } - err := protocol.WriteMessage(conn, protocol.ResponseSuccess) - if err != nil { + if err := protocol.WriteMessage(conn, protocol.ResponseSuccess); err != nil { if debug { log.Println("Failed to send session join response to ", conn.RemoteAddr(), "for", ses) } - conn.Close() return } - conn.SetDeadline(time.Time{}) + + if err := conn.SetDeadline(time.Time{}); err != nil { + if debug { + log.Println("Weird error setting deadline:", err, "on", conn.RemoteAddr()) + } + return + } default: if debug { log.Println("Unexpected message from", conn.RemoteAddr(), message) } protocol.WriteMessage(conn, protocol.ResponseUnexpectedMessage) - conn.Close() } }