Drop all sessions when we realize a node has gone away

This commit is contained in:
Audrius Butkevicius 2015-09-11 22:29:50 +01:00
parent 0b7ab0a095
commit 50f0da6793
4 changed files with 114 additions and 47 deletions

View File

@ -4,6 +4,7 @@ package main
import ( import (
"crypto/tls" "crypto/tls"
"encoding/hex"
"log" "log"
"net" "net"
"sync" "sync"
@ -34,7 +35,7 @@ func listener(addr string, config *tls.Config) {
conn, isTLS, err := listener.AcceptNoWrapTLS() conn, isTLS, err := listener.AcceptNoWrapTLS()
if err != nil { if err != nil {
if debug { if debug {
log.Println(err) log.Println("Listener failed to accept connection from", conn.RemoteAddr(), ". Possibly a TCP Ping.")
} }
continue continue
} }
@ -138,13 +139,13 @@ func protocolConnectionHandler(tcpConn net.Conn, config *tls.Config) {
conn.Close() conn.Close()
continue continue
} }
// requestedPeer is the server, id is the client
ses := newSession(sessionLimiter, globalLimiter) ses := newSession(requestedPeer, id, sessionLimiter, globalLimiter)
go ses.Serve() go ses.Serve()
clientInvitation := ses.GetClientInvitationMessage(requestedPeer) clientInvitation := ses.GetClientInvitationMessage()
serverInvitation := ses.GetServerInvitationMessage(id) serverInvitation := ses.GetServerInvitationMessage()
if err := protocol.WriteMessage(conn, clientInvitation); err != nil { if err := protocol.WriteMessage(conn, clientInvitation); err != nil {
if debug { if debug {
@ -181,12 +182,19 @@ func protocolConnectionHandler(tcpConn net.Conn, config *tls.Config) {
// Potentially closing a second time. // Potentially closing a second time.
conn.Close() conn.Close()
// Only delete the outbox if the client is joined, as it might be
// a lookup request coming from the same client.
if joined { if joined {
// Only delete the outbox if the client is joined, as it might be
// a lookup request coming from the same client.
outboxesMut.Lock() outboxesMut.Lock()
delete(outboxes, id) delete(outboxes, id)
outboxesMut.Unlock() outboxesMut.Unlock()
// Also, kill all sessions related to this node, as it probably
// went offline. This is for the other end to realize the client
// is no longer there faster. This also helps resolve
// 'already connected' errors when one of the sides is
// restarting, and connecting to the other peer before the other
// peer even realised that the node has gone away.
dropSessions(id)
} }
return return
@ -245,7 +253,7 @@ func sessionConnectionHandler(conn net.Conn) {
case protocol.JoinSessionRequest: case protocol.JoinSessionRequest:
ses := findSession(string(msg.Key)) ses := findSession(string(msg.Key))
if debug { if debug {
log.Println(conn.RemoteAddr(), "session lookup", ses) log.Println(conn.RemoteAddr(), "session lookup", ses, hex.EncodeToString(msg.Key)[:5])
} }
if ses == nil { if ses == nil {

View File

@ -42,6 +42,8 @@ var (
) )
func main() { func main() {
log.SetFlags(log.Lshortfile | log.LstdFlags)
var dir, extAddress string var dir, extAddress string
flag.StringVar(&listen, "listen", ":22067", "Protocol listen address") flag.StringVar(&listen, "listen", ":22067", "Protocol listen address")

View File

@ -19,22 +19,14 @@ import (
) )
var ( var (
sessionMut = sync.Mutex{} sessionMut = sync.RWMutex{}
sessions = make(map[string]*session, 0) activeSessions = make([]*session, 0)
numProxies int64 pendingSessions = make(map[string]*session, 0)
bytesProxied int64 numProxies int64
bytesProxied int64
) )
type session struct { func newSession(serverid, clientid syncthingprotocol.DeviceID, sessionRateLimit, globalRateLimit *ratelimit.Bucket) *session {
serverkey []byte
clientkey []byte
rateLimit func(bytes int64)
conns chan net.Conn
}
func newSession(sessionRateLimit, globalRateLimit *ratelimit.Bucket) *session {
serverkey := make([]byte, 32) serverkey := make([]byte, 32)
_, err := rand.Read(serverkey) _, err := rand.Read(serverkey)
if err != nil { if err != nil {
@ -49,9 +41,12 @@ func newSession(sessionRateLimit, globalRateLimit *ratelimit.Bucket) *session {
ses := &session{ ses := &session{
serverkey: serverkey, serverkey: serverkey,
serverid: serverid,
clientkey: clientkey, clientkey: clientkey,
clientid: clientid,
rateLimit: makeRateLimitFunc(sessionRateLimit, globalRateLimit), rateLimit: makeRateLimitFunc(sessionRateLimit, globalRateLimit),
conns: make(chan net.Conn), connsChan: make(chan net.Conn),
conns: make([]net.Conn, 0, 2),
} }
if debug { if debug {
@ -59,8 +54,8 @@ func newSession(sessionRateLimit, globalRateLimit *ratelimit.Bucket) *session {
} }
sessionMut.Lock() sessionMut.Lock()
sessions[string(ses.serverkey)] = ses pendingSessions[string(ses.serverkey)] = ses
sessions[string(ses.clientkey)] = ses pendingSessions[string(ses.clientkey)] = ses
sessionMut.Unlock() sessionMut.Unlock()
return ses return ses
@ -69,13 +64,41 @@ func newSession(sessionRateLimit, globalRateLimit *ratelimit.Bucket) *session {
func findSession(key string) *session { func findSession(key string) *session {
sessionMut.Lock() sessionMut.Lock()
defer sessionMut.Unlock() defer sessionMut.Unlock()
lob, ok := sessions[key] ses, ok := pendingSessions[key]
if !ok { if !ok {
return nil return nil
} }
delete(sessions, key) delete(pendingSessions, key)
return lob return ses
}
func dropSessions(id syncthingprotocol.DeviceID) {
sessionMut.RLock()
for _, session := range activeSessions {
if session.HasParticipant(id) {
if debug {
log.Println("Dropping session", session, "involving", id)
}
session.CloseConns()
}
}
sessionMut.RUnlock()
}
type session struct {
mut sync.Mutex
serverkey []byte
serverid syncthingprotocol.DeviceID
clientkey []byte
clientid syncthingprotocol.DeviceID
rateLimit func(bytes int64)
connsChan chan net.Conn
conns []net.Conn
} }
func (s *session) AddConnection(conn net.Conn) bool { func (s *session) AddConnection(conn net.Conn) bool {
@ -84,7 +107,7 @@ func (s *session) AddConnection(conn net.Conn) bool {
} }
select { select {
case s.conns <- conn: case s.connsChan <- conn:
return true return true
default: default:
} }
@ -98,19 +121,21 @@ func (s *session) Serve() {
log.Println("Session", s, "serving") log.Println("Session", s, "serving")
} }
conns := make([]net.Conn, 0, 2)
for { for {
select { select {
case conn := <-s.conns: case conn := <-s.connsChan:
conns = append(conns, conn) s.mut.Lock()
if len(conns) < 2 { s.conns = append(s.conns, conn)
s.mut.Unlock()
// We're the only ones mutating% s.conns, hence we are free to read it.
if len(s.conns) < 2 {
continue continue
} }
close(s.conns) close(s.connsChan)
if debug { if debug {
log.Println("Session", s, "starting between", conns[0].RemoteAddr(), "and", conns[1].RemoteAddr()) log.Println("Session", s, "starting between", s.conns[0].RemoteAddr(), "and", s.conns[1].RemoteAddr())
} }
wg := sync.WaitGroup{} wg := sync.WaitGroup{}
@ -118,16 +143,20 @@ func (s *session) Serve() {
var err0 error var err0 error
go func() { go func() {
err0 = s.proxy(conns[0], conns[1]) err0 = s.proxy(s.conns[0], s.conns[1])
wg.Done() wg.Done()
}() }()
var err1 error var err1 error
go func() { go func() {
err1 = s.proxy(conns[1], conns[0]) err1 = s.proxy(s.conns[1], s.conns[0])
wg.Done() wg.Done()
}() }()
sessionMut.Lock()
activeSessions = append(activeSessions, s)
sessionMut.Unlock()
wg.Wait() wg.Wait()
if debug { if debug {
@ -143,23 +172,37 @@ func (s *session) Serve() {
} }
} }
done: done:
// We can end up here in 3 cases:
// 1. Timeout joining, in which case there are potentially entries in pendingSessions
// 2. General session end/timeout, in which case there are entries in activeSessions
// 3. Protocol handler calls dropSession as one of it's clients disconnects.
sessionMut.Lock() sessionMut.Lock()
delete(sessions, string(s.serverkey)) delete(pendingSessions, string(s.serverkey))
delete(sessions, string(s.clientkey)) delete(pendingSessions, string(s.clientkey))
for i, session := range activeSessions {
if session == s {
l := len(activeSessions) - 1
activeSessions[i] = activeSessions[l]
activeSessions[l] = nil
activeSessions = activeSessions[:l]
}
}
sessionMut.Unlock() sessionMut.Unlock()
for _, conn := range conns { // If we are here because of case 2 or 3, we are potentially closing some or
conn.Close() // all connections a second time.
} s.CloseConns()
if debug { if debug {
log.Println("Session", s, "stopping") log.Println("Session", s, "stopping")
} }
} }
func (s *session) GetClientInvitationMessage(from syncthingprotocol.DeviceID) protocol.SessionInvitation { func (s *session) GetClientInvitationMessage() protocol.SessionInvitation {
return protocol.SessionInvitation{ return protocol.SessionInvitation{
From: from[:], From: s.serverid[:],
Key: []byte(s.clientkey), Key: []byte(s.clientkey),
Address: sessionAddress, Address: sessionAddress,
Port: sessionPort, Port: sessionPort,
@ -167,9 +210,9 @@ func (s *session) GetClientInvitationMessage(from syncthingprotocol.DeviceID) pr
} }
} }
func (s *session) GetServerInvitationMessage(from syncthingprotocol.DeviceID) protocol.SessionInvitation { func (s *session) GetServerInvitationMessage() protocol.SessionInvitation {
return protocol.SessionInvitation{ return protocol.SessionInvitation{
From: from[:], From: s.clientid[:],
Key: []byte(s.serverkey), Key: []byte(s.serverkey),
Address: sessionAddress, Address: sessionAddress,
Port: sessionPort, Port: sessionPort,
@ -177,6 +220,18 @@ func (s *session) GetServerInvitationMessage(from syncthingprotocol.DeviceID) pr
} }
} }
func (s *session) HasParticipant(id syncthingprotocol.DeviceID) bool {
return s.clientid == id || s.serverid == id
}
func (s *session) CloseConns() {
s.mut.Lock()
for _, conn := range s.conns {
conn.Close()
}
s.mut.Unlock()
}
func (s *session) proxy(c1, c2 net.Conn) error { func (s *session) proxy(c1, c2 net.Conn) error {
if debug { if debug {
log.Println("Proxy", c1.RemoteAddr(), "->", c2.RemoteAddr()) log.Println("Proxy", c1.RemoteAddr(), "->", c2.RemoteAddr())

View File

@ -24,7 +24,9 @@ func getStatus(w http.ResponseWriter, r *http.Request) {
status := make(map[string]interface{}) status := make(map[string]interface{})
sessionMut.Lock() sessionMut.Lock()
status["numSessions"] = len(sessions) // This can potentially be double the number of pending sessions, as each session has two keys, one for each side.
status["numPendingSessionKeys"] = len(pendingSessions)
status["numActiveSessions"] = len(activeSessions)
sessionMut.Unlock() sessionMut.Unlock()
status["numConnections"] = atomic.LoadInt64(&numConnections) status["numConnections"] = atomic.LoadInt64(&numConnections)
status["numProxies"] = atomic.LoadInt64(&numProxies) status["numProxies"] = atomic.LoadInt64(&numProxies)