diff --git a/protocol/protocol.go b/protocol/protocol.go index 7eca8428a..20a605f93 100644 --- a/protocol/protocol.go +++ b/protocol/protocol.go @@ -18,8 +18,6 @@ const ( messageTypePong ) -var ErrClosed = errors.New("Connection closed") - type FileInfo struct { Name string Flags uint32 @@ -50,11 +48,18 @@ type Connection struct { wLock sync.RWMutex closed bool closedLock sync.RWMutex - awaiting map[int]chan interface{} + awaiting map[int]chan asyncResult nextId int ID string } +var ErrClosed = errors.New("Connection closed") + +type asyncResult struct { + val []byte + err error +} + func NewConnection(nodeID string, reader io.Reader, writer io.Writer, receiver Model) *Connection { flrd := flate.NewReader(reader) flwr, err := flate.NewWriter(writer, flate.BestSpeed) @@ -68,7 +73,7 @@ func NewConnection(nodeID string, reader io.Reader, writer io.Writer, receiver M mreader: &marshalReader{flrd, 0, nil}, writer: flwr, mwriter: &marshalWriter{flwr, 0, nil}, - awaiting: make(map[int]chan interface{}), + awaiting: make(map[int]chan asyncResult), ID: nodeID, } @@ -91,7 +96,7 @@ func (c *Connection) Index(idx []FileInfo) { // Request returns the bytes for the specified block after fetching them from the connected peer. func (c *Connection) Request(name string, offset uint64, size uint32, hash []byte) ([]byte, error) { c.wLock.Lock() - rc := make(chan interface{}) + rc := make(chan asyncResult) c.awaiting[c.nextId] = rc c.mwriter.writeHeader(header{0, c.nextId, messageTypeRequest}) c.mwriter.writeRequest(request{name, offset, size, hash}) @@ -99,30 +104,16 @@ func (c *Connection) Request(name string, offset uint64, size uint32, hash []byt c.nextId = (c.nextId + 1) & 0xfff c.wLock.Unlock() - // Reading something that might be nil from a possibly closed channel... - // r0<~ - - var data []byte - i, ok := <-rc - if ok { - if d, ok := i.([]byte); ok { - data = d - } + res, ok := <-rc + if !ok { + return nil, ErrClosed } - - var err error - i, ok = <-rc - if ok { - if e, ok := i.(error); ok { - err = e - } - } - return data, err + return res.val, res.err } func (c *Connection) Ping() bool { c.wLock.Lock() - rc := make(chan interface{}) + rc := make(chan asyncResult) c.awaiting[c.nextId] = rc c.mwriter.writeHeader(header{0, c.nextId, messageTypePing}) c.flush() @@ -150,12 +141,14 @@ func (c *Connection) close() { c.closedLock.Lock() c.closed = true c.closedLock.Unlock() + c.wLock.Lock() for _, ch := range c.awaiting { close(ch) } c.awaiting = nil c.wLock.Unlock() + c.receiver.Close(c.ID) } @@ -196,10 +189,12 @@ func (c *Connection) readerLoop() { c.wLock.RUnlock() if ok { - rc <- data - rc <- c.mreader.err - delete(c.awaiting, hdr.msgID) + rc <- asyncResult{data, c.mreader.err} close(rc) + + c.wLock.Lock() + delete(c.awaiting, hdr.msgID) + c.wLock.Unlock() } } @@ -210,13 +205,18 @@ func (c *Connection) readerLoop() { c.wLock.Unlock() case messageTypePong: - c.wLock.Lock() - if rc, ok := c.awaiting[hdr.msgID]; ok { - rc <- true + c.wLock.RLock() + rc, ok := c.awaiting[hdr.msgID] + c.wLock.RUnlock() + + if ok { + rc <- asyncResult{} close(rc) + + c.wLock.Lock() delete(c.awaiting, hdr.msgID) + c.wLock.Unlock() } - c.wLock.Unlock() } } }