// Copyright (C) 2014 Jakob Borg and Contributors (see the CONTRIBUTORS file). // // This program is free software: you can redistribute it and/or modify it // under the terms of the GNU General Public License as published by the Free // Software Foundation, either version 3 of the License, or (at your option) // any later version. // // This program is distributed in the hope that it will be useful, but WITHOUT // ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or // FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for // more details. // // You should have received a copy of the GNU General Public License along // with this program. If not, see <http://www.gnu.org/licenses/>. package protocol import ( "bytes" "encoding/hex" "encoding/json" "errors" "fmt" "io" "io/ioutil" "os" "reflect" "testing" "testing/quick" "github.com/calmh/xdr" ) var ( c0ID = NewDeviceID([]byte{1}) c1ID = NewDeviceID([]byte{2}) ) func TestHeaderFunctions(t *testing.T) { f := func(ver, id, typ int) bool { ver = int(uint(ver) % 16) id = int(uint(id) % 4096) typ = int(uint(typ) % 256) h0 := header{version: ver, msgID: id, msgType: typ} h1 := decodeHeader(encodeHeader(h0)) return h0 == h1 } if err := quick.Check(f, nil); err != nil { t.Error(err) } } func TestHeaderLayout(t *testing.T) { var e, a uint32 // Version are the first four bits e = 0xf0000000 a = encodeHeader(header{version: 0xf}) if a != e { t.Errorf("Header layout incorrect; %08x != %08x", a, e) } // Message ID are the following 12 bits e = 0x0fff0000 a = encodeHeader(header{msgID: 0xfff}) if a != e { t.Errorf("Header layout incorrect; %08x != %08x", a, e) } // Type are the last 8 bits before reserved e = 0x0000ff00 a = encodeHeader(header{msgType: 0xff}) if a != e { t.Errorf("Header layout incorrect; %08x != %08x", a, e) } } func TestPing(t *testing.T) { ar, aw := io.Pipe() br, bw := io.Pipe() c0 := NewConnection(c0ID, ar, bw, nil, "name", true).(wireFormatConnection).next.(*rawConnection) c1 := NewConnection(c1ID, br, aw, nil, "name", true).(wireFormatConnection).next.(*rawConnection) if ok := c0.ping(); !ok { t.Error("c0 ping failed") } if ok := c1.ping(); !ok { t.Error("c1 ping failed") } } func TestPingErr(t *testing.T) { e := errors.New("something broke") for i := 0; i < 16; i++ { for j := 0; j < 16; j++ { m0 := newTestModel() m1 := newTestModel() ar, aw := io.Pipe() br, bw := io.Pipe() eaw := &ErrPipe{PipeWriter: *aw, max: i, err: e} ebw := &ErrPipe{PipeWriter: *bw, max: j, err: e} c0 := NewConnection(c0ID, ar, ebw, m0, "name", true).(wireFormatConnection).next.(*rawConnection) NewConnection(c1ID, br, eaw, m1, "name", true) res := c0.ping() if (i < 8 || j < 8) && res { t.Errorf("Unexpected ping success; i=%d, j=%d", i, j) } else if (i >= 12 && j >= 12) && !res { t.Errorf("Unexpected ping fail; i=%d, j=%d", i, j) } } } } // func TestRequestResponseErr(t *testing.T) { // e := errors.New("something broke") // var pass bool // for i := 0; i < 48; i++ { // for j := 0; j < 38; j++ { // m0 := newTestModel() // m0.data = []byte("response data") // m1 := newTestModel() // ar, aw := io.Pipe() // br, bw := io.Pipe() // eaw := &ErrPipe{PipeWriter: *aw, max: i, err: e} // ebw := &ErrPipe{PipeWriter: *bw, max: j, err: e} // NewConnection(c0ID, ar, ebw, m0, nil) // c1 := NewConnection(c1ID, br, eaw, m1, nil).(wireFormatConnection).next.(*rawConnection) // d, err := c1.Request("default", "tn", 1234, 5678) // if err == e || err == ErrClosed { // t.Logf("Error at %d+%d bytes", i, j) // if !m1.isClosed() { // t.Fatal("c1 not closed") // } // if !m0.isClosed() { // t.Fatal("c0 not closed") // } // continue // } // if err != nil { // t.Fatal(err) // } // if string(d) != "response data" { // t.Fatalf("Incorrect response data %q", string(d)) // } // if m0.folder != "default" { // t.Fatalf("Incorrect folder %q", m0.folder) // } // if m0.name != "tn" { // t.Fatalf("Incorrect name %q", m0.name) // } // if m0.offset != 1234 { // t.Fatalf("Incorrect offset %d", m0.offset) // } // if m0.size != 5678 { // t.Fatalf("Incorrect size %d", m0.size) // } // t.Logf("Pass at %d+%d bytes", i, j) // pass = true // } // } // if !pass { // t.Fatal("Never passed") // } // } func TestVersionErr(t *testing.T) { m0 := newTestModel() m1 := newTestModel() ar, aw := io.Pipe() br, bw := io.Pipe() c0 := NewConnection(c0ID, ar, bw, m0, "name", true).(wireFormatConnection).next.(*rawConnection) NewConnection(c1ID, br, aw, m1, "name", true) w := xdr.NewWriter(c0.cw) w.WriteUint32(encodeHeader(header{ version: 2, msgID: 0, msgType: 0, })) w.WriteUint32(0) if !m1.isClosed() { t.Error("Connection should close due to unknown version") } } func TestTypeErr(t *testing.T) { m0 := newTestModel() m1 := newTestModel() ar, aw := io.Pipe() br, bw := io.Pipe() c0 := NewConnection(c0ID, ar, bw, m0, "name", true).(wireFormatConnection).next.(*rawConnection) NewConnection(c1ID, br, aw, m1, "name", true) w := xdr.NewWriter(c0.cw) w.WriteUint32(encodeHeader(header{ version: 0, msgID: 0, msgType: 42, })) w.WriteUint32(0) if !m1.isClosed() { t.Error("Connection should close due to unknown message type") } } func TestClose(t *testing.T) { m0 := newTestModel() m1 := newTestModel() ar, aw := io.Pipe() br, bw := io.Pipe() c0 := NewConnection(c0ID, ar, bw, m0, "name", true).(wireFormatConnection).next.(*rawConnection) NewConnection(c1ID, br, aw, m1, "name", true) c0.close(nil) <-c0.closed if !m0.isClosed() { t.Fatal("Connection should be closed") } // None of these should panic, some should return an error if c0.ping() { t.Error("Ping should not return true") } c0.Index("default", nil) c0.Index("default", nil) if _, err := c0.Request("default", "foo", 0, 0); err == nil { t.Error("Request should return an error") } } func TestElementSizeExceededNested(t *testing.T) { m := ClusterConfigMessage{ Folders: []Folder{ {ID: "longstringlongstringlongstringinglongstringlongstringlonlongstringlongstringlon"}, }, } _, err := m.EncodeXDR(ioutil.Discard) if err == nil { t.Errorf("ID length %d > max 64, but no error", len(m.Folders[0].ID)) } } func TestMarshalIndexMessage(t *testing.T) { var quickCfg = &quick.Config{MaxCountScale: 10} if testing.Short() { quickCfg = nil } f := func(m1 IndexMessage) bool { for _, f := range m1.Files { for i := range f.Blocks { f.Blocks[i].Offset = 0 if len(f.Blocks[i].Hash) == 0 { f.Blocks[i].Hash = nil } } } return testMarshal(t, "index", &m1, &IndexMessage{}) } if err := quick.Check(f, quickCfg); err != nil { t.Error(err) } } func TestMarshalRequestMessage(t *testing.T) { var quickCfg = &quick.Config{MaxCountScale: 10} if testing.Short() { quickCfg = nil } f := func(m1 RequestMessage) bool { return testMarshal(t, "request", &m1, &RequestMessage{}) } if err := quick.Check(f, quickCfg); err != nil { t.Error(err) } } func TestMarshalResponseMessage(t *testing.T) { var quickCfg = &quick.Config{MaxCountScale: 10} if testing.Short() { quickCfg = nil } f := func(m1 ResponseMessage) bool { if len(m1.Data) == 0 { m1.Data = nil } return testMarshal(t, "response", &m1, &ResponseMessage{}) } if err := quick.Check(f, quickCfg); err != nil { t.Error(err) } } func TestMarshalClusterConfigMessage(t *testing.T) { var quickCfg = &quick.Config{MaxCountScale: 10} if testing.Short() { quickCfg = nil } f := func(m1 ClusterConfigMessage) bool { return testMarshal(t, "clusterconfig", &m1, &ClusterConfigMessage{}) } if err := quick.Check(f, quickCfg); err != nil { t.Error(err) } } func TestMarshalCloseMessage(t *testing.T) { var quickCfg = &quick.Config{MaxCountScale: 10} if testing.Short() { quickCfg = nil } f := func(m1 CloseMessage) bool { return testMarshal(t, "close", &m1, &CloseMessage{}) } if err := quick.Check(f, quickCfg); err != nil { t.Error(err) } } type message interface { EncodeXDR(io.Writer) (int, error) DecodeXDR(io.Reader) error } func testMarshal(t *testing.T, prefix string, m1, m2 message) bool { var buf bytes.Buffer failed := func(bc []byte) { bs, _ := json.MarshalIndent(m1, "", " ") ioutil.WriteFile(prefix+"-1.txt", bs, 0644) bs, _ = json.MarshalIndent(m2, "", " ") ioutil.WriteFile(prefix+"-2.txt", bs, 0644) if len(bc) > 0 { f, _ := os.Create(prefix + "-data.txt") fmt.Fprint(f, hex.Dump(bc)) f.Close() } } _, err := m1.EncodeXDR(&buf) if err == xdr.ErrElementSizeExceeded { return true } if err != nil { failed(nil) t.Fatal(err) } bc := make([]byte, len(buf.Bytes())) copy(bc, buf.Bytes()) err = m2.DecodeXDR(&buf) if err != nil { failed(bc) t.Fatal(err) } ok := reflect.DeepEqual(m1, m2) if !ok { failed(bc) } return ok }