From ed4f6fc4b3c30fd9430cff89210159def0cd0cf6 Mon Sep 17 00:00:00 2001
From: Jakob Borg <jakob@kastelo.net>
Date: Wed, 30 Nov 2016 07:54:20 +0000
Subject: [PATCH] lib/connections, lib/model: Connection service should expose
 a single interface

Makes testing easier, which we'll need

GitHub-Pull-Request: https://github.com/syncthing/syncthing/pull/3771
---
 lib/connections/relay_dial.go     |  12 +--
 lib/connections/relay_listen.go   |   6 +-
 lib/connections/service.go        |  20 ++--
 lib/connections/structs.go        |  63 ++++++++++--
 lib/connections/tcp_dial.go       |   8 +-
 lib/connections/tcp_listen.go     |   6 +-
 lib/model/model.go                |   4 +-
 lib/model/model_test.go           | 161 +++++++++---------------------
 lib/model/progressemitter_test.go |   2 +-
 9 files changed, 127 insertions(+), 155 deletions(-)

diff --git a/lib/connections/relay_dial.go b/lib/connections/relay_dial.go
index d8fe3d3e7..bd2377ccc 100644
--- a/lib/connections/relay_dial.go
+++ b/lib/connections/relay_dial.go
@@ -28,21 +28,21 @@ type relayDialer struct {
 	tlsCfg *tls.Config
 }
 
-func (d *relayDialer) Dial(id protocol.DeviceID, uri *url.URL) (IntermediateConnection, error) {
+func (d *relayDialer) Dial(id protocol.DeviceID, uri *url.URL) (internalConn, error) {
 	inv, err := client.GetInvitationFromRelay(uri, id, d.tlsCfg.Certificates, 10*time.Second)
 	if err != nil {
-		return IntermediateConnection{}, err
+		return internalConn{}, err
 	}
 
 	conn, err := client.JoinSession(inv)
 	if err != nil {
-		return IntermediateConnection{}, err
+		return internalConn{}, err
 	}
 
 	err = dialer.SetTCPOptions(conn)
 	if err != nil {
 		conn.Close()
-		return IntermediateConnection{}, err
+		return internalConn{}, err
 	}
 
 	var tc *tls.Conn
@@ -55,10 +55,10 @@ func (d *relayDialer) Dial(id protocol.DeviceID, uri *url.URL) (IntermediateConn
 	err = tlsTimedHandshake(tc)
 	if err != nil {
 		tc.Close()
-		return IntermediateConnection{}, err
+		return internalConn{}, err
 	}
 
-	return IntermediateConnection{tc, "Relay (Client)", relayPriority}, nil
+	return internalConn{tc, connTypeRelayClient, relayPriority}, nil
 }
 
 func (relayDialer) Priority() int {
diff --git a/lib/connections/relay_listen.go b/lib/connections/relay_listen.go
index 7fc67b62b..f23968c1d 100644
--- a/lib/connections/relay_listen.go
+++ b/lib/connections/relay_listen.go
@@ -30,7 +30,7 @@ type relayListener struct {
 
 	uri     *url.URL
 	tlsCfg  *tls.Config
-	conns   chan IntermediateConnection
+	conns   chan internalConn
 	factory listenerFactory
 
 	err    error
@@ -93,7 +93,7 @@ func (t *relayListener) Serve() {
 				continue
 			}
 
-			t.conns <- IntermediateConnection{tc, "Relay (Server)", relayPriority}
+			t.conns <- internalConn{tc, connTypeRelayServer, relayPriority}
 
 		// Poor mans notifier that informs the connection service that the
 		// relay URI has changed. This can only happen when we connect to a
@@ -167,7 +167,7 @@ func (t *relayListener) String() string {
 
 type relayListenerFactory struct{}
 
-func (f *relayListenerFactory) New(uri *url.URL, cfg *config.Wrapper, tlsCfg *tls.Config, conns chan IntermediateConnection, natService *nat.Service) genericListener {
+func (f *relayListenerFactory) New(uri *url.URL, cfg *config.Wrapper, tlsCfg *tls.Config, conns chan internalConn, natService *nat.Service) genericListener {
 	return &relayListener{
 		uri:     uri,
 		tlsCfg:  tlsCfg,
diff --git a/lib/connections/service.go b/lib/connections/service.go
index 5c19e9472..a0d88dc1b 100644
--- a/lib/connections/service.go
+++ b/lib/connections/service.go
@@ -50,7 +50,7 @@ type Service struct {
 	model                Model
 	tlsCfg               *tls.Config
 	discoverer           discover.Finder
-	conns                chan IntermediateConnection
+	conns                chan internalConn
 	bepProtocolName      string
 	tlsDefaultCommonName string
 	lans                 []*net.IPNet
@@ -65,7 +65,7 @@ type Service struct {
 	listenerSupervisor *suture.Supervisor
 
 	curConMut         sync.Mutex
-	currentConnection map[protocol.DeviceID]Connection
+	currentConnection map[protocol.DeviceID]completeConn
 }
 
 func NewService(cfg *config.Wrapper, myID protocol.DeviceID, mdl Model, tlsCfg *tls.Config, discoverer discover.Finder,
@@ -82,7 +82,7 @@ func NewService(cfg *config.Wrapper, myID protocol.DeviceID, mdl Model, tlsCfg *
 		model:                mdl,
 		tlsCfg:               tlsCfg,
 		discoverer:           discoverer,
-		conns:                make(chan IntermediateConnection),
+		conns:                make(chan internalConn),
 		bepProtocolName:      bepProtocolName,
 		tlsDefaultCommonName: tlsDefaultCommonName,
 		lans:                 lans,
@@ -105,7 +105,7 @@ func NewService(cfg *config.Wrapper, myID protocol.DeviceID, mdl Model, tlsCfg *
 		}),
 
 		curConMut:         sync.NewMutex(),
-		currentConnection: make(map[protocol.DeviceID]Connection),
+		currentConnection: make(map[protocol.DeviceID]completeConn),
 	}
 	cfg.Subscribe(service)
 
@@ -218,7 +218,7 @@ next:
 		priorityKnown := ok && connected
 
 		// Lower priority is better, just like nice etc.
-		if priorityKnown && ct.Priority > c.Priority {
+		if priorityKnown && ct.internalConn.priority > c.priority {
 			l.Debugln("Switching connections", remoteID)
 		} else if connected {
 			// We should not already be connected to the other party. TODO: This
@@ -268,9 +268,9 @@ next:
 			rd = NewReadLimiter(c, s.readRateLimit)
 		}
 
-		name := fmt.Sprintf("%s-%s (%s)", c.LocalAddr(), c.RemoteAddr(), c.Type)
+		name := fmt.Sprintf("%s-%s (%s)", c.LocalAddr(), c.RemoteAddr(), c.Type())
 		protoConn := protocol.NewConnection(remoteID, rd, wr, s.model, name, deviceCfg.Compression)
-		modelConn := Connection{c, protoConn}
+		modelConn := completeConn{c, protoConn}
 
 		l.Infof("Established secure connection to %s at %s", remoteID, name)
 		l.Debugf("cipher suite: %04X in lan: %t", c.ConnectionState().CipherSuite, !limit)
@@ -329,7 +329,7 @@ func (s *Service) connect() {
 			s.curConMut.Unlock()
 			priorityKnown := ok && connected
 
-			if priorityKnown && ct.Priority == bestDialerPrio {
+			if priorityKnown && ct.internalConn.priority == bestDialerPrio {
 				// Things are already as good as they can get.
 				continue
 			}
@@ -377,8 +377,8 @@ func (s *Service) connect() {
 					continue
 				}
 
-				if priorityKnown && dialerFactory.Priority() >= ct.Priority {
-					l.Debugf("Not dialing using %s as priority is less than current connection (%d >= %d)", dialerFactory, dialerFactory.Priority(), ct.Priority)
+				if priorityKnown && dialerFactory.Priority() >= ct.internalConn.priority {
+					l.Debugf("Not dialing using %s as priority is less than current connection (%d >= %d)", dialerFactory, dialerFactory.Priority(), ct.internalConn.priority)
 					continue
 				}
 
diff --git a/lib/connections/structs.go b/lib/connections/structs.go
index 53c432afb..37de9fba7 100644
--- a/lib/connections/structs.go
+++ b/lib/connections/structs.go
@@ -9,6 +9,7 @@ package connections
 import (
 	"crypto/tls"
 	"fmt"
+	"io"
 	"net"
 	"net/url"
 	"time"
@@ -18,19 +19,61 @@ import (
 	"github.com/syncthing/syncthing/lib/protocol"
 )
 
-type IntermediateConnection struct {
-	*tls.Conn
-	Type     string
-	Priority int
+// Connection is what we expose to the outside. It is a protocol.Connection
+// that can be closed and has some metadata.
+type Connection interface {
+	protocol.Connection
+	io.Closer
+	Type() string
+	RemoteAddr() net.Addr
 }
 
-type Connection struct {
-	IntermediateConnection
+// completeConn is the aggregation of an internalConn and the
+// protocol.Connection running on top of it. It implements the Connection
+// interface.
+type completeConn struct {
+	internalConn
 	protocol.Connection
 }
 
-func (c Connection) String() string {
-	return fmt.Sprintf("%s-%s/%s", c.LocalAddr(), c.RemoteAddr(), c.Type)
+// internalConn is the raw TLS connection plus some metadata on where it
+// came from (type, priority).
+type internalConn struct {
+	*tls.Conn
+	connType connType
+	priority int
+}
+
+type connType int
+
+const (
+	connTypeRelayClient connType = iota
+	connTypeRelayServer
+	connTypeTCPClient
+	connTypeTCPServer
+)
+
+func (t connType) String() string {
+	switch t {
+	case connTypeRelayClient:
+		return "relay-client"
+	case connTypeRelayServer:
+		return "relay-server"
+	case connTypeTCPClient:
+		return "tcp-client"
+	case connTypeTCPServer:
+		return "tcp-server"
+	default:
+		return "unknown-type"
+	}
+}
+
+func (c internalConn) Type() string {
+	return c.connType.String()
+}
+
+func (c internalConn) String() string {
+	return fmt.Sprintf("%s-%s/%s", c.LocalAddr(), c.RemoteAddr(), c.connType.String())
 }
 
 type dialerFactory interface {
@@ -41,12 +84,12 @@ type dialerFactory interface {
 }
 
 type genericDialer interface {
-	Dial(protocol.DeviceID, *url.URL) (IntermediateConnection, error)
+	Dial(protocol.DeviceID, *url.URL) (internalConn, error)
 	RedialFrequency() time.Duration
 }
 
 type listenerFactory interface {
-	New(*url.URL, *config.Wrapper, *tls.Config, chan IntermediateConnection, *nat.Service) genericListener
+	New(*url.URL, *config.Wrapper, *tls.Config, chan internalConn, *nat.Service) genericListener
 	Enabled(config.Configuration) bool
 }
 
diff --git a/lib/connections/tcp_dial.go b/lib/connections/tcp_dial.go
index d1bf52175..fa157dd3b 100644
--- a/lib/connections/tcp_dial.go
+++ b/lib/connections/tcp_dial.go
@@ -30,23 +30,23 @@ type tcpDialer struct {
 	tlsCfg *tls.Config
 }
 
-func (d *tcpDialer) Dial(id protocol.DeviceID, uri *url.URL) (IntermediateConnection, error) {
+func (d *tcpDialer) Dial(id protocol.DeviceID, uri *url.URL) (internalConn, error) {
 	uri = fixupPort(uri)
 
 	conn, err := dialer.DialTimeout(uri.Scheme, uri.Host, 10*time.Second)
 	if err != nil {
 		l.Debugln(err)
-		return IntermediateConnection{}, err
+		return internalConn{}, err
 	}
 
 	tc := tls.Client(conn, d.tlsCfg)
 	err = tlsTimedHandshake(tc)
 	if err != nil {
 		tc.Close()
-		return IntermediateConnection{}, err
+		return internalConn{}, err
 	}
 
-	return IntermediateConnection{tc, "TCP (Client)", tcpPriority}, nil
+	return internalConn{tc, connTypeTCPClient, tcpPriority}, nil
 }
 
 func (d *tcpDialer) RedialFrequency() time.Duration {
diff --git a/lib/connections/tcp_listen.go b/lib/connections/tcp_listen.go
index 013f919a4..e0011a068 100644
--- a/lib/connections/tcp_listen.go
+++ b/lib/connections/tcp_listen.go
@@ -32,7 +32,7 @@ type tcpListener struct {
 	uri     *url.URL
 	tlsCfg  *tls.Config
 	stop    chan struct{}
-	conns   chan IntermediateConnection
+	conns   chan internalConn
 	factory listenerFactory
 
 	natService *nat.Service
@@ -115,7 +115,7 @@ func (t *tcpListener) Serve() {
 			continue
 		}
 
-		t.conns <- IntermediateConnection{tc, "TCP (Server)", tcpPriority}
+		t.conns <- internalConn{tc, connTypeTCPServer, tcpPriority}
 	}
 }
 
@@ -173,7 +173,7 @@ func (t *tcpListener) Factory() listenerFactory {
 
 type tcpListenerFactory struct{}
 
-func (f *tcpListenerFactory) New(uri *url.URL, cfg *config.Wrapper, tlsCfg *tls.Config, conns chan IntermediateConnection, natService *nat.Service) genericListener {
+func (f *tcpListenerFactory) New(uri *url.URL, cfg *config.Wrapper, tlsCfg *tls.Config, conns chan internalConn, natService *nat.Service) genericListener {
 	return &tcpListener{
 		uri:        fixupPort(uri),
 		tlsCfg:     tlsCfg,
diff --git a/lib/model/model.go b/lib/model/model.go
index 80018aba6..0f5a7bdb5 100644
--- a/lib/model/model.go
+++ b/lib/model/model.go
@@ -412,7 +412,7 @@ func (m *Model) ConnectionStats() map[string]interface{} {
 			Paused:        m.devicePaused[device],
 		}
 		if conn, ok := m.conn[device]; ok {
-			ci.Type = conn.Type
+			ci.Type = conn.Type()
 			ci.Connected = ok
 			ci.Statistics = conn.Statistics()
 			if addr := conn.RemoteAddr(); addr != nil {
@@ -1334,7 +1334,7 @@ func (m *Model) AddConnection(conn connections.Connection, hello protocol.HelloR
 		"deviceName":    hello.DeviceName,
 		"clientName":    hello.ClientName,
 		"clientVersion": hello.ClientVersion,
-		"type":          conn.Type,
+		"type":          conn.Type(),
 	}
 
 	addr := conn.RemoteAddr()
diff --git a/lib/model/model_test.go b/lib/model/model_test.go
index 5cd670e29..525720cb0 100644
--- a/lib/model/model_test.go
+++ b/lib/model/model_test.go
@@ -8,7 +8,6 @@ package model
 
 import (
 	"bytes"
-	"crypto/tls"
 	"encoding/json"
 	"fmt"
 	"io/ioutil"
@@ -18,12 +17,12 @@ import (
 	"path/filepath"
 	"runtime"
 	"strconv"
+	"sync"
 	"testing"
 	"time"
 
 	"github.com/d4l3k/messagediff"
 	"github.com/syncthing/syncthing/lib/config"
-	"github.com/syncthing/syncthing/lib/connections"
 	"github.com/syncthing/syncthing/lib/db"
 	"github.com/syncthing/syncthing/lib/ignore"
 	"github.com/syncthing/syncthing/lib/osutil"
@@ -218,58 +217,75 @@ type downloadProgressMessage struct {
 	updates []protocol.FileDownloadProgressUpdate
 }
 
-type FakeConnection struct {
+type fakeConnection struct {
 	id                       protocol.DeviceID
 	requestData              []byte
 	downloadProgressMessages []downloadProgressMessage
+	closed                   bool
+	mut                      sync.Mutex
 }
 
-func (FakeConnection) Close() error {
+func (f *fakeConnection) Close() error {
+	f.mut.Lock()
+	defer f.mut.Unlock()
+	f.closed = true
 	return nil
 }
 
-func (f FakeConnection) Start() {
+func (f *fakeConnection) Start() {
 }
 
-func (f FakeConnection) ID() protocol.DeviceID {
+func (f *fakeConnection) ID() protocol.DeviceID {
 	return f.id
 }
 
-func (f FakeConnection) Name() string {
+func (f *fakeConnection) Name() string {
 	return ""
 }
 
-func (f FakeConnection) Option(string) string {
+func (f *fakeConnection) Option(string) string {
 	return ""
 }
 
-func (FakeConnection) Index(string, []protocol.FileInfo) error {
+func (f *fakeConnection) Index(string, []protocol.FileInfo) error {
 	return nil
 }
 
-func (FakeConnection) IndexUpdate(string, []protocol.FileInfo) error {
+func (f *fakeConnection) IndexUpdate(string, []protocol.FileInfo) error {
 	return nil
 }
 
-func (f FakeConnection) Request(folder, name string, offset int64, size int, hash []byte, fromTemporary bool) ([]byte, error) {
+func (f *fakeConnection) Request(folder, name string, offset int64, size int, hash []byte, fromTemporary bool) ([]byte, error) {
 	return f.requestData, nil
 }
 
-func (FakeConnection) ClusterConfig(protocol.ClusterConfig) {}
+func (f *fakeConnection) ClusterConfig(protocol.ClusterConfig) {}
 
-func (FakeConnection) Ping() bool {
-	return true
+func (f *fakeConnection) Ping() bool {
+	f.mut.Lock()
+	defer f.mut.Unlock()
+	return f.closed
 }
 
-func (FakeConnection) Closed() bool {
-	return false
+func (f *fakeConnection) Closed() bool {
+	f.mut.Lock()
+	defer f.mut.Unlock()
+	return f.closed
 }
 
-func (FakeConnection) Statistics() protocol.Statistics {
+func (f *fakeConnection) Statistics() protocol.Statistics {
 	return protocol.Statistics{}
 }
 
-func (f *FakeConnection) DownloadProgress(folder string, updates []protocol.FileDownloadProgressUpdate) {
+func (f *fakeConnection) RemoteAddr() net.Addr {
+	return &fakeAddr{}
+}
+
+func (f *fakeConnection) Type() string {
+	return "fake"
+}
+
+func (f *fakeConnection) DownloadProgress(folder string, updates []protocol.FileDownloadProgressUpdate) {
 	f.downloadProgressMessages = append(f.downloadProgressMessages, downloadProgressMessage{
 		folder:  folder,
 		updates: updates,
@@ -287,18 +303,11 @@ func BenchmarkRequest(b *testing.B) {
 	const n = 1000
 	files := genFiles(n)
 
-	fc := &FakeConnection{
+	fc := &fakeConnection{
 		id:          device1,
 		requestData: []byte("some data to return"),
 	}
-	m.AddConnection(connections.Connection{
-		IntermediateConnection: connections.IntermediateConnection{
-			Conn:     tls.Client(&fakeConn{}, nil),
-			Type:     "foo",
-			Priority: 10,
-		},
-		Connection: fc,
-	}, protocol.HelloResult{})
+	m.AddConnection(fc, protocol.HelloResult{})
 	m.Index(device1, "default", files)
 
 	b.ResetTimer()
@@ -335,16 +344,9 @@ func TestDeviceRename(t *testing.T) {
 		t.Errorf("Device already has a name")
 	}
 
-	conn := connections.Connection{
-		IntermediateConnection: connections.IntermediateConnection{
-			Conn:     tls.Client(&fakeConn{}, nil),
-			Type:     "foo",
-			Priority: 10,
-		},
-		Connection: &FakeConnection{
-			id:          device1,
-			requestData: []byte("some data to return"),
-		},
+	conn := &fakeConnection{
+		id:          device1,
+		requestData: []byte("some data to return"),
 	}
 
 	m.AddConnection(conn, hello)
@@ -504,14 +506,7 @@ func TestIntroducer(t *testing.T) {
 			m.AddFolder(folder)
 		}
 		m.ServeBackground()
-		m.AddConnection(connections.Connection{
-			IntermediateConnection: connections.IntermediateConnection{
-				Conn: tls.Client(&fakeConn{}, nil),
-			},
-			Connection: &FakeConnection{
-				id: device1,
-			},
-		}, protocol.HelloResult{})
+		m.AddConnection(&fakeConnection{id: device1}, protocol.HelloResult{})
 		return wcfg, m
 	}
 
@@ -1904,34 +1899,14 @@ func TestSharedWithClearedOnDisconnect(t *testing.T) {
 
 	wcfg := config.Wrap("/tmp/test", cfg)
 
-	d2c := &fakeConn{}
-
 	m := NewModel(wcfg, protocol.LocalDeviceID, "device", "syncthing", "dev", dbi, nil)
 	m.AddFolder(fcfg)
 	m.StartFolder(fcfg.ID)
 	m.ServeBackground()
 
-	conn1 := connections.Connection{
-		IntermediateConnection: connections.IntermediateConnection{
-			Conn:     tls.Client(&fakeConn{}, nil),
-			Type:     "foo",
-			Priority: 10,
-		},
-		Connection: &FakeConnection{
-			id: device1,
-		},
-	}
+	conn1 := &fakeConnection{id: device1}
 	m.AddConnection(conn1, protocol.HelloResult{})
-	conn2 := connections.Connection{
-		IntermediateConnection: connections.IntermediateConnection{
-			Conn:     tls.Client(d2c, nil),
-			Type:     "foo",
-			Priority: 10,
-		},
-		Connection: &FakeConnection{
-			id: device2,
-		},
-	}
+	conn2 := &fakeConnection{id: device2}
 	m.AddConnection(conn2, protocol.HelloResult{})
 
 	m.ClusterConfig(device1, protocol.ClusterConfig{
@@ -1964,7 +1939,7 @@ func TestSharedWithClearedOnDisconnect(t *testing.T) {
 		t.Error("not shared with device2")
 	}
 
-	if d2c.closed {
+	if conn2.Closed() {
 		t.Error("conn already closed")
 	}
 
@@ -1984,7 +1959,7 @@ func TestSharedWithClearedOnDisconnect(t *testing.T) {
 		t.Error("shared with device2")
 	}
 
-	if !d2c.closed {
+	if !conn2.Closed() {
 		t.Error("connection not closed")
 	}
 
@@ -2108,16 +2083,7 @@ func TestIssue3496(t *testing.T) {
 }
 
 func addFakeConn(m *Model, dev protocol.DeviceID) {
-	conn1 := connections.Connection{
-		IntermediateConnection: connections.IntermediateConnection{
-			Conn:     tls.Client(&fakeConn{}, nil),
-			Type:     "foo",
-			Priority: 10,
-		},
-		Connection: &FakeConnection{
-			id: dev,
-		},
-	}
+	conn1 := &fakeConnection{id: dev}
 	m.AddConnection(conn1, protocol.HelloResult{})
 
 	m.ClusterConfig(device1, protocol.ClusterConfig{
@@ -2142,40 +2108,3 @@ func (fakeAddr) Network() string {
 func (fakeAddr) String() string {
 	return "address"
 }
-
-type fakeConn struct {
-	closed bool
-}
-
-func (c *fakeConn) Close() error {
-	c.closed = true
-	return nil
-}
-
-func (fakeConn) LocalAddr() net.Addr {
-	return &fakeAddr{}
-}
-
-func (fakeConn) RemoteAddr() net.Addr {
-	return &fakeAddr{}
-}
-
-func (fakeConn) Read([]byte) (int, error) {
-	return 0, nil
-}
-
-func (fakeConn) Write([]byte) (int, error) {
-	return 0, nil
-}
-
-func (fakeConn) SetDeadline(time.Time) error {
-	return nil
-}
-
-func (fakeConn) SetReadDeadline(time.Time) error {
-	return nil
-}
-
-func (fakeConn) SetWriteDeadline(time.Time) error {
-	return nil
-}
diff --git a/lib/model/progressemitter_test.go b/lib/model/progressemitter_test.go
index 595d12873..00ba1d154 100644
--- a/lib/model/progressemitter_test.go
+++ b/lib/model/progressemitter_test.go
@@ -107,7 +107,7 @@ func TestSendDownloadProgressMessages(t *testing.T) {
 		TempIndexMinBlocks:      10,
 	})
 
-	fc := &FakeConnection{}
+	fc := &fakeConnection{}
 
 	p := NewProgressEmitter(c)
 	p.temporaryIndexSubscribe(fc, []string{"folder", "folder2"})