diff --git a/lib/connections/limiter.go b/lib/connections/limiter.go index b2ff22f56..e4528d97a 100644 --- a/lib/connections/limiter.go +++ b/lib/connections/limiter.go @@ -32,9 +32,13 @@ type limiter struct { type waiter interface { // This is the rate limiting operation WaitN(ctx context.Context, n int) error + Limit() rate.Limit } -const limiterBurstSize = 4 * 128 << 10 +const ( + limiterBurstSize = 4 * 128 << 10 + maxSingleWriteSize = 8 << 10 +) func newLimiter(cfg config.Wrapper) *limiter { l := &limiter{ @@ -186,19 +190,23 @@ func (lim *limiter) getLimiters(remoteID protocol.DeviceID, rw io.ReadWriter, is func (lim *limiter) newLimitedReaderLocked(remoteID protocol.DeviceID, r io.Reader, isLAN bool) io.Reader { return &limitedReader{ - reader: r, - limitsLAN: &lim.limitsLAN, - waiter: totalWaiter{lim.getReadLimiterLocked(remoteID), lim.read}, - isLAN: isLAN, + reader: r, + waiterHolder: waiterHolder{ + waiter: totalWaiter{lim.getReadLimiterLocked(remoteID), lim.read}, + limitsLAN: &lim.limitsLAN, + isLAN: isLAN, + }, } } func (lim *limiter) newLimitedWriterLocked(remoteID protocol.DeviceID, w io.Writer, isLAN bool) io.Writer { return &limitedWriter{ - writer: w, - limitsLAN: &lim.limitsLAN, - waiter: totalWaiter{lim.getWriteLimiterLocked(remoteID), lim.write}, - isLAN: isLAN, + writer: w, + waiterHolder: waiterHolder{ + waiter: totalWaiter{lim.getWriteLimiterLocked(remoteID), lim.write}, + limitsLAN: &lim.limitsLAN, + isLAN: isLAN, + }, } } @@ -221,53 +229,87 @@ func getRateLimiter(m map[protocol.DeviceID]*rate.Limiter, deviceID protocol.Dev // limitedReader is a rate limited io.Reader type limitedReader struct { - reader io.Reader - limitsLAN *atomicBool - waiter waiter - isLAN bool + reader io.Reader + waiterHolder } func (r *limitedReader) Read(buf []byte) (int, error) { n, err := r.reader.Read(buf) - if !r.isLAN || r.limitsLAN.get() { - take(r.waiter, n) + if !r.unlimited() { + r.take(n) } return n, err } // limitedWriter is a rate limited io.Writer type limitedWriter struct { - writer io.Writer - limitsLAN *atomicBool - waiter waiter - isLAN bool + writer io.Writer + waiterHolder } func (w *limitedWriter) Write(buf []byte) (int, error) { - if !w.isLAN || w.limitsLAN.get() { - take(w.waiter, len(buf)) + if w.unlimited() { + return w.writer.Write(buf) } - return w.writer.Write(buf) + + // This does (potentially) multiple smaller writes in order to be less + // bursty with large writes and slow rates. + written := 0 + for written < len(buf) { + toWrite := maxSingleWriteSize + if toWrite > len(buf)-written { + toWrite = len(buf) - written + } + w.take(toWrite) + n, err := w.writer.Write(buf[written : written+toWrite]) + written += n + if err != nil { + return written, err + } + } + + return written, nil } -// take is a utility function to consume tokens from a overall rate.Limiter and deviceLimiter. -// No call to WaitN can be larger than the limiter burst size so we split it up into -// several calls when necessary. -func take(waiter waiter, tokens int) { +// waiterHolder is the common functionality around having and evaluating a +// waiter, valid for both writers and readers +type waiterHolder struct { + waiter waiter + limitsLAN *atomicBool + isLAN bool +} + +// unlimited returns true if the waiter is not limiting the rate +func (w waiterHolder) unlimited() bool { + if w.isLAN && !w.limitsLAN.get() { + return true + } + return w.waiter.Limit() == rate.Inf +} + +// take is a utility function to consume tokens, because no call to WaitN +// must be larger than the limiter burst size or it will hang. +func (w waiterHolder) take(tokens int) { + // For writes we already split the buffer into smaller operations so those + // will always end up in the fast path below. For reads, however, we don't + // control the size of the incoming buffer and don't split the calls + // into the lower level reads so we might get a large amount of data and + // end up in the loop further down. + if tokens < limiterBurstSize { - // This is the by far more common case so we get it out of the way - // early. - waiter.WaitN(context.TODO(), tokens) + // Fast path. We won't get an error from WaitN as we don't pass a + // context with a deadline. + _ = w.waiter.WaitN(context.TODO(), tokens) return } for tokens > 0 { // Consume limiterBurstSize tokens at a time until we're done. if tokens > limiterBurstSize { - waiter.WaitN(context.TODO(), limiterBurstSize) + _ = w.waiter.WaitN(context.TODO(), limiterBurstSize) tokens -= limiterBurstSize } else { - waiter.WaitN(context.TODO(), tokens) + _ = w.waiter.WaitN(context.TODO(), tokens) tokens = 0 } } @@ -300,3 +342,13 @@ func (tw totalWaiter) WaitN(ctx context.Context, n int) error { } return nil } + +func (tw totalWaiter) Limit() rate.Limit { + min := rate.Inf + for _, w := range tw { + if l := w.Limit(); l < min { + min = l + } + } + return min +} diff --git a/lib/connections/limiter_test.go b/lib/connections/limiter_test.go index e7b3a0436..784452109 100644 --- a/lib/connections/limiter_test.go +++ b/lib/connections/limiter_test.go @@ -7,12 +7,16 @@ package connections import ( + "bytes" + crand "crypto/rand" + "io" + "math/rand" + "testing" + "github.com/syncthing/syncthing/lib/config" "github.com/syncthing/syncthing/lib/events" "github.com/syncthing/syncthing/lib/protocol" "golang.org/x/time/rate" - "math/rand" - "testing" ) var device1, device2, device3, device4 protocol.DeviceID @@ -185,6 +189,151 @@ func TestAddAndRemove(t *testing.T) { checkActualAndExpected(t, actualR, actualW, expectedR, expectedW) } +func TestLimitedWriterWrite(t *testing.T) { + // Check that the limited writer writes the correct data in the correct manner. + + // A buffer with random data that is larger than the write size and not + // a precise multiple either. + src := make([]byte, int(12.5*maxSingleWriteSize)) + if _, err := crand.Reader.Read(src); err != nil { + t.Fatal(err) + } + + // Write it to the destination using a limited writer, with a wrapper to + // count the write calls. The defaults on the limited writer should mean + // it is used (and doesn't take the fast path). In practice the limiter + // won't delay the test as the burst size is large enough to accommodate + // regardless of the rate. + dst := new(bytes.Buffer) + cw := &countingWriter{w: dst} + lw := &limitedWriter{ + writer: cw, + waiterHolder: waiterHolder{ + waiter: rate.NewLimiter(rate.Limit(42), limiterBurstSize), + limitsLAN: new(atomicBool), + isLAN: false, // enables limiting + }, + } + if _, err := io.Copy(lw, bytes.NewReader(src)); err != nil { + t.Fatal(err) + } + + // Verify there were lots of writes and that the end result is identical. + if cw.writeCount != 13 { + t.Error("expected lots of smaller writes, but not too many") + } + if !bytes.Equal(src, dst.Bytes()) { + t.Error("results should be equal") + } + + // Write it to the destination using a limited writer, with a wrapper to + // count the write calls. Now we make sure the fast path is used. + dst = new(bytes.Buffer) + cw = &countingWriter{w: dst} + lw = &limitedWriter{ + writer: cw, + waiterHolder: waiterHolder{ + waiter: rate.NewLimiter(rate.Limit(42), limiterBurstSize), + limitsLAN: new(atomicBool), + isLAN: true, // disables limiting + }, + } + if _, err := io.Copy(lw, bytes.NewReader(src)); err != nil { + t.Fatal(err) + } + + // Verify there were a single write and that the end result is identical. + if cw.writeCount != 1 { + t.Error("expected just the one write") + } + if !bytes.Equal(src, dst.Bytes()) { + t.Error("results should be equal") + } + + // Once more, but making sure the fast path is used for an unlimited + // rate, with multiple unlimited raters even (global and per-device). + dst = new(bytes.Buffer) + cw = &countingWriter{w: dst} + lw = &limitedWriter{ + writer: cw, + waiterHolder: waiterHolder{ + waiter: totalWaiter{rate.NewLimiter(rate.Inf, limiterBurstSize), rate.NewLimiter(rate.Inf, limiterBurstSize)}, + limitsLAN: new(atomicBool), + isLAN: false, // enables limiting + }, + } + if _, err := io.Copy(lw, bytes.NewReader(src)); err != nil { + t.Fatal(err) + } + + // Verify there were a single write and that the end result is identical. + if cw.writeCount != 1 { + t.Error("expected just the one write") + } + if !bytes.Equal(src, dst.Bytes()) { + t.Error("results should be equal") + } + + // Once more, but making sure we *don't* take the fast path when there + // is a combo of limited and unlimited writers. + dst = new(bytes.Buffer) + cw = &countingWriter{w: dst} + lw = &limitedWriter{ + writer: cw, + waiterHolder: waiterHolder{ + waiter: totalWaiter{ + rate.NewLimiter(rate.Inf, limiterBurstSize), + rate.NewLimiter(rate.Limit(42), limiterBurstSize), + rate.NewLimiter(rate.Inf, limiterBurstSize), + }, + limitsLAN: new(atomicBool), + isLAN: false, // enables limiting + }, + } + if _, err := io.Copy(lw, bytes.NewReader(src)); err != nil { + t.Fatal(err) + } + + // Verify there were lots of writes and that the end result is identical. + if cw.writeCount != 13 { + t.Error("expected just the one write") + } + if !bytes.Equal(src, dst.Bytes()) { + t.Error("results should be equal") + } +} + +func TestTotalWaiterLimit(t *testing.T) { + cases := []struct { + w waiter + r rate.Limit + }{ + { + totalWaiter{}, + rate.Inf, + }, + { + totalWaiter{rate.NewLimiter(rate.Inf, 42)}, + rate.Inf, + }, + { + totalWaiter{rate.NewLimiter(rate.Inf, 42), rate.NewLimiter(rate.Inf, 42)}, + rate.Inf, + }, + { + totalWaiter{rate.NewLimiter(rate.Inf, 42), rate.NewLimiter(rate.Limit(12), 42), rate.NewLimiter(rate.Limit(15), 42)}, + rate.Limit(12), + }, + } + + for _, tc := range cases { + l := tc.w.Limit() + if l != tc.r { + t.Error("incorrect limit returned") + } + } +} + func checkActualAndExpected(t *testing.T, actualR, actualW, expectedR, expectedW map[protocol.DeviceID]*rate.Limiter) { t.Helper() if len(expectedW) != len(actualW) || len(expectedR) != len(actualR) { @@ -204,3 +353,13 @@ func checkActualAndExpected(t *testing.T, actualR, actualW, expectedR, expectedW } } } + +type countingWriter struct { + w io.Writer + writeCount int +} + +func (w *countingWriter) Write(data []byte) (int, error) { + w.writeCount++ + return w.w.Write(data) +}