lib: Close underlying conn in protocol (fixes #7165) (#7212)

This commit is contained in:
Simon Frei 2020-12-21 11:40:51 +01:00 committed by GitHub
parent 4a787986cd
commit c845e245a1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 106 additions and 124 deletions

View File

@ -11,7 +11,6 @@ import (
"net" "net"
"time" "time"
"github.com/syncthing/syncthing/lib/connections"
"github.com/syncthing/syncthing/lib/db" "github.com/syncthing/syncthing/lib/db"
"github.com/syncthing/syncthing/lib/model" "github.com/syncthing/syncthing/lib/model"
"github.com/syncthing/syncthing/lib/protocol" "github.com/syncthing/syncthing/lib/protocol"
@ -114,7 +113,7 @@ func (m *mockedModel) ScanFolderSubdirs(folder string, subs []string) error {
func (m *mockedModel) BringToFront(folder, file string) {} func (m *mockedModel) BringToFront(folder, file string) {}
func (m *mockedModel) Connection(deviceID protocol.DeviceID) (connections.Connection, bool) { func (m *mockedModel) Connection(deviceID protocol.DeviceID) (protocol.Connection, bool) {
return nil, false return nil, false
} }
@ -165,7 +164,7 @@ func (m *mockedModel) DownloadProgress(deviceID protocol.DeviceID, folder string
return nil return nil
} }
func (m *mockedModel) AddConnection(conn connections.Connection, hello protocol.Hello) {} func (m *mockedModel) AddConnection(conn protocol.Connection, hello protocol.Hello) {}
func (m *mockedModel) OnHello(protocol.DeviceID, net.Addr, protocol.Hello) error { func (m *mockedModel) OnHello(protocol.DeviceID, net.Addr, protocol.Hello) error {
return nil return nil

View File

@ -329,15 +329,14 @@ func (s *service) handle(ctx context.Context) error {
var protoConn protocol.Connection var protoConn protocol.Connection
passwords := s.cfg.FolderPasswords(remoteID) passwords := s.cfg.FolderPasswords(remoteID)
if len(passwords) > 0 { if len(passwords) > 0 {
protoConn = protocol.NewEncryptedConnection(passwords, remoteID, rd, wr, s.model, c.String(), deviceCfg.Compression) protoConn = protocol.NewEncryptedConnection(passwords, remoteID, rd, wr, c, s.model, c, deviceCfg.Compression)
} else { } else {
protoConn = protocol.NewConnection(remoteID, rd, wr, s.model, c.String(), deviceCfg.Compression) protoConn = protocol.NewConnection(remoteID, rd, wr, c, s.model, c, deviceCfg.Compression)
} }
modelConn := completeConn{c, protoConn}
l.Infof("Established secure connection to %s at %s", remoteID, c) l.Infof("Established secure connection to %s at %s", remoteID, c)
s.model.AddConnection(modelConn, hello) s.model.AddConnection(protoConn, hello)
continue continue
} }
} }

View File

@ -22,31 +22,6 @@ import (
"github.com/thejerf/suture/v4" "github.com/thejerf/suture/v4"
) )
// Connection is what we expose to the outside. It is a protocol.Connection
// that can be closed and has some metadata.
type Connection interface {
protocol.Connection
Type() string
Transport() string
RemoteAddr() net.Addr
Priority() int
String() string
Crypto() string
}
// completeConn is the aggregation of an internalConn and the
// protocol.Connection running on top of it. It implements the Connection
// interface.
type completeConn struct {
internalConn
protocol.Connection
}
func (c completeConn) Close(err error) {
c.Connection.Close(err)
c.internalConn.Close()
}
type tlsConn interface { type tlsConn interface {
io.ReadWriteCloser io.ReadWriteCloser
ConnectionState() tls.ConnectionState ConnectionState() tls.ConnectionState
@ -107,12 +82,12 @@ func (t connType) Transport() string {
} }
} }
func (c internalConn) Close() { func (c internalConn) Close() error {
// *tls.Conn.Close() does more than it says on the tin. Specifically, it // *tls.Conn.Close() does more than it says on the tin. Specifically, it
// sends a TLS alert message, which might block forever if the // sends a TLS alert message, which might block forever if the
// connection is dead and we don't have a deadline set. // connection is dead and we don't have a deadline set.
_ = c.SetWriteDeadline(time.Now().Add(250 * time.Millisecond)) _ = c.SetWriteDeadline(time.Now().Add(250 * time.Millisecond))
_ = c.tlsConn.Close() return c.tlsConn.Close()
} }
func (c internalConn) Type() string { func (c internalConn) Type() string {
@ -203,8 +178,8 @@ type genericListener interface {
type Model interface { type Model interface {
protocol.Model protocol.Model
AddConnection(conn Connection, hello protocol.Hello) AddConnection(conn protocol.Connection, hello protocol.Hello)
Connection(remoteID protocol.DeviceID) (Connection, bool) Connection(remoteID protocol.DeviceID) (protocol.Connection, bool)
OnHello(protocol.DeviceID, net.Addr, protocol.Hello) error OnHello(protocol.DeviceID, net.Addr, protocol.Hello) error
GetHello(protocol.DeviceID) protocol.HelloIntf GetHello(protocol.DeviceID) protocol.HelloIntf
} }

View File

@ -9,13 +9,12 @@ package model
import ( import (
"bytes" "bytes"
"context" "context"
"net"
"sync" "sync"
"time" "time"
"github.com/syncthing/syncthing/lib/connections"
"github.com/syncthing/syncthing/lib/protocol" "github.com/syncthing/syncthing/lib/protocol"
"github.com/syncthing/syncthing/lib/scanner" "github.com/syncthing/syncthing/lib/scanner"
"github.com/syncthing/syncthing/lib/testutils"
) )
type downloadProgressMessage struct { type downloadProgressMessage struct {
@ -24,7 +23,7 @@ type downloadProgressMessage struct {
} }
type fakeConnection struct { type fakeConnection struct {
fakeUnderlyingConn testutils.FakeConnectionInfo
id protocol.DeviceID id protocol.DeviceID
downloadProgressMessages []downloadProgressMessage downloadProgressMessages []downloadProgressMessage
closed bool closed bool
@ -219,50 +218,3 @@ func addFakeConn(m *testModel, dev protocol.DeviceID) *fakeConnection {
return fc return fc
} }
type fakeProtoConn struct {
protocol.Connection
fakeUnderlyingConn
}
func newFakeProtoConn(protoConn protocol.Connection) connections.Connection {
return &fakeProtoConn{Connection: protoConn}
}
// fakeUnderlyingConn implements the methods of connections.Connection that are
// not implemented by protocol.Connection
type fakeUnderlyingConn struct{}
func (f *fakeUnderlyingConn) RemoteAddr() net.Addr {
return &fakeAddr{}
}
func (f *fakeUnderlyingConn) Type() string {
return "fake"
}
func (f *fakeUnderlyingConn) Crypto() string {
return "fake"
}
func (f *fakeUnderlyingConn) Transport() string {
return "fake"
}
func (f *fakeUnderlyingConn) Priority() int {
return 9000
}
func (f *fakeUnderlyingConn) String() string {
return ""
}
type fakeAddr struct{}
func (fakeAddr) Network() string {
return "network"
}
func (fakeAddr) String() string {
return "address"
}

View File

@ -150,7 +150,7 @@ type model struct {
// fields protected by pmut // fields protected by pmut
pmut sync.RWMutex pmut sync.RWMutex
conn map[protocol.DeviceID]connections.Connection conn map[protocol.DeviceID]protocol.Connection
connRequestLimiters map[protocol.DeviceID]*byteSemaphore connRequestLimiters map[protocol.DeviceID]*byteSemaphore
closed map[protocol.DeviceID]chan struct{} closed map[protocol.DeviceID]chan struct{}
helloMessages map[protocol.DeviceID]protocol.Hello helloMessages map[protocol.DeviceID]protocol.Hello
@ -232,7 +232,7 @@ func NewModel(cfg config.Wrapper, id protocol.DeviceID, clientName, clientVersio
// fields protected by pmut // fields protected by pmut
pmut: sync.NewRWMutex(), pmut: sync.NewRWMutex(),
conn: make(map[protocol.DeviceID]connections.Connection), conn: make(map[protocol.DeviceID]protocol.Connection),
connRequestLimiters: make(map[protocol.DeviceID]*byteSemaphore), connRequestLimiters: make(map[protocol.DeviceID]*byteSemaphore),
closed: make(map[protocol.DeviceID]chan struct{}), closed: make(map[protocol.DeviceID]chan struct{}),
helloMessages: make(map[protocol.DeviceID]protocol.Hello), helloMessages: make(map[protocol.DeviceID]protocol.Hello),
@ -1660,7 +1660,7 @@ func (m *model) Closed(conn protocol.Connection, err error) {
m.progressEmitter.temporaryIndexUnsubscribe(conn) m.progressEmitter.temporaryIndexUnsubscribe(conn)
l.Infof("Connection to %s at %s closed: %v", device, conn.Name(), err) l.Infof("Connection to %s at %s closed: %v", device, conn, err)
m.evLogger.Log(events.DeviceDisconnected, map[string]string{ m.evLogger.Log(events.DeviceDisconnected, map[string]string{
"id": device.String(), "id": device.String(),
"error": err.Error(), "error": err.Error(),
@ -1912,7 +1912,7 @@ func (m *model) CurrentGlobalFile(folder string, file string) (protocol.FileInfo
} }
// Connection returns the current connection for device, and a boolean whether a connection was found. // Connection returns the current connection for device, and a boolean whether a connection was found.
func (m *model) Connection(deviceID protocol.DeviceID) (connections.Connection, bool) { func (m *model) Connection(deviceID protocol.DeviceID) (protocol.Connection, bool) {
m.pmut.RLock() m.pmut.RLock()
cn, ok := m.conn[deviceID] cn, ok := m.conn[deviceID]
m.pmut.RUnlock() m.pmut.RUnlock()
@ -2039,7 +2039,7 @@ func (m *model) GetHello(id protocol.DeviceID) protocol.HelloIntf {
// AddConnection adds a new peer connection to the model. An initial index will // AddConnection adds a new peer connection to the model. An initial index will
// be sent to the connected peer, thereafter index updates whenever the local // be sent to the connected peer, thereafter index updates whenever the local
// folder changes. // folder changes.
func (m *model) AddConnection(conn connections.Connection, hello protocol.Hello) { func (m *model) AddConnection(conn protocol.Connection, hello protocol.Hello) {
deviceID := conn.ID() deviceID := conn.ID()
device, ok := m.cfg.Device(deviceID) device, ok := m.cfg.Device(deviceID)
if !ok { if !ok {

View File

@ -3297,7 +3297,7 @@ func TestConnCloseOnRestart(t *testing.T) {
br := &testutils.BlockingRW{} br := &testutils.BlockingRW{}
nw := &testutils.NoopRW{} nw := &testutils.NoopRW{}
m.AddConnection(newFakeProtoConn(protocol.NewConnection(device1, br, nw, m, "testConn", protocol.CompressionNever)), protocol.Hello{}) m.AddConnection(protocol.NewConnection(device1, br, nw, testutils.NoopCloser{}, m, &testutils.FakeConnectionInfo{"fc"}, protocol.CompressionNever), protocol.Hello{})
m.pmut.RLock() m.pmut.RLock()
if len(m.closed) != 1 { if len(m.closed) != 1 {
t.Fatalf("Expected just one conn (len(m.conn) == %v)", len(m.conn)) t.Fatalf("Expected just one conn (len(m.conn) == %v)", len(m.conn))

View File

@ -10,6 +10,7 @@ import (
"testing" "testing"
"github.com/syncthing/syncthing/lib/dialer" "github.com/syncthing/syncthing/lib/dialer"
"github.com/syncthing/syncthing/lib/testutils"
) )
func BenchmarkRequestsRawTCP(b *testing.B) { func BenchmarkRequestsRawTCP(b *testing.B) {
@ -59,9 +60,9 @@ func benchmarkRequestsTLS(b *testing.B, conn0, conn1 net.Conn) {
func benchmarkRequestsConnPair(b *testing.B, conn0, conn1 net.Conn) { func benchmarkRequestsConnPair(b *testing.B, conn0, conn1 net.Conn) {
// Start up Connections on them // Start up Connections on them
c0 := NewConnection(LocalDeviceID, conn0, conn0, new(fakeModel), "c0", CompressionMetadata) c0 := NewConnection(LocalDeviceID, conn0, conn0, testutils.NoopCloser{}, new(fakeModel), &testutils.FakeConnectionInfo{"c0"}, CompressionMetadata)
c0.Start() c0.Start()
c1 := NewConnection(LocalDeviceID, conn1, conn1, new(fakeModel), "c1", CompressionMetadata) c1 := NewConnection(LocalDeviceID, conn1, conn1, testutils.NoopCloser{}, new(fakeModel), &testutils.FakeConnectionInfo{"c1"}, CompressionMetadata)
c1.Start() c1.Start()
// Satisfy the assertions in the protocol by sending an initial cluster config // Satisfy the assertions in the protocol by sending an initial cluster config

View File

@ -128,6 +128,7 @@ func (e encryptedModel) Closed(conn Connection, err error) {
// The encryptedConnection sits between the model and the encrypted device. It // The encryptedConnection sits between the model and the encrypted device. It
// encrypts outgoing metadata and decrypts incoming responses. // encrypts outgoing metadata and decrypts incoming responses.
type encryptedConnection struct { type encryptedConnection struct {
ConnectionInfo
conn Connection conn Connection
folderKeys map[string]*[keySize]byte // folder ID -> key folderKeys map[string]*[keySize]byte // folder ID -> key
} }
@ -140,10 +141,6 @@ func (e encryptedConnection) ID() DeviceID {
return e.conn.ID() return e.conn.ID()
} }
func (e encryptedConnection) Name() string {
return e.conn.Name()
}
func (e encryptedConnection) Index(ctx context.Context, folder string, files []FileInfo) error { func (e encryptedConnection) Index(ctx context.Context, folder string, files []FileInfo) error {
if folderKey, ok := e.folderKeys[folder]; ok { if folderKey, ok := e.folderKeys[folder]; ok {
encryptFileInfos(files, folderKey) encryptFileInfos(files, folderKey)

View File

@ -8,6 +8,7 @@ import (
"encoding/binary" "encoding/binary"
"fmt" "fmt"
"io" "io"
"net"
"path" "path"
"strings" "strings"
"sync" "sync"
@ -134,7 +135,6 @@ type Connection interface {
Start() Start()
Close(err error) Close(err error)
ID() DeviceID ID() DeviceID
Name() string
Index(ctx context.Context, folder string, files []FileInfo) error Index(ctx context.Context, folder string, files []FileInfo) error
IndexUpdate(ctx context.Context, folder string, files []FileInfo) error IndexUpdate(ctx context.Context, folder string, files []FileInfo) error
Request(ctx context.Context, folder string, name string, blockNo int, offset int64, size int, hash []byte, weakHash uint32, fromTemporary bool) ([]byte, error) Request(ctx context.Context, folder string, name string, blockNo int, offset int64, size int, hash []byte, weakHash uint32, fromTemporary bool) ([]byte, error)
@ -142,16 +142,28 @@ type Connection interface {
DownloadProgress(ctx context.Context, folder string, updates []FileDownloadProgressUpdate) DownloadProgress(ctx context.Context, folder string, updates []FileDownloadProgressUpdate)
Statistics() Statistics Statistics() Statistics
Closed() bool Closed() bool
ConnectionInfo
}
type ConnectionInfo interface {
Type() string
Transport() string
RemoteAddr() net.Addr
Priority() int
String() string
Crypto() string
} }
type rawConnection struct { type rawConnection struct {
ConnectionInfo
id DeviceID id DeviceID
name string
receiver Model receiver Model
startTime time.Time startTime time.Time
cr *countingReader cr *countingReader
cw *countingWriter cw *countingWriter
closer io.Closer // Closing the underlying connection and thus cr and cw
awaiting map[int]chan asyncResult awaiting map[int]chan asyncResult
awaitingMut sync.Mutex awaitingMut sync.Mutex
@ -205,13 +217,13 @@ const (
// Should not be modified in production code, just for testing. // Should not be modified in production code, just for testing.
var CloseTimeout = 10 * time.Second var CloseTimeout = 10 * time.Second
func NewConnection(deviceID DeviceID, reader io.Reader, writer io.Writer, receiver Model, name string, compress Compression) Connection { func NewConnection(deviceID DeviceID, reader io.Reader, writer io.Writer, closer io.Closer, receiver Model, connInfo ConnectionInfo, compress Compression) Connection {
receiver = nativeModel{receiver} receiver = nativeModel{receiver}
rc := newRawConnection(deviceID, reader, writer, receiver, name, compress) rc := newRawConnection(deviceID, reader, writer, closer, receiver, connInfo, compress)
return wireFormatConnection{rc} return wireFormatConnection{rc}
} }
func NewEncryptedConnection(passwords map[string]string, deviceID DeviceID, reader io.Reader, writer io.Writer, receiver Model, name string, compress Compression) Connection { func NewEncryptedConnection(passwords map[string]string, deviceID DeviceID, reader io.Reader, writer io.Writer, closer io.Closer, receiver Model, connInfo ConnectionInfo, compress Compression) Connection {
keys := keysFromPasswords(passwords) keys := keysFromPasswords(passwords)
// Encryption / decryption is first (outermost) before conversion to // Encryption / decryption is first (outermost) before conversion to
@ -221,23 +233,24 @@ func NewEncryptedConnection(passwords map[string]string, deviceID DeviceID, read
// We do the wire format conversion first (outermost) so that the // We do the wire format conversion first (outermost) so that the
// metadata is in wire format when it reaches the encryption step. // metadata is in wire format when it reaches the encryption step.
rc := newRawConnection(deviceID, reader, writer, em, name, compress) rc := newRawConnection(deviceID, reader, writer, closer, em, connInfo, compress)
ec := encryptedConnection{conn: rc, folderKeys: keys} ec := encryptedConnection{ConnectionInfo: rc, conn: rc, folderKeys: keys}
wc := wireFormatConnection{ec} wc := wireFormatConnection{ec}
return wc return wc
} }
func newRawConnection(deviceID DeviceID, reader io.Reader, writer io.Writer, receiver Model, name string, compress Compression) *rawConnection { func newRawConnection(deviceID DeviceID, reader io.Reader, writer io.Writer, closer io.Closer, receiver Model, connInfo ConnectionInfo, compress Compression) *rawConnection {
cr := &countingReader{Reader: reader} cr := &countingReader{Reader: reader}
cw := &countingWriter{Writer: writer} cw := &countingWriter{Writer: writer}
return &rawConnection{ return &rawConnection{
ConnectionInfo: connInfo,
id: deviceID, id: deviceID,
name: name,
receiver: receiver, receiver: receiver,
cr: cr, cr: cr,
cw: cw, cw: cw,
closer: closer,
awaiting: make(map[int]chan asyncResult), awaiting: make(map[int]chan asyncResult),
inbox: make(chan message), inbox: make(chan message),
outbox: make(chan asyncMessage), outbox: make(chan asyncMessage),
@ -282,10 +295,6 @@ func (c *rawConnection) ID() DeviceID {
return c.id return c.id
} }
func (c *rawConnection) Name() string {
return c.name
}
// Index writes the list of file information to the connected peer device // Index writes the list of file information to the connected peer device
func (c *rawConnection) Index(ctx context.Context, folder string, idx []FileInfo) error { func (c *rawConnection) Index(ctx context.Context, folder string, idx []FileInfo) error {
select { select {
@ -931,6 +940,9 @@ func (c *rawConnection) Close(err error) {
func (c *rawConnection) internalClose(err error) { func (c *rawConnection) internalClose(err error) {
c.closeOnce.Do(func() { c.closeOnce.Do(func() {
l.Debugln("close due to", err) l.Debugln("close due to", err)
if cerr := c.closer.Close(); cerr != nil {
l.Debugln(c.id, "failed to close underlying conn:", cerr)
}
close(c.closed) close(c.closed)
c.awaitingMut.Lock() c.awaitingMut.Lock()

View File

@ -31,10 +31,10 @@ func TestPing(t *testing.T) {
ar, aw := io.Pipe() ar, aw := io.Pipe()
br, bw := io.Pipe() br, bw := io.Pipe()
c0 := NewConnection(c0ID, ar, bw, newTestModel(), "name", CompressionAlways).(wireFormatConnection).Connection.(*rawConnection) c0 := NewConnection(c0ID, ar, bw, testutils.NoopCloser{}, newTestModel(), &testutils.FakeConnectionInfo{"name"}, CompressionAlways).(wireFormatConnection).Connection.(*rawConnection)
c0.Start() c0.Start()
defer closeAndWait(c0, ar, bw) defer closeAndWait(c0, ar, bw)
c1 := NewConnection(c1ID, br, aw, newTestModel(), "name", CompressionAlways).(wireFormatConnection).Connection.(*rawConnection) c1 := NewConnection(c1ID, br, aw, testutils.NoopCloser{}, newTestModel(), &testutils.FakeConnectionInfo{"name"}, CompressionAlways).(wireFormatConnection).Connection.(*rawConnection)
c1.Start() c1.Start()
defer closeAndWait(c1, ar, bw) defer closeAndWait(c1, ar, bw)
c0.ClusterConfig(ClusterConfig{}) c0.ClusterConfig(ClusterConfig{})
@ -57,10 +57,10 @@ func TestClose(t *testing.T) {
ar, aw := io.Pipe() ar, aw := io.Pipe()
br, bw := io.Pipe() br, bw := io.Pipe()
c0 := NewConnection(c0ID, ar, bw, m0, "name", CompressionAlways).(wireFormatConnection).Connection.(*rawConnection) c0 := NewConnection(c0ID, ar, bw, testutils.NoopCloser{}, m0, &testutils.FakeConnectionInfo{"name"}, CompressionAlways).(wireFormatConnection).Connection.(*rawConnection)
c0.Start() c0.Start()
defer closeAndWait(c0, ar, bw) defer closeAndWait(c0, ar, bw)
c1 := NewConnection(c1ID, br, aw, m1, "name", CompressionAlways) c1 := NewConnection(c1ID, br, aw, testutils.NoopCloser{}, m1, &testutils.FakeConnectionInfo{"name"}, CompressionAlways)
c1.Start() c1.Start()
defer closeAndWait(c1, ar, bw) defer closeAndWait(c1, ar, bw)
c0.ClusterConfig(ClusterConfig{}) c0.ClusterConfig(ClusterConfig{})
@ -102,7 +102,7 @@ func TestCloseOnBlockingSend(t *testing.T) {
m := newTestModel() m := newTestModel()
rw := testutils.NewBlockingRW() rw := testutils.NewBlockingRW()
c := NewConnection(c0ID, rw, rw, m, "name", CompressionAlways).(wireFormatConnection).Connection.(*rawConnection) c := NewConnection(c0ID, rw, rw, testutils.NoopCloser{}, m, &testutils.FakeConnectionInfo{"name"}, CompressionAlways).(wireFormatConnection).Connection.(*rawConnection)
c.Start() c.Start()
defer closeAndWait(c, rw) defer closeAndWait(c, rw)
@ -153,10 +153,10 @@ func TestCloseRace(t *testing.T) {
ar, aw := io.Pipe() ar, aw := io.Pipe()
br, bw := io.Pipe() br, bw := io.Pipe()
c0 := NewConnection(c0ID, ar, bw, m0, "c0", CompressionNever).(wireFormatConnection).Connection.(*rawConnection) c0 := NewConnection(c0ID, ar, bw, testutils.NoopCloser{}, m0, &testutils.FakeConnectionInfo{"c0"}, CompressionNever).(wireFormatConnection).Connection.(*rawConnection)
c0.Start() c0.Start()
defer closeAndWait(c0, ar, bw) defer closeAndWait(c0, ar, bw)
c1 := NewConnection(c1ID, br, aw, m1, "c1", CompressionNever) c1 := NewConnection(c1ID, br, aw, testutils.NoopCloser{}, m1, &testutils.FakeConnectionInfo{"c1"}, CompressionNever)
c1.Start() c1.Start()
defer closeAndWait(c1, ar, bw) defer closeAndWait(c1, ar, bw)
c0.ClusterConfig(ClusterConfig{}) c0.ClusterConfig(ClusterConfig{})
@ -193,7 +193,7 @@ func TestClusterConfigFirst(t *testing.T) {
m := newTestModel() m := newTestModel()
rw := testutils.NewBlockingRW() rw := testutils.NewBlockingRW()
c := NewConnection(c0ID, rw, &testutils.NoopRW{}, m, "name", CompressionAlways).(wireFormatConnection).Connection.(*rawConnection) c := NewConnection(c0ID, rw, &testutils.NoopRW{}, testutils.NoopCloser{}, m, &testutils.FakeConnectionInfo{"name"}, CompressionAlways).(wireFormatConnection).Connection.(*rawConnection)
c.Start() c.Start()
defer closeAndWait(c, rw) defer closeAndWait(c, rw)
@ -245,7 +245,7 @@ func TestCloseTimeout(t *testing.T) {
m := newTestModel() m := newTestModel()
rw := testutils.NewBlockingRW() rw := testutils.NewBlockingRW()
c := NewConnection(c0ID, rw, rw, m, "name", CompressionAlways).(wireFormatConnection).Connection.(*rawConnection) c := NewConnection(c0ID, rw, rw, testutils.NoopCloser{}, m, &testutils.FakeConnectionInfo{"name"}, CompressionAlways).(wireFormatConnection).Connection.(*rawConnection)
c.Start() c.Start()
defer closeAndWait(c, rw) defer closeAndWait(c, rw)
@ -865,7 +865,7 @@ func TestClusterConfigAfterClose(t *testing.T) {
m := newTestModel() m := newTestModel()
rw := testutils.NewBlockingRW() rw := testutils.NewBlockingRW()
c := NewConnection(c0ID, rw, rw, m, "name", CompressionAlways).(wireFormatConnection).Connection.(*rawConnection) c := NewConnection(c0ID, rw, rw, testutils.NoopCloser{}, m, &testutils.FakeConnectionInfo{"name"}, CompressionAlways).(wireFormatConnection).Connection.(*rawConnection)
c.Start() c.Start()
defer closeAndWait(c, rw) defer closeAndWait(c, rw)
@ -889,7 +889,7 @@ func TestDispatcherToCloseDeadlock(t *testing.T) {
// the model callbacks (ClusterConfig). // the model callbacks (ClusterConfig).
m := newTestModel() m := newTestModel()
rw := testutils.NewBlockingRW() rw := testutils.NewBlockingRW()
c := NewConnection(c0ID, rw, &testutils.NoopRW{}, m, "name", CompressionAlways).(wireFormatConnection).Connection.(*rawConnection) c := NewConnection(c0ID, rw, &testutils.NoopRW{}, testutils.NoopCloser{}, m, &testutils.FakeConnectionInfo{"name"}, CompressionAlways).(wireFormatConnection).Connection.(*rawConnection)
m.ccFn = func(devID DeviceID, cc ClusterConfig) { m.ccFn = func(devID DeviceID, cc ClusterConfig) {
c.Close(errManual) c.Close(errManual)
} }

View File

@ -8,6 +8,7 @@ package testutils
import ( import (
"errors" "errors"
"net"
"sync" "sync"
) )
@ -52,3 +53,49 @@ func (rw *NoopRW) Read(p []byte) (n int, err error) {
func (rw *NoopRW) Write(p []byte) (n int, err error) { func (rw *NoopRW) Write(p []byte) (n int, err error) {
return len(p), nil return len(p), nil
} }
type NoopCloser struct{}
func (NoopCloser) Close() error {
return nil
}
// FakeConnectionInfo implements the methods of protocol.Connection that are
// not implemented by protocol.Connection
type FakeConnectionInfo struct {
Name string
}
func (f *FakeConnectionInfo) RemoteAddr() net.Addr {
return &FakeAddr{}
}
func (f *FakeConnectionInfo) Type() string {
return "fake"
}
func (f *FakeConnectionInfo) Crypto() string {
return "fake"
}
func (f *FakeConnectionInfo) Transport() string {
return "fake"
}
func (f *FakeConnectionInfo) Priority() int {
return 9000
}
func (f *FakeConnectionInfo) String() string {
return ""
}
type FakeAddr struct{}
func (FakeAddr) Network() string {
return "network"
}
func (FakeAddr) String() string {
return "address"
}