diff --git a/cmd/relaysrv/listener.go b/cmd/relaysrv/listener.go index 2091b1dd6..9dde76273 100644 --- a/cmd/relaysrv/listener.go +++ b/cmd/relaysrv/listener.go @@ -4,6 +4,7 @@ package main import ( "crypto/tls" + "encoding/hex" "log" "net" "sync" @@ -34,7 +35,7 @@ func listener(addr string, config *tls.Config) { conn, isTLS, err := listener.AcceptNoWrapTLS() if err != nil { if debug { - log.Println(err) + log.Println("Listener failed to accept connection from", conn.RemoteAddr(), ". Possibly a TCP Ping.") } continue } @@ -138,13 +139,13 @@ func protocolConnectionHandler(tcpConn net.Conn, config *tls.Config) { conn.Close() continue } - - ses := newSession(sessionLimiter, globalLimiter) + // requestedPeer is the server, id is the client + ses := newSession(requestedPeer, id, sessionLimiter, globalLimiter) go ses.Serve() - clientInvitation := ses.GetClientInvitationMessage(requestedPeer) - serverInvitation := ses.GetServerInvitationMessage(id) + clientInvitation := ses.GetClientInvitationMessage() + serverInvitation := ses.GetServerInvitationMessage() if err := protocol.WriteMessage(conn, clientInvitation); err != nil { if debug { @@ -181,12 +182,19 @@ func protocolConnectionHandler(tcpConn net.Conn, config *tls.Config) { // Potentially closing a second time. conn.Close() - // Only delete the outbox if the client is joined, as it might be - // a lookup request coming from the same client. if joined { + // Only delete the outbox if the client is joined, as it might be + // a lookup request coming from the same client. outboxesMut.Lock() delete(outboxes, id) outboxesMut.Unlock() + // Also, kill all sessions related to this node, as it probably + // went offline. This is for the other end to realize the client + // is no longer there faster. This also helps resolve + // 'already connected' errors when one of the sides is + // restarting, and connecting to the other peer before the other + // peer even realised that the node has gone away. + dropSessions(id) } return @@ -245,7 +253,7 @@ func sessionConnectionHandler(conn net.Conn) { case protocol.JoinSessionRequest: ses := findSession(string(msg.Key)) if debug { - log.Println(conn.RemoteAddr(), "session lookup", ses) + log.Println(conn.RemoteAddr(), "session lookup", ses, hex.EncodeToString(msg.Key)[:5]) } if ses == nil { diff --git a/cmd/relaysrv/main.go b/cmd/relaysrv/main.go index 1da3c455a..614f82c7f 100644 --- a/cmd/relaysrv/main.go +++ b/cmd/relaysrv/main.go @@ -42,6 +42,8 @@ var ( ) func main() { + log.SetFlags(log.Lshortfile | log.LstdFlags) + var dir, extAddress string flag.StringVar(&listen, "listen", ":22067", "Protocol listen address") diff --git a/cmd/relaysrv/session.go b/cmd/relaysrv/session.go index c94cdc2a0..bbd29d1f5 100644 --- a/cmd/relaysrv/session.go +++ b/cmd/relaysrv/session.go @@ -19,22 +19,14 @@ import ( ) var ( - sessionMut = sync.Mutex{} - sessions = make(map[string]*session, 0) - numProxies int64 - bytesProxied int64 + sessionMut = sync.RWMutex{} + activeSessions = make([]*session, 0) + pendingSessions = make(map[string]*session, 0) + numProxies int64 + bytesProxied int64 ) -type session struct { - serverkey []byte - clientkey []byte - - rateLimit func(bytes int64) - - conns chan net.Conn -} - -func newSession(sessionRateLimit, globalRateLimit *ratelimit.Bucket) *session { +func newSession(serverid, clientid syncthingprotocol.DeviceID, sessionRateLimit, globalRateLimit *ratelimit.Bucket) *session { serverkey := make([]byte, 32) _, err := rand.Read(serverkey) if err != nil { @@ -49,9 +41,12 @@ func newSession(sessionRateLimit, globalRateLimit *ratelimit.Bucket) *session { ses := &session{ serverkey: serverkey, + serverid: serverid, clientkey: clientkey, + clientid: clientid, rateLimit: makeRateLimitFunc(sessionRateLimit, globalRateLimit), - conns: make(chan net.Conn), + connsChan: make(chan net.Conn), + conns: make([]net.Conn, 0, 2), } if debug { @@ -59,8 +54,8 @@ func newSession(sessionRateLimit, globalRateLimit *ratelimit.Bucket) *session { } sessionMut.Lock() - sessions[string(ses.serverkey)] = ses - sessions[string(ses.clientkey)] = ses + pendingSessions[string(ses.serverkey)] = ses + pendingSessions[string(ses.clientkey)] = ses sessionMut.Unlock() return ses @@ -69,13 +64,41 @@ func newSession(sessionRateLimit, globalRateLimit *ratelimit.Bucket) *session { func findSession(key string) *session { sessionMut.Lock() defer sessionMut.Unlock() - lob, ok := sessions[key] + ses, ok := pendingSessions[key] if !ok { return nil } - delete(sessions, key) - return lob + delete(pendingSessions, key) + return ses +} + +func dropSessions(id syncthingprotocol.DeviceID) { + sessionMut.RLock() + for _, session := range activeSessions { + if session.HasParticipant(id) { + if debug { + log.Println("Dropping session", session, "involving", id) + } + session.CloseConns() + } + } + sessionMut.RUnlock() +} + +type session struct { + mut sync.Mutex + + serverkey []byte + serverid syncthingprotocol.DeviceID + + clientkey []byte + clientid syncthingprotocol.DeviceID + + rateLimit func(bytes int64) + + connsChan chan net.Conn + conns []net.Conn } func (s *session) AddConnection(conn net.Conn) bool { @@ -84,7 +107,7 @@ func (s *session) AddConnection(conn net.Conn) bool { } select { - case s.conns <- conn: + case s.connsChan <- conn: return true default: } @@ -98,19 +121,21 @@ func (s *session) Serve() { log.Println("Session", s, "serving") } - conns := make([]net.Conn, 0, 2) for { select { - case conn := <-s.conns: - conns = append(conns, conn) - if len(conns) < 2 { + case conn := <-s.connsChan: + s.mut.Lock() + s.conns = append(s.conns, conn) + s.mut.Unlock() + // We're the only ones mutating% s.conns, hence we are free to read it. + if len(s.conns) < 2 { continue } - close(s.conns) + close(s.connsChan) if debug { - log.Println("Session", s, "starting between", conns[0].RemoteAddr(), "and", conns[1].RemoteAddr()) + log.Println("Session", s, "starting between", s.conns[0].RemoteAddr(), "and", s.conns[1].RemoteAddr()) } wg := sync.WaitGroup{} @@ -118,16 +143,20 @@ func (s *session) Serve() { var err0 error go func() { - err0 = s.proxy(conns[0], conns[1]) + err0 = s.proxy(s.conns[0], s.conns[1]) wg.Done() }() var err1 error go func() { - err1 = s.proxy(conns[1], conns[0]) + err1 = s.proxy(s.conns[1], s.conns[0]) wg.Done() }() + sessionMut.Lock() + activeSessions = append(activeSessions, s) + sessionMut.Unlock() + wg.Wait() if debug { @@ -143,23 +172,37 @@ func (s *session) Serve() { } } done: + // We can end up here in 3 cases: + // 1. Timeout joining, in which case there are potentially entries in pendingSessions + // 2. General session end/timeout, in which case there are entries in activeSessions + // 3. Protocol handler calls dropSession as one of it's clients disconnects. + sessionMut.Lock() - delete(sessions, string(s.serverkey)) - delete(sessions, string(s.clientkey)) + delete(pendingSessions, string(s.serverkey)) + delete(pendingSessions, string(s.clientkey)) + + for i, session := range activeSessions { + if session == s { + l := len(activeSessions) - 1 + activeSessions[i] = activeSessions[l] + activeSessions[l] = nil + activeSessions = activeSessions[:l] + } + } sessionMut.Unlock() - for _, conn := range conns { - conn.Close() - } + // If we are here because of case 2 or 3, we are potentially closing some or + // all connections a second time. + s.CloseConns() if debug { log.Println("Session", s, "stopping") } } -func (s *session) GetClientInvitationMessage(from syncthingprotocol.DeviceID) protocol.SessionInvitation { +func (s *session) GetClientInvitationMessage() protocol.SessionInvitation { return protocol.SessionInvitation{ - From: from[:], + From: s.serverid[:], Key: []byte(s.clientkey), Address: sessionAddress, Port: sessionPort, @@ -167,9 +210,9 @@ func (s *session) GetClientInvitationMessage(from syncthingprotocol.DeviceID) pr } } -func (s *session) GetServerInvitationMessage(from syncthingprotocol.DeviceID) protocol.SessionInvitation { +func (s *session) GetServerInvitationMessage() protocol.SessionInvitation { return protocol.SessionInvitation{ - From: from[:], + From: s.clientid[:], Key: []byte(s.serverkey), Address: sessionAddress, Port: sessionPort, @@ -177,6 +220,18 @@ func (s *session) GetServerInvitationMessage(from syncthingprotocol.DeviceID) pr } } +func (s *session) HasParticipant(id syncthingprotocol.DeviceID) bool { + return s.clientid == id || s.serverid == id +} + +func (s *session) CloseConns() { + s.mut.Lock() + for _, conn := range s.conns { + conn.Close() + } + s.mut.Unlock() +} + func (s *session) proxy(c1, c2 net.Conn) error { if debug { log.Println("Proxy", c1.RemoteAddr(), "->", c2.RemoteAddr()) diff --git a/cmd/relaysrv/status.go b/cmd/relaysrv/status.go index 53cea572c..b18cf3ea7 100644 --- a/cmd/relaysrv/status.go +++ b/cmd/relaysrv/status.go @@ -24,7 +24,9 @@ func getStatus(w http.ResponseWriter, r *http.Request) { status := make(map[string]interface{}) sessionMut.Lock() - status["numSessions"] = len(sessions) + // This can potentially be double the number of pending sessions, as each session has two keys, one for each side. + status["numPendingSessionKeys"] = len(pendingSessions) + status["numActiveSessions"] = len(activeSessions) sessionMut.Unlock() status["numConnections"] = atomic.LoadInt64(&numConnections) status["numProxies"] = atomic.LoadInt64(&numProxies)