diff --git a/protocol/counting.go b/protocol/counting.go new file mode 100644 index 000000000..d7a3f6c0e --- /dev/null +++ b/protocol/counting.go @@ -0,0 +1,36 @@ +package protocol + +import ( + "io" + "sync/atomic" +) + +type countingReader struct { + io.Reader + tot uint64 +} + +func (c *countingReader) Read(bs []byte) (int, error) { + n, err := c.Reader.Read(bs) + atomic.AddUint64(&c.tot, uint64(n)) + return n, err +} + +func (c *countingReader) Tot() uint64 { + return atomic.LoadUint64(&c.tot) +} + +type countingWriter struct { + io.Writer + tot uint64 +} + +func (c *countingWriter) Write(bs []byte) (int, error) { + n, err := c.Writer.Write(bs) + atomic.AddUint64(&c.tot, uint64(n)) + return n, err +} + +func (c *countingWriter) Tot() uint64 { + return atomic.LoadUint64(&c.tot) +} diff --git a/protocol/protocol.go b/protocol/protocol.go index 9b190c02b..889199591 100644 --- a/protocol/protocol.go +++ b/protocol/protocol.go @@ -69,8 +69,10 @@ type rawConnection struct { id string receiver Model reader io.ReadCloser + cr *countingReader xr *xdr.Reader writer io.WriteCloser + cw *countingWriter wb *bufio.Writer xw *xdr.Writer closed chan struct{} @@ -93,8 +95,11 @@ const ( ) func NewConnection(nodeID string, reader io.Reader, writer io.Writer, receiver Model) Connection { - flrd := flate.NewReader(reader) - flwr, err := flate.NewWriter(writer, flate.BestSpeed) + cr := &countingReader{Reader: reader} + cw := &countingWriter{Writer: writer} + + flrd := flate.NewReader(cr) + flwr, err := flate.NewWriter(cw, flate.BestSpeed) if err != nil { panic(err) } @@ -104,8 +109,10 @@ func NewConnection(nodeID string, reader io.Reader, writer io.Writer, receiver M id: nodeID, receiver: nativeModel{receiver}, reader: flrd, + cr: cr, xr: xdr.NewReader(flrd), writer: flwr, + cw: cw, wb: wb, xw: xdr.NewWriter(wb), closed: make(chan struct{}), @@ -461,7 +468,7 @@ type Statistics struct { func (c *rawConnection) Statistics() Statistics { return Statistics{ At: time.Now(), - InBytesTotal: int(c.xr.Tot()), - OutBytesTotal: int(c.xw.Tot()), + InBytesTotal: int(c.cr.Tot()), + OutBytesTotal: int(c.cw.Tot()), } }