mirror of
https://github.com/octoleo/syncthing.git
synced 2025-02-08 14:58:26 +00:00
Drop all sessions when we realize a node has gone away
This commit is contained in:
parent
0b7ab0a095
commit
50f0da6793
@ -4,6 +4,7 @@ package main
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
|
"encoding/hex"
|
||||||
"log"
|
"log"
|
||||||
"net"
|
"net"
|
||||||
"sync"
|
"sync"
|
||||||
@ -34,7 +35,7 @@ func listener(addr string, config *tls.Config) {
|
|||||||
conn, isTLS, err := listener.AcceptNoWrapTLS()
|
conn, isTLS, err := listener.AcceptNoWrapTLS()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if debug {
|
if debug {
|
||||||
log.Println(err)
|
log.Println("Listener failed to accept connection from", conn.RemoteAddr(), ". Possibly a TCP Ping.")
|
||||||
}
|
}
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@ -138,13 +139,13 @@ func protocolConnectionHandler(tcpConn net.Conn, config *tls.Config) {
|
|||||||
conn.Close()
|
conn.Close()
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
// requestedPeer is the server, id is the client
|
||||||
ses := newSession(sessionLimiter, globalLimiter)
|
ses := newSession(requestedPeer, id, sessionLimiter, globalLimiter)
|
||||||
|
|
||||||
go ses.Serve()
|
go ses.Serve()
|
||||||
|
|
||||||
clientInvitation := ses.GetClientInvitationMessage(requestedPeer)
|
clientInvitation := ses.GetClientInvitationMessage()
|
||||||
serverInvitation := ses.GetServerInvitationMessage(id)
|
serverInvitation := ses.GetServerInvitationMessage()
|
||||||
|
|
||||||
if err := protocol.WriteMessage(conn, clientInvitation); err != nil {
|
if err := protocol.WriteMessage(conn, clientInvitation); err != nil {
|
||||||
if debug {
|
if debug {
|
||||||
@ -181,12 +182,19 @@ func protocolConnectionHandler(tcpConn net.Conn, config *tls.Config) {
|
|||||||
// Potentially closing a second time.
|
// Potentially closing a second time.
|
||||||
conn.Close()
|
conn.Close()
|
||||||
|
|
||||||
// Only delete the outbox if the client is joined, as it might be
|
|
||||||
// a lookup request coming from the same client.
|
|
||||||
if joined {
|
if joined {
|
||||||
|
// Only delete the outbox if the client is joined, as it might be
|
||||||
|
// a lookup request coming from the same client.
|
||||||
outboxesMut.Lock()
|
outboxesMut.Lock()
|
||||||
delete(outboxes, id)
|
delete(outboxes, id)
|
||||||
outboxesMut.Unlock()
|
outboxesMut.Unlock()
|
||||||
|
// Also, kill all sessions related to this node, as it probably
|
||||||
|
// went offline. This is for the other end to realize the client
|
||||||
|
// is no longer there faster. This also helps resolve
|
||||||
|
// 'already connected' errors when one of the sides is
|
||||||
|
// restarting, and connecting to the other peer before the other
|
||||||
|
// peer even realised that the node has gone away.
|
||||||
|
dropSessions(id)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
|
|
||||||
@ -245,7 +253,7 @@ func sessionConnectionHandler(conn net.Conn) {
|
|||||||
case protocol.JoinSessionRequest:
|
case protocol.JoinSessionRequest:
|
||||||
ses := findSession(string(msg.Key))
|
ses := findSession(string(msg.Key))
|
||||||
if debug {
|
if debug {
|
||||||
log.Println(conn.RemoteAddr(), "session lookup", ses)
|
log.Println(conn.RemoteAddr(), "session lookup", ses, hex.EncodeToString(msg.Key)[:5])
|
||||||
}
|
}
|
||||||
|
|
||||||
if ses == nil {
|
if ses == nil {
|
||||||
|
@ -42,6 +42,8 @@ var (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
|
log.SetFlags(log.Lshortfile | log.LstdFlags)
|
||||||
|
|
||||||
var dir, extAddress string
|
var dir, extAddress string
|
||||||
|
|
||||||
flag.StringVar(&listen, "listen", ":22067", "Protocol listen address")
|
flag.StringVar(&listen, "listen", ":22067", "Protocol listen address")
|
||||||
|
@ -19,22 +19,14 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
sessionMut = sync.Mutex{}
|
sessionMut = sync.RWMutex{}
|
||||||
sessions = make(map[string]*session, 0)
|
activeSessions = make([]*session, 0)
|
||||||
numProxies int64
|
pendingSessions = make(map[string]*session, 0)
|
||||||
bytesProxied int64
|
numProxies int64
|
||||||
|
bytesProxied int64
|
||||||
)
|
)
|
||||||
|
|
||||||
type session struct {
|
func newSession(serverid, clientid syncthingprotocol.DeviceID, sessionRateLimit, globalRateLimit *ratelimit.Bucket) *session {
|
||||||
serverkey []byte
|
|
||||||
clientkey []byte
|
|
||||||
|
|
||||||
rateLimit func(bytes int64)
|
|
||||||
|
|
||||||
conns chan net.Conn
|
|
||||||
}
|
|
||||||
|
|
||||||
func newSession(sessionRateLimit, globalRateLimit *ratelimit.Bucket) *session {
|
|
||||||
serverkey := make([]byte, 32)
|
serverkey := make([]byte, 32)
|
||||||
_, err := rand.Read(serverkey)
|
_, err := rand.Read(serverkey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -49,9 +41,12 @@ func newSession(sessionRateLimit, globalRateLimit *ratelimit.Bucket) *session {
|
|||||||
|
|
||||||
ses := &session{
|
ses := &session{
|
||||||
serverkey: serverkey,
|
serverkey: serverkey,
|
||||||
|
serverid: serverid,
|
||||||
clientkey: clientkey,
|
clientkey: clientkey,
|
||||||
|
clientid: clientid,
|
||||||
rateLimit: makeRateLimitFunc(sessionRateLimit, globalRateLimit),
|
rateLimit: makeRateLimitFunc(sessionRateLimit, globalRateLimit),
|
||||||
conns: make(chan net.Conn),
|
connsChan: make(chan net.Conn),
|
||||||
|
conns: make([]net.Conn, 0, 2),
|
||||||
}
|
}
|
||||||
|
|
||||||
if debug {
|
if debug {
|
||||||
@ -59,8 +54,8 @@ func newSession(sessionRateLimit, globalRateLimit *ratelimit.Bucket) *session {
|
|||||||
}
|
}
|
||||||
|
|
||||||
sessionMut.Lock()
|
sessionMut.Lock()
|
||||||
sessions[string(ses.serverkey)] = ses
|
pendingSessions[string(ses.serverkey)] = ses
|
||||||
sessions[string(ses.clientkey)] = ses
|
pendingSessions[string(ses.clientkey)] = ses
|
||||||
sessionMut.Unlock()
|
sessionMut.Unlock()
|
||||||
|
|
||||||
return ses
|
return ses
|
||||||
@ -69,13 +64,41 @@ func newSession(sessionRateLimit, globalRateLimit *ratelimit.Bucket) *session {
|
|||||||
func findSession(key string) *session {
|
func findSession(key string) *session {
|
||||||
sessionMut.Lock()
|
sessionMut.Lock()
|
||||||
defer sessionMut.Unlock()
|
defer sessionMut.Unlock()
|
||||||
lob, ok := sessions[key]
|
ses, ok := pendingSessions[key]
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil
|
return nil
|
||||||
|
|
||||||
}
|
}
|
||||||
delete(sessions, key)
|
delete(pendingSessions, key)
|
||||||
return lob
|
return ses
|
||||||
|
}
|
||||||
|
|
||||||
|
func dropSessions(id syncthingprotocol.DeviceID) {
|
||||||
|
sessionMut.RLock()
|
||||||
|
for _, session := range activeSessions {
|
||||||
|
if session.HasParticipant(id) {
|
||||||
|
if debug {
|
||||||
|
log.Println("Dropping session", session, "involving", id)
|
||||||
|
}
|
||||||
|
session.CloseConns()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
sessionMut.RUnlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
type session struct {
|
||||||
|
mut sync.Mutex
|
||||||
|
|
||||||
|
serverkey []byte
|
||||||
|
serverid syncthingprotocol.DeviceID
|
||||||
|
|
||||||
|
clientkey []byte
|
||||||
|
clientid syncthingprotocol.DeviceID
|
||||||
|
|
||||||
|
rateLimit func(bytes int64)
|
||||||
|
|
||||||
|
connsChan chan net.Conn
|
||||||
|
conns []net.Conn
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *session) AddConnection(conn net.Conn) bool {
|
func (s *session) AddConnection(conn net.Conn) bool {
|
||||||
@ -84,7 +107,7 @@ func (s *session) AddConnection(conn net.Conn) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case s.conns <- conn:
|
case s.connsChan <- conn:
|
||||||
return true
|
return true
|
||||||
default:
|
default:
|
||||||
}
|
}
|
||||||
@ -98,19 +121,21 @@ func (s *session) Serve() {
|
|||||||
log.Println("Session", s, "serving")
|
log.Println("Session", s, "serving")
|
||||||
}
|
}
|
||||||
|
|
||||||
conns := make([]net.Conn, 0, 2)
|
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case conn := <-s.conns:
|
case conn := <-s.connsChan:
|
||||||
conns = append(conns, conn)
|
s.mut.Lock()
|
||||||
if len(conns) < 2 {
|
s.conns = append(s.conns, conn)
|
||||||
|
s.mut.Unlock()
|
||||||
|
// We're the only ones mutating% s.conns, hence we are free to read it.
|
||||||
|
if len(s.conns) < 2 {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
close(s.conns)
|
close(s.connsChan)
|
||||||
|
|
||||||
if debug {
|
if debug {
|
||||||
log.Println("Session", s, "starting between", conns[0].RemoteAddr(), "and", conns[1].RemoteAddr())
|
log.Println("Session", s, "starting between", s.conns[0].RemoteAddr(), "and", s.conns[1].RemoteAddr())
|
||||||
}
|
}
|
||||||
|
|
||||||
wg := sync.WaitGroup{}
|
wg := sync.WaitGroup{}
|
||||||
@ -118,16 +143,20 @@ func (s *session) Serve() {
|
|||||||
|
|
||||||
var err0 error
|
var err0 error
|
||||||
go func() {
|
go func() {
|
||||||
err0 = s.proxy(conns[0], conns[1])
|
err0 = s.proxy(s.conns[0], s.conns[1])
|
||||||
wg.Done()
|
wg.Done()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
var err1 error
|
var err1 error
|
||||||
go func() {
|
go func() {
|
||||||
err1 = s.proxy(conns[1], conns[0])
|
err1 = s.proxy(s.conns[1], s.conns[0])
|
||||||
wg.Done()
|
wg.Done()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
sessionMut.Lock()
|
||||||
|
activeSessions = append(activeSessions, s)
|
||||||
|
sessionMut.Unlock()
|
||||||
|
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
|
|
||||||
if debug {
|
if debug {
|
||||||
@ -143,23 +172,37 @@ func (s *session) Serve() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
done:
|
done:
|
||||||
|
// We can end up here in 3 cases:
|
||||||
|
// 1. Timeout joining, in which case there are potentially entries in pendingSessions
|
||||||
|
// 2. General session end/timeout, in which case there are entries in activeSessions
|
||||||
|
// 3. Protocol handler calls dropSession as one of it's clients disconnects.
|
||||||
|
|
||||||
sessionMut.Lock()
|
sessionMut.Lock()
|
||||||
delete(sessions, string(s.serverkey))
|
delete(pendingSessions, string(s.serverkey))
|
||||||
delete(sessions, string(s.clientkey))
|
delete(pendingSessions, string(s.clientkey))
|
||||||
|
|
||||||
|
for i, session := range activeSessions {
|
||||||
|
if session == s {
|
||||||
|
l := len(activeSessions) - 1
|
||||||
|
activeSessions[i] = activeSessions[l]
|
||||||
|
activeSessions[l] = nil
|
||||||
|
activeSessions = activeSessions[:l]
|
||||||
|
}
|
||||||
|
}
|
||||||
sessionMut.Unlock()
|
sessionMut.Unlock()
|
||||||
|
|
||||||
for _, conn := range conns {
|
// If we are here because of case 2 or 3, we are potentially closing some or
|
||||||
conn.Close()
|
// all connections a second time.
|
||||||
}
|
s.CloseConns()
|
||||||
|
|
||||||
if debug {
|
if debug {
|
||||||
log.Println("Session", s, "stopping")
|
log.Println("Session", s, "stopping")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *session) GetClientInvitationMessage(from syncthingprotocol.DeviceID) protocol.SessionInvitation {
|
func (s *session) GetClientInvitationMessage() protocol.SessionInvitation {
|
||||||
return protocol.SessionInvitation{
|
return protocol.SessionInvitation{
|
||||||
From: from[:],
|
From: s.serverid[:],
|
||||||
Key: []byte(s.clientkey),
|
Key: []byte(s.clientkey),
|
||||||
Address: sessionAddress,
|
Address: sessionAddress,
|
||||||
Port: sessionPort,
|
Port: sessionPort,
|
||||||
@ -167,9 +210,9 @@ func (s *session) GetClientInvitationMessage(from syncthingprotocol.DeviceID) pr
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *session) GetServerInvitationMessage(from syncthingprotocol.DeviceID) protocol.SessionInvitation {
|
func (s *session) GetServerInvitationMessage() protocol.SessionInvitation {
|
||||||
return protocol.SessionInvitation{
|
return protocol.SessionInvitation{
|
||||||
From: from[:],
|
From: s.clientid[:],
|
||||||
Key: []byte(s.serverkey),
|
Key: []byte(s.serverkey),
|
||||||
Address: sessionAddress,
|
Address: sessionAddress,
|
||||||
Port: sessionPort,
|
Port: sessionPort,
|
||||||
@ -177,6 +220,18 @@ func (s *session) GetServerInvitationMessage(from syncthingprotocol.DeviceID) pr
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *session) HasParticipant(id syncthingprotocol.DeviceID) bool {
|
||||||
|
return s.clientid == id || s.serverid == id
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *session) CloseConns() {
|
||||||
|
s.mut.Lock()
|
||||||
|
for _, conn := range s.conns {
|
||||||
|
conn.Close()
|
||||||
|
}
|
||||||
|
s.mut.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
func (s *session) proxy(c1, c2 net.Conn) error {
|
func (s *session) proxy(c1, c2 net.Conn) error {
|
||||||
if debug {
|
if debug {
|
||||||
log.Println("Proxy", c1.RemoteAddr(), "->", c2.RemoteAddr())
|
log.Println("Proxy", c1.RemoteAddr(), "->", c2.RemoteAddr())
|
||||||
|
@ -24,7 +24,9 @@ func getStatus(w http.ResponseWriter, r *http.Request) {
|
|||||||
status := make(map[string]interface{})
|
status := make(map[string]interface{})
|
||||||
|
|
||||||
sessionMut.Lock()
|
sessionMut.Lock()
|
||||||
status["numSessions"] = len(sessions)
|
// This can potentially be double the number of pending sessions, as each session has two keys, one for each side.
|
||||||
|
status["numPendingSessionKeys"] = len(pendingSessions)
|
||||||
|
status["numActiveSessions"] = len(activeSessions)
|
||||||
sessionMut.Unlock()
|
sessionMut.Unlock()
|
||||||
status["numConnections"] = atomic.LoadInt64(&numConnections)
|
status["numConnections"] = atomic.LoadInt64(&numConnections)
|
||||||
status["numProxies"] = atomic.LoadInt64(&numProxies)
|
status["numProxies"] = atomic.LoadInt64(&numProxies)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user