lib/protocol: Fix yet another deadlock (fixes #5678) (#5679)

* lib/protocol: Fix yet another deadlock (fixes #5678)

* more consistency

* read deadlock

* naming

* more naming
This commit is contained in:
Simon Frei 2019-05-02 10:21:07 +02:00 committed by Audrius Butkevicius
parent 26e6d94c00
commit ec7c88ca55

View File

@ -240,14 +240,21 @@ func NewConnection(deviceID DeviceID, reader io.Reader, writer io.Writer, receiv
// Start creates the goroutines for sending and receiving of messages. It must // Start creates the goroutines for sending and receiving of messages. It must
// be called exactly once after creating a connection. // be called exactly once after creating a connection.
func (c *rawConnection) Start() { func (c *rawConnection) Start() {
c.wg.Add(4) c.startGoroutine(c.readerLoop)
c.startGoroutine(c.writerLoop)
c.startGoroutine(c.pingSender)
c.startGoroutine(c.pingReceiver)
}
func (c *rawConnection) startGoroutine(loop func() error) {
c.wg.Add(1)
go func() { go func() {
err := c.readerLoop() err := loop()
c.internalClose(err) c.wg.Done()
if err != nil && err != ErrClosed {
c.internalClose(err)
}
}() }()
go c.writerLoop()
go c.pingSender()
go c.pingReceiver()
} }
func (c *rawConnection) ID() DeviceID { func (c *rawConnection) ID() DeviceID {
@ -363,25 +370,44 @@ func (c *rawConnection) ping() bool {
return c.send(&Ping{}, nil) return c.send(&Ping{}, nil)
} }
func (c *rawConnection) readerLoop() (err error) { type messageWithError struct {
defer c.wg.Done() msg message
err error
}
func (c *rawConnection) readerLoop() error {
fourByteBuf := make([]byte, 4) fourByteBuf := make([]byte, 4)
inbox := make(chan messageWithError)
// Reading from the wire may block until the underlying connection is closed.
go func() {
for {
msg, err := c.readMessage(fourByteBuf)
select {
case inbox <- messageWithError{msg: msg, err: err}:
case <-c.closed:
return
}
}
}()
state := stateInitial state := stateInitial
var msgWithErr messageWithError
for { for {
if c.Closed() { select {
case msgWithErr = <-inbox:
case <-c.closed:
return ErrClosed return ErrClosed
} }
if msgWithErr.err != nil {
msg, err := c.readMessage(fourByteBuf) if msgWithErr.err == errUnknownMessage {
if err == errUnknownMessage { // Unknown message types are skipped, for future extensibility.
// Unknown message types are skipped, for future extensibility. continue
continue }
} return msgWithErr.err
if err != nil {
return err
} }
switch msg := msg.(type) { switch msg := msgWithErr.msg.(type) {
case *ClusterConfig: case *ClusterConfig:
l.Debugln("read ClusterConfig message") l.Debugln("read ClusterConfig message")
if state != stateInitial { if state != stateInitial {
@ -660,8 +686,7 @@ func (c *rawConnection) send(msg message, done chan struct{}) (sent bool) {
} }
} }
func (c *rawConnection) writerLoop() { func (c *rawConnection) writerLoop() error {
defer c.wg.Done()
for { for {
select { select {
case hm := <-c.outbox: case hm := <-c.outbox:
@ -670,12 +695,11 @@ func (c *rawConnection) writerLoop() {
close(hm.done) close(hm.done)
} }
if err != nil { if err != nil {
c.internalClose(err) return err
return
} }
case <-c.closed: case <-c.closed:
return return ErrClosed
} }
} }
} }
@ -882,9 +906,7 @@ func (c *rawConnection) internalClose(err error) {
// PingSendInterval/2, we do nothing. Otherwise we send a ping message. This // PingSendInterval/2, we do nothing. Otherwise we send a ping message. This
// results in an effecting ping interval of somewhere between // results in an effecting ping interval of somewhere between
// PingSendInterval/2 and PingSendInterval. // PingSendInterval/2 and PingSendInterval.
func (c *rawConnection) pingSender() { func (c *rawConnection) pingSender() error {
defer c.wg.Done()
ticker := time.NewTicker(PingSendInterval / 2) ticker := time.NewTicker(PingSendInterval / 2)
defer ticker.Stop() defer ticker.Stop()
@ -901,7 +923,7 @@ func (c *rawConnection) pingSender() {
c.ping() c.ping()
case <-c.closed: case <-c.closed:
return return ErrClosed
} }
} }
} }
@ -909,9 +931,7 @@ func (c *rawConnection) pingSender() {
// The pingReceiver checks that we've received a message (any message will do, // The pingReceiver checks that we've received a message (any message will do,
// but we expect pings in the absence of other messages) within the last // but we expect pings in the absence of other messages) within the last
// ReceiveTimeout. If not, we close the connection with an ErrTimeout. // ReceiveTimeout. If not, we close the connection with an ErrTimeout.
func (c *rawConnection) pingReceiver() { func (c *rawConnection) pingReceiver() error {
defer c.wg.Done()
ticker := time.NewTicker(ReceiveTimeout / 2) ticker := time.NewTicker(ReceiveTimeout / 2)
defer ticker.Stop() defer ticker.Stop()
@ -921,13 +941,13 @@ func (c *rawConnection) pingReceiver() {
d := time.Since(c.cr.Last()) d := time.Since(c.cr.Last())
if d > ReceiveTimeout { if d > ReceiveTimeout {
l.Debugln(c.id, "ping timeout", d) l.Debugln(c.id, "ping timeout", d)
c.internalClose(ErrTimeout) return ErrTimeout
} }
l.Debugln(c.id, "last read within", d) l.Debugln(c.id, "last read within", d)
case <-c.closed: case <-c.closed:
return return ErrClosed
} }
} }
} }