diff --git a/lib/connections/relay_dial.go b/lib/connections/relay_dial.go index d8fe3d3e7..bd2377ccc 100644 --- a/lib/connections/relay_dial.go +++ b/lib/connections/relay_dial.go @@ -28,21 +28,21 @@ type relayDialer struct { tlsCfg *tls.Config } -func (d *relayDialer) Dial(id protocol.DeviceID, uri *url.URL) (IntermediateConnection, error) { +func (d *relayDialer) Dial(id protocol.DeviceID, uri *url.URL) (internalConn, error) { inv, err := client.GetInvitationFromRelay(uri, id, d.tlsCfg.Certificates, 10*time.Second) if err != nil { - return IntermediateConnection{}, err + return internalConn{}, err } conn, err := client.JoinSession(inv) if err != nil { - return IntermediateConnection{}, err + return internalConn{}, err } err = dialer.SetTCPOptions(conn) if err != nil { conn.Close() - return IntermediateConnection{}, err + return internalConn{}, err } var tc *tls.Conn @@ -55,10 +55,10 @@ func (d *relayDialer) Dial(id protocol.DeviceID, uri *url.URL) (IntermediateConn err = tlsTimedHandshake(tc) if err != nil { tc.Close() - return IntermediateConnection{}, err + return internalConn{}, err } - return IntermediateConnection{tc, "Relay (Client)", relayPriority}, nil + return internalConn{tc, connTypeRelayClient, relayPriority}, nil } func (relayDialer) Priority() int { diff --git a/lib/connections/relay_listen.go b/lib/connections/relay_listen.go index 7fc67b62b..f23968c1d 100644 --- a/lib/connections/relay_listen.go +++ b/lib/connections/relay_listen.go @@ -30,7 +30,7 @@ type relayListener struct { uri *url.URL tlsCfg *tls.Config - conns chan IntermediateConnection + conns chan internalConn factory listenerFactory err error @@ -93,7 +93,7 @@ func (t *relayListener) Serve() { continue } - t.conns <- IntermediateConnection{tc, "Relay (Server)", relayPriority} + t.conns <- internalConn{tc, connTypeRelayServer, relayPriority} // Poor mans notifier that informs the connection service that the // relay URI has changed. This can only happen when we connect to a @@ -167,7 +167,7 @@ func (t *relayListener) String() string { type relayListenerFactory struct{} -func (f *relayListenerFactory) New(uri *url.URL, cfg *config.Wrapper, tlsCfg *tls.Config, conns chan IntermediateConnection, natService *nat.Service) genericListener { +func (f *relayListenerFactory) New(uri *url.URL, cfg *config.Wrapper, tlsCfg *tls.Config, conns chan internalConn, natService *nat.Service) genericListener { return &relayListener{ uri: uri, tlsCfg: tlsCfg, diff --git a/lib/connections/service.go b/lib/connections/service.go index 5c19e9472..a0d88dc1b 100644 --- a/lib/connections/service.go +++ b/lib/connections/service.go @@ -50,7 +50,7 @@ type Service struct { model Model tlsCfg *tls.Config discoverer discover.Finder - conns chan IntermediateConnection + conns chan internalConn bepProtocolName string tlsDefaultCommonName string lans []*net.IPNet @@ -65,7 +65,7 @@ type Service struct { listenerSupervisor *suture.Supervisor curConMut sync.Mutex - currentConnection map[protocol.DeviceID]Connection + currentConnection map[protocol.DeviceID]completeConn } func NewService(cfg *config.Wrapper, myID protocol.DeviceID, mdl Model, tlsCfg *tls.Config, discoverer discover.Finder, @@ -82,7 +82,7 @@ func NewService(cfg *config.Wrapper, myID protocol.DeviceID, mdl Model, tlsCfg * model: mdl, tlsCfg: tlsCfg, discoverer: discoverer, - conns: make(chan IntermediateConnection), + conns: make(chan internalConn), bepProtocolName: bepProtocolName, tlsDefaultCommonName: tlsDefaultCommonName, lans: lans, @@ -105,7 +105,7 @@ func NewService(cfg *config.Wrapper, myID protocol.DeviceID, mdl Model, tlsCfg * }), curConMut: sync.NewMutex(), - currentConnection: make(map[protocol.DeviceID]Connection), + currentConnection: make(map[protocol.DeviceID]completeConn), } cfg.Subscribe(service) @@ -218,7 +218,7 @@ next: priorityKnown := ok && connected // Lower priority is better, just like nice etc. - if priorityKnown && ct.Priority > c.Priority { + if priorityKnown && ct.internalConn.priority > c.priority { l.Debugln("Switching connections", remoteID) } else if connected { // We should not already be connected to the other party. TODO: This @@ -268,9 +268,9 @@ next: rd = NewReadLimiter(c, s.readRateLimit) } - name := fmt.Sprintf("%s-%s (%s)", c.LocalAddr(), c.RemoteAddr(), c.Type) + name := fmt.Sprintf("%s-%s (%s)", c.LocalAddr(), c.RemoteAddr(), c.Type()) protoConn := protocol.NewConnection(remoteID, rd, wr, s.model, name, deviceCfg.Compression) - modelConn := Connection{c, protoConn} + modelConn := completeConn{c, protoConn} l.Infof("Established secure connection to %s at %s", remoteID, name) l.Debugf("cipher suite: %04X in lan: %t", c.ConnectionState().CipherSuite, !limit) @@ -329,7 +329,7 @@ func (s *Service) connect() { s.curConMut.Unlock() priorityKnown := ok && connected - if priorityKnown && ct.Priority == bestDialerPrio { + if priorityKnown && ct.internalConn.priority == bestDialerPrio { // Things are already as good as they can get. continue } @@ -377,8 +377,8 @@ func (s *Service) connect() { continue } - if priorityKnown && dialerFactory.Priority() >= ct.Priority { - l.Debugf("Not dialing using %s as priority is less than current connection (%d >= %d)", dialerFactory, dialerFactory.Priority(), ct.Priority) + if priorityKnown && dialerFactory.Priority() >= ct.internalConn.priority { + l.Debugf("Not dialing using %s as priority is less than current connection (%d >= %d)", dialerFactory, dialerFactory.Priority(), ct.internalConn.priority) continue } diff --git a/lib/connections/structs.go b/lib/connections/structs.go index 53c432afb..37de9fba7 100644 --- a/lib/connections/structs.go +++ b/lib/connections/structs.go @@ -9,6 +9,7 @@ package connections import ( "crypto/tls" "fmt" + "io" "net" "net/url" "time" @@ -18,19 +19,61 @@ import ( "github.com/syncthing/syncthing/lib/protocol" ) -type IntermediateConnection struct { - *tls.Conn - Type string - Priority int +// Connection is what we expose to the outside. It is a protocol.Connection +// that can be closed and has some metadata. +type Connection interface { + protocol.Connection + io.Closer + Type() string + RemoteAddr() net.Addr } -type Connection struct { - IntermediateConnection +// completeConn is the aggregation of an internalConn and the +// protocol.Connection running on top of it. It implements the Connection +// interface. +type completeConn struct { + internalConn protocol.Connection } -func (c Connection) String() string { - return fmt.Sprintf("%s-%s/%s", c.LocalAddr(), c.RemoteAddr(), c.Type) +// internalConn is the raw TLS connection plus some metadata on where it +// came from (type, priority). +type internalConn struct { + *tls.Conn + connType connType + priority int +} + +type connType int + +const ( + connTypeRelayClient connType = iota + connTypeRelayServer + connTypeTCPClient + connTypeTCPServer +) + +func (t connType) String() string { + switch t { + case connTypeRelayClient: + return "relay-client" + case connTypeRelayServer: + return "relay-server" + case connTypeTCPClient: + return "tcp-client" + case connTypeTCPServer: + return "tcp-server" + default: + return "unknown-type" + } +} + +func (c internalConn) Type() string { + return c.connType.String() +} + +func (c internalConn) String() string { + return fmt.Sprintf("%s-%s/%s", c.LocalAddr(), c.RemoteAddr(), c.connType.String()) } type dialerFactory interface { @@ -41,12 +84,12 @@ type dialerFactory interface { } type genericDialer interface { - Dial(protocol.DeviceID, *url.URL) (IntermediateConnection, error) + Dial(protocol.DeviceID, *url.URL) (internalConn, error) RedialFrequency() time.Duration } type listenerFactory interface { - New(*url.URL, *config.Wrapper, *tls.Config, chan IntermediateConnection, *nat.Service) genericListener + New(*url.URL, *config.Wrapper, *tls.Config, chan internalConn, *nat.Service) genericListener Enabled(config.Configuration) bool } diff --git a/lib/connections/tcp_dial.go b/lib/connections/tcp_dial.go index d1bf52175..fa157dd3b 100644 --- a/lib/connections/tcp_dial.go +++ b/lib/connections/tcp_dial.go @@ -30,23 +30,23 @@ type tcpDialer struct { tlsCfg *tls.Config } -func (d *tcpDialer) Dial(id protocol.DeviceID, uri *url.URL) (IntermediateConnection, error) { +func (d *tcpDialer) Dial(id protocol.DeviceID, uri *url.URL) (internalConn, error) { uri = fixupPort(uri) conn, err := dialer.DialTimeout(uri.Scheme, uri.Host, 10*time.Second) if err != nil { l.Debugln(err) - return IntermediateConnection{}, err + return internalConn{}, err } tc := tls.Client(conn, d.tlsCfg) err = tlsTimedHandshake(tc) if err != nil { tc.Close() - return IntermediateConnection{}, err + return internalConn{}, err } - return IntermediateConnection{tc, "TCP (Client)", tcpPriority}, nil + return internalConn{tc, connTypeTCPClient, tcpPriority}, nil } func (d *tcpDialer) RedialFrequency() time.Duration { diff --git a/lib/connections/tcp_listen.go b/lib/connections/tcp_listen.go index 013f919a4..e0011a068 100644 --- a/lib/connections/tcp_listen.go +++ b/lib/connections/tcp_listen.go @@ -32,7 +32,7 @@ type tcpListener struct { uri *url.URL tlsCfg *tls.Config stop chan struct{} - conns chan IntermediateConnection + conns chan internalConn factory listenerFactory natService *nat.Service @@ -115,7 +115,7 @@ func (t *tcpListener) Serve() { continue } - t.conns <- IntermediateConnection{tc, "TCP (Server)", tcpPriority} + t.conns <- internalConn{tc, connTypeTCPServer, tcpPriority} } } @@ -173,7 +173,7 @@ func (t *tcpListener) Factory() listenerFactory { type tcpListenerFactory struct{} -func (f *tcpListenerFactory) New(uri *url.URL, cfg *config.Wrapper, tlsCfg *tls.Config, conns chan IntermediateConnection, natService *nat.Service) genericListener { +func (f *tcpListenerFactory) New(uri *url.URL, cfg *config.Wrapper, tlsCfg *tls.Config, conns chan internalConn, natService *nat.Service) genericListener { return &tcpListener{ uri: fixupPort(uri), tlsCfg: tlsCfg, diff --git a/lib/model/model.go b/lib/model/model.go index 80018aba6..0f5a7bdb5 100644 --- a/lib/model/model.go +++ b/lib/model/model.go @@ -412,7 +412,7 @@ func (m *Model) ConnectionStats() map[string]interface{} { Paused: m.devicePaused[device], } if conn, ok := m.conn[device]; ok { - ci.Type = conn.Type + ci.Type = conn.Type() ci.Connected = ok ci.Statistics = conn.Statistics() if addr := conn.RemoteAddr(); addr != nil { @@ -1334,7 +1334,7 @@ func (m *Model) AddConnection(conn connections.Connection, hello protocol.HelloR "deviceName": hello.DeviceName, "clientName": hello.ClientName, "clientVersion": hello.ClientVersion, - "type": conn.Type, + "type": conn.Type(), } addr := conn.RemoteAddr() diff --git a/lib/model/model_test.go b/lib/model/model_test.go index 5cd670e29..525720cb0 100644 --- a/lib/model/model_test.go +++ b/lib/model/model_test.go @@ -8,7 +8,6 @@ package model import ( "bytes" - "crypto/tls" "encoding/json" "fmt" "io/ioutil" @@ -18,12 +17,12 @@ import ( "path/filepath" "runtime" "strconv" + "sync" "testing" "time" "github.com/d4l3k/messagediff" "github.com/syncthing/syncthing/lib/config" - "github.com/syncthing/syncthing/lib/connections" "github.com/syncthing/syncthing/lib/db" "github.com/syncthing/syncthing/lib/ignore" "github.com/syncthing/syncthing/lib/osutil" @@ -218,58 +217,75 @@ type downloadProgressMessage struct { updates []protocol.FileDownloadProgressUpdate } -type FakeConnection struct { +type fakeConnection struct { id protocol.DeviceID requestData []byte downloadProgressMessages []downloadProgressMessage + closed bool + mut sync.Mutex } -func (FakeConnection) Close() error { +func (f *fakeConnection) Close() error { + f.mut.Lock() + defer f.mut.Unlock() + f.closed = true return nil } -func (f FakeConnection) Start() { +func (f *fakeConnection) Start() { } -func (f FakeConnection) ID() protocol.DeviceID { +func (f *fakeConnection) ID() protocol.DeviceID { return f.id } -func (f FakeConnection) Name() string { +func (f *fakeConnection) Name() string { return "" } -func (f FakeConnection) Option(string) string { +func (f *fakeConnection) Option(string) string { return "" } -func (FakeConnection) Index(string, []protocol.FileInfo) error { +func (f *fakeConnection) Index(string, []protocol.FileInfo) error { return nil } -func (FakeConnection) IndexUpdate(string, []protocol.FileInfo) error { +func (f *fakeConnection) IndexUpdate(string, []protocol.FileInfo) error { return nil } -func (f FakeConnection) Request(folder, name string, offset int64, size int, hash []byte, fromTemporary bool) ([]byte, error) { +func (f *fakeConnection) Request(folder, name string, offset int64, size int, hash []byte, fromTemporary bool) ([]byte, error) { return f.requestData, nil } -func (FakeConnection) ClusterConfig(protocol.ClusterConfig) {} +func (f *fakeConnection) ClusterConfig(protocol.ClusterConfig) {} -func (FakeConnection) Ping() bool { - return true +func (f *fakeConnection) Ping() bool { + f.mut.Lock() + defer f.mut.Unlock() + return f.closed } -func (FakeConnection) Closed() bool { - return false +func (f *fakeConnection) Closed() bool { + f.mut.Lock() + defer f.mut.Unlock() + return f.closed } -func (FakeConnection) Statistics() protocol.Statistics { +func (f *fakeConnection) Statistics() protocol.Statistics { return protocol.Statistics{} } -func (f *FakeConnection) DownloadProgress(folder string, updates []protocol.FileDownloadProgressUpdate) { +func (f *fakeConnection) RemoteAddr() net.Addr { + return &fakeAddr{} +} + +func (f *fakeConnection) Type() string { + return "fake" +} + +func (f *fakeConnection) DownloadProgress(folder string, updates []protocol.FileDownloadProgressUpdate) { f.downloadProgressMessages = append(f.downloadProgressMessages, downloadProgressMessage{ folder: folder, updates: updates, @@ -287,18 +303,11 @@ func BenchmarkRequest(b *testing.B) { const n = 1000 files := genFiles(n) - fc := &FakeConnection{ + fc := &fakeConnection{ id: device1, requestData: []byte("some data to return"), } - m.AddConnection(connections.Connection{ - IntermediateConnection: connections.IntermediateConnection{ - Conn: tls.Client(&fakeConn{}, nil), - Type: "foo", - Priority: 10, - }, - Connection: fc, - }, protocol.HelloResult{}) + m.AddConnection(fc, protocol.HelloResult{}) m.Index(device1, "default", files) b.ResetTimer() @@ -335,16 +344,9 @@ func TestDeviceRename(t *testing.T) { t.Errorf("Device already has a name") } - conn := connections.Connection{ - IntermediateConnection: connections.IntermediateConnection{ - Conn: tls.Client(&fakeConn{}, nil), - Type: "foo", - Priority: 10, - }, - Connection: &FakeConnection{ - id: device1, - requestData: []byte("some data to return"), - }, + conn := &fakeConnection{ + id: device1, + requestData: []byte("some data to return"), } m.AddConnection(conn, hello) @@ -504,14 +506,7 @@ func TestIntroducer(t *testing.T) { m.AddFolder(folder) } m.ServeBackground() - m.AddConnection(connections.Connection{ - IntermediateConnection: connections.IntermediateConnection{ - Conn: tls.Client(&fakeConn{}, nil), - }, - Connection: &FakeConnection{ - id: device1, - }, - }, protocol.HelloResult{}) + m.AddConnection(&fakeConnection{id: device1}, protocol.HelloResult{}) return wcfg, m } @@ -1904,34 +1899,14 @@ func TestSharedWithClearedOnDisconnect(t *testing.T) { wcfg := config.Wrap("/tmp/test", cfg) - d2c := &fakeConn{} - m := NewModel(wcfg, protocol.LocalDeviceID, "device", "syncthing", "dev", dbi, nil) m.AddFolder(fcfg) m.StartFolder(fcfg.ID) m.ServeBackground() - conn1 := connections.Connection{ - IntermediateConnection: connections.IntermediateConnection{ - Conn: tls.Client(&fakeConn{}, nil), - Type: "foo", - Priority: 10, - }, - Connection: &FakeConnection{ - id: device1, - }, - } + conn1 := &fakeConnection{id: device1} m.AddConnection(conn1, protocol.HelloResult{}) - conn2 := connections.Connection{ - IntermediateConnection: connections.IntermediateConnection{ - Conn: tls.Client(d2c, nil), - Type: "foo", - Priority: 10, - }, - Connection: &FakeConnection{ - id: device2, - }, - } + conn2 := &fakeConnection{id: device2} m.AddConnection(conn2, protocol.HelloResult{}) m.ClusterConfig(device1, protocol.ClusterConfig{ @@ -1964,7 +1939,7 @@ func TestSharedWithClearedOnDisconnect(t *testing.T) { t.Error("not shared with device2") } - if d2c.closed { + if conn2.Closed() { t.Error("conn already closed") } @@ -1984,7 +1959,7 @@ func TestSharedWithClearedOnDisconnect(t *testing.T) { t.Error("shared with device2") } - if !d2c.closed { + if !conn2.Closed() { t.Error("connection not closed") } @@ -2108,16 +2083,7 @@ func TestIssue3496(t *testing.T) { } func addFakeConn(m *Model, dev protocol.DeviceID) { - conn1 := connections.Connection{ - IntermediateConnection: connections.IntermediateConnection{ - Conn: tls.Client(&fakeConn{}, nil), - Type: "foo", - Priority: 10, - }, - Connection: &FakeConnection{ - id: dev, - }, - } + conn1 := &fakeConnection{id: dev} m.AddConnection(conn1, protocol.HelloResult{}) m.ClusterConfig(device1, protocol.ClusterConfig{ @@ -2142,40 +2108,3 @@ func (fakeAddr) Network() string { func (fakeAddr) String() string { return "address" } - -type fakeConn struct { - closed bool -} - -func (c *fakeConn) Close() error { - c.closed = true - return nil -} - -func (fakeConn) LocalAddr() net.Addr { - return &fakeAddr{} -} - -func (fakeConn) RemoteAddr() net.Addr { - return &fakeAddr{} -} - -func (fakeConn) Read([]byte) (int, error) { - return 0, nil -} - -func (fakeConn) Write([]byte) (int, error) { - return 0, nil -} - -func (fakeConn) SetDeadline(time.Time) error { - return nil -} - -func (fakeConn) SetReadDeadline(time.Time) error { - return nil -} - -func (fakeConn) SetWriteDeadline(time.Time) error { - return nil -} diff --git a/lib/model/progressemitter_test.go b/lib/model/progressemitter_test.go index 595d12873..00ba1d154 100644 --- a/lib/model/progressemitter_test.go +++ b/lib/model/progressemitter_test.go @@ -107,7 +107,7 @@ func TestSendDownloadProgressMessages(t *testing.T) { TempIndexMinBlocks: 10, }) - fc := &FakeConnection{} + fc := &fakeConnection{} p := NewProgressEmitter(c) p.temporaryIndexSubscribe(fc, []string{"folder", "folder2"})