Error handling, testing

This commit is contained in:
Jakob Borg 2013-12-20 00:01:34 +01:00
parent eba1c9e649
commit f5987fba32
6 changed files with 207 additions and 35 deletions

View File

@ -47,7 +47,7 @@ var testDataExpected = map[string]File{
func TestUpdateLocal(t *testing.T) { func TestUpdateLocal(t *testing.T) {
m := NewModel("foo") m := NewModel("foo")
fs := Walk("testdata", m) fs := Walk("testdata", m, false)
m.ReplaceLocal(fs) m.ReplaceLocal(fs)
if len(m.need) > 0 { if len(m.need) > 0 {
@ -89,7 +89,7 @@ func TestUpdateLocal(t *testing.T) {
func TestRemoteUpdateExisting(t *testing.T) { func TestRemoteUpdateExisting(t *testing.T) {
m := NewModel("foo") m := NewModel("foo")
fs := Walk("testdata", m) fs := Walk("testdata", m, false)
m.ReplaceLocal(fs) m.ReplaceLocal(fs)
newFile := protocol.FileInfo{ newFile := protocol.FileInfo{
@ -106,7 +106,7 @@ func TestRemoteUpdateExisting(t *testing.T) {
func TestRemoteAddNew(t *testing.T) { func TestRemoteAddNew(t *testing.T) {
m := NewModel("foo") m := NewModel("foo")
fs := Walk("testdata", m) fs := Walk("testdata", m, false)
m.ReplaceLocal(fs) m.ReplaceLocal(fs)
newFile := protocol.FileInfo{ newFile := protocol.FileInfo{
@ -123,7 +123,7 @@ func TestRemoteAddNew(t *testing.T) {
func TestRemoteUpdateOld(t *testing.T) { func TestRemoteUpdateOld(t *testing.T) {
m := NewModel("foo") m := NewModel("foo")
fs := Walk("testdata", m) fs := Walk("testdata", m, false)
m.ReplaceLocal(fs) m.ReplaceLocal(fs)
oldTimeStamp := int64(1234) oldTimeStamp := int64(1234)
@ -141,7 +141,7 @@ func TestRemoteUpdateOld(t *testing.T) {
func TestDelete(t *testing.T) { func TestDelete(t *testing.T) {
m := NewModel("foo") m := NewModel("foo")
fs := Walk("testdata", m) fs := Walk("testdata", m, false)
m.ReplaceLocal(fs) m.ReplaceLocal(fs)
if l1, l2 := len(m.local), len(fs); l1 != l2 { if l1, l2 := len(m.local), len(fs); l1 != l2 {
@ -231,7 +231,7 @@ func TestDelete(t *testing.T) {
func TestForgetNode(t *testing.T) { func TestForgetNode(t *testing.T) {
m := NewModel("foo") m := NewModel("foo")
fs := Walk("testdata", m) fs := Walk("testdata", m, false)
m.ReplaceLocal(fs) m.ReplaceLocal(fs)
if l1, l2 := len(m.local), len(fs); l1 != l2 { if l1, l2 := len(m.local), len(fs); l1 != l2 {

49
protocol/common_test.go Normal file
View File

@ -0,0 +1,49 @@
package protocol
import "io"
type TestModel struct {
data []byte
name string
offset uint64
size uint32
hash []byte
closed bool
}
func (t *TestModel) Index(nodeID string, files []FileInfo) {
}
func (t *TestModel) Request(nodeID, name string, offset uint64, size uint32, hash []byte) ([]byte, error) {
t.name = name
t.offset = offset
t.size = size
t.hash = hash
return t.data, nil
}
func (t *TestModel) Close(nodeID string) {
t.closed = true
}
type ErrPipe struct {
io.PipeWriter
written int
max int
err error
closed bool
}
func (e *ErrPipe) Write(data []byte) (int, error) {
if e.closed {
return 0, e.err
}
if e.written+len(data) > e.max {
n, _ := e.PipeWriter.Write(data[:e.max-e.written])
e.PipeWriter.CloseWithError(e.err)
e.closed = true
return n, e.err
} else {
return e.PipeWriter.Write(data)
}
}

View File

@ -35,8 +35,8 @@ func (w *marshalWriter) writeBytes(bs []byte) {
return return
} }
_, w.err = w.w.Write(bs) _, w.err = w.w.Write(bs)
if p := pad(len(bs)); p > 0 { if p := pad(len(bs)); w.err == nil && p > 0 {
w.w.Write(padBytes[:p]) _, w.err = w.w.Write(padBytes[:p])
} }
w.tot += len(bs) + pad(len(bs)) w.tot += len(bs) + pad(len(bs))
} }

View File

@ -49,7 +49,6 @@ type Connection struct {
mwriter *marshalWriter mwriter *marshalWriter
wLock sync.RWMutex wLock sync.RWMutex
closed bool closed bool
closedLock sync.RWMutex
awaiting map[int]chan asyncResult awaiting map[int]chan asyncResult
nextId int nextId int
lastReceive time.Time lastReceive time.Time
@ -74,13 +73,14 @@ func NewConnection(nodeID string, reader io.Reader, writer io.Writer, receiver M
} }
c := Connection{ c := Connection{
receiver: receiver, receiver: receiver,
reader: flrd, reader: flrd,
mreader: &marshalReader{flrd, 0, nil}, mreader: &marshalReader{flrd, 0, nil},
writer: flwr, writer: flwr,
mwriter: &marshalWriter{flwr, 0, nil}, mwriter: &marshalWriter{flwr, 0, nil},
awaiting: make(map[int]chan asyncResult), awaiting: make(map[int]chan asyncResult),
ID: nodeID, lastReceive: time.Now(),
ID: nodeID,
} }
go c.readerLoop() go c.readerLoop()
@ -92,12 +92,15 @@ func NewConnection(nodeID string, reader io.Reader, writer io.Writer, receiver M
// Index writes the list of file information to the connected peer node // Index writes the list of file information to the connected peer node
func (c *Connection) Index(idx []FileInfo) { func (c *Connection) Index(idx []FileInfo) {
c.wLock.Lock() c.wLock.Lock()
defer c.wLock.Unlock()
c.mwriter.writeHeader(header{0, c.nextId, messageTypeIndex}) c.mwriter.writeHeader(header{0, c.nextId, messageTypeIndex})
c.nextId = (c.nextId + 1) & 0xfff
c.mwriter.writeIndex(idx) c.mwriter.writeIndex(idx)
c.flush() err := c.flush()
c.nextId = (c.nextId + 1) & 0xfff
c.wLock.Unlock()
if err != nil || c.mwriter.err != nil {
c.close()
return
}
} }
// Request returns the bytes for the specified block after fetching them from the connected peer. // Request returns the bytes for the specified block after fetching them from the connected peer.
@ -107,7 +110,17 @@ func (c *Connection) Request(name string, offset uint64, size uint32, hash []byt
c.awaiting[c.nextId] = rc c.awaiting[c.nextId] = rc
c.mwriter.writeHeader(header{0, c.nextId, messageTypeRequest}) c.mwriter.writeHeader(header{0, c.nextId, messageTypeRequest})
c.mwriter.writeRequest(request{name, offset, size, hash}) c.mwriter.writeRequest(request{name, offset, size, hash})
c.flush() if c.mwriter.err != nil {
c.wLock.Unlock()
c.close()
return nil, c.mwriter.err
}
err := c.flush()
if err != nil {
c.wLock.Unlock()
c.close()
return nil, err
}
c.nextId = (c.nextId + 1) & 0xfff c.nextId = (c.nextId + 1) & 0xfff
c.wLock.Unlock() c.wLock.Unlock()
@ -123,7 +136,12 @@ func (c *Connection) Ping() bool {
rc := make(chan asyncResult) rc := make(chan asyncResult)
c.awaiting[c.nextId] = rc c.awaiting[c.nextId] = rc
c.mwriter.writeHeader(header{0, c.nextId, messageTypePing}) c.mwriter.writeHeader(header{0, c.nextId, messageTypePing})
c.flush() err := c.flush()
if err != nil || c.mwriter.err != nil {
c.wLock.Unlock()
c.close()
return false
}
c.nextId = (c.nextId + 1) & 0xfff c.nextId = (c.nextId + 1) & 0xfff
c.wLock.Unlock() c.wLock.Unlock()
@ -138,18 +156,20 @@ type flusher interface {
Flush() error Flush() error
} }
func (c *Connection) flush() { func (c *Connection) flush() error {
if f, ok := c.writer.(flusher); ok { if f, ok := c.writer.(flusher); ok {
f.Flush() return f.Flush()
} }
return nil
} }
func (c *Connection) close() { func (c *Connection) close() {
c.closedLock.Lock()
c.closed = true
c.closedLock.Unlock()
c.wLock.Lock() c.wLock.Lock()
if c.closed {
c.wLock.Unlock()
return
}
c.closed = true
for _, ch := range c.awaiting { for _, ch := range c.awaiting {
close(ch) close(ch)
} }
@ -160,8 +180,8 @@ func (c *Connection) close() {
} }
func (c *Connection) isClosed() bool { func (c *Connection) isClosed() bool {
c.closedLock.RLock() c.wLock.RLock()
defer c.closedLock.RUnlock() defer c.wLock.RUnlock()
return c.closed return c.closed
} }
@ -215,9 +235,9 @@ func (c *Connection) readerLoop() {
case messageTypePing: case messageTypePing:
c.wLock.Lock() c.wLock.Lock()
c.mwriter.writeUint32(encodeHeader(header{0, hdr.msgID, messageTypePong})) c.mwriter.writeUint32(encodeHeader(header{0, hdr.msgID, messageTypePong}))
c.flush() err := c.flush()
c.wLock.Unlock() c.wLock.Unlock()
if c.mwriter.err != nil { if err != nil || c.mwriter.err != nil {
c.close() c.close()
} }
@ -248,9 +268,12 @@ func (c *Connection) processRequest(msgID int) {
c.wLock.Lock() c.wLock.Lock()
c.mwriter.writeUint32(encodeHeader(header{0, msgID, messageTypeResponse})) c.mwriter.writeUint32(encodeHeader(header{0, msgID, messageTypeResponse}))
c.mwriter.writeResponse(data) c.mwriter.writeResponse(data)
buffers.Put(data) err := c.flush()
c.flush()
c.wLock.Unlock() c.wLock.Unlock()
buffers.Put(data)
if c.mwriter.err != nil || err != nil {
c.close()
}
}() }()
} }
} }

View File

@ -1,8 +1,11 @@
package protocol package protocol
import ( import (
"errors"
"io"
"testing" "testing"
"testing/quick" "testing/quick"
"time"
) )
func TestHeaderFunctions(t *testing.T) { func TestHeaderFunctions(t *testing.T) {
@ -35,3 +38,100 @@ func TestPad(t *testing.T) {
} }
} }
} }
func TestPing(t *testing.T) {
ar, aw := io.Pipe()
br, bw := io.Pipe()
c0 := NewConnection("c0", ar, bw, nil)
c1 := NewConnection("c1", br, aw, nil)
if !c0.Ping() {
t.Error("c0 ping failed")
}
if !c1.Ping() {
t.Error("c1 ping failed")
}
}
func TestPingErr(t *testing.T) {
e := errors.New("Something broke")
for i := 0; i < 12; i++ {
for j := 0; j < 12; j++ {
m0 := &TestModel{}
m1 := &TestModel{}
ar, aw := io.Pipe()
br, bw := io.Pipe()
eaw := &ErrPipe{PipeWriter: *aw, max: i, err: e}
ebw := &ErrPipe{PipeWriter: *bw, max: j, err: e}
c0 := NewConnection("c0", ar, ebw, m0)
NewConnection("c1", br, eaw, m1)
res := c0.Ping()
if (i < 4 || j < 4) && res {
t.Errorf("Unexpected ping success; i=%d, j=%d", i, j)
} else if (i >= 8 && j >= 8) && !res {
t.Errorf("Unexpected ping fail; i=%d, j=%d", i, j)
}
}
}
}
func TestRequestResponseErr(t *testing.T) {
e := errors.New("Something broke")
var pass bool
for i := 0; i < 36; i++ {
for j := 0; j < 26; j++ {
m0 := &TestModel{data: []byte("response data")}
m1 := &TestModel{}
ar, aw := io.Pipe()
br, bw := io.Pipe()
eaw := &ErrPipe{PipeWriter: *aw, max: i, err: e}
ebw := &ErrPipe{PipeWriter: *bw, max: j, err: e}
NewConnection("c0", ar, ebw, m0)
c1 := NewConnection("c1", br, eaw, m1)
d, err := c1.Request("tn", 1234, 3456, []byte("hashbytes"))
if err == e || err == ErrClosed {
t.Logf("Error at %d+%d bytes", i, j)
if !m1.closed {
t.Error("c1 not closed")
}
time.Sleep(1 * time.Millisecond)
if !m0.closed {
t.Error("c0 not closed")
}
continue
}
if err != nil {
t.Error(err)
}
if string(d) != "response data" {
t.Errorf("Incorrect response data %q", string(d))
}
if m0.name != "tn" {
t.Error("Incorrect name %q", m0.name)
}
if m0.offset != 1234 {
t.Error("Incorrect offset %d", m0.offset)
}
if m0.size != 3456 {
t.Error("Incorrect size %d", m0.size)
}
if string(m0.hash) != "hashbytes" {
t.Error("Incorrect hash %q", m0.hash)
}
t.Logf("Pass at %d+%d bytes", i, j)
pass = true
}
}
if !pass {
t.Error("Never passed")
}
}

View File

@ -18,7 +18,7 @@ var testdata = []struct {
func TestWalk(t *testing.T) { func TestWalk(t *testing.T) {
m := new(Model) m := new(Model)
files := Walk("testdata", m) files := Walk("testdata", m, false)
if l1, l2 := len(files), len(testdata); l1 != l2 { if l1, l2 := len(files), len(testdata); l1 != l2 {
t.Fatalf("Incorrect number of walked files %d != %d", l1, l2) t.Fatalf("Incorrect number of walked files %d != %d", l1, l2)