diff --git a/cmd/syncthing/connections.go b/cmd/syncthing/connections.go index 6e826d2b5..2a275cf5e 100644 --- a/cmd/syncthing/connections.go +++ b/cmd/syncthing/connections.go @@ -15,23 +15,84 @@ import ( "time" "github.com/syncthing/protocol" + "github.com/syncthing/syncthing/internal/config" "github.com/syncthing/syncthing/internal/events" "github.com/syncthing/syncthing/internal/model" + "github.com/thejerf/suture" ) -func listenConnect(myID protocol.DeviceID, m *model.Model, tlsCfg *tls.Config) { - var conns = make(chan *tls.Conn) +// The connection service listens on TLS and dials configured unconnected +// devices. Successfull connections are handed to the model. +type connectionSvc struct { + *suture.Supervisor + cfg *config.Wrapper + myID protocol.DeviceID + model *model.Model + tlsCfg *tls.Config + conns chan *tls.Conn +} - // Listen - for _, addr := range cfg.Options().ListenAddress { - go listenTLS(conns, addr, tlsCfg) +func newConnectionSvc(cfg *config.Wrapper, myID protocol.DeviceID, model *model.Model, tlsCfg *tls.Config) *connectionSvc { + svc := &connectionSvc{ + Supervisor: suture.NewSimple("connectionSvc"), + cfg: cfg, + myID: myID, + model: model, + tlsCfg: tlsCfg, + conns: make(chan *tls.Conn), } - // Connect - go dialTLS(m, conns, tlsCfg) + // There are several moving parts here; one routine per listening address + // to handle incoming connections, one routine to periodically attempt + // outgoing connections, and lastly one routine to the the common handling + // regardless of whether the connection was incoming or outgoing. It ends + // up as in the diagram below. We embed a Supervisor to manage the + // routines (i.e. log and restart if they crash or exit, etc). + // + // +-----------------+ + // Incoming | +---------------+-+ +-----------------+ + // Connections | | | | | Outgoing + // -------------->| | svc.listen | | | Connections + // | | (1 per listen | | svc.connect |--------------> + // | | address) | | | + // +-+ | | | + // +-----------------+ +-----------------+ + // v v + // | | + // | | + // +------------+-----------+ + // | + // | svc.conns + // v + // +-----------------+ + // | | + // | | + // | svc.handle |------> model.AddConnection() + // | | + // | | + // +-----------------+ + // + // TODO: Clean shutdown, and/or handling config changes on the fly. We + // partly do this now - new devices and addresses will be picked up, but + // not new listen addresses and we don't support disconnecting devices + // that are removed and so on... + svc.Add(serviceFunc(svc.connect)) + for _, addr := range svc.cfg.Options().ListenAddress { + addr := addr + listener := serviceFunc(func() { + svc.listen(addr) + }) + svc.Add(listener) + } + svc.Add(serviceFunc(svc.handle)) + + return svc +} + +func (s *connectionSvc) handle() { next: - for conn := range conns { + for conn := range s.conns { cs := conn.ConnectionState() // We should have negotiated the next level protocol "bep/1.0" as part @@ -69,13 +130,13 @@ next: // this one. But in case we are two devices connecting to each other // in parallell we don't want to do that or we end up with no // connections still established... - if m.ConnectedTo(remoteID) { + if s.model.ConnectedTo(remoteID) { l.Infof("Connected to already connected device (%s)", remoteID) conn.Close() continue } - for deviceID, deviceCfg := range cfg.Devices() { + for deviceID, deviceCfg := range s.cfg.Devices() { if deviceID == remoteID { // Verify the name on the certificate. By default we set it to // "syncthing" when generating, but the user may have replaced @@ -97,7 +158,7 @@ next: // If rate limiting is set, and based on the address we should // limit the connection, then we wrap it in a limiter. - limit := shouldLimit(conn.RemoteAddr()) + limit := s.shouldLimit(conn.RemoteAddr()) wr := io.Writer(conn) if limit && writeRateLimit != nil { @@ -110,7 +171,7 @@ next: } name := fmt.Sprintf("%s-%s", conn.LocalAddr(), conn.RemoteAddr()) - protoConn := protocol.NewConnection(remoteID, rd, wr, m, name, deviceCfg.Compression) + protoConn := protocol.NewConnection(remoteID, rd, wr, s.model, name, deviceCfg.Compression) l.Infof("Established secure connection to %s at %s", remoteID, name) if debugNet { @@ -121,12 +182,12 @@ next: "addr": conn.RemoteAddr().String(), }) - m.AddConnection(conn, protoConn) + s.model.AddConnection(conn, protoConn) continue next } } - if !cfg.IgnoredDevice(remoteID) { + if !s.cfg.IgnoredDevice(remoteID) { events.Default.Log(events.DeviceRejected, map[string]string{ "device": remoteID.String(), "address": conn.RemoteAddr().String(), @@ -140,7 +201,7 @@ next: } } -func listenTLS(conns chan *tls.Conn, addr string, tlsCfg *tls.Config) { +func (s *connectionSvc) listen(addr string) { if debugNet { l.Debugln("listening on", addr) } @@ -166,9 +227,9 @@ func listenTLS(conns chan *tls.Conn, addr string, tlsCfg *tls.Config) { } tcpConn := conn.(*net.TCPConn) - setTCPOptions(tcpConn) + s.setTCPOptions(tcpConn) - tc := tls.Server(conn, tlsCfg) + tc := tls.Server(conn, s.tlsCfg) err = tc.Handshake() if err != nil { l.Infoln("TLS handshake:", err) @@ -176,21 +237,20 @@ func listenTLS(conns chan *tls.Conn, addr string, tlsCfg *tls.Config) { continue } - conns <- tc + s.conns <- tc } - } -func dialTLS(m *model.Model, conns chan *tls.Conn, tlsCfg *tls.Config) { +func (s *connectionSvc) connect() { delay := time.Second for { nextDevice: - for deviceID, deviceCfg := range cfg.Devices() { + for deviceID, deviceCfg := range s.cfg.Devices() { if deviceID == myID { continue } - if m.ConnectedTo(deviceID) { + if s.model.ConnectedTo(deviceID) { continue } @@ -238,9 +298,9 @@ func dialTLS(m *model.Model, conns chan *tls.Conn, tlsCfg *tls.Config) { continue } - setTCPOptions(conn) + s.setTCPOptions(conn) - tc := tls.Client(conn, tlsCfg) + tc := tls.Client(conn, s.tlsCfg) err = tc.Handshake() if err != nil { l.Infoln("TLS handshake:", err) @@ -248,20 +308,20 @@ func dialTLS(m *model.Model, conns chan *tls.Conn, tlsCfg *tls.Config) { continue } - conns <- tc + s.conns <- tc continue nextDevice } } time.Sleep(delay) delay *= 2 - if maxD := time.Duration(cfg.Options().ReconnectIntervalS) * time.Second; delay > maxD { + if maxD := time.Duration(s.cfg.Options().ReconnectIntervalS) * time.Second; delay > maxD { delay = maxD } } } -func setTCPOptions(conn *net.TCPConn) { +func (*connectionSvc) setTCPOptions(conn *net.TCPConn) { var err error if err = conn.SetLinger(0); err != nil { l.Infoln(err) @@ -277,8 +337,8 @@ func setTCPOptions(conn *net.TCPConn) { } } -func shouldLimit(addr net.Addr) bool { - if cfg.Options().LimitBandwidthInLan { +func (s *connectionSvc) shouldLimit(addr net.Addr) bool { + if s.cfg.Options().LimitBandwidthInLan { return true } diff --git a/cmd/syncthing/main.go b/cmd/syncthing/main.go index a953dbf85..ef837dc35 100644 --- a/cmd/syncthing/main.go +++ b/cmd/syncthing/main.go @@ -584,7 +584,9 @@ func syncthingMain() { // Routine to connect out to configured devices discoverer = discovery(externalPort) - go listenConnect(myID, m, tlsCfg) + + connectionSvc := newConnectionSvc(cfg, myID, m, tlsCfg) + mainSvc.Add(connectionSvc) for _, folder := range cfg.Folders() { // Routine to pull blocks from other devices to synchronize the local