mirror of
https://github.com/octoleo/syncthing.git
synced 2025-01-05 16:12:20 +00:00
Factor out XDR en/decoding
This commit is contained in:
parent
21a7f3960a
commit
f89fa6caed
@ -1,142 +0,0 @@
|
|||||||
package protocol
|
|
||||||
|
|
||||||
import (
|
|
||||||
"errors"
|
|
||||||
"io"
|
|
||||||
"sync/atomic"
|
|
||||||
|
|
||||||
"github.com/calmh/syncthing/buffers"
|
|
||||||
)
|
|
||||||
|
|
||||||
func pad(l int) int {
|
|
||||||
d := l % 4
|
|
||||||
if d == 0 {
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
return 4 - d
|
|
||||||
}
|
|
||||||
|
|
||||||
var padBytes = []byte{0, 0, 0}
|
|
||||||
|
|
||||||
type marshalWriter struct {
|
|
||||||
w io.Writer
|
|
||||||
tot uint64
|
|
||||||
err error
|
|
||||||
b [8]byte
|
|
||||||
}
|
|
||||||
|
|
||||||
// We will never encode nor expect to decode blobs larger than 10 MB. Check
|
|
||||||
// inserted to protect against attempting to allocate arbitrary amounts of
|
|
||||||
// memory when reading a corrupt message.
|
|
||||||
const maxBytesFieldLength = 10 * 1 << 20
|
|
||||||
|
|
||||||
var ErrFieldLengthExceeded = errors.New("Protocol error: raw bytes field size exceeds limit")
|
|
||||||
|
|
||||||
func (w *marshalWriter) writeString(s string) {
|
|
||||||
w.writeBytes([]byte(s))
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *marshalWriter) writeBytes(bs []byte) {
|
|
||||||
if w.err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if len(bs) > maxBytesFieldLength {
|
|
||||||
w.err = ErrFieldLengthExceeded
|
|
||||||
return
|
|
||||||
}
|
|
||||||
w.writeUint32(uint32(len(bs)))
|
|
||||||
if w.err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
_, w.err = w.w.Write(bs)
|
|
||||||
if p := pad(len(bs)); w.err == nil && p > 0 {
|
|
||||||
_, w.err = w.w.Write(padBytes[:p])
|
|
||||||
}
|
|
||||||
atomic.AddUint64(&w.tot, uint64(len(bs)+pad(len(bs))))
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *marshalWriter) writeUint32(v uint32) {
|
|
||||||
if w.err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
w.b[0] = byte(v >> 24)
|
|
||||||
w.b[1] = byte(v >> 16)
|
|
||||||
w.b[2] = byte(v >> 8)
|
|
||||||
w.b[3] = byte(v)
|
|
||||||
_, w.err = w.w.Write(w.b[:4])
|
|
||||||
atomic.AddUint64(&w.tot, 4)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *marshalWriter) writeUint64(v uint64) {
|
|
||||||
if w.err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
w.b[0] = byte(v >> 56)
|
|
||||||
w.b[1] = byte(v >> 48)
|
|
||||||
w.b[2] = byte(v >> 40)
|
|
||||||
w.b[3] = byte(v >> 32)
|
|
||||||
w.b[4] = byte(v >> 24)
|
|
||||||
w.b[5] = byte(v >> 16)
|
|
||||||
w.b[6] = byte(v >> 8)
|
|
||||||
w.b[7] = byte(v)
|
|
||||||
_, w.err = w.w.Write(w.b[:8])
|
|
||||||
atomic.AddUint64(&w.tot, 8)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *marshalWriter) getTot() uint64 {
|
|
||||||
return atomic.LoadUint64(&w.tot)
|
|
||||||
}
|
|
||||||
|
|
||||||
type marshalReader struct {
|
|
||||||
r io.Reader
|
|
||||||
tot uint64
|
|
||||||
err error
|
|
||||||
b [8]byte
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *marshalReader) readString() string {
|
|
||||||
bs := r.readBytes()
|
|
||||||
defer buffers.Put(bs)
|
|
||||||
return string(bs)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *marshalReader) readBytes() []byte {
|
|
||||||
if r.err != nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
l := int(r.readUint32())
|
|
||||||
if r.err != nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
if l > maxBytesFieldLength {
|
|
||||||
r.err = ErrFieldLengthExceeded
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
b := buffers.Get(l + pad(l))
|
|
||||||
_, r.err = io.ReadFull(r.r, b)
|
|
||||||
atomic.AddUint64(&r.tot, uint64(l+pad(l)))
|
|
||||||
return b[:l]
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *marshalReader) readUint32() uint32 {
|
|
||||||
if r.err != nil {
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
_, r.err = io.ReadFull(r.r, r.b[:4])
|
|
||||||
atomic.AddUint64(&r.tot, 8)
|
|
||||||
return uint32(r.b[3]) | uint32(r.b[2])<<8 | uint32(r.b[1])<<16 | uint32(r.b[0])<<24
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *marshalReader) readUint64() uint64 {
|
|
||||||
if r.err != nil {
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
_, r.err = io.ReadFull(r.r, r.b[:8])
|
|
||||||
atomic.AddUint64(&r.tot, 8)
|
|
||||||
return uint64(r.b[7]) | uint64(r.b[6])<<8 | uint64(r.b[5])<<16 | uint64(r.b[4])<<24 |
|
|
||||||
uint64(r.b[3])<<32 | uint64(r.b[2])<<40 | uint64(r.b[1])<<48 | uint64(r.b[0])<<56
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *marshalReader) getTot() uint64 {
|
|
||||||
return atomic.LoadUint64(&r.tot)
|
|
||||||
}
|
|
@ -3,6 +3,9 @@ package protocol
|
|||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"io"
|
"io"
|
||||||
|
|
||||||
|
"github.com/calmh/syncthing/buffers"
|
||||||
|
"github.com/calmh/syncthing/xdr"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@ -43,60 +46,93 @@ func decodeHeader(u uint32) header {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func WriteIndex(w io.Writer, repo string, idx []FileInfo) (int, error) {
|
||||||
|
mw := newMarshalWriter(w)
|
||||||
|
mw.writeIndex(repo, idx)
|
||||||
|
return int(mw.Tot()), mw.Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
type marshalWriter struct {
|
||||||
|
*xdr.Writer
|
||||||
|
}
|
||||||
|
|
||||||
|
func newMarshalWriter(w io.Writer) marshalWriter {
|
||||||
|
return marshalWriter{xdr.NewWriter(w)}
|
||||||
|
}
|
||||||
|
|
||||||
func (w *marshalWriter) writeHeader(h header) {
|
func (w *marshalWriter) writeHeader(h header) {
|
||||||
w.writeUint32(encodeHeader(h))
|
w.WriteUint32(encodeHeader(h))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *marshalWriter) writeIndex(repo string, idx []FileInfo) {
|
func (w *marshalWriter) writeIndex(repo string, idx []FileInfo) {
|
||||||
w.writeString(repo)
|
w.WriteString(repo)
|
||||||
w.writeUint32(uint32(len(idx)))
|
w.WriteUint32(uint32(len(idx)))
|
||||||
for _, f := range idx {
|
for _, f := range idx {
|
||||||
w.writeString(f.Name)
|
w.WriteString(f.Name)
|
||||||
w.writeUint32(f.Flags)
|
w.WriteUint32(f.Flags)
|
||||||
w.writeUint64(uint64(f.Modified))
|
w.WriteUint64(uint64(f.Modified))
|
||||||
w.writeUint32(f.Version)
|
w.WriteUint32(f.Version)
|
||||||
w.writeUint32(uint32(len(f.Blocks)))
|
w.WriteUint32(uint32(len(f.Blocks)))
|
||||||
for _, b := range f.Blocks {
|
for _, b := range f.Blocks {
|
||||||
w.writeUint32(b.Size)
|
w.WriteUint32(b.Size)
|
||||||
w.writeBytes(b.Hash)
|
w.WriteBytes(b.Hash)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func WriteIndex(w io.Writer, repo string, idx []FileInfo) (int, error) {
|
|
||||||
mw := marshalWriter{w: w}
|
|
||||||
mw.writeIndex(repo, idx)
|
|
||||||
return int(mw.getTot()), mw.err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *marshalWriter) writeRequest(r request) {
|
func (w *marshalWriter) writeRequest(r request) {
|
||||||
w.writeString(r.repo)
|
w.WriteString(r.repo)
|
||||||
w.writeString(r.name)
|
w.WriteString(r.name)
|
||||||
w.writeUint64(uint64(r.offset))
|
w.WriteUint64(uint64(r.offset))
|
||||||
w.writeUint32(r.size)
|
w.WriteUint32(r.size)
|
||||||
w.writeBytes(r.hash)
|
w.WriteBytes(r.hash)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *marshalWriter) writeResponse(data []byte) {
|
func (w *marshalWriter) writeResponse(data []byte) {
|
||||||
w.writeBytes(data)
|
w.WriteBytes(data)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *marshalWriter) writeOptions(opts map[string]string) {
|
func (w *marshalWriter) writeOptions(opts map[string]string) {
|
||||||
w.writeUint32(uint32(len(opts)))
|
w.WriteUint32(uint32(len(opts)))
|
||||||
for k, v := range opts {
|
for k, v := range opts {
|
||||||
w.writeString(k)
|
w.WriteString(k)
|
||||||
w.writeString(v)
|
w.WriteString(v)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *marshalReader) readHeader() header {
|
func ReadIndex(r io.Reader) (string, []FileInfo, error) {
|
||||||
return decodeHeader(r.readUint32())
|
mr := newMarshalReader(r)
|
||||||
|
repo, idx := mr.readIndex()
|
||||||
|
return repo, idx, mr.Err()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *marshalReader) readIndex() (string, []FileInfo) {
|
type marshalReader struct {
|
||||||
|
*xdr.Reader
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
func newMarshalReader(r io.Reader) marshalReader {
|
||||||
|
return marshalReader{
|
||||||
|
Reader: xdr.NewReader(r),
|
||||||
|
err: nil,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r marshalReader) Err() error {
|
||||||
|
if r.err != nil {
|
||||||
|
return r.err
|
||||||
|
}
|
||||||
|
return r.Reader.Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r marshalReader) readHeader() header {
|
||||||
|
return decodeHeader(r.ReadUint32())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r marshalReader) readIndex() (string, []FileInfo) {
|
||||||
var files []FileInfo
|
var files []FileInfo
|
||||||
repo := r.readString()
|
repo := r.ReadString()
|
||||||
nfiles := r.readUint32()
|
nfiles := r.ReadUint32()
|
||||||
if nfiles > maxNumFiles {
|
if nfiles > maxNumFiles {
|
||||||
r.err = ErrMaxFilesExceeded
|
r.err = ErrMaxFilesExceeded
|
||||||
return "", nil
|
return "", nil
|
||||||
@ -104,19 +140,19 @@ func (r *marshalReader) readIndex() (string, []FileInfo) {
|
|||||||
if nfiles > 0 {
|
if nfiles > 0 {
|
||||||
files = make([]FileInfo, nfiles)
|
files = make([]FileInfo, nfiles)
|
||||||
for i := range files {
|
for i := range files {
|
||||||
files[i].Name = r.readString()
|
files[i].Name = r.ReadString()
|
||||||
files[i].Flags = r.readUint32()
|
files[i].Flags = r.ReadUint32()
|
||||||
files[i].Modified = int64(r.readUint64())
|
files[i].Modified = int64(r.ReadUint64())
|
||||||
files[i].Version = r.readUint32()
|
files[i].Version = r.ReadUint32()
|
||||||
nblocks := r.readUint32()
|
nblocks := r.ReadUint32()
|
||||||
if nblocks > maxNumBlocks {
|
if nblocks > maxNumBlocks {
|
||||||
r.err = ErrMaxBlocksExceeded
|
r.err = ErrMaxBlocksExceeded
|
||||||
return "", nil
|
return "", nil
|
||||||
}
|
}
|
||||||
blocks := make([]BlockInfo, nblocks)
|
blocks := make([]BlockInfo, nblocks)
|
||||||
for j := range blocks {
|
for j := range blocks {
|
||||||
blocks[j].Size = r.readUint32()
|
blocks[j].Size = r.ReadUint32()
|
||||||
blocks[j].Hash = r.readBytes()
|
blocks[j].Hash = r.ReadBytes(buffers.Get(32))
|
||||||
}
|
}
|
||||||
files[i].Blocks = blocks
|
files[i].Blocks = blocks
|
||||||
}
|
}
|
||||||
@ -124,32 +160,26 @@ func (r *marshalReader) readIndex() (string, []FileInfo) {
|
|||||||
return repo, files
|
return repo, files
|
||||||
}
|
}
|
||||||
|
|
||||||
func ReadIndex(r io.Reader) (string, []FileInfo, error) {
|
func (r marshalReader) readRequest() request {
|
||||||
mr := marshalReader{r: r}
|
|
||||||
repo, idx := mr.readIndex()
|
|
||||||
return repo, idx, mr.err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *marshalReader) readRequest() request {
|
|
||||||
var req request
|
var req request
|
||||||
req.repo = r.readString()
|
req.repo = r.ReadString()
|
||||||
req.name = r.readString()
|
req.name = r.ReadString()
|
||||||
req.offset = int64(r.readUint64())
|
req.offset = int64(r.ReadUint64())
|
||||||
req.size = r.readUint32()
|
req.size = r.ReadUint32()
|
||||||
req.hash = r.readBytes()
|
req.hash = r.ReadBytes(buffers.Get(32))
|
||||||
return req
|
return req
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *marshalReader) readResponse() []byte {
|
func (r marshalReader) readResponse() []byte {
|
||||||
return r.readBytes()
|
return r.ReadBytes(buffers.Get(128 * 1024))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *marshalReader) readOptions() map[string]string {
|
func (r marshalReader) readOptions() map[string]string {
|
||||||
n := r.readUint32()
|
n := r.ReadUint32()
|
||||||
opts := make(map[string]string, n)
|
opts := make(map[string]string, n)
|
||||||
for i := 0; i < int(n); i++ {
|
for i := 0; i < int(n); i++ {
|
||||||
k := r.readString()
|
k := r.ReadString()
|
||||||
v := r.readString()
|
v := r.ReadString()
|
||||||
opts[k] = v
|
opts[k] = v
|
||||||
}
|
}
|
||||||
return opts
|
return opts
|
||||||
|
@ -34,10 +34,10 @@ func TestIndex(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var buf = new(bytes.Buffer)
|
var buf = new(bytes.Buffer)
|
||||||
var wr = marshalWriter{w: buf}
|
var wr = newMarshalWriter(buf)
|
||||||
wr.writeIndex("default", idx)
|
wr.writeIndex("default", idx)
|
||||||
|
|
||||||
var rd = marshalReader{r: buf}
|
var rd = newMarshalReader(buf)
|
||||||
var repo, idx2 = rd.readIndex()
|
var repo, idx2 = rd.readIndex()
|
||||||
|
|
||||||
if repo != "default" {
|
if repo != "default" {
|
||||||
@ -53,9 +53,9 @@ func TestRequest(t *testing.T) {
|
|||||||
f := func(repo, name string, offset int64, size uint32, hash []byte) bool {
|
f := func(repo, name string, offset int64, size uint32, hash []byte) bool {
|
||||||
var buf = new(bytes.Buffer)
|
var buf = new(bytes.Buffer)
|
||||||
var req = request{repo, name, offset, size, hash}
|
var req = request{repo, name, offset, size, hash}
|
||||||
var wr = marshalWriter{w: buf}
|
var wr = newMarshalWriter(buf)
|
||||||
wr.writeRequest(req)
|
wr.writeRequest(req)
|
||||||
var rd = marshalReader{r: buf}
|
var rd = newMarshalReader(buf)
|
||||||
var req2 = rd.readRequest()
|
var req2 = rd.readRequest()
|
||||||
return req.name == req2.name &&
|
return req.name == req2.name &&
|
||||||
req.offset == req2.offset &&
|
req.offset == req2.offset &&
|
||||||
@ -70,9 +70,9 @@ func TestRequest(t *testing.T) {
|
|||||||
func TestResponse(t *testing.T) {
|
func TestResponse(t *testing.T) {
|
||||||
f := func(data []byte) bool {
|
f := func(data []byte) bool {
|
||||||
var buf = new(bytes.Buffer)
|
var buf = new(bytes.Buffer)
|
||||||
var wr = marshalWriter{w: buf}
|
var wr = newMarshalWriter(buf)
|
||||||
wr.writeResponse(data)
|
wr.writeResponse(data)
|
||||||
var rd = marshalReader{r: buf}
|
var rd = newMarshalReader(buf)
|
||||||
var read = rd.readResponse()
|
var read = rd.readResponse()
|
||||||
return bytes.Compare(read, data) == 0
|
return bytes.Compare(read, data) == 0
|
||||||
}
|
}
|
||||||
@ -106,7 +106,7 @@ func BenchmarkWriteIndex(b *testing.B) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
var wr = marshalWriter{w: ioutil.Discard}
|
var wr = newMarshalWriter(ioutil.Discard)
|
||||||
|
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
wr.writeIndex("default", idx)
|
wr.writeIndex("default", idx)
|
||||||
@ -115,7 +115,7 @@ func BenchmarkWriteIndex(b *testing.B) {
|
|||||||
|
|
||||||
func BenchmarkWriteRequest(b *testing.B) {
|
func BenchmarkWriteRequest(b *testing.B) {
|
||||||
var req = request{"default", "blah blah", 1231323, 13123123, []byte("hash hash hash")}
|
var req = request{"default", "blah blah", 1231323, 13123123, []byte("hash hash hash")}
|
||||||
var wr = marshalWriter{w: ioutil.Discard}
|
var wr = newMarshalWriter(ioutil.Discard)
|
||||||
|
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
wr.writeRequest(req)
|
wr.writeRequest(req)
|
||||||
@ -131,10 +131,10 @@ func TestOptions(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var buf = new(bytes.Buffer)
|
var buf = new(bytes.Buffer)
|
||||||
var wr = marshalWriter{w: buf}
|
var wr = newMarshalWriter(buf)
|
||||||
wr.writeOptions(opts)
|
wr.writeOptions(opts)
|
||||||
|
|
||||||
var rd = marshalReader{r: buf}
|
var rd = newMarshalReader(buf)
|
||||||
var ropts = rd.readOptions()
|
var ropts = rd.readOptions()
|
||||||
|
|
||||||
if !reflect.DeepEqual(opts, ropts) {
|
if !reflect.DeepEqual(opts, ropts) {
|
||||||
|
@ -10,6 +10,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/calmh/syncthing/buffers"
|
"github.com/calmh/syncthing/buffers"
|
||||||
|
"github.com/calmh/syncthing/xdr"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@ -61,9 +62,9 @@ type Connection struct {
|
|||||||
id string
|
id string
|
||||||
receiver Model
|
receiver Model
|
||||||
reader io.Reader
|
reader io.Reader
|
||||||
mreader *marshalReader
|
mreader marshalReader
|
||||||
writer io.Writer
|
writer io.Writer
|
||||||
mwriter *marshalWriter
|
mwriter marshalWriter
|
||||||
closed bool
|
closed bool
|
||||||
awaiting map[int]chan asyncResult
|
awaiting map[int]chan asyncResult
|
||||||
nextId int
|
nextId int
|
||||||
@ -101,9 +102,9 @@ func NewConnection(nodeID string, reader io.Reader, writer io.Writer, receiver M
|
|||||||
id: nodeID,
|
id: nodeID,
|
||||||
receiver: receiver,
|
receiver: receiver,
|
||||||
reader: flrd,
|
reader: flrd,
|
||||||
mreader: &marshalReader{r: flrd},
|
mreader: marshalReader{Reader: xdr.NewReader(flrd)},
|
||||||
writer: flwr,
|
writer: flwr,
|
||||||
mwriter: &marshalWriter{w: flwr},
|
mwriter: marshalWriter{Writer: xdr.NewWriter(flwr)},
|
||||||
awaiting: make(map[int]chan asyncResult),
|
awaiting: make(map[int]chan asyncResult),
|
||||||
indexSent: make(map[string]map[string][2]int64),
|
indexSent: make(map[string]map[string][2]int64),
|
||||||
}
|
}
|
||||||
@ -168,8 +169,8 @@ func (c *Connection) Index(repo string, idx []FileInfo) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
c.close(err)
|
c.close(err)
|
||||||
return
|
return
|
||||||
} else if c.mwriter.err != nil {
|
} else if c.mwriter.Err() != nil {
|
||||||
c.close(c.mwriter.err)
|
c.close(c.mwriter.Err())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -185,10 +186,10 @@ func (c *Connection) Request(repo string, name string, offset int64, size uint32
|
|||||||
c.awaiting[c.nextId] = rc
|
c.awaiting[c.nextId] = rc
|
||||||
c.mwriter.writeHeader(header{0, c.nextId, messageTypeRequest})
|
c.mwriter.writeHeader(header{0, c.nextId, messageTypeRequest})
|
||||||
c.mwriter.writeRequest(request{repo, name, offset, size, hash})
|
c.mwriter.writeRequest(request{repo, name, offset, size, hash})
|
||||||
if c.mwriter.err != nil {
|
if c.mwriter.Err() != nil {
|
||||||
c.Unlock()
|
c.Unlock()
|
||||||
c.close(c.mwriter.err)
|
c.close(c.mwriter.Err())
|
||||||
return nil, c.mwriter.err
|
return nil, c.mwriter.Err()
|
||||||
}
|
}
|
||||||
err := c.flush()
|
err := c.flush()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -220,9 +221,9 @@ func (c *Connection) ping() bool {
|
|||||||
c.Unlock()
|
c.Unlock()
|
||||||
c.close(err)
|
c.close(err)
|
||||||
return false
|
return false
|
||||||
} else if c.mwriter.err != nil {
|
} else if c.mwriter.Err() != nil {
|
||||||
c.Unlock()
|
c.Unlock()
|
||||||
c.close(c.mwriter.err)
|
c.close(c.mwriter.Err())
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
c.nextId = (c.nextId + 1) & 0xfff
|
c.nextId = (c.nextId + 1) & 0xfff
|
||||||
@ -269,8 +270,8 @@ func (c *Connection) readerLoop() {
|
|||||||
loop:
|
loop:
|
||||||
for {
|
for {
|
||||||
hdr := c.mreader.readHeader()
|
hdr := c.mreader.readHeader()
|
||||||
if c.mreader.err != nil {
|
if c.mreader.Err() != nil {
|
||||||
c.close(c.mreader.err)
|
c.close(c.mreader.Err())
|
||||||
break loop
|
break loop
|
||||||
}
|
}
|
||||||
if hdr.version != 0 {
|
if hdr.version != 0 {
|
||||||
@ -282,8 +283,8 @@ loop:
|
|||||||
case messageTypeIndex:
|
case messageTypeIndex:
|
||||||
repo, files := c.mreader.readIndex()
|
repo, files := c.mreader.readIndex()
|
||||||
_ = repo
|
_ = repo
|
||||||
if c.mreader.err != nil {
|
if c.mreader.Err() != nil {
|
||||||
c.close(c.mreader.err)
|
c.close(c.mreader.Err())
|
||||||
break loop
|
break loop
|
||||||
} else {
|
} else {
|
||||||
c.receiver.Index(c.id, files)
|
c.receiver.Index(c.id, files)
|
||||||
@ -295,8 +296,8 @@ loop:
|
|||||||
case messageTypeIndexUpdate:
|
case messageTypeIndexUpdate:
|
||||||
repo, files := c.mreader.readIndex()
|
repo, files := c.mreader.readIndex()
|
||||||
_ = repo
|
_ = repo
|
||||||
if c.mreader.err != nil {
|
if c.mreader.Err() != nil {
|
||||||
c.close(c.mreader.err)
|
c.close(c.mreader.Err())
|
||||||
break loop
|
break loop
|
||||||
} else {
|
} else {
|
||||||
c.receiver.IndexUpdate(c.id, files)
|
c.receiver.IndexUpdate(c.id, files)
|
||||||
@ -304,8 +305,8 @@ loop:
|
|||||||
|
|
||||||
case messageTypeRequest:
|
case messageTypeRequest:
|
||||||
req := c.mreader.readRequest()
|
req := c.mreader.readRequest()
|
||||||
if c.mreader.err != nil {
|
if c.mreader.Err() != nil {
|
||||||
c.close(c.mreader.err)
|
c.close(c.mreader.Err())
|
||||||
break loop
|
break loop
|
||||||
}
|
}
|
||||||
go c.processRequest(hdr.msgID, req)
|
go c.processRequest(hdr.msgID, req)
|
||||||
@ -313,8 +314,8 @@ loop:
|
|||||||
case messageTypeResponse:
|
case messageTypeResponse:
|
||||||
data := c.mreader.readResponse()
|
data := c.mreader.readResponse()
|
||||||
|
|
||||||
if c.mreader.err != nil {
|
if c.mreader.Err() != nil {
|
||||||
c.close(c.mreader.err)
|
c.close(c.mreader.Err())
|
||||||
break loop
|
break loop
|
||||||
} else {
|
} else {
|
||||||
c.Lock()
|
c.Lock()
|
||||||
@ -323,21 +324,21 @@ loop:
|
|||||||
c.Unlock()
|
c.Unlock()
|
||||||
|
|
||||||
if ok {
|
if ok {
|
||||||
rc <- asyncResult{data, c.mreader.err}
|
rc <- asyncResult{data, c.mreader.Err()}
|
||||||
close(rc)
|
close(rc)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
case messageTypePing:
|
case messageTypePing:
|
||||||
c.Lock()
|
c.Lock()
|
||||||
c.mwriter.writeUint32(encodeHeader(header{0, hdr.msgID, messageTypePong}))
|
c.mwriter.WriteUint32(encodeHeader(header{0, hdr.msgID, messageTypePong}))
|
||||||
err := c.flush()
|
err := c.flush()
|
||||||
c.Unlock()
|
c.Unlock()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.close(err)
|
c.close(err)
|
||||||
break loop
|
break loop
|
||||||
} else if c.mwriter.err != nil {
|
} else if c.mwriter.Err() != nil {
|
||||||
c.close(c.mwriter.err)
|
c.close(c.mwriter.Err())
|
||||||
break loop
|
break loop
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -376,9 +377,9 @@ func (c *Connection) processRequest(msgID int, req request) {
|
|||||||
data, _ := c.receiver.Request(c.id, req.repo, req.name, req.offset, req.size, req.hash)
|
data, _ := c.receiver.Request(c.id, req.repo, req.name, req.offset, req.size, req.hash)
|
||||||
|
|
||||||
c.Lock()
|
c.Lock()
|
||||||
c.mwriter.writeUint32(encodeHeader(header{0, msgID, messageTypeResponse}))
|
c.mwriter.WriteUint32(encodeHeader(header{0, msgID, messageTypeResponse}))
|
||||||
c.mwriter.writeResponse(data)
|
c.mwriter.writeResponse(data)
|
||||||
err := c.mwriter.err
|
err := c.mwriter.Err()
|
||||||
if err == nil {
|
if err == nil {
|
||||||
err = c.flush()
|
err = c.flush()
|
||||||
}
|
}
|
||||||
@ -427,8 +428,8 @@ func (c *Connection) Statistics() Statistics {
|
|||||||
|
|
||||||
stats := Statistics{
|
stats := Statistics{
|
||||||
At: time.Now(),
|
At: time.Now(),
|
||||||
InBytesTotal: int(c.mreader.getTot()),
|
InBytesTotal: int(c.mreader.Tot()),
|
||||||
OutBytesTotal: int(c.mwriter.getTot()),
|
OutBytesTotal: int(c.mwriter.Tot()),
|
||||||
}
|
}
|
||||||
|
|
||||||
return stats
|
return stats
|
||||||
|
@ -22,23 +22,6 @@ func TestHeaderFunctions(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestPad(t *testing.T) {
|
|
||||||
tests := [][]int{
|
|
||||||
{0, 0},
|
|
||||||
{1, 3},
|
|
||||||
{2, 2},
|
|
||||||
{3, 1},
|
|
||||||
{4, 0},
|
|
||||||
{32, 0},
|
|
||||||
{33, 3},
|
|
||||||
}
|
|
||||||
for _, tc := range tests {
|
|
||||||
if p := pad(tc[0]); p != tc[1] {
|
|
||||||
t.Errorf("Incorrect padding for %d bytes, %d != %d", tc[0], p, tc[1])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestPing(t *testing.T) {
|
func TestPing(t *testing.T) {
|
||||||
ar, aw := io.Pipe()
|
ar, aw := io.Pipe()
|
||||||
br, bw := io.Pipe()
|
br, bw := io.Pipe()
|
||||||
|
65
xdr/reader.go
Normal file
65
xdr/reader.go
Normal file
@ -0,0 +1,65 @@
|
|||||||
|
package xdr
|
||||||
|
|
||||||
|
import "io"
|
||||||
|
|
||||||
|
type Reader struct {
|
||||||
|
r io.Reader
|
||||||
|
tot uint64
|
||||||
|
err error
|
||||||
|
b [8]byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewReader(r io.Reader) *Reader {
|
||||||
|
return &Reader{
|
||||||
|
r: r,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Reader) ReadString() string {
|
||||||
|
return string(r.ReadBytes(nil))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Reader) ReadBytes(dst []byte) []byte {
|
||||||
|
if r.err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
l := int(r.ReadUint32())
|
||||||
|
if r.err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if l+pad(l) > len(dst) {
|
||||||
|
dst = make([]byte, l+pad(l))
|
||||||
|
} else {
|
||||||
|
dst = dst[:l+pad(l)]
|
||||||
|
}
|
||||||
|
_, r.err = io.ReadFull(r.r, dst)
|
||||||
|
r.tot += uint64(l + pad(l))
|
||||||
|
return dst[:l]
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Reader) ReadUint32() uint32 {
|
||||||
|
if r.err != nil {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
_, r.err = io.ReadFull(r.r, r.b[:4])
|
||||||
|
r.tot += 8
|
||||||
|
return uint32(r.b[3]) | uint32(r.b[2])<<8 | uint32(r.b[1])<<16 | uint32(r.b[0])<<24
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Reader) ReadUint64() uint64 {
|
||||||
|
if r.err != nil {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
_, r.err = io.ReadFull(r.r, r.b[:8])
|
||||||
|
r.tot += 8
|
||||||
|
return uint64(r.b[7]) | uint64(r.b[6])<<8 | uint64(r.b[5])<<16 | uint64(r.b[4])<<24 |
|
||||||
|
uint64(r.b[3])<<32 | uint64(r.b[2])<<40 | uint64(r.b[1])<<48 | uint64(r.b[0])<<56
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Reader) Tot() uint64 {
|
||||||
|
return r.tot
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Reader) Err() error {
|
||||||
|
return r.err
|
||||||
|
}
|
95
xdr/writer.go
Normal file
95
xdr/writer.go
Normal file
@ -0,0 +1,95 @@
|
|||||||
|
package xdr
|
||||||
|
|
||||||
|
import "io"
|
||||||
|
|
||||||
|
func pad(l int) int {
|
||||||
|
d := l % 4
|
||||||
|
if d == 0 {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
return 4 - d
|
||||||
|
}
|
||||||
|
|
||||||
|
var padBytes = []byte{0, 0, 0}
|
||||||
|
|
||||||
|
type Writer struct {
|
||||||
|
w io.Writer
|
||||||
|
tot uint64
|
||||||
|
err error
|
||||||
|
b [8]byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewWriter(w io.Writer) *Writer {
|
||||||
|
return &Writer{
|
||||||
|
w: w,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *Writer) WriteString(s string) (int, error) {
|
||||||
|
return w.WriteBytes([]byte(s))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *Writer) WriteBytes(bs []byte) (int, error) {
|
||||||
|
if w.err != nil {
|
||||||
|
return 0, w.err
|
||||||
|
}
|
||||||
|
|
||||||
|
w.WriteUint32(uint32(len(bs)))
|
||||||
|
if w.err != nil {
|
||||||
|
return 0, w.err
|
||||||
|
}
|
||||||
|
|
||||||
|
var l, n int
|
||||||
|
n, w.err = w.w.Write(bs)
|
||||||
|
l += n
|
||||||
|
|
||||||
|
if p := pad(len(bs)); w.err == nil && p > 0 {
|
||||||
|
n, w.err = w.w.Write(padBytes[:p])
|
||||||
|
l += n
|
||||||
|
}
|
||||||
|
|
||||||
|
w.tot += uint64(l)
|
||||||
|
return l, w.err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *Writer) WriteUint32(v uint32) (int, error) {
|
||||||
|
if w.err != nil {
|
||||||
|
return 0, w.err
|
||||||
|
}
|
||||||
|
w.b[0] = byte(v >> 24)
|
||||||
|
w.b[1] = byte(v >> 16)
|
||||||
|
w.b[2] = byte(v >> 8)
|
||||||
|
w.b[3] = byte(v)
|
||||||
|
|
||||||
|
var l int
|
||||||
|
l, w.err = w.w.Write(w.b[:4])
|
||||||
|
w.tot += uint64(l)
|
||||||
|
return l, w.err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *Writer) WriteUint64(v uint64) (int, error) {
|
||||||
|
if w.err != nil {
|
||||||
|
return 0, w.err
|
||||||
|
}
|
||||||
|
w.b[0] = byte(v >> 56)
|
||||||
|
w.b[1] = byte(v >> 48)
|
||||||
|
w.b[2] = byte(v >> 40)
|
||||||
|
w.b[3] = byte(v >> 32)
|
||||||
|
w.b[4] = byte(v >> 24)
|
||||||
|
w.b[5] = byte(v >> 16)
|
||||||
|
w.b[6] = byte(v >> 8)
|
||||||
|
w.b[7] = byte(v)
|
||||||
|
|
||||||
|
var l int
|
||||||
|
l, w.err = w.w.Write(w.b[:8])
|
||||||
|
w.tot += uint64(l)
|
||||||
|
return l, w.err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *Writer) Tot() uint64 {
|
||||||
|
return w.tot
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *Writer) Err() error {
|
||||||
|
return w.err
|
||||||
|
}
|
57
xdr/xdr_test.go
Normal file
57
xdr/xdr_test.go
Normal file
@ -0,0 +1,57 @@
|
|||||||
|
package xdr
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"testing"
|
||||||
|
"testing/quick"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestPad(t *testing.T) {
|
||||||
|
tests := [][]int{
|
||||||
|
{0, 0},
|
||||||
|
{1, 3},
|
||||||
|
{2, 2},
|
||||||
|
{3, 1},
|
||||||
|
{4, 0},
|
||||||
|
{32, 0},
|
||||||
|
{33, 3},
|
||||||
|
}
|
||||||
|
for _, tc := range tests {
|
||||||
|
if p := pad(tc[0]); p != tc[1] {
|
||||||
|
t.Errorf("Incorrect padding for %d bytes, %d != %d", tc[0], p, tc[1])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBytesNil(t *testing.T) {
|
||||||
|
fn := func(bs []byte) bool {
|
||||||
|
var b = new(bytes.Buffer)
|
||||||
|
var w = NewWriter(b)
|
||||||
|
var r = NewReader(b)
|
||||||
|
w.WriteBytes(bs)
|
||||||
|
w.WriteBytes(bs)
|
||||||
|
r.ReadBytes(nil)
|
||||||
|
res := r.ReadBytes(nil)
|
||||||
|
return bytes.Compare(bs, res) == 0
|
||||||
|
}
|
||||||
|
if err := quick.Check(fn, nil); err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBytesGiven(t *testing.T) {
|
||||||
|
fn := func(bs []byte) bool {
|
||||||
|
var b = new(bytes.Buffer)
|
||||||
|
var w = NewWriter(b)
|
||||||
|
var r = NewReader(b)
|
||||||
|
w.WriteBytes(bs)
|
||||||
|
w.WriteBytes(bs)
|
||||||
|
res := make([]byte, 12)
|
||||||
|
res = r.ReadBytes(res)
|
||||||
|
res = r.ReadBytes(res)
|
||||||
|
return bytes.Compare(bs, res) == 0
|
||||||
|
}
|
||||||
|
if err := quick.Check(fn, nil); err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user