diff --git a/fuzz.go b/fuzz.go new file mode 100644 index 000000000..9b82abe7c --- /dev/null +++ b/fuzz.go @@ -0,0 +1,70 @@ +// Copyright (C) 2015 The Protocol Authors. + +// +build gofuzz + +package protocol + +import ( + "bytes" + "encoding/binary" + "encoding/hex" + "fmt" + "reflect" + "sync" +) + +func Fuzz(data []byte) int { + // Regenerate the length, or we'll most commonly exit quickly due to an + // unexpected eof which is unintestering. + if len(data) > 8 { + binary.BigEndian.PutUint32(data[4:], uint32(len(data))-8) + } + + // Setup a rawConnection we'll use to parse the message. + c := rawConnection{ + cr: &countingReader{Reader: bytes.NewReader(data)}, + closed: make(chan struct{}), + pool: sync.Pool{ + New: func() interface{} { + return make([]byte, BlockSize) + }, + }, + } + + // Attempt to parse the message. + hdr, msg, err := c.readMessage() + if err != nil { + return 0 + } + + // If parsing worked, attempt to encode it again. + newBs, err := msg.AppendXDR(nil) + if err != nil { + panic("not encodable") + } + + // Create an appriate header for the re-encoding. + newMsg := make([]byte, 8) + binary.BigEndian.PutUint32(newMsg, encodeHeader(hdr)) + binary.BigEndian.PutUint32(newMsg[4:], uint32(len(newBs))) + newMsg = append(newMsg, newBs...) + + // Use the rawConnection to parse the re-encoding. + c.cr = &countingReader{Reader: bytes.NewReader(newMsg)} + hdr2, msg2, err := c.readMessage() + if err != nil { + fmt.Println("Initial:\n" + hex.Dump(data)) + fmt.Println("New:\n" + hex.Dump(newMsg)) + panic("not parseable after re-encode: " + err.Error()) + } + + // Make sure the data is the same as it was before. + if hdr != hdr2 { + panic("headers differ") + } + if !reflect.DeepEqual(msg, msg2) { + panic("contents differ") + } + + return 1 +} diff --git a/fuzz_test.go b/fuzz_test.go new file mode 100644 index 000000000..65c2d9010 --- /dev/null +++ b/fuzz_test.go @@ -0,0 +1,89 @@ +// Copyright (C) 2015 The Protocol Authors. + +// +build gofuzz + +package protocol + +import ( + "encoding/binary" + "fmt" + "io/ioutil" + "os" + "strings" + "testing" + "testing/quick" +) + +// This can be used to generate a corpus of valid messages as a starting point +// for the fuzzer. +func TestGenerateCorpus(t *testing.T) { + t.Skip("Use to generate initial corpus only") + + n := 0 + check := func(idx IndexMessage) bool { + for i := range idx.Options { + if len(idx.Options[i].Key) > 64 { + idx.Options[i].Key = idx.Options[i].Key[:64] + } + } + hdr := header{ + version: 0, + msgID: 42, + msgType: messageTypeIndex, + compression: false, + } + + msgBs := idx.MustMarshalXDR() + + buf := make([]byte, 8) + binary.BigEndian.PutUint32(buf, encodeHeader(hdr)) + binary.BigEndian.PutUint32(buf[4:], uint32(len(msgBs))) + buf = append(buf, msgBs...) + + ioutil.WriteFile(fmt.Sprintf("testdata/corpus/test-%03d.xdr", n), buf, 0644) + n++ + return true + } + + if err := quick.Check(check, &quick.Config{MaxCount: 1000}); err != nil { + t.Fatal(err) + } +} + +// Tests any crashers found by the fuzzer, for closer investigation. +func TestCrashers(t *testing.T) { + testFiles(t, "testdata/crashers") +} + +// Tests the entire corpus, which should PASS before the fuzzer starts +// fuzzing. +func TestCorpus(t *testing.T) { + testFiles(t, "testdata/corpus") +} + +func testFiles(t *testing.T, dir string) { + fd, err := os.Open(dir) + if err != nil { + t.Fatal(err) + } + crashers, err := fd.Readdirnames(-1) + if err != nil { + t.Fatal(err) + } + for _, name := range crashers { + if strings.HasSuffix(name, ".output") { + continue + } + if strings.HasSuffix(name, ".quoted") { + continue + } + + t.Log(name) + crasher, err := ioutil.ReadFile(dir + "/" + name) + if err != nil { + t.Fatal(err) + } + + Fuzz(crasher) + } +} diff --git a/message.go b/message.go index 49df7d4fa..0cfeaa381 100644 --- a/message.go +++ b/message.go @@ -9,7 +9,7 @@ import "fmt" type IndexMessage struct { Folder string - Files []FileInfo + Files []FileInfo // max:1000000 Flags uint32 Options []Option // max:64 } @@ -20,7 +20,7 @@ type FileInfo struct { Modified int64 Version Vector LocalVersion int64 - Blocks []BlockInfo + Blocks []BlockInfo // max:1000000 } func (f FileInfo) String() string { @@ -109,9 +109,9 @@ type ResponseMessage struct { } type ClusterConfigMessage struct { - ClientName string // max:64 - ClientVersion string // max:64 - Folders []Folder + ClientName string // max:64 + ClientVersion string // max:64 + Folders []Folder // max:1000000 Options []Option // max:64 } @@ -125,8 +125,8 @@ func (o *ClusterConfigMessage) GetOption(key string) string { } type Folder struct { - ID string // max:64 - Devices []Device + ID string // max:64 + Devices []Device // max:1000000 Flags uint32 Options []Option // max:64 } diff --git a/message_xdr.go b/message_xdr.go index 68d01b696..876fbb77c 100644 --- a/message_xdr.go +++ b/message_xdr.go @@ -42,7 +42,7 @@ IndexMessage Structure: struct IndexMessage { string Folder<>; - FileInfo Files<>; + FileInfo Files<1000000>; unsigned int Flags; Option Options<64>; } @@ -75,6 +75,9 @@ func (o IndexMessage) AppendXDR(bs []byte) ([]byte, error) { func (o IndexMessage) EncodeXDRInto(xw *xdr.Writer) (int, error) { xw.WriteString(o.Folder) + if l := len(o.Files); l > 1000000 { + return xw.Tot(), xdr.ElementSizeExceeded("Files", l, 1000000) + } xw.WriteUint32(uint32(len(o.Files))) for i := range o.Files { _, err := o.Files[i].EncodeXDRInto(xw) @@ -111,7 +114,10 @@ func (o *IndexMessage) DecodeXDRFrom(xr *xdr.Reader) error { o.Folder = xr.ReadString() _FilesSize := int(xr.ReadUint32()) if _FilesSize < 0 { - return xdr.ElementSizeExceeded("Files", _FilesSize, 0) + return xdr.ElementSizeExceeded("Files", _FilesSize, 1000000) + } + if _FilesSize > 1000000 { + return xdr.ElementSizeExceeded("Files", _FilesSize, 1000000) } o.Files = make([]FileInfo, _FilesSize) for i := range o.Files { @@ -173,7 +179,7 @@ struct FileInfo { hyper Modified; Vector Version; hyper LocalVersion; - BlockInfo Blocks<>; + BlockInfo Blocks<1000000>; } */ @@ -214,6 +220,9 @@ func (o FileInfo) EncodeXDRInto(xw *xdr.Writer) (int, error) { return xw.Tot(), err } xw.WriteUint64(uint64(o.LocalVersion)) + if l := len(o.Blocks); l > 1000000 { + return xw.Tot(), xdr.ElementSizeExceeded("Blocks", l, 1000000) + } xw.WriteUint32(uint32(len(o.Blocks))) for i := range o.Blocks { _, err := o.Blocks[i].EncodeXDRInto(xw) @@ -243,7 +252,10 @@ func (o *FileInfo) DecodeXDRFrom(xr *xdr.Reader) error { o.LocalVersion = int64(xr.ReadUint64()) _BlocksSize := int(xr.ReadUint32()) if _BlocksSize < 0 { - return xdr.ElementSizeExceeded("Blocks", _BlocksSize, 0) + return xdr.ElementSizeExceeded("Blocks", _BlocksSize, 1000000) + } + if _BlocksSize > 1000000 { + return xdr.ElementSizeExceeded("Blocks", _BlocksSize, 1000000) } o.Blocks = make([]BlockInfo, _BlocksSize) for i := range o.Blocks { @@ -571,7 +583,7 @@ ClusterConfigMessage Structure: struct ClusterConfigMessage { string ClientName<64>; string ClientVersion<64>; - Folder Folders<>; + Folder Folders<1000000>; Option Options<64>; } @@ -610,6 +622,9 @@ func (o ClusterConfigMessage) EncodeXDRInto(xw *xdr.Writer) (int, error) { return xw.Tot(), xdr.ElementSizeExceeded("ClientVersion", l, 64) } xw.WriteString(o.ClientVersion) + if l := len(o.Folders); l > 1000000 { + return xw.Tot(), xdr.ElementSizeExceeded("Folders", l, 1000000) + } xw.WriteUint32(uint32(len(o.Folders))) for i := range o.Folders { _, err := o.Folders[i].EncodeXDRInto(xw) @@ -646,7 +661,10 @@ func (o *ClusterConfigMessage) DecodeXDRFrom(xr *xdr.Reader) error { o.ClientVersion = xr.ReadStringMax(64) _FoldersSize := int(xr.ReadUint32()) if _FoldersSize < 0 { - return xdr.ElementSizeExceeded("Folders", _FoldersSize, 0) + return xdr.ElementSizeExceeded("Folders", _FoldersSize, 1000000) + } + if _FoldersSize > 1000000 { + return xdr.ElementSizeExceeded("Folders", _FoldersSize, 1000000) } o.Folders = make([]Folder, _FoldersSize) for i := range o.Folders { @@ -697,7 +715,7 @@ Folder Structure: struct Folder { string ID<64>; - Device Devices<>; + Device Devices<1000000>; unsigned int Flags; Option Options<64>; } @@ -733,6 +751,9 @@ func (o Folder) EncodeXDRInto(xw *xdr.Writer) (int, error) { return xw.Tot(), xdr.ElementSizeExceeded("ID", l, 64) } xw.WriteString(o.ID) + if l := len(o.Devices); l > 1000000 { + return xw.Tot(), xdr.ElementSizeExceeded("Devices", l, 1000000) + } xw.WriteUint32(uint32(len(o.Devices))) for i := range o.Devices { _, err := o.Devices[i].EncodeXDRInto(xw) @@ -769,7 +790,10 @@ func (o *Folder) DecodeXDRFrom(xr *xdr.Reader) error { o.ID = xr.ReadStringMax(64) _DevicesSize := int(xr.ReadUint32()) if _DevicesSize < 0 { - return xdr.ElementSizeExceeded("Devices", _DevicesSize, 0) + return xdr.ElementSizeExceeded("Devices", _DevicesSize, 1000000) + } + if _DevicesSize > 1000000 { + return xdr.ElementSizeExceeded("Devices", _DevicesSize, 1000000) } o.Devices = make([]Device, _DevicesSize) for i := range o.Devices { diff --git a/protocol.go b/protocol.go index 8b41c0138..8e73afea5 100644 --- a/protocol.go +++ b/protocol.go @@ -15,7 +15,11 @@ import ( ) const ( - BlockSize = 128 * 1024 + // Data block size (128 KiB) + BlockSize = 128 << 10 + + // We reject messages larger than this when encountered on the wire. (64 MiB) + MaxMessageLen = 64 << 20 ) const ( @@ -383,6 +387,11 @@ func (c *rawConnection) readMessage() (hdr header, msg encodable, err error) { l.Debugf("read header %v (msglen=%d)", hdr, msglen) } + if msglen > MaxMessageLen { + err = fmt.Errorf("message length %d exceeds maximum %d", msglen, MaxMessageLen) + return + } + if hdr.version != 0 { err = fmt.Errorf("unknown protocol version 0x%x", hdr.version) return @@ -403,7 +412,7 @@ func (c *rawConnection) readMessage() (hdr header, msg encodable, err error) { } msgBuf := c.rdbuf0 - if hdr.compression { + if hdr.compression && msglen > 0 { c.rdbuf1 = c.rdbuf1[:cap(c.rdbuf1)] c.rdbuf1, err = lz4.Decode(c.rdbuf1, c.rdbuf0) if err != nil { diff --git a/vector_xdr.go b/vector_xdr.go index a4b6b132b..01efa7e4e 100644 --- a/vector_xdr.go +++ b/vector_xdr.go @@ -2,6 +2,8 @@ package protocol +import "github.com/calmh/xdr" + // This stuff is hacked up manually because genxdr doesn't support 'type // Vector []Counter' declarations and it was tricky when I tried to add it... @@ -28,6 +30,9 @@ func (v Vector) EncodeXDRInto(w xdrWriter) (int, error) { // DecodeXDRFrom decodes the XDR objects from the given reader into itself. func (v *Vector) DecodeXDRFrom(r xdrReader) error { l := int(r.ReadUint32()) + if l > 1e6 { + return xdr.ElementSizeExceeded("number of counters", l, 1e6) + } n := make(Vector, l) for i := range n { n[i].ID = r.ReadUint64()