diff --git a/lib/config/deviceconfiguration.go b/lib/config/deviceconfiguration.go index 1315c863a..802571109 100644 --- a/lib/config/deviceconfiguration.go +++ b/lib/config/deviceconfiguration.go @@ -28,6 +28,7 @@ type DeviceConfiguration struct { MaxRecvKbps int `xml:"maxRecvKbps" json:"maxRecvKbps"` IgnoredFolders []ObservedFolder `xml:"ignoredFolder" json:"ignoredFolders"` PendingFolders []ObservedFolder `xml:"pendingFolder" json:"pendingFolders"` + MaxRequestKiB int `xml:"maxRequestKiB" json:"maxRequestKiB"` } func NewDeviceConfiguration(id protocol.DeviceID, name string) DeviceConfiguration { diff --git a/lib/model/bytesemaphore.go b/lib/model/bytesemaphore.go new file mode 100644 index 000000000..04a00e76f --- /dev/null +++ b/lib/model/bytesemaphore.go @@ -0,0 +1,50 @@ +// Copyright (C) 2018 The Syncthing Authors. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at https://mozilla.org/MPL/2.0/. + +package model + +import "sync" + +type byteSemaphore struct { + max int + available int + mut sync.Mutex + cond *sync.Cond +} + +func newByteSemaphore(max int) *byteSemaphore { + s := byteSemaphore{ + max: max, + available: max, + } + s.cond = sync.NewCond(&s.mut) + return &s +} + +func (s *byteSemaphore) take(bytes int) { + if bytes > s.max { + bytes = s.max + } + s.mut.Lock() + for bytes > s.available { + s.cond.Wait() + } + s.available -= bytes + s.mut.Unlock() +} + +func (s *byteSemaphore) give(bytes int) { + if bytes > s.max { + bytes = s.max + } + s.mut.Lock() + if s.available+bytes > s.max { + panic("bug: can never give more than max") + } + s.available += bytes + s.cond.Broadcast() + s.mut.Unlock() +} diff --git a/lib/model/folder_sendrecv.go b/lib/model/folder_sendrecv.go index fe6988691..1e75b0319 100644 --- a/lib/model/folder_sendrecv.go +++ b/lib/model/folder_sendrecv.go @@ -15,7 +15,6 @@ import ( "runtime" "sort" "strings" - stdsync "sync" "time" "github.com/syncthing/syncthing/lib/config" @@ -1147,7 +1146,10 @@ func (f *sendReceiveFolder) shortcutFile(file, curFile protocol.FileInfo, dbUpda // copierRoutine reads copierStates until the in channel closes and performs // the relevant copies when possible, or passes it to the puller routine. func (f *sendReceiveFolder) copierRoutine(in <-chan copyBlocksState, pullChan chan<- pullBlockState, out chan<- *sharedPullerState) { - buf := make([]byte, protocol.MinBlockSize) + buf := protocol.BufferPool.Get(protocol.MinBlockSize) + defer func() { + protocol.BufferPool.Put(buf) + }() for state := range in { dstFd, err := state.tempFile() @@ -1223,11 +1225,7 @@ func (f *sendReceiveFolder) copierRoutine(in <-chan copyBlocksState, pullChan ch continue } - if s := int(block.Size); s > cap(buf) { - buf = make([]byte, s) - } else { - buf = buf[:s] - } + buf = protocol.BufferPool.Upgrade(buf, int(block.Size)) found, err := weakHashFinder.Iterate(block.WeakHash, buf, func(offset int64) bool { if verifyBuffer(buf, block) != nil { @@ -1935,41 +1933,3 @@ func componentCount(name string) int { } return count } - -type byteSemaphore struct { - max int - available int - mut stdsync.Mutex - cond *stdsync.Cond -} - -func newByteSemaphore(max int) *byteSemaphore { - s := byteSemaphore{ - max: max, - available: max, - } - s.cond = stdsync.NewCond(&s.mut) - return &s -} - -func (s *byteSemaphore) take(bytes int) { - if bytes > s.max { - panic("bug: more than max bytes will never be available") - } - s.mut.Lock() - for bytes > s.available { - s.cond.Wait() - } - s.available -= bytes - s.mut.Unlock() -} - -func (s *byteSemaphore) give(bytes int) { - s.mut.Lock() - if s.available+bytes > s.max { - panic("bug: can never give more than max") - } - s.available += bytes - s.cond.Broadcast() - s.mut.Unlock() -} diff --git a/lib/model/model.go b/lib/model/model.go index a92d833f7..7502c64dc 100644 --- a/lib/model/model.go +++ b/lib/model/model.go @@ -105,6 +105,7 @@ type Model struct { pmut sync.RWMutex // protects the below conn map[protocol.DeviceID]connections.Connection + connRequestLimiters map[protocol.DeviceID]*byteSemaphore closed map[protocol.DeviceID]chan struct{} helloMessages map[protocol.DeviceID]protocol.HelloResult deviceDownloads map[protocol.DeviceID]*deviceDownloadState @@ -158,6 +159,7 @@ func NewModel(cfg *config.Wrapper, id protocol.DeviceID, clientName, clientVersi folderRunnerTokens: make(map[string][]suture.ServiceToken), folderStatRefs: make(map[string]*stats.FolderStatisticsReference), conn: make(map[protocol.DeviceID]connections.Connection), + connRequestLimiters: make(map[protocol.DeviceID]*byteSemaphore), closed: make(map[protocol.DeviceID]chan struct{}), helloMessages: make(map[protocol.DeviceID]protocol.HelloResult), deviceDownloads: make(map[protocol.DeviceID]*deviceDownloadState), @@ -1281,6 +1283,7 @@ func (m *Model) Closed(conn protocol.Connection, err error) { m.progressEmitter.temporaryIndexUnsubscribe(conn) } delete(m.conn, device) + delete(m.connRequestLimiters, device) delete(m.helloMessages, device) delete(m.deviceDownloads, device) delete(m.remotePausedFolders, device) @@ -1314,19 +1317,40 @@ func (m *Model) closeLocked(device protocol.DeviceID) { closeRawConn(conn) } +// Implements protocol.RequestResponse +type requestResponse struct { + data []byte + closed chan struct{} + once stdsync.Once +} + +func newRequestResponse(size int) *requestResponse { + return &requestResponse{ + data: protocol.BufferPool.Get(size), + closed: make(chan struct{}), + } +} + +func (r *requestResponse) Data() []byte { + return r.data +} + +func (r *requestResponse) Close() { + r.once.Do(func() { + protocol.BufferPool.Put(r.data) + close(r.closed) + }) +} + +func (r *requestResponse) Wait() { + <-r.closed +} + // Request returns the specified data segment by reading it from local disk. // Implements the protocol.Model interface. -func (m *Model) Request(deviceID protocol.DeviceID, folder, name string, offset int64, hash []byte, weakHash uint32, fromTemporary bool, buf []byte) error { - if offset < 0 { - return protocol.ErrInvalid - } - - if cfg, ok := m.cfg.Folder(folder); !ok || !cfg.SharedWith(deviceID) { - l.Warnf("Request from %s for file %s in unshared folder %q", deviceID, name, folder) - return protocol.ErrNoSuchFile - } else if cfg.Paused { - l.Debugf("Request from %s for file %s in paused folder %q", deviceID, name, folder) - return protocol.ErrInvalid +func (m *Model) Request(deviceID protocol.DeviceID, folder, name string, size int32, offset int64, hash []byte, weakHash uint32, fromTemporary bool) (out protocol.RequestResponse, err error) { + if size < 0 || offset < 0 { + return nil, protocol.ErrInvalid } m.fmut.RLock() @@ -1337,35 +1361,69 @@ func (m *Model) Request(deviceID protocol.DeviceID, folder, name string, offset // The folder might be already unpaused in the config, but not yet // in the model. l.Debugf("Request from %s for file %s in unstarted folder %q", deviceID, name, folder) - return protocol.ErrInvalid + return nil, protocol.ErrInvalid + } + + if !folderCfg.SharedWith(deviceID) { + l.Warnf("Request from %s for file %s in unshared folder %q", deviceID, name, folder) + return nil, protocol.ErrNoSuchFile + } + if folderCfg.Paused { + l.Debugf("Request from %s for file %s in paused folder %q", deviceID, name, folder) + return nil, protocol.ErrInvalid } // Make sure the path is valid and in canonical form - var err error if name, err = fs.Canonicalize(name); err != nil { l.Debugf("Request from %s in folder %q for invalid filename %s", deviceID, folder, name) - return protocol.ErrInvalid + return nil, protocol.ErrInvalid } if deviceID != protocol.LocalDeviceID { - l.Debugf("%v REQ(in): %s: %q / %q o=%d s=%d t=%v", m, deviceID, folder, name, offset, len(buf), fromTemporary) + l.Debugf("%v REQ(in): %s: %q / %q o=%d s=%d t=%v", m, deviceID, folder, name, offset, size, fromTemporary) + } + + if fs.IsInternal(name) { + l.Debugf("%v REQ(in) for internal file: %s: %q / %q o=%d s=%d", m, deviceID, folder, name, offset, size) + return nil, protocol.ErrNoSuchFile + } + + if folderIgnores.Match(name).IsIgnored() { + l.Debugf("%v REQ(in) for ignored file: %s: %q / %q o=%d s=%d", m, deviceID, folder, name, offset, size) + return nil, protocol.ErrNoSuchFile } folderFs := folderCfg.Filesystem() - if fs.IsInternal(name) { - l.Debugf("%v REQ(in) for internal file: %s: %q / %q o=%d s=%d", m, deviceID, folder, name, offset, len(buf)) - return protocol.ErrNoSuchFile - } - - if folderIgnores.Match(name).IsIgnored() { - l.Debugf("%v REQ(in) for ignored file: %s: %q / %q o=%d s=%d", m, deviceID, folder, name, offset, len(buf)) - return protocol.ErrNoSuchFile - } - if err := osutil.TraversesSymlink(folderFs, filepath.Dir(name)); err != nil { - l.Debugf("%v REQ(in) traversal check: %s - %s: %q / %q o=%d s=%d", m, err, deviceID, folder, name, offset, len(buf)) - return protocol.ErrNoSuchFile + l.Debugf("%v REQ(in) traversal check: %s - %s: %q / %q o=%d s=%d", m, err, deviceID, folder, name, offset, size) + return nil, protocol.ErrNoSuchFile + } + + // Restrict parallel requests by connection/device + + m.pmut.RLock() + limiter := m.connRequestLimiters[deviceID] + m.pmut.RUnlock() + + if limiter != nil { + limiter.take(int(size)) + } + + // The requestResponse releases the bytes to the limiter when its Close method is called. + res := newRequestResponse(int(size)) + defer func() { + // Close it ourselves if it isn't returned due to an error + if err != nil { + res.Close() + } + }() + + if limiter != nil { + go func() { + res.Wait() + limiter.give(int(size)) + }() } // Only check temp files if the flag is set, and if we are set to advertise @@ -1376,11 +1434,12 @@ func (m *Model) Request(deviceID protocol.DeviceID, folder, name string, offset if info, err := folderFs.Lstat(tempFn); err != nil || !info.IsRegular() { // Reject reads for anything that doesn't exist or is something // other than a regular file. - return protocol.ErrNoSuchFile + l.Debugf("%v REQ(in) failed stating temp file (%v): %s: %q / %q o=%d s=%d", m, err, deviceID, folder, name, offset, size) + return nil, protocol.ErrNoSuchFile } - err := readOffsetIntoBuf(folderFs, tempFn, offset, buf) - if err == nil && scanner.Validate(buf, hash, weakHash) { - return nil + err := readOffsetIntoBuf(folderFs, tempFn, offset, res.data) + if err == nil && scanner.Validate(res.data, hash, weakHash) { + return res, nil } // Fall through to reading from a non-temp file, just incase the temp // file has finished downloading. @@ -1389,21 +1448,25 @@ func (m *Model) Request(deviceID protocol.DeviceID, folder, name string, offset if info, err := folderFs.Lstat(name); err != nil || !info.IsRegular() { // Reject reads for anything that doesn't exist or is something // other than a regular file. - return protocol.ErrNoSuchFile + l.Debugf("%v REQ(in) failed stating file (%v): %s: %q / %q o=%d s=%d", m, err, deviceID, folder, name, offset, size) + return nil, protocol.ErrNoSuchFile } - if err = readOffsetIntoBuf(folderFs, name, offset, buf); fs.IsNotExist(err) { - return protocol.ErrNoSuchFile + if err := readOffsetIntoBuf(folderFs, name, offset, res.data); fs.IsNotExist(err) { + l.Debugf("%v REQ(in) file doesn't exist: %s: %q / %q o=%d s=%d", m, deviceID, folder, name, offset, size) + return nil, protocol.ErrNoSuchFile } else if err != nil { - return protocol.ErrGeneric + l.Debugf("%v REQ(in) failed reading file (%v): %s: %q / %q o=%d s=%d", m, err, deviceID, folder, name, offset, size) + return nil, protocol.ErrGeneric } - if !scanner.Validate(buf, hash, weakHash) { - m.recheckFile(deviceID, folderFs, folder, name, int(offset)/len(buf), hash) - return protocol.ErrNoSuchFile + if !scanner.Validate(res.data, hash, weakHash) { + m.recheckFile(deviceID, folderFs, folder, name, int(offset)/int(size), hash) + l.Debugf("%v REQ(in) failed validating data (%v): %s: %q / %q o=%d s=%d", m, err, deviceID, folder, name, offset, size) + return nil, protocol.ErrNoSuchFile } - return nil + return res, nil } func (m *Model) recheckFile(deviceID protocol.DeviceID, folderFs fs.Filesystem, folder, name string, blockIndex int, hash []byte) { @@ -1598,6 +1661,11 @@ func (m *Model) GetHello(id protocol.DeviceID) protocol.HelloIntf { // folder changes. func (m *Model) AddConnection(conn connections.Connection, hello protocol.HelloResult) { deviceID := conn.ID() + device, ok := m.cfg.Device(deviceID) + if !ok { + l.Infoln("Trying to add connection to unknown device") + return + } m.pmut.Lock() if oldConn, ok := m.conn[deviceID]; ok { @@ -1617,6 +1685,13 @@ func (m *Model) AddConnection(conn connections.Connection, hello protocol.HelloR m.conn[deviceID] = conn m.closed[deviceID] = make(chan struct{}) m.deviceDownloads[deviceID] = newDeviceDownloadState() + // 0: default, <0: no limiting + switch { + case device.MaxRequestKiB > 0: + m.connRequestLimiters[deviceID] = newByteSemaphore(1024 * device.MaxRequestKiB) + case device.MaxRequestKiB == 0: + m.connRequestLimiters[deviceID] = newByteSemaphore(1024 * defaultPullerPendingKiB) + } m.helloMessages[deviceID] = hello @@ -1644,8 +1719,7 @@ func (m *Model) AddConnection(conn connections.Connection, hello protocol.HelloR cm := m.generateClusterConfig(deviceID) conn.ClusterConfig(cm) - device, ok := m.cfg.Devices()[deviceID] - if ok && (device.Name == "" || m.cfg.Options().OverwriteRemoteDevNames) && hello.DeviceName != "" { + if (device.Name == "" || m.cfg.Options().OverwriteRemoteDevNames) && hello.DeviceName != "" { device.Name = hello.DeviceName m.cfg.SetDevice(device) m.cfg.Save() diff --git a/lib/model/model_test.go b/lib/model/model_test.go index b5e381b74..c094612f6 100644 --- a/lib/model/model_test.go +++ b/lib/model/model_test.go @@ -183,45 +183,42 @@ func TestRequest(t *testing.T) { defer m.Stop() m.ScanFolder("default") - bs := make([]byte, protocol.MinBlockSize) - // Existing, shared file - bs = bs[:6] - err := m.Request(device1, "default", "foo", 0, nil, 0, false, bs) + res, err := m.Request(device1, "default", "foo", 6, 0, nil, 0, false) if err != nil { t.Error(err) } + bs := res.Data() if !bytes.Equal(bs, []byte("foobar")) { t.Errorf("Incorrect data from request: %q", string(bs)) } // Existing, nonshared file - err = m.Request(device2, "default", "foo", 0, nil, 0, false, bs) + _, err = m.Request(device2, "default", "foo", 6, 0, nil, 0, false) if err == nil { t.Error("Unexpected nil error on insecure file read") } // Nonexistent file - err = m.Request(device1, "default", "nonexistent", 0, nil, 0, false, bs) + _, err = m.Request(device1, "default", "nonexistent", 6, 0, nil, 0, false) if err == nil { t.Error("Unexpected nil error on insecure file read") } // Shared folder, but disallowed file name - err = m.Request(device1, "default", "../walk.go", 0, nil, 0, false, bs) + _, err = m.Request(device1, "default", "../walk.go", 6, 0, nil, 0, false) if err == nil { t.Error("Unexpected nil error on insecure file read") } // Negative offset - err = m.Request(device1, "default", "foo", -4, nil, 0, false, bs[:0]) + _, err = m.Request(device1, "default", "foo", -4, 0, nil, 0, false) if err == nil { t.Error("Unexpected nil error on insecure file read") } // Larger block than available - bs = bs[:42] - err = m.Request(device1, "default", "foo", 0, nil, 0, false, bs) + _, err = m.Request(device1, "default", "foo", 42, 0, nil, 0, false) if err == nil { t.Error("Unexpected nil error on insecure file read") } @@ -536,7 +533,7 @@ func BenchmarkRequestInSingleFile(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { - if err := m.Request(device1, "default", "request/for/a/file/in/a/couple/of/dirs/128k", 0, nil, 0, false, buf); err != nil { + if _, err := m.Request(device1, "default", "request/for/a/file/in/a/couple/of/dirs/128k", 128<<10, 0, nil, 0, false); err != nil { b.Error(err) } } @@ -3667,6 +3664,7 @@ func TestFolderRestartZombies(t *testing.T) { // would leave more than one folder runner alive. wrapper := createTmpWrapper(defaultCfg.Copy()) + defer os.Remove(wrapper.ConfigPath()) folderCfg, _ := wrapper.Folder("default") folderCfg.FilesystemType = fs.FilesystemTypeFake wrapper.SetFolder(folderCfg) @@ -3759,3 +3757,45 @@ func (c *alwaysChanged) Seen(fs fs.Filesystem, name string) bool { func (c *alwaysChanged) Changed() bool { return true } + +func TestRequestLimit(t *testing.T) { + cfg := defaultCfg.Copy() + cfg.Devices = append(cfg.Devices, config.NewDeviceConfiguration(device2, "device2")) + cfg.Devices[1].MaxRequestKiB = 1 + cfg.Folders[0].Devices = []config.FolderDeviceConfiguration{ + {DeviceID: device1}, + {DeviceID: device2}, + } + m, _, wrapper := setupModelWithConnectionManual(cfg) + defer m.Stop() + defer os.Remove(wrapper.ConfigPath()) + + file := "tmpfile" + befReq := time.Now() + first, err := m.Request(device2, "default", file, 2000, 0, nil, 0, false) + if err != nil { + t.Fatalf("First request failed: %v", err) + } + reqDur := time.Since(befReq) + returned := make(chan struct{}) + go func() { + second, err := m.Request(device2, "default", file, 2000, 0, nil, 0, false) + if err != nil { + t.Fatalf("Second request failed: %v", err) + } + close(returned) + second.Close() + }() + time.Sleep(10 * reqDur) + select { + case <-returned: + t.Fatalf("Second request returned before first was done") + default: + } + first.Close() + select { + case <-returned: + case <-time.After(time.Second): + t.Fatalf("Second request did not return after first was done") + } +} diff --git a/lib/model/requests_test.go b/lib/model/requests_test.go index 0d378f2fb..1a48ec37f 100644 --- a/lib/model/requests_test.go +++ b/lib/model/requests_test.go @@ -98,9 +98,8 @@ func TestSymlinkTraversalRead(t *testing.T) { <-done // Request a file by traversing the symlink - buf := make([]byte, 10) - err := m.Request(device1, "default", "symlink/requests_test.go", 0, nil, 0, false, buf) - if err == nil || !bytes.Equal(buf, make([]byte, 10)) { + res, err := m.Request(device1, "default", "symlink/requests_test.go", 10, 0, nil, 0, false) + if err == nil || res != nil { t.Error("Managed to traverse symlink") } } @@ -225,6 +224,7 @@ func TestRequestVersioningSymlinkAttack(t *testing.T) { defer os.RemoveAll(tmpDir) cfg := defaultCfgWrapper.RawCopy() + cfg.Devices = append(cfg.Devices, config.NewDeviceConfiguration(device2, "device2")) cfg.Folders[0] = config.NewFolderConfiguration(protocol.LocalDeviceID, "default", "default", fs.FilesystemTypeBasic, tmpDir) cfg.Folders[0].Devices = []config.FolderDeviceConfiguration{ {DeviceID: device1}, @@ -519,12 +519,11 @@ func TestRescanIfHaveInvalidContent(t *testing.T) { t.Fatalf("unexpected weak hash: %d != 103547413", f.Blocks[0].WeakHash) } - buf := make([]byte, len(payload)) - - err := m.Request(device2, "default", "foo", 0, f.Blocks[0].Hash, f.Blocks[0].WeakHash, false, buf) + res, err := m.Request(device2, "default", "foo", int32(len(payload)), 0, f.Blocks[0].Hash, f.Blocks[0].WeakHash, false) if err != nil { t.Fatal(err) } + buf := res.Data() if !bytes.Equal(buf, payload) { t.Errorf("%s != %s", buf, payload) } @@ -536,7 +535,7 @@ func TestRescanIfHaveInvalidContent(t *testing.T) { t.Fatal(err) } - err = m.Request(device2, "default", "foo", 0, f.Blocks[0].Hash, f.Blocks[0].WeakHash, false, buf) + res, err = m.Request(device2, "default", "foo", int32(len(payload)), 0, f.Blocks[0].Hash, f.Blocks[0].WeakHash, false) if err == nil { t.Fatalf("expected failure") } diff --git a/lib/protocol/benchmark_test.go b/lib/protocol/benchmark_test.go index f840adaa3..308d7cc13 100644 --- a/lib/protocol/benchmark_test.go +++ b/lib/protocol/benchmark_test.go @@ -171,12 +171,13 @@ func (m *fakeModel) Index(deviceID DeviceID, folder string, files []FileInfo) { func (m *fakeModel) IndexUpdate(deviceID DeviceID, folder string, files []FileInfo) { } -func (m *fakeModel) Request(deviceID DeviceID, folder string, name string, offset int64, hash []byte, weakHAsh uint32, fromTemporary bool, buf []byte) error { +func (m *fakeModel) Request(deviceID DeviceID, folder, name string, size int32, offset int64, hash []byte, weakHash uint32, fromTemporary bool) (RequestResponse, error) { // We write the offset to the end of the buffer, so the receiver // can verify that it did in fact get some data back over the // connection. + buf := make([]byte, size) binary.BigEndian.PutUint64(buf[len(buf)-8:], uint64(offset)) - return nil + return &fakeRequestResponse{buf}, nil } func (m *fakeModel) ClusterConfig(deviceID DeviceID, config ClusterConfig) { diff --git a/lib/protocol/bufferpool.go b/lib/protocol/bufferpool.go index af7e8d2d6..17ad2f386 100644 --- a/lib/protocol/bufferpool.go +++ b/lib/protocol/bufferpool.go @@ -4,32 +4,59 @@ package protocol import "sync" +// Global pool to get buffers from. Requires Blocksizes to be initialised, +// therefore it is initialized in the same init() as BlockSizes +var BufferPool bufferPool + type bufferPool struct { - minSize int - pool sync.Pool + pools []sync.Pool } -// get returns a new buffer of the requested size -func (p *bufferPool) get(size int) []byte { - intf := p.pool.Get() - if intf == nil { - // Pool is empty, must allocate. - return p.new(size) - } - - bs := *intf.(*[]byte) - if cap(bs) < size { - // Buffer was too small, leave it for someone else and allocate. - p.pool.Put(intf) - return p.new(size) - } - - return bs[:size] +func newBufferPool() bufferPool { + return bufferPool{make([]sync.Pool, len(BlockSizes))} } -// upgrade grows the buffer to the requested size, while attempting to reuse +func (p *bufferPool) Get(size int) []byte { + // Too big, isn't pooled + if size > MaxBlockSize { + return make([]byte, size) + } + var i int + for i = range BlockSizes { + if size <= BlockSizes[i] { + break + } + } + var bs []byte + // Try the fitting and all bigger pools + for j := i; j < len(BlockSizes); j++ { + if intf := p.pools[j].Get(); intf != nil { + bs = *intf.(*[]byte) + return bs[:size] + } + } + // All pools are empty, must allocate. + return make([]byte, BlockSizes[i])[:size] +} + +// Put makes the given byte slice availabe again in the global pool +func (p *bufferPool) Put(bs []byte) { + c := cap(bs) + // Don't buffer huge byte slices + if c > 2*MaxBlockSize { + return + } + for i := range BlockSizes { + if c >= BlockSizes[i] { + p.pools[i].Put(&bs) + return + } + } +} + +// Upgrade grows the buffer to the requested size, while attempting to reuse // it if possible. -func (p *bufferPool) upgrade(bs []byte, size int) []byte { +func (p *bufferPool) Upgrade(bs []byte, size int) []byte { if cap(bs) >= size { // Reslicing is enough, lets go! return bs[:size] @@ -37,23 +64,6 @@ func (p *bufferPool) upgrade(bs []byte, size int) []byte { // It was too small. But it pack into the pool and try to get another // buffer. - p.put(bs) - return p.get(size) -} - -// put returns the buffer to the pool -func (p *bufferPool) put(bs []byte) { - p.pool.Put(&bs) -} - -// new creates a new buffer of the requested size, taking the minimum -// allocation count into account. For internal use only. -func (p *bufferPool) new(size int) []byte { - allocSize := size - if allocSize < p.minSize { - // Avoid allocating tiny buffers that we won't be able to reuse for - // anything useful. - allocSize = p.minSize - } - return make([]byte, allocSize)[:size] + p.Put(bs) + return p.Get(size) } diff --git a/lib/protocol/common_test.go b/lib/protocol/common_test.go index a12e36104..8faf9a3f0 100644 --- a/lib/protocol/common_test.go +++ b/lib/protocol/common_test.go @@ -9,7 +9,7 @@ type TestModel struct { folder string name string offset int64 - size int + size int32 hash []byte weakHash uint32 fromTemporary bool @@ -29,16 +29,17 @@ func (t *TestModel) Index(deviceID DeviceID, folder string, files []FileInfo) { func (t *TestModel) IndexUpdate(deviceID DeviceID, folder string, files []FileInfo) { } -func (t *TestModel) Request(deviceID DeviceID, folder, name string, offset int64, hash []byte, weakHash uint32, fromTemporary bool, buf []byte) error { +func (t *TestModel) Request(deviceID DeviceID, folder, name string, size int32, offset int64, hash []byte, weakHash uint32, fromTemporary bool) (RequestResponse, error) { t.folder = folder t.name = name t.offset = offset - t.size = len(buf) + t.size = size t.hash = hash t.weakHash = weakHash t.fromTemporary = fromTemporary + buf := make([]byte, len(t.data)) copy(buf, t.data) - return nil + return &fakeRequestResponse{buf}, nil } func (t *TestModel) Closed(conn Connection, err error) { @@ -60,3 +61,15 @@ func (t *TestModel) closedError() error { return nil // Timeout } } + +type fakeRequestResponse struct { + data []byte +} + +func (r *fakeRequestResponse) Data() []byte { + return r.data +} + +func (r *fakeRequestResponse) Close() {} + +func (r *fakeRequestResponse) Wait() {} diff --git a/lib/protocol/nativemodel_darwin.go b/lib/protocol/nativemodel_darwin.go index 8beca8d20..b4a20fe0b 100644 --- a/lib/protocol/nativemodel_darwin.go +++ b/lib/protocol/nativemodel_darwin.go @@ -26,7 +26,7 @@ func (m nativeModel) IndexUpdate(deviceID DeviceID, folder string, files []FileI m.Model.IndexUpdate(deviceID, folder, files) } -func (m nativeModel) Request(deviceID DeviceID, folder string, name string, offset int64, hash []byte, weakHash uint32, fromTemporary bool, buf []byte) error { +func (m nativeModel) Request(deviceID DeviceID, folder, name string, size int32, offset int64, hash []byte, weakHash uint32, fromTemporary bool) (RequestResponse, error) { name = norm.NFD.String(name) - return m.Model.Request(deviceID, folder, name, offset, hash, weakHash, fromTemporary, buf) + return m.Model.Request(deviceID, folder, name, size, offset, hash, weakHash, fromTemporary) } diff --git a/lib/protocol/nativemodel_windows.go b/lib/protocol/nativemodel_windows.go index 508625bb6..f3076c961 100644 --- a/lib/protocol/nativemodel_windows.go +++ b/lib/protocol/nativemodel_windows.go @@ -25,14 +25,14 @@ func (m nativeModel) IndexUpdate(deviceID DeviceID, folder string, files []FileI m.Model.IndexUpdate(deviceID, folder, files) } -func (m nativeModel) Request(deviceID DeviceID, folder string, name string, offset int64, hash []byte, weakHash uint32, fromTemporary bool, buf []byte) error { +func (m nativeModel) Request(deviceID DeviceID, folder, name string, size int32, offset int64, hash []byte, weakHash uint32, fromTemporary bool) (RequestResponse, error) { if strings.Contains(name, `\`) { l.Warnf("Dropping request for %s, contains invalid path separator", name) - return ErrNoSuchFile + return nil, ErrNoSuchFile } name = filepath.FromSlash(name) - return m.Model.Request(deviceID, folder, name, offset, hash, weakHash, fromTemporary, buf) + return m.Model.Request(deviceID, folder, name, size, offset, hash, weakHash, fromTemporary) } func fixupFiles(files []FileInfo) []FileInfo { diff --git a/lib/protocol/nativemodel_windows_test.go b/lib/protocol/nativemodel_windows_test.go index 2308f4fac..05cff2e59 100644 --- a/lib/protocol/nativemodel_windows_test.go +++ b/lib/protocol/nativemodel_windows_test.go @@ -2,8 +2,10 @@ package protocol -import "testing" -import "reflect" +import ( + "reflect" + "testing" +) func TestFixupFiles(t *testing.T) { files := []FileInfo{ diff --git a/lib/protocol/protocol.go b/lib/protocol/protocol.go index 2309ab950..384dab499 100644 --- a/lib/protocol/protocol.go +++ b/lib/protocol/protocol.go @@ -48,6 +48,7 @@ func init() { BlockSizes = append(BlockSizes, blockSize) sha256OfEmptyBlock[blockSize] = sha256.Sum256(make([]byte, blockSize)) } + BufferPool = newBufferPool() } // BlockSize returns the block size to use for the given file size @@ -125,7 +126,7 @@ type Model interface { // An index update was received from the peer device IndexUpdate(deviceID DeviceID, folder string, files []FileInfo) // A request was made by the peer device - Request(deviceID DeviceID, folder string, name string, offset int64, hash []byte, weakHash uint32, fromTemporary bool, buf []byte) error + Request(deviceID DeviceID, folder, name string, size int32, offset int64, hash []byte, weakHash uint32, fromTemporary bool) (RequestResponse, error) // A cluster configuration message was received ClusterConfig(deviceID DeviceID, config ClusterConfig) // The peer device closed the connection @@ -134,6 +135,12 @@ type Model interface { DownloadProgress(deviceID DeviceID, folder string, updates []FileDownloadProgressUpdate) } +type RequestResponse interface { + Data() []byte + Close() // Must always be called once the byte slice is no longer in use + Wait() // Blocks until Close is called +} + type Connection interface { Start() ID() DeviceID @@ -166,7 +173,6 @@ type rawConnection struct { outbox chan asyncMessage closed chan struct{} once sync.Once - pool bufferPool compression Compression } @@ -184,7 +190,7 @@ type message interface { type asyncMessage struct { msg message - done chan struct{} // done closes when we're done marshalling the message and its contents can be reused + done chan struct{} // done closes when we're done sending the message } const ( @@ -196,12 +202,6 @@ const ( ReceiveTimeout = 300 * time.Second ) -// A buffer pool for global use. We don't allocate smaller buffers than 64k, -// in the hope of being able to reuse them later. -var buffers = bufferPool{ - minSize: 64 << 10, -} - func NewConnection(deviceID DeviceID, reader io.Reader, writer io.Writer, receiver Model, name string, compress Compression) Connection { cr := &countingReader{Reader: reader} cw := &countingWriter{Writer: writer} @@ -215,7 +215,6 @@ func NewConnection(deviceID DeviceID, reader io.Reader, writer io.Writer, receiv awaiting: make(map[int32]chan asyncResult), outbox: make(chan asyncMessage), closed: make(chan struct{}), - pool: bufferPool{minSize: MinBlockSize}, compression: compress, } @@ -338,6 +337,7 @@ func (c *rawConnection) readerLoop() (err error) { c.close(err) }() + fourByteBuf := make([]byte, 4) state := stateInitial for { select { @@ -346,7 +346,7 @@ func (c *rawConnection) readerLoop() (err error) { default: } - msg, err := c.readMessage() + msg, err := c.readMessage(fourByteBuf) if err == errUnknownMessage { // Unknown message types are skipped, for future extensibility. continue @@ -394,7 +394,6 @@ func (c *rawConnection) readerLoop() (err error) { if err := checkFilename(msg.Name); err != nil { return fmt.Errorf("protocol error: request: %q: %v", msg.Name, err) } - // Requests are handled asynchronously go c.handleRequest(*msg) case *Response: @@ -429,30 +428,29 @@ func (c *rawConnection) readerLoop() (err error) { } } -func (c *rawConnection) readMessage() (message, error) { - hdr, err := c.readHeader() +func (c *rawConnection) readMessage(fourByteBuf []byte) (message, error) { + hdr, err := c.readHeader(fourByteBuf) if err != nil { return nil, err } - return c.readMessageAfterHeader(hdr) + return c.readMessageAfterHeader(hdr, fourByteBuf) } -func (c *rawConnection) readMessageAfterHeader(hdr Header) (message, error) { +func (c *rawConnection) readMessageAfterHeader(hdr Header, fourByteBuf []byte) (message, error) { // First comes a 4 byte message length - buf := buffers.get(4) - if _, err := io.ReadFull(c.cr, buf); err != nil { + if _, err := io.ReadFull(c.cr, fourByteBuf[:4]); err != nil { return nil, fmt.Errorf("reading message length: %v", err) } - msgLen := int32(binary.BigEndian.Uint32(buf)) + msgLen := int32(binary.BigEndian.Uint32(fourByteBuf)) if msgLen < 0 { return nil, fmt.Errorf("negative message length %d", msgLen) } // Then comes the message - buf = buffers.upgrade(buf, int(msgLen)) + buf := BufferPool.Get(int(msgLen)) if _, err := io.ReadFull(c.cr, buf); err != nil { return nil, fmt.Errorf("reading message: %v", err) } @@ -465,7 +463,7 @@ func (c *rawConnection) readMessageAfterHeader(hdr Header) (message, error) { case MessageCompressionLZ4: decomp, err := c.lz4Decompress(buf) - buffers.put(buf) + BufferPool.Put(buf) if err != nil { return nil, fmt.Errorf("decompressing message: %v", err) } @@ -484,26 +482,25 @@ func (c *rawConnection) readMessageAfterHeader(hdr Header) (message, error) { if err := msg.Unmarshal(buf); err != nil { return nil, fmt.Errorf("unmarshalling message: %v", err) } - buffers.put(buf) + BufferPool.Put(buf) return msg, nil } -func (c *rawConnection) readHeader() (Header, error) { +func (c *rawConnection) readHeader(fourByteBuf []byte) (Header, error) { // First comes a 2 byte header length - buf := buffers.get(2) - if _, err := io.ReadFull(c.cr, buf); err != nil { + if _, err := io.ReadFull(c.cr, fourByteBuf[:2]); err != nil { return Header{}, fmt.Errorf("reading length: %v", err) } - hdrLen := int16(binary.BigEndian.Uint16(buf)) + hdrLen := int16(binary.BigEndian.Uint16(fourByteBuf)) if hdrLen < 0 { return Header{}, fmt.Errorf("negative header length %d", hdrLen) } // Then comes the header - buf = buffers.upgrade(buf, int(hdrLen)) + buf := BufferPool.Get(int(hdrLen)) if _, err := io.ReadFull(c.cr, buf); err != nil { return Header{}, fmt.Errorf("reading header: %v", err) } @@ -513,7 +510,7 @@ func (c *rawConnection) readHeader() (Header, error) { return Header{}, fmt.Errorf("unmarshalling header: %v", err) } - buffers.put(buf) + BufferPool.Put(buf) return hdr, nil } @@ -590,38 +587,22 @@ func checkFilename(name string) error { } func (c *rawConnection) handleRequest(req Request) { - size := int(req.Size) - usePool := size <= MaxBlockSize - - var buf []byte - var done chan struct{} - - if usePool { - buf = c.pool.get(size) - done = make(chan struct{}) - } else { - buf = make([]byte, size) - } - - err := c.receiver.Request(c.id, req.Folder, req.Name, req.Offset, req.Hash, req.WeakHash, req.FromTemporary, buf) + res, err := c.receiver.Request(c.id, req.Folder, req.Name, req.Size, req.Offset, req.Hash, req.WeakHash, req.FromTemporary) if err != nil { c.send(&Response{ ID: req.ID, - Data: nil, Code: errorToCode(err), - }, done) - } else { - c.send(&Response{ - ID: req.ID, - Data: buf, - Code: errorToCode(err), - }, done) - } - - if usePool { - <-done - c.pool.put(buf) + }, nil) + return } + done := make(chan struct{}) + c.send(&Response{ + ID: req.ID, + Data: res.Data(), + Code: errorToCode(nil), + }, done) + <-done + res.Close() } func (c *rawConnection) handleResponse(resp Response) { @@ -639,6 +620,9 @@ func (c *rawConnection) send(msg message, done chan struct{}) bool { case c.outbox <- asyncMessage{msg, done}: return true case <-c.closed: + if done != nil { + close(done) + } return false } } @@ -647,7 +631,11 @@ func (c *rawConnection) writerLoop() { for { select { case hm := <-c.outbox: - if err := c.writeMessage(hm); err != nil { + err := c.writeMessage(hm) + if hm.done != nil { + close(hm.done) + } + if err != nil { c.close(err) return } @@ -667,13 +655,10 @@ func (c *rawConnection) writeMessage(hm asyncMessage) error { func (c *rawConnection) writeCompressedMessage(hm asyncMessage) error { size := hm.msg.ProtoSize() - buf := buffers.get(size) + buf := BufferPool.Get(size) if _, err := hm.msg.MarshalTo(buf); err != nil { return fmt.Errorf("marshalling message: %v", err) } - if hm.done != nil { - close(hm.done) - } compressed, err := c.lz4Compress(buf) if err != nil { @@ -690,7 +675,7 @@ func (c *rawConnection) writeCompressedMessage(hm asyncMessage) error { } totSize := 2 + hdrSize + 4 + len(compressed) - buf = buffers.upgrade(buf, totSize) + buf = BufferPool.Upgrade(buf, totSize) // Header length binary.BigEndian.PutUint16(buf, uint16(hdrSize)) @@ -702,10 +687,10 @@ func (c *rawConnection) writeCompressedMessage(hm asyncMessage) error { binary.BigEndian.PutUint32(buf[2+hdrSize:], uint32(len(compressed))) // Message copy(buf[2+hdrSize+4:], compressed) - buffers.put(compressed) + BufferPool.Put(compressed) n, err := c.cw.Write(buf) - buffers.put(buf) + BufferPool.Put(buf) l.Debugf("wrote %d bytes on the wire (2 bytes length, %d bytes header, 4 bytes message length, %d bytes message (%d uncompressed)), err=%v", n, hdrSize, len(compressed), size, err) if err != nil { @@ -726,7 +711,7 @@ func (c *rawConnection) writeUncompressedMessage(hm asyncMessage) error { } totSize := 2 + hdrSize + 4 + size - buf := buffers.get(totSize) + buf := BufferPool.Get(totSize) // Header length binary.BigEndian.PutUint16(buf, uint16(hdrSize)) @@ -740,12 +725,9 @@ func (c *rawConnection) writeUncompressedMessage(hm asyncMessage) error { if _, err := hm.msg.MarshalTo(buf[2+hdrSize+4:]); err != nil { return fmt.Errorf("marshalling message: %v", err) } - if hm.done != nil { - close(hm.done) - } n, err := c.cw.Write(buf[:totSize]) - buffers.put(buf) + BufferPool.Put(buf) l.Debugf("wrote %d bytes on the wire (2 bytes length, %d bytes header, 4 bytes message length, %d bytes message), err=%v", n, hdrSize, size, err) if err != nil { @@ -904,7 +886,7 @@ func (c *rawConnection) Statistics() Statistics { func (c *rawConnection) lz4Compress(src []byte) ([]byte, error) { var err error - buf := buffers.get(len(src)) + buf := BufferPool.Get(len(src)) buf, err = lz4.Encode(buf, src) if err != nil { return nil, err @@ -918,7 +900,7 @@ func (c *rawConnection) lz4Decompress(src []byte) ([]byte, error) { size := binary.BigEndian.Uint32(src) binary.LittleEndian.PutUint32(src, size) var err error - buf := buffers.get(int(size)) + buf := BufferPool.Get(int(size)) buf, err = lz4.Decode(buf, src) if err != nil { return nil, err