diff --git a/protocol/protocol.go b/protocol/protocol.go index 20ef89959..9e5a6d3cd 100644 --- a/protocol/protocol.go +++ b/protocol/protocol.go @@ -28,6 +28,12 @@ const ( messageTypeIndexUpdate = 6 ) +const ( + stateInitial = iota + stateCCRcvd + stateIdxRcvd +) + const ( FlagDeleted uint32 = 1 << 12 FlagInvalid = 1 << 13 @@ -70,6 +76,7 @@ type Connection interface { type rawConnection struct { id string receiver Model + state int reader io.ReadCloser cr *countingReader @@ -116,6 +123,7 @@ func NewConnection(nodeID string, reader io.Reader, writer io.Writer, receiver M c := rawConnection{ id: nodeID, receiver: nativeModel{receiver}, + state: stateInitial, reader: flrd, cr: cr, xr: xdr.NewReader(flrd), @@ -257,21 +265,34 @@ func (c *rawConnection) readerLoop() (err error) { switch hdr.msgType { case messageTypeIndex: + if c.state < stateCCRcvd { + return fmt.Errorf("protocol error: index message in state %d", c.state) + } if err := c.handleIndex(); err != nil { return err } + c.state = stateIdxRcvd case messageTypeIndexUpdate: + if c.state < stateIdxRcvd { + return fmt.Errorf("protocol error: index update message in state %d", c.state) + } if err := c.handleIndexUpdate(); err != nil { return err } case messageTypeRequest: + if c.state < stateIdxRcvd { + return fmt.Errorf("protocol error: request message in state %d", c.state) + } if err := c.handleRequest(hdr); err != nil { return err } case messageTypeResponse: + if c.state < stateIdxRcvd { + return fmt.Errorf("protocol error: response message in state %d", c.state) + } if err := c.handleResponse(hdr); err != nil { return err } @@ -283,9 +304,13 @@ func (c *rawConnection) readerLoop() (err error) { c.handlePong(hdr) case messageTypeClusterConfig: + if c.state != stateInitial { + return fmt.Errorf("protocol error: cluster config message in state %d", c.state) + } if err := c.handleClusterConfig(); err != nil { return err } + c.state = stateCCRcvd default: return fmt.Errorf("protocol error: %s: unknown message type %#x", c.id, hdr.msgType)