From 0f6b34160cd01efcf839b53d0d44830e271354c5 Mon Sep 17 00:00:00 2001 From: Jakob Borg Date: Mon, 30 Dec 2013 21:21:57 -0500 Subject: [PATCH] Propagate and log reason for connection close --- model.go | 8 ++++-- model_test.go | 2 +- protocol/common_test.go | 2 +- protocol/protocol.go | 59 ++++++++++++++++++++++----------------- protocol/protocol_test.go | 2 +- 5 files changed, 43 insertions(+), 30 deletions(-) diff --git a/model.go b/model.go index 054d28c37..619c1d9c0 100644 --- a/model.go +++ b/model.go @@ -188,11 +188,15 @@ func (m *Model) SeedIndex(fs []protocol.FileInfo) { m.printModelStats() } -func (m *Model) Close(node string) { +func (m *Model) Close(node string, err error) { m.Lock() defer m.Unlock() - infoln("Disconnected from node", node) + if err != nil { + warnf("Disconnected from node %s: %v", node, err) + } else { + infoln("Disconnected from node", node) + } delete(m.remote, node) delete(m.nodes, node) diff --git a/model_test.go b/model_test.go index 7b8982254..31b6a7507 100644 --- a/model_test.go +++ b/model_test.go @@ -294,7 +294,7 @@ func TestForgetNode(t *testing.T) { t.Errorf("Model len(need) incorrect (%d != %d)", l1, l2) } - m.Close("42") + m.Close("42", nil) if l1, l2 := len(m.local), len(fs); l1 != l2 { t.Errorf("Model len(local) incorrect (%d != %d)", l1, l2) diff --git a/protocol/common_test.go b/protocol/common_test.go index 3e77f618f..d5b885dbd 100644 --- a/protocol/common_test.go +++ b/protocol/common_test.go @@ -25,7 +25,7 @@ func (t *TestModel) Request(nodeID, name string, offset uint64, size uint32, has return t.data, nil } -func (t *TestModel) Close(nodeID string) { +func (t *TestModel) Close(nodeID string, err error) { t.closed = true } diff --git a/protocol/protocol.go b/protocol/protocol.go index 485fd0220..0eef04e57 100644 --- a/protocol/protocol.go +++ b/protocol/protocol.go @@ -3,8 +3,8 @@ package protocol import ( "compress/flate" "errors" + "fmt" "io" - "log" "sync" "time" @@ -40,7 +40,7 @@ type Model interface { // A request was made by the peer node Request(nodeID, name string, offset uint64, size uint32, hash []byte) ([]byte, error) // The peer node closed the connection - Close(nodeID string) + Close(nodeID string, err error) } type Connection struct { @@ -130,8 +130,11 @@ func (c *Connection) Index(idx []FileInfo) { err := c.flush() c.nextId = (c.nextId + 1) & 0xfff c.Unlock() - if err != nil || c.mwriter.err != nil { - c.Close() + if err != nil { + c.Close(err) + return + } else if c.mwriter.err != nil { + c.Close(c.mwriter.err) return } } @@ -149,13 +152,13 @@ func (c *Connection) Request(name string, offset uint64, size uint32, hash []byt c.mwriter.writeRequest(request{name, offset, size, hash}) if c.mwriter.err != nil { c.Unlock() - c.Close() + c.Close(c.mwriter.err) return nil, c.mwriter.err } err := c.flush() if err != nil { c.Unlock() - c.Close() + c.Close(err) return nil, err } c.nextId = (c.nextId + 1) & 0xfff @@ -178,9 +181,13 @@ func (c *Connection) Ping() bool { c.awaiting[c.nextId] = rc c.mwriter.writeHeader(header{0, c.nextId, messageTypePing}) err := c.flush() - if err != nil || c.mwriter.err != nil { + if err != nil { c.Unlock() - c.Close() + c.Close(err) + return false + } else if c.mwriter.err != nil { + c.Unlock() + c.Close(c.mwriter.err) return false } c.nextId = (c.nextId + 1) & 0xfff @@ -204,7 +211,7 @@ func (c *Connection) flush() error { return nil } -func (c *Connection) Close() { +func (c *Connection) Close(err error) { c.Lock() if c.closed { c.Unlock() @@ -217,7 +224,7 @@ func (c *Connection) Close() { c.awaiting = nil c.Unlock() - c.receiver.Close(c.ID) + c.receiver.Close(c.ID, err) } func (c *Connection) isClosed() bool { @@ -230,12 +237,11 @@ func (c *Connection) readerLoop() { for !c.isClosed() { hdr := c.mreader.readHeader() if c.mreader.err != nil { - c.Close() + c.Close(c.mreader.err) break } if hdr.version != 0 { - log.Printf("Protocol error: %s: unknown message version %#x", c.ID, hdr.version) - c.Close() + c.Close(fmt.Errorf("Protocol error: %s: unknown message version %#x", c.ID, hdr.version)) break } @@ -247,7 +253,7 @@ func (c *Connection) readerLoop() { case messageTypeIndex: files := c.mreader.readIndex() if c.mreader.err != nil { - c.Close() + c.Close(c.mreader.err) } else { c.receiver.Index(c.ID, files) } @@ -255,7 +261,7 @@ func (c *Connection) readerLoop() { case messageTypeIndexUpdate: files := c.mreader.readIndex() if c.mreader.err != nil { - c.Close() + c.Close(c.mreader.err) } else { c.receiver.IndexUpdate(c.ID, files) } @@ -267,7 +273,7 @@ func (c *Connection) readerLoop() { data := c.mreader.readResponse() if c.mreader.err != nil { - c.Close() + c.Close(c.mreader.err) } else { c.Lock() rc, ok := c.awaiting[hdr.msgID] @@ -285,8 +291,10 @@ func (c *Connection) readerLoop() { c.mwriter.writeUint32(encodeHeader(header{0, hdr.msgID, messageTypePong})) err := c.flush() c.Unlock() - if err != nil || c.mwriter.err != nil { - c.Close() + if err != nil { + c.Close(err) + } else if c.mwriter.err != nil { + c.Close(c.mwriter.err) } case messageTypePong: @@ -304,8 +312,7 @@ func (c *Connection) readerLoop() { } default: - log.Printf("Protocol error: %s: unknown message type %#x", c.ID, hdr.msgType) - c.Close() + c.Close(fmt.Errorf("Protocol error: %s: unknown message type %#x", c.ID, hdr.msgType)) } } } @@ -313,7 +320,7 @@ func (c *Connection) readerLoop() { func (c *Connection) processRequest(msgID int) { req := c.mreader.readRequest() if c.mreader.err != nil { - c.Close() + c.Close(c.mreader.err) } else { go func() { data, _ := c.receiver.Request(c.ID, req.name, req.offset, req.size, req.hash) @@ -323,8 +330,10 @@ func (c *Connection) processRequest(msgID int) { err := c.flush() c.Unlock() buffers.Put(data) - if c.mwriter.err != nil || err != nil { - c.Close() + if err != nil { + c.Close(err) + } else if c.mwriter.err != nil { + c.Close(c.mwriter.err) } }() } @@ -340,10 +349,10 @@ func (c *Connection) pingerLoop() { select { case ok := <-rc: if !ok { - c.Close() + c.Close(fmt.Errorf("Ping failure")) } case <-time.After(pingTimeout): - c.Close() + c.Close(fmt.Errorf("Ping timeout")) } } } diff --git a/protocol/protocol_test.go b/protocol/protocol_test.go index 62e75b450..9a037ff3a 100644 --- a/protocol/protocol_test.go +++ b/protocol/protocol_test.go @@ -190,7 +190,7 @@ func TestClose(t *testing.T) { c0 := NewConnection("c0", ar, bw, m0) NewConnection("c1", br, aw, m1) - c0.Close() + c0.Close(nil) ok := c0.isClosed() if !ok {