lib/connections, lib/model: Connection service should expose a single interface

Makes testing easier, which we'll need

GitHub-Pull-Request: https://github.com/syncthing/syncthing/pull/3771
This commit is contained in:
Jakob Borg 2016-11-30 07:54:20 +00:00
parent 9da422f1c5
commit ed4f6fc4b3
9 changed files with 127 additions and 155 deletions

View File

@ -28,21 +28,21 @@ type relayDialer struct {
tlsCfg *tls.Config tlsCfg *tls.Config
} }
func (d *relayDialer) Dial(id protocol.DeviceID, uri *url.URL) (IntermediateConnection, error) { func (d *relayDialer) Dial(id protocol.DeviceID, uri *url.URL) (internalConn, error) {
inv, err := client.GetInvitationFromRelay(uri, id, d.tlsCfg.Certificates, 10*time.Second) inv, err := client.GetInvitationFromRelay(uri, id, d.tlsCfg.Certificates, 10*time.Second)
if err != nil { if err != nil {
return IntermediateConnection{}, err return internalConn{}, err
} }
conn, err := client.JoinSession(inv) conn, err := client.JoinSession(inv)
if err != nil { if err != nil {
return IntermediateConnection{}, err return internalConn{}, err
} }
err = dialer.SetTCPOptions(conn) err = dialer.SetTCPOptions(conn)
if err != nil { if err != nil {
conn.Close() conn.Close()
return IntermediateConnection{}, err return internalConn{}, err
} }
var tc *tls.Conn var tc *tls.Conn
@ -55,10 +55,10 @@ func (d *relayDialer) Dial(id protocol.DeviceID, uri *url.URL) (IntermediateConn
err = tlsTimedHandshake(tc) err = tlsTimedHandshake(tc)
if err != nil { if err != nil {
tc.Close() tc.Close()
return IntermediateConnection{}, err return internalConn{}, err
} }
return IntermediateConnection{tc, "Relay (Client)", relayPriority}, nil return internalConn{tc, connTypeRelayClient, relayPriority}, nil
} }
func (relayDialer) Priority() int { func (relayDialer) Priority() int {

View File

@ -30,7 +30,7 @@ type relayListener struct {
uri *url.URL uri *url.URL
tlsCfg *tls.Config tlsCfg *tls.Config
conns chan IntermediateConnection conns chan internalConn
factory listenerFactory factory listenerFactory
err error err error
@ -93,7 +93,7 @@ func (t *relayListener) Serve() {
continue continue
} }
t.conns <- IntermediateConnection{tc, "Relay (Server)", relayPriority} t.conns <- internalConn{tc, connTypeRelayServer, relayPriority}
// Poor mans notifier that informs the connection service that the // Poor mans notifier that informs the connection service that the
// relay URI has changed. This can only happen when we connect to a // relay URI has changed. This can only happen when we connect to a
@ -167,7 +167,7 @@ func (t *relayListener) String() string {
type relayListenerFactory struct{} type relayListenerFactory struct{}
func (f *relayListenerFactory) New(uri *url.URL, cfg *config.Wrapper, tlsCfg *tls.Config, conns chan IntermediateConnection, natService *nat.Service) genericListener { func (f *relayListenerFactory) New(uri *url.URL, cfg *config.Wrapper, tlsCfg *tls.Config, conns chan internalConn, natService *nat.Service) genericListener {
return &relayListener{ return &relayListener{
uri: uri, uri: uri,
tlsCfg: tlsCfg, tlsCfg: tlsCfg,

View File

@ -50,7 +50,7 @@ type Service struct {
model Model model Model
tlsCfg *tls.Config tlsCfg *tls.Config
discoverer discover.Finder discoverer discover.Finder
conns chan IntermediateConnection conns chan internalConn
bepProtocolName string bepProtocolName string
tlsDefaultCommonName string tlsDefaultCommonName string
lans []*net.IPNet lans []*net.IPNet
@ -65,7 +65,7 @@ type Service struct {
listenerSupervisor *suture.Supervisor listenerSupervisor *suture.Supervisor
curConMut sync.Mutex curConMut sync.Mutex
currentConnection map[protocol.DeviceID]Connection currentConnection map[protocol.DeviceID]completeConn
} }
func NewService(cfg *config.Wrapper, myID protocol.DeviceID, mdl Model, tlsCfg *tls.Config, discoverer discover.Finder, func NewService(cfg *config.Wrapper, myID protocol.DeviceID, mdl Model, tlsCfg *tls.Config, discoverer discover.Finder,
@ -82,7 +82,7 @@ func NewService(cfg *config.Wrapper, myID protocol.DeviceID, mdl Model, tlsCfg *
model: mdl, model: mdl,
tlsCfg: tlsCfg, tlsCfg: tlsCfg,
discoverer: discoverer, discoverer: discoverer,
conns: make(chan IntermediateConnection), conns: make(chan internalConn),
bepProtocolName: bepProtocolName, bepProtocolName: bepProtocolName,
tlsDefaultCommonName: tlsDefaultCommonName, tlsDefaultCommonName: tlsDefaultCommonName,
lans: lans, lans: lans,
@ -105,7 +105,7 @@ func NewService(cfg *config.Wrapper, myID protocol.DeviceID, mdl Model, tlsCfg *
}), }),
curConMut: sync.NewMutex(), curConMut: sync.NewMutex(),
currentConnection: make(map[protocol.DeviceID]Connection), currentConnection: make(map[protocol.DeviceID]completeConn),
} }
cfg.Subscribe(service) cfg.Subscribe(service)
@ -218,7 +218,7 @@ next:
priorityKnown := ok && connected priorityKnown := ok && connected
// Lower priority is better, just like nice etc. // Lower priority is better, just like nice etc.
if priorityKnown && ct.Priority > c.Priority { if priorityKnown && ct.internalConn.priority > c.priority {
l.Debugln("Switching connections", remoteID) l.Debugln("Switching connections", remoteID)
} else if connected { } else if connected {
// We should not already be connected to the other party. TODO: This // We should not already be connected to the other party. TODO: This
@ -268,9 +268,9 @@ next:
rd = NewReadLimiter(c, s.readRateLimit) rd = NewReadLimiter(c, s.readRateLimit)
} }
name := fmt.Sprintf("%s-%s (%s)", c.LocalAddr(), c.RemoteAddr(), c.Type) name := fmt.Sprintf("%s-%s (%s)", c.LocalAddr(), c.RemoteAddr(), c.Type())
protoConn := protocol.NewConnection(remoteID, rd, wr, s.model, name, deviceCfg.Compression) protoConn := protocol.NewConnection(remoteID, rd, wr, s.model, name, deviceCfg.Compression)
modelConn := Connection{c, protoConn} modelConn := completeConn{c, protoConn}
l.Infof("Established secure connection to %s at %s", remoteID, name) l.Infof("Established secure connection to %s at %s", remoteID, name)
l.Debugf("cipher suite: %04X in lan: %t", c.ConnectionState().CipherSuite, !limit) l.Debugf("cipher suite: %04X in lan: %t", c.ConnectionState().CipherSuite, !limit)
@ -329,7 +329,7 @@ func (s *Service) connect() {
s.curConMut.Unlock() s.curConMut.Unlock()
priorityKnown := ok && connected priorityKnown := ok && connected
if priorityKnown && ct.Priority == bestDialerPrio { if priorityKnown && ct.internalConn.priority == bestDialerPrio {
// Things are already as good as they can get. // Things are already as good as they can get.
continue continue
} }
@ -377,8 +377,8 @@ func (s *Service) connect() {
continue continue
} }
if priorityKnown && dialerFactory.Priority() >= ct.Priority { if priorityKnown && dialerFactory.Priority() >= ct.internalConn.priority {
l.Debugf("Not dialing using %s as priority is less than current connection (%d >= %d)", dialerFactory, dialerFactory.Priority(), ct.Priority) l.Debugf("Not dialing using %s as priority is less than current connection (%d >= %d)", dialerFactory, dialerFactory.Priority(), ct.internalConn.priority)
continue continue
} }

View File

@ -9,6 +9,7 @@ package connections
import ( import (
"crypto/tls" "crypto/tls"
"fmt" "fmt"
"io"
"net" "net"
"net/url" "net/url"
"time" "time"
@ -18,19 +19,61 @@ import (
"github.com/syncthing/syncthing/lib/protocol" "github.com/syncthing/syncthing/lib/protocol"
) )
type IntermediateConnection struct { // Connection is what we expose to the outside. It is a protocol.Connection
*tls.Conn // that can be closed and has some metadata.
Type string type Connection interface {
Priority int protocol.Connection
io.Closer
Type() string
RemoteAddr() net.Addr
} }
type Connection struct { // completeConn is the aggregation of an internalConn and the
IntermediateConnection // protocol.Connection running on top of it. It implements the Connection
// interface.
type completeConn struct {
internalConn
protocol.Connection protocol.Connection
} }
func (c Connection) String() string { // internalConn is the raw TLS connection plus some metadata on where it
return fmt.Sprintf("%s-%s/%s", c.LocalAddr(), c.RemoteAddr(), c.Type) // came from (type, priority).
type internalConn struct {
*tls.Conn
connType connType
priority int
}
type connType int
const (
connTypeRelayClient connType = iota
connTypeRelayServer
connTypeTCPClient
connTypeTCPServer
)
func (t connType) String() string {
switch t {
case connTypeRelayClient:
return "relay-client"
case connTypeRelayServer:
return "relay-server"
case connTypeTCPClient:
return "tcp-client"
case connTypeTCPServer:
return "tcp-server"
default:
return "unknown-type"
}
}
func (c internalConn) Type() string {
return c.connType.String()
}
func (c internalConn) String() string {
return fmt.Sprintf("%s-%s/%s", c.LocalAddr(), c.RemoteAddr(), c.connType.String())
} }
type dialerFactory interface { type dialerFactory interface {
@ -41,12 +84,12 @@ type dialerFactory interface {
} }
type genericDialer interface { type genericDialer interface {
Dial(protocol.DeviceID, *url.URL) (IntermediateConnection, error) Dial(protocol.DeviceID, *url.URL) (internalConn, error)
RedialFrequency() time.Duration RedialFrequency() time.Duration
} }
type listenerFactory interface { type listenerFactory interface {
New(*url.URL, *config.Wrapper, *tls.Config, chan IntermediateConnection, *nat.Service) genericListener New(*url.URL, *config.Wrapper, *tls.Config, chan internalConn, *nat.Service) genericListener
Enabled(config.Configuration) bool Enabled(config.Configuration) bool
} }

View File

@ -30,23 +30,23 @@ type tcpDialer struct {
tlsCfg *tls.Config tlsCfg *tls.Config
} }
func (d *tcpDialer) Dial(id protocol.DeviceID, uri *url.URL) (IntermediateConnection, error) { func (d *tcpDialer) Dial(id protocol.DeviceID, uri *url.URL) (internalConn, error) {
uri = fixupPort(uri) uri = fixupPort(uri)
conn, err := dialer.DialTimeout(uri.Scheme, uri.Host, 10*time.Second) conn, err := dialer.DialTimeout(uri.Scheme, uri.Host, 10*time.Second)
if err != nil { if err != nil {
l.Debugln(err) l.Debugln(err)
return IntermediateConnection{}, err return internalConn{}, err
} }
tc := tls.Client(conn, d.tlsCfg) tc := tls.Client(conn, d.tlsCfg)
err = tlsTimedHandshake(tc) err = tlsTimedHandshake(tc)
if err != nil { if err != nil {
tc.Close() tc.Close()
return IntermediateConnection{}, err return internalConn{}, err
} }
return IntermediateConnection{tc, "TCP (Client)", tcpPriority}, nil return internalConn{tc, connTypeTCPClient, tcpPriority}, nil
} }
func (d *tcpDialer) RedialFrequency() time.Duration { func (d *tcpDialer) RedialFrequency() time.Duration {

View File

@ -32,7 +32,7 @@ type tcpListener struct {
uri *url.URL uri *url.URL
tlsCfg *tls.Config tlsCfg *tls.Config
stop chan struct{} stop chan struct{}
conns chan IntermediateConnection conns chan internalConn
factory listenerFactory factory listenerFactory
natService *nat.Service natService *nat.Service
@ -115,7 +115,7 @@ func (t *tcpListener) Serve() {
continue continue
} }
t.conns <- IntermediateConnection{tc, "TCP (Server)", tcpPriority} t.conns <- internalConn{tc, connTypeTCPServer, tcpPriority}
} }
} }
@ -173,7 +173,7 @@ func (t *tcpListener) Factory() listenerFactory {
type tcpListenerFactory struct{} type tcpListenerFactory struct{}
func (f *tcpListenerFactory) New(uri *url.URL, cfg *config.Wrapper, tlsCfg *tls.Config, conns chan IntermediateConnection, natService *nat.Service) genericListener { func (f *tcpListenerFactory) New(uri *url.URL, cfg *config.Wrapper, tlsCfg *tls.Config, conns chan internalConn, natService *nat.Service) genericListener {
return &tcpListener{ return &tcpListener{
uri: fixupPort(uri), uri: fixupPort(uri),
tlsCfg: tlsCfg, tlsCfg: tlsCfg,

View File

@ -412,7 +412,7 @@ func (m *Model) ConnectionStats() map[string]interface{} {
Paused: m.devicePaused[device], Paused: m.devicePaused[device],
} }
if conn, ok := m.conn[device]; ok { if conn, ok := m.conn[device]; ok {
ci.Type = conn.Type ci.Type = conn.Type()
ci.Connected = ok ci.Connected = ok
ci.Statistics = conn.Statistics() ci.Statistics = conn.Statistics()
if addr := conn.RemoteAddr(); addr != nil { if addr := conn.RemoteAddr(); addr != nil {
@ -1334,7 +1334,7 @@ func (m *Model) AddConnection(conn connections.Connection, hello protocol.HelloR
"deviceName": hello.DeviceName, "deviceName": hello.DeviceName,
"clientName": hello.ClientName, "clientName": hello.ClientName,
"clientVersion": hello.ClientVersion, "clientVersion": hello.ClientVersion,
"type": conn.Type, "type": conn.Type(),
} }
addr := conn.RemoteAddr() addr := conn.RemoteAddr()

View File

@ -8,7 +8,6 @@ package model
import ( import (
"bytes" "bytes"
"crypto/tls"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
@ -18,12 +17,12 @@ import (
"path/filepath" "path/filepath"
"runtime" "runtime"
"strconv" "strconv"
"sync"
"testing" "testing"
"time" "time"
"github.com/d4l3k/messagediff" "github.com/d4l3k/messagediff"
"github.com/syncthing/syncthing/lib/config" "github.com/syncthing/syncthing/lib/config"
"github.com/syncthing/syncthing/lib/connections"
"github.com/syncthing/syncthing/lib/db" "github.com/syncthing/syncthing/lib/db"
"github.com/syncthing/syncthing/lib/ignore" "github.com/syncthing/syncthing/lib/ignore"
"github.com/syncthing/syncthing/lib/osutil" "github.com/syncthing/syncthing/lib/osutil"
@ -218,58 +217,75 @@ type downloadProgressMessage struct {
updates []protocol.FileDownloadProgressUpdate updates []protocol.FileDownloadProgressUpdate
} }
type FakeConnection struct { type fakeConnection struct {
id protocol.DeviceID id protocol.DeviceID
requestData []byte requestData []byte
downloadProgressMessages []downloadProgressMessage downloadProgressMessages []downloadProgressMessage
closed bool
mut sync.Mutex
} }
func (FakeConnection) Close() error { func (f *fakeConnection) Close() error {
f.mut.Lock()
defer f.mut.Unlock()
f.closed = true
return nil return nil
} }
func (f FakeConnection) Start() { func (f *fakeConnection) Start() {
} }
func (f FakeConnection) ID() protocol.DeviceID { func (f *fakeConnection) ID() protocol.DeviceID {
return f.id return f.id
} }
func (f FakeConnection) Name() string { func (f *fakeConnection) Name() string {
return "" return ""
} }
func (f FakeConnection) Option(string) string { func (f *fakeConnection) Option(string) string {
return "" return ""
} }
func (FakeConnection) Index(string, []protocol.FileInfo) error { func (f *fakeConnection) Index(string, []protocol.FileInfo) error {
return nil return nil
} }
func (FakeConnection) IndexUpdate(string, []protocol.FileInfo) error { func (f *fakeConnection) IndexUpdate(string, []protocol.FileInfo) error {
return nil return nil
} }
func (f FakeConnection) Request(folder, name string, offset int64, size int, hash []byte, fromTemporary bool) ([]byte, error) { func (f *fakeConnection) Request(folder, name string, offset int64, size int, hash []byte, fromTemporary bool) ([]byte, error) {
return f.requestData, nil return f.requestData, nil
} }
func (FakeConnection) ClusterConfig(protocol.ClusterConfig) {} func (f *fakeConnection) ClusterConfig(protocol.ClusterConfig) {}
func (FakeConnection) Ping() bool { func (f *fakeConnection) Ping() bool {
return true f.mut.Lock()
defer f.mut.Unlock()
return f.closed
} }
func (FakeConnection) Closed() bool { func (f *fakeConnection) Closed() bool {
return false f.mut.Lock()
defer f.mut.Unlock()
return f.closed
} }
func (FakeConnection) Statistics() protocol.Statistics { func (f *fakeConnection) Statistics() protocol.Statistics {
return protocol.Statistics{} return protocol.Statistics{}
} }
func (f *FakeConnection) DownloadProgress(folder string, updates []protocol.FileDownloadProgressUpdate) { func (f *fakeConnection) RemoteAddr() net.Addr {
return &fakeAddr{}
}
func (f *fakeConnection) Type() string {
return "fake"
}
func (f *fakeConnection) DownloadProgress(folder string, updates []protocol.FileDownloadProgressUpdate) {
f.downloadProgressMessages = append(f.downloadProgressMessages, downloadProgressMessage{ f.downloadProgressMessages = append(f.downloadProgressMessages, downloadProgressMessage{
folder: folder, folder: folder,
updates: updates, updates: updates,
@ -287,18 +303,11 @@ func BenchmarkRequest(b *testing.B) {
const n = 1000 const n = 1000
files := genFiles(n) files := genFiles(n)
fc := &FakeConnection{ fc := &fakeConnection{
id: device1, id: device1,
requestData: []byte("some data to return"), requestData: []byte("some data to return"),
} }
m.AddConnection(connections.Connection{ m.AddConnection(fc, protocol.HelloResult{})
IntermediateConnection: connections.IntermediateConnection{
Conn: tls.Client(&fakeConn{}, nil),
Type: "foo",
Priority: 10,
},
Connection: fc,
}, protocol.HelloResult{})
m.Index(device1, "default", files) m.Index(device1, "default", files)
b.ResetTimer() b.ResetTimer()
@ -335,16 +344,9 @@ func TestDeviceRename(t *testing.T) {
t.Errorf("Device already has a name") t.Errorf("Device already has a name")
} }
conn := connections.Connection{ conn := &fakeConnection{
IntermediateConnection: connections.IntermediateConnection{
Conn: tls.Client(&fakeConn{}, nil),
Type: "foo",
Priority: 10,
},
Connection: &FakeConnection{
id: device1, id: device1,
requestData: []byte("some data to return"), requestData: []byte("some data to return"),
},
} }
m.AddConnection(conn, hello) m.AddConnection(conn, hello)
@ -504,14 +506,7 @@ func TestIntroducer(t *testing.T) {
m.AddFolder(folder) m.AddFolder(folder)
} }
m.ServeBackground() m.ServeBackground()
m.AddConnection(connections.Connection{ m.AddConnection(&fakeConnection{id: device1}, protocol.HelloResult{})
IntermediateConnection: connections.IntermediateConnection{
Conn: tls.Client(&fakeConn{}, nil),
},
Connection: &FakeConnection{
id: device1,
},
}, protocol.HelloResult{})
return wcfg, m return wcfg, m
} }
@ -1904,34 +1899,14 @@ func TestSharedWithClearedOnDisconnect(t *testing.T) {
wcfg := config.Wrap("/tmp/test", cfg) wcfg := config.Wrap("/tmp/test", cfg)
d2c := &fakeConn{}
m := NewModel(wcfg, protocol.LocalDeviceID, "device", "syncthing", "dev", dbi, nil) m := NewModel(wcfg, protocol.LocalDeviceID, "device", "syncthing", "dev", dbi, nil)
m.AddFolder(fcfg) m.AddFolder(fcfg)
m.StartFolder(fcfg.ID) m.StartFolder(fcfg.ID)
m.ServeBackground() m.ServeBackground()
conn1 := connections.Connection{ conn1 := &fakeConnection{id: device1}
IntermediateConnection: connections.IntermediateConnection{
Conn: tls.Client(&fakeConn{}, nil),
Type: "foo",
Priority: 10,
},
Connection: &FakeConnection{
id: device1,
},
}
m.AddConnection(conn1, protocol.HelloResult{}) m.AddConnection(conn1, protocol.HelloResult{})
conn2 := connections.Connection{ conn2 := &fakeConnection{id: device2}
IntermediateConnection: connections.IntermediateConnection{
Conn: tls.Client(d2c, nil),
Type: "foo",
Priority: 10,
},
Connection: &FakeConnection{
id: device2,
},
}
m.AddConnection(conn2, protocol.HelloResult{}) m.AddConnection(conn2, protocol.HelloResult{})
m.ClusterConfig(device1, protocol.ClusterConfig{ m.ClusterConfig(device1, protocol.ClusterConfig{
@ -1964,7 +1939,7 @@ func TestSharedWithClearedOnDisconnect(t *testing.T) {
t.Error("not shared with device2") t.Error("not shared with device2")
} }
if d2c.closed { if conn2.Closed() {
t.Error("conn already closed") t.Error("conn already closed")
} }
@ -1984,7 +1959,7 @@ func TestSharedWithClearedOnDisconnect(t *testing.T) {
t.Error("shared with device2") t.Error("shared with device2")
} }
if !d2c.closed { if !conn2.Closed() {
t.Error("connection not closed") t.Error("connection not closed")
} }
@ -2108,16 +2083,7 @@ func TestIssue3496(t *testing.T) {
} }
func addFakeConn(m *Model, dev protocol.DeviceID) { func addFakeConn(m *Model, dev protocol.DeviceID) {
conn1 := connections.Connection{ conn1 := &fakeConnection{id: dev}
IntermediateConnection: connections.IntermediateConnection{
Conn: tls.Client(&fakeConn{}, nil),
Type: "foo",
Priority: 10,
},
Connection: &FakeConnection{
id: dev,
},
}
m.AddConnection(conn1, protocol.HelloResult{}) m.AddConnection(conn1, protocol.HelloResult{})
m.ClusterConfig(device1, protocol.ClusterConfig{ m.ClusterConfig(device1, protocol.ClusterConfig{
@ -2142,40 +2108,3 @@ func (fakeAddr) Network() string {
func (fakeAddr) String() string { func (fakeAddr) String() string {
return "address" return "address"
} }
type fakeConn struct {
closed bool
}
func (c *fakeConn) Close() error {
c.closed = true
return nil
}
func (fakeConn) LocalAddr() net.Addr {
return &fakeAddr{}
}
func (fakeConn) RemoteAddr() net.Addr {
return &fakeAddr{}
}
func (fakeConn) Read([]byte) (int, error) {
return 0, nil
}
func (fakeConn) Write([]byte) (int, error) {
return 0, nil
}
func (fakeConn) SetDeadline(time.Time) error {
return nil
}
func (fakeConn) SetReadDeadline(time.Time) error {
return nil
}
func (fakeConn) SetWriteDeadline(time.Time) error {
return nil
}

View File

@ -107,7 +107,7 @@ func TestSendDownloadProgressMessages(t *testing.T) {
TempIndexMinBlocks: 10, TempIndexMinBlocks: 10,
}) })
fc := &FakeConnection{} fc := &fakeConnection{}
p := NewProgressEmitter(c) p := NewProgressEmitter(c)
p.temporaryIndexSubscribe(fc, []string{"folder", "folder2"}) p.temporaryIndexSubscribe(fc, []string{"folder", "folder2"})