diff --git a/internal/protocol/protocol.go b/internal/protocol/protocol.go index a55256799..0c59ef56e 100644 --- a/internal/protocol/protocol.go +++ b/internal/protocol/protocol.go @@ -71,6 +71,10 @@ var ( ErrClosed = errors.New("connection closed") ) +// Specific variants of empty messages... +type pingMessage struct{ EmptyMessage } +type pongMessage struct{ EmptyMessage } + type Model interface { // An index was received from the peer device Index(deviceID DeviceID, folder string, files []FileInfo) @@ -289,48 +293,60 @@ func (c *rawConnection) readerLoop() (err error) { return err } - switch hdr.msgType { - case messageTypeIndex: - if c.state < stateCCRcvd { - return fmt.Errorf("protocol error: index message in state %d", c.state) + switch msg := msg.(type) { + case IndexMessage: + if msg.Flags != 0 { + // We don't currently support or expect any flags. + return fmt.Errorf("protocol error: unknown flags 0x%x in Index(Update) message", msg.Flags) } - c.handleIndex(msg.(IndexMessage)) - c.state = stateIdxRcvd - case messageTypeIndexUpdate: - if c.state < stateIdxRcvd { - return fmt.Errorf("protocol error: index update message in state %d", c.state) + switch hdr.msgType { + case messageTypeIndex: + if c.state < stateCCRcvd { + return fmt.Errorf("protocol error: index message in state %d", c.state) + } + c.handleIndex(msg) + c.state = stateIdxRcvd + + case messageTypeIndexUpdate: + if c.state < stateIdxRcvd { + return fmt.Errorf("protocol error: index update message in state %d", c.state) + } + c.handleIndexUpdate(msg) } - c.handleIndexUpdate(msg.(IndexMessage)) - case messageTypeRequest: + case RequestMessage: + if msg.Flags != 0 { + // We don't currently support or expect any flags. + return fmt.Errorf("protocol error: unknown flags 0x%x in Request message", msg.Flags) + } if c.state < stateIdxRcvd { return fmt.Errorf("protocol error: request message in state %d", c.state) } // Requests are handled asynchronously - go c.handleRequest(hdr.msgID, msg.(RequestMessage)) + go c.handleRequest(hdr.msgID, msg) - case messageTypeResponse: + case ResponseMessage: if c.state < stateIdxRcvd { return fmt.Errorf("protocol error: response message in state %d", c.state) } - c.handleResponse(hdr.msgID, msg.(ResponseMessage)) + c.handleResponse(hdr.msgID, msg) - case messageTypePing: - c.send(hdr.msgID, messageTypePong, EmptyMessage{}) + case pingMessage: + c.send(hdr.msgID, messageTypePong, pongMessage{}) - case messageTypePong: + case pongMessage: c.handlePong(hdr.msgID) - case messageTypeClusterConfig: + case ClusterConfigMessage: if c.state != stateInitial { return fmt.Errorf("protocol error: cluster config message in state %d", c.state) } - go c.receiver.ClusterConfig(c.id, msg.(ClusterConfigMessage)) + go c.receiver.ClusterConfig(c.id, msg) c.state = stateCCRcvd - case messageTypeClose: - return errors.New(msg.(CloseMessage).Reason) + case CloseMessage: + return errors.New(msg.Reason) default: return fmt.Errorf("protocol error: %s: unknown message type %#x", c.id, hdr.msgType) @@ -428,8 +444,11 @@ func (c *rawConnection) readMessage() (hdr header, msg encodable, err error) { } msg = resp - case messageTypePing, messageTypePong: - msg = EmptyMessage{} + case messageTypePing: + msg = pingMessage{} + + case messageTypePong: + msg = pongMessage{} case messageTypeClusterConfig: var cc ClusterConfigMessage