diff --git a/model_test.go b/model_test.go index fe2a5c38e..d197a6542 100644 --- a/model_test.go +++ b/model_test.go @@ -47,7 +47,7 @@ var testDataExpected = map[string]File{ func TestUpdateLocal(t *testing.T) { m := NewModel("foo") - fs := Walk("testdata", m) + fs := Walk("testdata", m, false) m.ReplaceLocal(fs) if len(m.need) > 0 { @@ -89,7 +89,7 @@ func TestUpdateLocal(t *testing.T) { func TestRemoteUpdateExisting(t *testing.T) { m := NewModel("foo") - fs := Walk("testdata", m) + fs := Walk("testdata", m, false) m.ReplaceLocal(fs) newFile := protocol.FileInfo{ @@ -106,7 +106,7 @@ func TestRemoteUpdateExisting(t *testing.T) { func TestRemoteAddNew(t *testing.T) { m := NewModel("foo") - fs := Walk("testdata", m) + fs := Walk("testdata", m, false) m.ReplaceLocal(fs) newFile := protocol.FileInfo{ @@ -123,7 +123,7 @@ func TestRemoteAddNew(t *testing.T) { func TestRemoteUpdateOld(t *testing.T) { m := NewModel("foo") - fs := Walk("testdata", m) + fs := Walk("testdata", m, false) m.ReplaceLocal(fs) oldTimeStamp := int64(1234) @@ -141,7 +141,7 @@ func TestRemoteUpdateOld(t *testing.T) { func TestDelete(t *testing.T) { m := NewModel("foo") - fs := Walk("testdata", m) + fs := Walk("testdata", m, false) m.ReplaceLocal(fs) if l1, l2 := len(m.local), len(fs); l1 != l2 { @@ -231,7 +231,7 @@ func TestDelete(t *testing.T) { func TestForgetNode(t *testing.T) { m := NewModel("foo") - fs := Walk("testdata", m) + fs := Walk("testdata", m, false) m.ReplaceLocal(fs) if l1, l2 := len(m.local), len(fs); l1 != l2 { diff --git a/protocol/common_test.go b/protocol/common_test.go new file mode 100644 index 000000000..deaed2dd6 --- /dev/null +++ b/protocol/common_test.go @@ -0,0 +1,49 @@ +package protocol + +import "io" + +type TestModel struct { + data []byte + name string + offset uint64 + size uint32 + hash []byte + closed bool +} + +func (t *TestModel) Index(nodeID string, files []FileInfo) { +} + +func (t *TestModel) Request(nodeID, name string, offset uint64, size uint32, hash []byte) ([]byte, error) { + t.name = name + t.offset = offset + t.size = size + t.hash = hash + return t.data, nil +} + +func (t *TestModel) Close(nodeID string) { + t.closed = true +} + +type ErrPipe struct { + io.PipeWriter + written int + max int + err error + closed bool +} + +func (e *ErrPipe) Write(data []byte) (int, error) { + if e.closed { + return 0, e.err + } + if e.written+len(data) > e.max { + n, _ := e.PipeWriter.Write(data[:e.max-e.written]) + e.PipeWriter.CloseWithError(e.err) + e.closed = true + return n, e.err + } else { + return e.PipeWriter.Write(data) + } +} diff --git a/protocol/marshal.go b/protocol/marshal.go index 74e554d57..f363fb3c7 100644 --- a/protocol/marshal.go +++ b/protocol/marshal.go @@ -35,8 +35,8 @@ func (w *marshalWriter) writeBytes(bs []byte) { return } _, w.err = w.w.Write(bs) - if p := pad(len(bs)); p > 0 { - w.w.Write(padBytes[:p]) + if p := pad(len(bs)); w.err == nil && p > 0 { + _, w.err = w.w.Write(padBytes[:p]) } w.tot += len(bs) + pad(len(bs)) } diff --git a/protocol/protocol.go b/protocol/protocol.go index 963e86459..86d05b594 100644 --- a/protocol/protocol.go +++ b/protocol/protocol.go @@ -49,7 +49,6 @@ type Connection struct { mwriter *marshalWriter wLock sync.RWMutex closed bool - closedLock sync.RWMutex awaiting map[int]chan asyncResult nextId int lastReceive time.Time @@ -74,13 +73,14 @@ func NewConnection(nodeID string, reader io.Reader, writer io.Writer, receiver M } c := Connection{ - receiver: receiver, - reader: flrd, - mreader: &marshalReader{flrd, 0, nil}, - writer: flwr, - mwriter: &marshalWriter{flwr, 0, nil}, - awaiting: make(map[int]chan asyncResult), - ID: nodeID, + receiver: receiver, + reader: flrd, + mreader: &marshalReader{flrd, 0, nil}, + writer: flwr, + mwriter: &marshalWriter{flwr, 0, nil}, + awaiting: make(map[int]chan asyncResult), + lastReceive: time.Now(), + ID: nodeID, } go c.readerLoop() @@ -92,12 +92,15 @@ func NewConnection(nodeID string, reader io.Reader, writer io.Writer, receiver M // Index writes the list of file information to the connected peer node func (c *Connection) Index(idx []FileInfo) { c.wLock.Lock() - defer c.wLock.Unlock() - c.mwriter.writeHeader(header{0, c.nextId, messageTypeIndex}) - c.nextId = (c.nextId + 1) & 0xfff c.mwriter.writeIndex(idx) - c.flush() + err := c.flush() + c.nextId = (c.nextId + 1) & 0xfff + c.wLock.Unlock() + if err != nil || c.mwriter.err != nil { + c.close() + return + } } // Request returns the bytes for the specified block after fetching them from the connected peer. @@ -107,7 +110,17 @@ func (c *Connection) Request(name string, offset uint64, size uint32, hash []byt c.awaiting[c.nextId] = rc c.mwriter.writeHeader(header{0, c.nextId, messageTypeRequest}) c.mwriter.writeRequest(request{name, offset, size, hash}) - c.flush() + if c.mwriter.err != nil { + c.wLock.Unlock() + c.close() + return nil, c.mwriter.err + } + err := c.flush() + if err != nil { + c.wLock.Unlock() + c.close() + return nil, err + } c.nextId = (c.nextId + 1) & 0xfff c.wLock.Unlock() @@ -123,7 +136,12 @@ func (c *Connection) Ping() bool { rc := make(chan asyncResult) c.awaiting[c.nextId] = rc c.mwriter.writeHeader(header{0, c.nextId, messageTypePing}) - c.flush() + err := c.flush() + if err != nil || c.mwriter.err != nil { + c.wLock.Unlock() + c.close() + return false + } c.nextId = (c.nextId + 1) & 0xfff c.wLock.Unlock() @@ -138,18 +156,20 @@ type flusher interface { Flush() error } -func (c *Connection) flush() { +func (c *Connection) flush() error { if f, ok := c.writer.(flusher); ok { - f.Flush() + return f.Flush() } + return nil } func (c *Connection) close() { - c.closedLock.Lock() - c.closed = true - c.closedLock.Unlock() - c.wLock.Lock() + if c.closed { + c.wLock.Unlock() + return + } + c.closed = true for _, ch := range c.awaiting { close(ch) } @@ -160,8 +180,8 @@ func (c *Connection) close() { } func (c *Connection) isClosed() bool { - c.closedLock.RLock() - defer c.closedLock.RUnlock() + c.wLock.RLock() + defer c.wLock.RUnlock() return c.closed } @@ -215,9 +235,9 @@ func (c *Connection) readerLoop() { case messageTypePing: c.wLock.Lock() c.mwriter.writeUint32(encodeHeader(header{0, hdr.msgID, messageTypePong})) - c.flush() + err := c.flush() c.wLock.Unlock() - if c.mwriter.err != nil { + if err != nil || c.mwriter.err != nil { c.close() } @@ -248,9 +268,12 @@ func (c *Connection) processRequest(msgID int) { c.wLock.Lock() c.mwriter.writeUint32(encodeHeader(header{0, msgID, messageTypeResponse})) c.mwriter.writeResponse(data) - buffers.Put(data) - c.flush() + err := c.flush() c.wLock.Unlock() + buffers.Put(data) + if c.mwriter.err != nil || err != nil { + c.close() + } }() } } diff --git a/protocol/protocol_test.go b/protocol/protocol_test.go index 99d2618c2..4492df25f 100644 --- a/protocol/protocol_test.go +++ b/protocol/protocol_test.go @@ -1,8 +1,11 @@ package protocol import ( + "errors" + "io" "testing" "testing/quick" + "time" ) func TestHeaderFunctions(t *testing.T) { @@ -35,3 +38,100 @@ func TestPad(t *testing.T) { } } } + +func TestPing(t *testing.T) { + ar, aw := io.Pipe() + br, bw := io.Pipe() + + c0 := NewConnection("c0", ar, bw, nil) + c1 := NewConnection("c1", br, aw, nil) + + if !c0.Ping() { + t.Error("c0 ping failed") + } + if !c1.Ping() { + t.Error("c1 ping failed") + } +} + +func TestPingErr(t *testing.T) { + e := errors.New("Something broke") + + for i := 0; i < 12; i++ { + for j := 0; j < 12; j++ { + m0 := &TestModel{} + m1 := &TestModel{} + + ar, aw := io.Pipe() + br, bw := io.Pipe() + eaw := &ErrPipe{PipeWriter: *aw, max: i, err: e} + ebw := &ErrPipe{PipeWriter: *bw, max: j, err: e} + + c0 := NewConnection("c0", ar, ebw, m0) + NewConnection("c1", br, eaw, m1) + + res := c0.Ping() + if (i < 4 || j < 4) && res { + t.Errorf("Unexpected ping success; i=%d, j=%d", i, j) + } else if (i >= 8 && j >= 8) && !res { + t.Errorf("Unexpected ping fail; i=%d, j=%d", i, j) + } + } + } +} + +func TestRequestResponseErr(t *testing.T) { + e := errors.New("Something broke") + + var pass bool + for i := 0; i < 36; i++ { + for j := 0; j < 26; j++ { + m0 := &TestModel{data: []byte("response data")} + m1 := &TestModel{} + + ar, aw := io.Pipe() + br, bw := io.Pipe() + eaw := &ErrPipe{PipeWriter: *aw, max: i, err: e} + ebw := &ErrPipe{PipeWriter: *bw, max: j, err: e} + + NewConnection("c0", ar, ebw, m0) + c1 := NewConnection("c1", br, eaw, m1) + + d, err := c1.Request("tn", 1234, 3456, []byte("hashbytes")) + if err == e || err == ErrClosed { + t.Logf("Error at %d+%d bytes", i, j) + if !m1.closed { + t.Error("c1 not closed") + } + time.Sleep(1 * time.Millisecond) + if !m0.closed { + t.Error("c0 not closed") + } + continue + } + if err != nil { + t.Error(err) + } + if string(d) != "response data" { + t.Errorf("Incorrect response data %q", string(d)) + } + if m0.name != "tn" { + t.Error("Incorrect name %q", m0.name) + } + if m0.offset != 1234 { + t.Error("Incorrect offset %d", m0.offset) + } + if m0.size != 3456 { + t.Error("Incorrect size %d", m0.size) + } + if string(m0.hash) != "hashbytes" { + t.Error("Incorrect hash %q", m0.hash) + } + t.Logf("Pass at %d+%d bytes", i, j) + pass = true + } + } + if !pass { + t.Error("Never passed") + } +} diff --git a/walk_test.go b/walk_test.go index 4f41daedc..e88ab0857 100644 --- a/walk_test.go +++ b/walk_test.go @@ -18,7 +18,7 @@ var testdata = []struct { func TestWalk(t *testing.T) { m := new(Model) - files := Walk("testdata", m) + files := Walk("testdata", m, false) if l1, l2 := len(files), len(testdata); l1 != l2 { t.Fatalf("Incorrect number of walked files %d != %d", l1, l2)