diff --git a/protocol/common_test.go b/protocol/common_test.go index deaed2dd6..3e77f618f 100644 --- a/protocol/common_test.go +++ b/protocol/common_test.go @@ -14,6 +14,9 @@ type TestModel struct { func (t *TestModel) Index(nodeID string, files []FileInfo) { } +func (t *TestModel) IndexUpdate(nodeID string, files []FileInfo) { +} + func (t *TestModel) Request(nodeID, name string, offset uint64, size uint32, hash []byte) ([]byte, error) { t.name = name t.offset = offset diff --git a/protocol/protocol.go b/protocol/protocol.go index e510e9fec..564fd28f7 100644 --- a/protocol/protocol.go +++ b/protocol/protocol.go @@ -134,6 +134,9 @@ func (c *Connection) Index(idx []FileInfo) { // Request returns the bytes for the specified block after fetching them from the connected peer. func (c *Connection) Request(name string, offset uint64, size uint32, hash []byte) ([]byte, error) { + if c.isClosed() { + return nil, ErrClosed + } c.Lock() rc := make(chan asyncResult) c.awaiting[c.nextId] = rc @@ -161,6 +164,9 @@ func (c *Connection) Request(name string, offset uint64, size uint32, hash []byt } func (c *Connection) Ping() (time.Duration, bool) { + if c.isClosed() { + return 0, false + } c.Lock() rc := make(chan asyncResult) c.awaiting[c.nextId] = rc diff --git a/protocol/protocol_test.go b/protocol/protocol_test.go index 1b3c8c8c1..9c1fbf1df 100644 --- a/protocol/protocol_test.go +++ b/protocol/protocol_test.go @@ -179,3 +179,36 @@ func TestTypeErr(t *testing.T) { t.Error("Connection should close due to unknown message type") } } + +func TestClose(t *testing.T) { + m0 := &TestModel{} + m1 := &TestModel{} + + ar, aw := io.Pipe() + br, bw := io.Pipe() + + c0 := NewConnection("c0", ar, bw, m0) + NewConnection("c1", br, aw, m1) + + c0.close() + + ok := c0.isClosed() + if !ok { + t.Fatal("Connection should be closed") + } + + // None of these should panic, some should return an error + + _, ok = c0.Ping() + if ok { + t.Error("Ping should not return true") + } + + c0.Index(nil) + c0.Index(nil) + + _, err := c0.Request("foo", 0, 0, nil) + if err == nil { + t.Error("Request should return an error") + } +}