diff --git a/model_test.go b/model_test.go index d197a6542..1d41cfca5 100644 --- a/model_test.go +++ b/model_test.go @@ -97,7 +97,7 @@ func TestRemoteUpdateExisting(t *testing.T) { Modified: time.Now().Unix(), Blocks: []protocol.BlockInfo{{100, []byte("some hash bytes")}}, } - m.Index(string("42"), []protocol.FileInfo{newFile}) + m.Index("42", []protocol.FileInfo{newFile}) if l := len(m.need); l != 1 { t.Errorf("Model missing Need for one file (%d != 1)", l) @@ -114,7 +114,7 @@ func TestRemoteAddNew(t *testing.T) { Modified: time.Now().Unix(), Blocks: []protocol.BlockInfo{{100, []byte("some hash bytes")}}, } - m.Index(string("42"), []protocol.FileInfo{newFile}) + m.Index("42", []protocol.FileInfo{newFile}) if l1, l2 := len(m.need), 1; l1 != l2 { t.Errorf("Model len(m.need) incorrect (%d != %d)", l1, l2) @@ -132,7 +132,7 @@ func TestRemoteUpdateOld(t *testing.T) { Modified: oldTimeStamp, Blocks: []protocol.BlockInfo{{100, []byte("some hash bytes")}}, } - m.Index(string("42"), []protocol.FileInfo{newFile}) + m.Index("42", []protocol.FileInfo{newFile}) if l1, l2 := len(m.need), 0; l1 != l2 { t.Errorf("Model len(need) incorrect (%d != %d)", l1, l2) @@ -249,7 +249,7 @@ func TestForgetNode(t *testing.T) { Modified: time.Now().Unix(), Blocks: []protocol.BlockInfo{{100, []byte("some hash bytes")}}, } - m.Index(string("42"), []protocol.FileInfo{newFile}) + m.Index("42", []protocol.FileInfo{newFile}) if l1, l2 := len(m.local), len(fs); l1 != l2 { t.Errorf("Model len(local) incorrect (%d != %d)", l1, l2) @@ -261,7 +261,7 @@ func TestForgetNode(t *testing.T) { t.Errorf("Model len(need) incorrect (%d != %d)", l1, l2) } - m.Close(string("42")) + m.Close("42") if l1, l2 := len(m.local), len(fs); l1 != l2 { t.Errorf("Model len(local) incorrect (%d != %d)", l1, l2) diff --git a/protocol/protocol.go b/protocol/protocol.go index 82cb46d22..2cd57135c 100644 --- a/protocol/protocol.go +++ b/protocol/protocol.go @@ -4,6 +4,7 @@ import ( "compress/flate" "errors" "io" + "log" "sync" "time" @@ -193,6 +194,7 @@ func (c *Connection) readerLoop() { break } if hdr.version != 0 { + log.Printf("Protocol error: %s: unknown message version %#x", c.ID, hdr.version) c.close() break } @@ -258,6 +260,10 @@ func (c *Connection) readerLoop() { delete(c.awaiting, hdr.msgID) c.wLock.Unlock() } + + default: + log.Printf("Protocol error: %s: unknown message type %#x", c.ID, hdr.msgType) + c.close() } } } diff --git a/protocol/protocol_test.go b/protocol/protocol_test.go index c0cb8a945..83ca69135 100644 --- a/protocol/protocol_test.go +++ b/protocol/protocol_test.go @@ -157,3 +157,25 @@ func TestVersionErr(t *testing.T) { t.Error("Connection should close due to unknown version") } } + +func TestTypeErr(t *testing.T) { + m0 := &TestModel{} + m1 := &TestModel{} + + ar, aw := io.Pipe() + br, bw := io.Pipe() + + c0 := NewConnection("c0", ar, bw, m0) + NewConnection("c1", br, aw, m1) + + c0.mwriter.writeHeader(header{ + version: 0, + msgID: 0, + msgType: 42, + }) + c0.flush() + + if !m1.closed { + t.Error("Connection should close due to unknown message type") + } +}