mirror of
https://github.com/octoleo/syncthing.git
synced 2024-12-23 11:28:59 +00:00
commit
efa0a06947
@ -14,27 +14,6 @@ import (
|
|||||||
"github.com/syncthing/relaysrv/protocol"
|
"github.com/syncthing/relaysrv/protocol"
|
||||||
)
|
)
|
||||||
|
|
||||||
func NewProtocolClient(uri *url.URL, certs []tls.Certificate, invitations chan protocol.SessionInvitation) ProtocolClient {
|
|
||||||
closeInvitationsOnFinish := false
|
|
||||||
if invitations == nil {
|
|
||||||
closeInvitationsOnFinish = true
|
|
||||||
invitations = make(chan protocol.SessionInvitation)
|
|
||||||
}
|
|
||||||
return ProtocolClient{
|
|
||||||
URI: uri,
|
|
||||||
Invitations: invitations,
|
|
||||||
|
|
||||||
closeInvitationsOnFinish: closeInvitationsOnFinish,
|
|
||||||
|
|
||||||
config: configForCerts(certs),
|
|
||||||
|
|
||||||
timeout: time.Minute * 2,
|
|
||||||
|
|
||||||
stop: make(chan struct{}),
|
|
||||||
stopped: make(chan struct{}),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
type ProtocolClient struct {
|
type ProtocolClient struct {
|
||||||
URI *url.URL
|
URI *url.URL
|
||||||
Invitations chan protocol.SessionInvitation
|
Invitations chan protocol.SessionInvitation
|
||||||
@ -51,6 +30,129 @@ type ProtocolClient struct {
|
|||||||
conn *tls.Conn
|
conn *tls.Conn
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func NewProtocolClient(uri *url.URL, certs []tls.Certificate, invitations chan protocol.SessionInvitation) *ProtocolClient {
|
||||||
|
closeInvitationsOnFinish := false
|
||||||
|
if invitations == nil {
|
||||||
|
closeInvitationsOnFinish = true
|
||||||
|
invitations = make(chan protocol.SessionInvitation)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &ProtocolClient{
|
||||||
|
URI: uri,
|
||||||
|
Invitations: invitations,
|
||||||
|
|
||||||
|
closeInvitationsOnFinish: closeInvitationsOnFinish,
|
||||||
|
|
||||||
|
config: configForCerts(certs),
|
||||||
|
|
||||||
|
timeout: time.Minute * 2,
|
||||||
|
|
||||||
|
stop: make(chan struct{}),
|
||||||
|
stopped: make(chan struct{}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ProtocolClient) Serve() {
|
||||||
|
c.stop = make(chan struct{})
|
||||||
|
c.stopped = make(chan struct{})
|
||||||
|
defer close(c.stopped)
|
||||||
|
|
||||||
|
if err := c.connect(); err != nil {
|
||||||
|
l.Infoln("Relay connect:", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if debug {
|
||||||
|
l.Debugln(c, "connected", c.conn.RemoteAddr())
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := c.join(); err != nil {
|
||||||
|
c.conn.Close()
|
||||||
|
l.Infoln("Relay join:", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := c.conn.SetDeadline(time.Time{}); err != nil {
|
||||||
|
l.Infoln("Relay set deadline:", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if debug {
|
||||||
|
l.Debugln(c, "joined", c.conn.RemoteAddr(), "via", c.conn.LocalAddr())
|
||||||
|
}
|
||||||
|
|
||||||
|
defer c.cleanup()
|
||||||
|
|
||||||
|
messages := make(chan interface{})
|
||||||
|
errors := make(chan error, 1)
|
||||||
|
|
||||||
|
go messageReader(c.conn, messages, errors)
|
||||||
|
|
||||||
|
timeout := time.NewTimer(c.timeout)
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case message := <-messages:
|
||||||
|
timeout.Reset(c.timeout)
|
||||||
|
if debug {
|
||||||
|
log.Printf("%s received message %T", c, message)
|
||||||
|
}
|
||||||
|
|
||||||
|
switch msg := message.(type) {
|
||||||
|
case protocol.Ping:
|
||||||
|
if err := protocol.WriteMessage(c.conn, protocol.Pong{}); err != nil {
|
||||||
|
l.Infoln("Relay write:", err)
|
||||||
|
return
|
||||||
|
|
||||||
|
}
|
||||||
|
if debug {
|
||||||
|
l.Debugln(c, "sent pong")
|
||||||
|
}
|
||||||
|
|
||||||
|
case protocol.SessionInvitation:
|
||||||
|
ip := net.IP(msg.Address)
|
||||||
|
if len(ip) == 0 || ip.IsUnspecified() {
|
||||||
|
msg.Address = c.conn.RemoteAddr().(*net.TCPAddr).IP[:]
|
||||||
|
}
|
||||||
|
c.Invitations <- msg
|
||||||
|
|
||||||
|
default:
|
||||||
|
l.Infoln("Relay: protocol error: unexpected message %v", msg)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
case <-c.stop:
|
||||||
|
if debug {
|
||||||
|
l.Debugln(c, "stopping")
|
||||||
|
}
|
||||||
|
return
|
||||||
|
|
||||||
|
case err := <-errors:
|
||||||
|
l.Infoln("Relay received:", err)
|
||||||
|
return
|
||||||
|
|
||||||
|
case <-timeout.C:
|
||||||
|
if debug {
|
||||||
|
l.Debugln(c, "timed out")
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ProtocolClient) Stop() {
|
||||||
|
if c.stop == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
close(c.stop)
|
||||||
|
<-c.stopped
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ProtocolClient) String() string {
|
||||||
|
return fmt.Sprintf("ProtocolClient@%p", c)
|
||||||
|
}
|
||||||
|
|
||||||
func (c *ProtocolClient) connect() error {
|
func (c *ProtocolClient) connect() error {
|
||||||
if c.URI.Scheme != "relay" {
|
if c.URI.Scheme != "relay" {
|
||||||
return fmt.Errorf("Unsupported relay schema:", c.URI.Scheme)
|
return fmt.Errorf("Unsupported relay schema:", c.URI.Scheme)
|
||||||
@ -61,9 +163,13 @@ func (c *ProtocolClient) connect() error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
conn.SetDeadline(time.Now().Add(10 * time.Second))
|
if err := conn.SetDeadline(time.Now().Add(10 * time.Second)); err != nil {
|
||||||
|
conn.Close()
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
if err := performHandshakeAndValidation(conn, c.URI); err != nil {
|
if err := performHandshakeAndValidation(conn, c.URI); err != nil {
|
||||||
|
conn.Close()
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -71,101 +177,6 @@ func (c *ProtocolClient) connect() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *ProtocolClient) Serve() {
|
|
||||||
if err := c.connect(); err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if debug {
|
|
||||||
l.Debugln(c, "connected", c.conn.RemoteAddr())
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := c.join(); err != nil {
|
|
||||||
c.conn.Close()
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
c.conn.SetDeadline(time.Time{})
|
|
||||||
|
|
||||||
if debug {
|
|
||||||
l.Debugln(c, "joined", c.conn.RemoteAddr(), "via", c.conn.LocalAddr())
|
|
||||||
}
|
|
||||||
|
|
||||||
c.stop = make(chan struct{})
|
|
||||||
c.stopped = make(chan struct{})
|
|
||||||
|
|
||||||
defer c.cleanup()
|
|
||||||
|
|
||||||
messages := make(chan interface{})
|
|
||||||
errors := make(chan error, 1)
|
|
||||||
|
|
||||||
go func(conn net.Conn, message chan<- interface{}, errors chan<- error) {
|
|
||||||
for {
|
|
||||||
msg, err := protocol.ReadMessage(conn)
|
|
||||||
if err != nil {
|
|
||||||
errors <- err
|
|
||||||
return
|
|
||||||
}
|
|
||||||
messages <- msg
|
|
||||||
}
|
|
||||||
}(c.conn, messages, errors)
|
|
||||||
|
|
||||||
timeout := time.NewTimer(c.timeout)
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case message := <-messages:
|
|
||||||
timeout.Reset(c.timeout)
|
|
||||||
if debug {
|
|
||||||
log.Printf("%s received message %T", c, message)
|
|
||||||
}
|
|
||||||
switch msg := message.(type) {
|
|
||||||
case protocol.Ping:
|
|
||||||
if err := protocol.WriteMessage(c.conn, protocol.Pong{}); err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
if debug {
|
|
||||||
l.Debugln(c, "sent pong")
|
|
||||||
}
|
|
||||||
case protocol.SessionInvitation:
|
|
||||||
ip := net.IP(msg.Address)
|
|
||||||
if len(ip) == 0 || ip.IsUnspecified() {
|
|
||||||
msg.Address = c.conn.RemoteAddr().(*net.TCPAddr).IP[:]
|
|
||||||
}
|
|
||||||
c.Invitations <- msg
|
|
||||||
default:
|
|
||||||
panic(fmt.Errorf("protocol error: unexpected message %v", msg))
|
|
||||||
}
|
|
||||||
case <-c.stop:
|
|
||||||
if debug {
|
|
||||||
l.Debugln(c, "stopping")
|
|
||||||
}
|
|
||||||
break
|
|
||||||
case err := <-errors:
|
|
||||||
panic(err)
|
|
||||||
case <-timeout.C:
|
|
||||||
if debug {
|
|
||||||
l.Debugln(c, "timed out")
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
c.stopped <- struct{}{}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *ProtocolClient) Stop() {
|
|
||||||
if c.stop == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
c.stop <- struct{}{}
|
|
||||||
<-c.stopped
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *ProtocolClient) String() string {
|
|
||||||
return fmt.Sprintf("ProtocolClient@%p", c)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *ProtocolClient) cleanup() {
|
func (c *ProtocolClient) cleanup() {
|
||||||
if c.closeInvitationsOnFinish {
|
if c.closeInvitationsOnFinish {
|
||||||
close(c.Invitations)
|
close(c.Invitations)
|
||||||
@ -176,24 +187,11 @@ func (c *ProtocolClient) cleanup() {
|
|||||||
l.Debugln(c, "cleaning up")
|
l.Debugln(c, "cleaning up")
|
||||||
}
|
}
|
||||||
|
|
||||||
if c.stop != nil {
|
|
||||||
close(c.stop)
|
|
||||||
c.stop = nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if c.stopped != nil {
|
|
||||||
close(c.stopped)
|
|
||||||
c.stopped = nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if c.conn != nil {
|
|
||||||
c.conn.Close()
|
c.conn.Close()
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *ProtocolClient) join() error {
|
func (c *ProtocolClient) join() error {
|
||||||
err := protocol.WriteMessage(c.conn, protocol.JoinRelayRequest{})
|
if err := protocol.WriteMessage(c.conn, protocol.JoinRelayRequest{}); err != nil {
|
||||||
if err != nil {
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -207,6 +205,7 @@ func (c *ProtocolClient) join() error {
|
|||||||
if msg.Code != 0 {
|
if msg.Code != 0 {
|
||||||
return fmt.Errorf("Incorrect response code %d: %s", msg.Code, msg.Message)
|
return fmt.Errorf("Incorrect response code %d: %s", msg.Code, msg.Message)
|
||||||
}
|
}
|
||||||
|
|
||||||
default:
|
default:
|
||||||
return fmt.Errorf("protocol error: expecting response got %v", msg)
|
return fmt.Errorf("protocol error: expecting response got %v", msg)
|
||||||
}
|
}
|
||||||
@ -215,15 +214,12 @@ func (c *ProtocolClient) join() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func performHandshakeAndValidation(conn *tls.Conn, uri *url.URL) error {
|
func performHandshakeAndValidation(conn *tls.Conn, uri *url.URL) error {
|
||||||
err := conn.Handshake()
|
if err := conn.Handshake(); err != nil {
|
||||||
if err != nil {
|
|
||||||
conn.Close()
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
cs := conn.ConnectionState()
|
cs := conn.ConnectionState()
|
||||||
if !cs.NegotiatedProtocolIsMutual || cs.NegotiatedProtocol != protocol.ProtocolName {
|
if !cs.NegotiatedProtocolIsMutual || cs.NegotiatedProtocol != protocol.ProtocolName {
|
||||||
conn.Close()
|
|
||||||
return fmt.Errorf("protocol negotiation error")
|
return fmt.Errorf("protocol negotiation error")
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -232,22 +228,30 @@ func performHandshakeAndValidation(conn *tls.Conn, uri *url.URL) error {
|
|||||||
if relayIDs != "" {
|
if relayIDs != "" {
|
||||||
relayID, err := syncthingprotocol.DeviceIDFromString(relayIDs)
|
relayID, err := syncthingprotocol.DeviceIDFromString(relayIDs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
conn.Close()
|
|
||||||
return fmt.Errorf("relay address contains invalid verification id: %s", err)
|
return fmt.Errorf("relay address contains invalid verification id: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
certs := cs.PeerCertificates
|
certs := cs.PeerCertificates
|
||||||
if cl := len(certs); cl != 1 {
|
if cl := len(certs); cl != 1 {
|
||||||
conn.Close()
|
|
||||||
return fmt.Errorf("unexpected certificate count: %d", cl)
|
return fmt.Errorf("unexpected certificate count: %d", cl)
|
||||||
}
|
}
|
||||||
|
|
||||||
remoteID := syncthingprotocol.NewDeviceID(certs[0].Raw)
|
remoteID := syncthingprotocol.NewDeviceID(certs[0].Raw)
|
||||||
if remoteID != relayID {
|
if remoteID != relayID {
|
||||||
conn.Close()
|
|
||||||
return fmt.Errorf("relay id does not match. Expected %v got %v", relayID, remoteID)
|
return fmt.Errorf("relay id does not match. Expected %v got %v", relayID, remoteID)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func messageReader(conn net.Conn, messages chan<- interface{}, errors chan<- error) {
|
||||||
|
for {
|
||||||
|
msg, err := protocol.ReadMessage(conn)
|
||||||
|
if err != nil {
|
||||||
|
errors <- err
|
||||||
|
return
|
||||||
|
}
|
||||||
|
messages <- msg
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -27,7 +27,6 @@ func protocolListener(addr string, config *tls.Config) {
|
|||||||
|
|
||||||
for {
|
for {
|
||||||
conn, err := listener.Accept()
|
conn, err := listener.Accept()
|
||||||
setTCPOptions(conn)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if debug {
|
if debug {
|
||||||
log.Println(err)
|
log.Println(err)
|
||||||
@ -35,6 +34,8 @@ func protocolListener(addr string, config *tls.Config) {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
setTCPOptions(conn)
|
||||||
|
|
||||||
if debug {
|
if debug {
|
||||||
log.Println("Protocol listener accepted connection from", conn.RemoteAddr())
|
log.Println("Protocol listener accepted connection from", conn.RemoteAddr())
|
||||||
}
|
}
|
||||||
@ -74,16 +75,12 @@ func protocolConnectionHandler(tcpConn net.Conn, config *tls.Config) {
|
|||||||
errors := make(chan error, 1)
|
errors := make(chan error, 1)
|
||||||
outbox := make(chan interface{})
|
outbox := make(chan interface{})
|
||||||
|
|
||||||
go func(conn net.Conn, message chan<- interface{}, errors chan<- error) {
|
// Read messages from the connection and send them on the messages
|
||||||
for {
|
// channel. When there is an error, send it on the error channel and
|
||||||
msg, err := protocol.ReadMessage(conn)
|
// return. Applies also when the connection gets closed, so the pattern
|
||||||
if err != nil {
|
// below is to close the connection on error, then wait for the error
|
||||||
errors <- err
|
// signal from messageReader to exit.
|
||||||
return
|
go messageReader(conn, messages, errors)
|
||||||
}
|
|
||||||
messages <- msg
|
|
||||||
}
|
|
||||||
}(conn, messages, errors)
|
|
||||||
|
|
||||||
pingTicker := time.NewTicker(pingInterval)
|
pingTicker := time.NewTicker(pingInterval)
|
||||||
timeoutTicker := time.NewTimer(networkTimeout)
|
timeoutTicker := time.NewTimer(networkTimeout)
|
||||||
@ -96,6 +93,7 @@ func protocolConnectionHandler(tcpConn net.Conn, config *tls.Config) {
|
|||||||
if debug {
|
if debug {
|
||||||
log.Printf("Message %T from %s", message, id)
|
log.Printf("Message %T from %s", message, id)
|
||||||
}
|
}
|
||||||
|
|
||||||
switch msg := message.(type) {
|
switch msg := message.(type) {
|
||||||
case protocol.JoinRelayRequest:
|
case protocol.JoinRelayRequest:
|
||||||
outboxesMut.RLock()
|
outboxesMut.RLock()
|
||||||
@ -116,6 +114,7 @@ func protocolConnectionHandler(tcpConn net.Conn, config *tls.Config) {
|
|||||||
joined = true
|
joined = true
|
||||||
|
|
||||||
protocol.WriteMessage(conn, protocol.ResponseSuccess)
|
protocol.WriteMessage(conn, protocol.ResponseSuccess)
|
||||||
|
|
||||||
case protocol.ConnectRequest:
|
case protocol.ConnectRequest:
|
||||||
requestedPeer := syncthingprotocol.DeviceIDFromBytes(msg.ID)
|
requestedPeer := syncthingprotocol.DeviceIDFromBytes(msg.ID)
|
||||||
outboxesMut.RLock()
|
outboxesMut.RLock()
|
||||||
@ -151,7 +150,10 @@ func protocolConnectionHandler(tcpConn net.Conn, config *tls.Config) {
|
|||||||
log.Println("Sent invitation from", id, "to", requestedPeer)
|
log.Println("Sent invitation from", id, "to", requestedPeer)
|
||||||
}
|
}
|
||||||
conn.Close()
|
conn.Close()
|
||||||
|
|
||||||
case protocol.Pong:
|
case protocol.Pong:
|
||||||
|
// Nothing
|
||||||
|
|
||||||
default:
|
default:
|
||||||
if debug {
|
if debug {
|
||||||
log.Printf("Unknown message %s: %T", id, message)
|
log.Printf("Unknown message %s: %T", id, message)
|
||||||
@ -159,21 +161,25 @@ func protocolConnectionHandler(tcpConn net.Conn, config *tls.Config) {
|
|||||||
protocol.WriteMessage(conn, protocol.ResponseUnexpectedMessage)
|
protocol.WriteMessage(conn, protocol.ResponseUnexpectedMessage)
|
||||||
conn.Close()
|
conn.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
case err := <-errors:
|
case err := <-errors:
|
||||||
if debug {
|
if debug {
|
||||||
log.Printf("Closing connection %s: %s", id, err)
|
log.Printf("Closing connection %s: %s", id, err)
|
||||||
}
|
}
|
||||||
// Potentially closing a second time.
|
|
||||||
close(outbox)
|
close(outbox)
|
||||||
|
|
||||||
|
// Potentially closing a second time.
|
||||||
conn.Close()
|
conn.Close()
|
||||||
// Only delete the outbox if the client join, as it migth be a
|
|
||||||
// lookup request coming from the same client.
|
// Only delete the outbox if the client is joined, as it might be
|
||||||
|
// a lookup request coming from the same client.
|
||||||
if joined {
|
if joined {
|
||||||
outboxesMut.Lock()
|
outboxesMut.Lock()
|
||||||
delete(outboxes, id)
|
delete(outboxes, id)
|
||||||
outboxesMut.Unlock()
|
outboxesMut.Unlock()
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
|
|
||||||
case <-pingTicker.C:
|
case <-pingTicker.C:
|
||||||
if !joined {
|
if !joined {
|
||||||
if debug {
|
if debug {
|
||||||
@ -189,6 +195,7 @@ func protocolConnectionHandler(tcpConn net.Conn, config *tls.Config) {
|
|||||||
}
|
}
|
||||||
conn.Close()
|
conn.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
case <-timeoutTicker.C:
|
case <-timeoutTicker.C:
|
||||||
// We should receive a error from the reader loop, which will cause
|
// We should receive a error from the reader loop, which will cause
|
||||||
// us to quit this loop.
|
// us to quit this loop.
|
||||||
@ -196,6 +203,7 @@ func protocolConnectionHandler(tcpConn net.Conn, config *tls.Config) {
|
|||||||
log.Printf("%s timed out", id)
|
log.Printf("%s timed out", id)
|
||||||
}
|
}
|
||||||
conn.Close()
|
conn.Close()
|
||||||
|
|
||||||
case msg := <-outbox:
|
case msg := <-outbox:
|
||||||
if debug {
|
if debug {
|
||||||
log.Printf("Sending message %T to %s", msg, id)
|
log.Printf("Sending message %T to %s", msg, id)
|
||||||
@ -209,3 +217,14 @@ func protocolConnectionHandler(tcpConn net.Conn, config *tls.Config) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func messageReader(conn net.Conn, messages chan<- interface{}, errors chan<- error) {
|
||||||
|
for {
|
||||||
|
msg, err := protocol.ReadMessage(conn)
|
||||||
|
if err != nil {
|
||||||
|
errors <- err
|
||||||
|
return
|
||||||
|
}
|
||||||
|
messages <- msg
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -18,7 +18,6 @@ func sessionListener(addr string) {
|
|||||||
|
|
||||||
for {
|
for {
|
||||||
conn, err := listener.Accept()
|
conn, err := listener.Accept()
|
||||||
setTCPOptions(conn)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if debug {
|
if debug {
|
||||||
log.Println(err)
|
log.Println(err)
|
||||||
@ -26,6 +25,8 @@ func sessionListener(addr string) {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
setTCPOptions(conn)
|
||||||
|
|
||||||
if debug {
|
if debug {
|
||||||
log.Println("Session listener accepted connection from", conn.RemoteAddr())
|
log.Println("Session listener accepted connection from", conn.RemoteAddr())
|
||||||
}
|
}
|
||||||
@ -35,10 +36,17 @@ func sessionListener(addr string) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func sessionConnectionHandler(conn net.Conn) {
|
func sessionConnectionHandler(conn net.Conn) {
|
||||||
conn.SetDeadline(time.Now().Add(messageTimeout))
|
defer conn.Close()
|
||||||
|
|
||||||
|
if err := conn.SetDeadline(time.Now().Add(messageTimeout)); err != nil {
|
||||||
|
if debug {
|
||||||
|
log.Println("Weird error setting deadline:", err, "on", conn.RemoteAddr())
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
message, err := protocol.ReadMessage(conn)
|
message, err := protocol.ReadMessage(conn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
conn.Close()
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -51,7 +59,6 @@ func sessionConnectionHandler(conn net.Conn) {
|
|||||||
|
|
||||||
if ses == nil {
|
if ses == nil {
|
||||||
protocol.WriteMessage(conn, protocol.ResponseNotFound)
|
protocol.WriteMessage(conn, protocol.ResponseNotFound)
|
||||||
conn.Close()
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -60,24 +67,26 @@ func sessionConnectionHandler(conn net.Conn) {
|
|||||||
log.Println("Failed to add", conn.RemoteAddr(), "to session", ses)
|
log.Println("Failed to add", conn.RemoteAddr(), "to session", ses)
|
||||||
}
|
}
|
||||||
protocol.WriteMessage(conn, protocol.ResponseAlreadyConnected)
|
protocol.WriteMessage(conn, protocol.ResponseAlreadyConnected)
|
||||||
conn.Close()
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
err := protocol.WriteMessage(conn, protocol.ResponseSuccess)
|
if err := protocol.WriteMessage(conn, protocol.ResponseSuccess); err != nil {
|
||||||
if err != nil {
|
|
||||||
if debug {
|
if debug {
|
||||||
log.Println("Failed to send session join response to ", conn.RemoteAddr(), "for", ses)
|
log.Println("Failed to send session join response to ", conn.RemoteAddr(), "for", ses)
|
||||||
}
|
}
|
||||||
conn.Close()
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
conn.SetDeadline(time.Time{})
|
|
||||||
|
if err := conn.SetDeadline(time.Time{}); err != nil {
|
||||||
|
if debug {
|
||||||
|
log.Println("Weird error setting deadline:", err, "on", conn.RemoteAddr())
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
default:
|
default:
|
||||||
if debug {
|
if debug {
|
||||||
log.Println("Unexpected message from", conn.RemoteAddr(), message)
|
log.Println("Unexpected message from", conn.RemoteAddr(), message)
|
||||||
}
|
}
|
||||||
protocol.WriteMessage(conn, protocol.ResponseUnexpectedMessage)
|
protocol.WriteMessage(conn, protocol.ResponseUnexpectedMessage)
|
||||||
conn.Close()
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user