mirror of
https://github.com/octoleo/syncthing.git
synced 2025-01-11 10:38:16 +00:00
Implement global and per session rate limiting
This commit is contained in:
parent
dab1c4cfc9
commit
35d20a19bc
22
main.go
22
main.go
@ -10,6 +10,7 @@ import (
|
|||||||
"path/filepath"
|
"path/filepath"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/juju/ratelimit"
|
||||||
"github.com/syncthing/relaysrv/protocol"
|
"github.com/syncthing/relaysrv/protocol"
|
||||||
|
|
||||||
syncthingprotocol "github.com/syncthing/protocol"
|
syncthingprotocol "github.com/syncthing/protocol"
|
||||||
@ -26,6 +27,11 @@ var (
|
|||||||
networkTimeout time.Duration
|
networkTimeout time.Duration
|
||||||
pingInterval time.Duration
|
pingInterval time.Duration
|
||||||
messageTimeout time.Duration
|
messageTimeout time.Duration
|
||||||
|
|
||||||
|
sessionLimitBps int
|
||||||
|
globalLimitBps int
|
||||||
|
sessionLimiter *ratelimit.Bucket
|
||||||
|
globalLimiter *ratelimit.Bucket
|
||||||
)
|
)
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
@ -38,6 +44,11 @@ func main() {
|
|||||||
flag.DurationVar(&networkTimeout, "network-timeout", 2*time.Minute, "Timeout for network operations")
|
flag.DurationVar(&networkTimeout, "network-timeout", 2*time.Minute, "Timeout for network operations")
|
||||||
flag.DurationVar(&pingInterval, "ping-interval", time.Minute, "How often pings are sent")
|
flag.DurationVar(&pingInterval, "ping-interval", time.Minute, "How often pings are sent")
|
||||||
flag.DurationVar(&messageTimeout, "message-timeout", time.Minute, "Maximum amount of time we wait for relevant messages to arrive")
|
flag.DurationVar(&messageTimeout, "message-timeout", time.Minute, "Maximum amount of time we wait for relevant messages to arrive")
|
||||||
|
flag.IntVar(&sessionLimitBps, "per-session-rate", sessionLimitBps, "Per session rate limit, in bytes/s")
|
||||||
|
flag.IntVar(&globalLimitBps, "global-rate", globalLimitBps, "Global rate limit, in bytes/s")
|
||||||
|
flag.BoolVar(&debug, "debug", false, "Enable debug output")
|
||||||
|
|
||||||
|
flag.Parse()
|
||||||
|
|
||||||
if extAddress == "" {
|
if extAddress == "" {
|
||||||
extAddress = listenSession
|
extAddress = listenSession
|
||||||
@ -51,10 +62,6 @@ func main() {
|
|||||||
sessionAddress = addr.IP[:]
|
sessionAddress = addr.IP[:]
|
||||||
sessionPort = uint16(addr.Port)
|
sessionPort = uint16(addr.Port)
|
||||||
|
|
||||||
flag.BoolVar(&debug, "debug", false, "Enable debug output")
|
|
||||||
|
|
||||||
flag.Parse()
|
|
||||||
|
|
||||||
certFile, keyFile := filepath.Join(dir, "cert.pem"), filepath.Join(dir, "key.pem")
|
certFile, keyFile := filepath.Join(dir, "cert.pem"), filepath.Join(dir, "key.pem")
|
||||||
cert, err := tls.LoadX509KeyPair(certFile, keyFile)
|
cert, err := tls.LoadX509KeyPair(certFile, keyFile)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -83,6 +90,13 @@ func main() {
|
|||||||
log.Println("ID:", id)
|
log.Println("ID:", id)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if sessionLimitBps > 0 {
|
||||||
|
sessionLimiter = ratelimit.NewBucketWithRate(float64(sessionLimitBps), int64(2*sessionLimitBps))
|
||||||
|
}
|
||||||
|
if globalLimitBps > 0 {
|
||||||
|
globalLimiter = ratelimit.NewBucketWithRate(float64(globalLimitBps), int64(2*globalLimitBps))
|
||||||
|
}
|
||||||
|
|
||||||
go sessionListener(listenSession)
|
go sessionListener(listenSession)
|
||||||
|
|
||||||
protocolListener(listenProtocol, tlsCfg)
|
protocolListener(listenProtocol, tlsCfg)
|
||||||
|
@ -130,7 +130,7 @@ func protocolConnectionHandler(tcpConn net.Conn, config *tls.Config) {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
ses := newSession()
|
ses := newSession(sessionLimiter, globalLimiter)
|
||||||
|
|
||||||
go ses.Serve()
|
go ses.Serve()
|
||||||
|
|
||||||
|
63
session.go
63
session.go
@ -11,6 +11,7 @@ import (
|
|||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/juju/ratelimit"
|
||||||
"github.com/syncthing/relaysrv/protocol"
|
"github.com/syncthing/relaysrv/protocol"
|
||||||
|
|
||||||
syncthingprotocol "github.com/syncthing/protocol"
|
syncthingprotocol "github.com/syncthing/protocol"
|
||||||
@ -25,10 +26,12 @@ type session struct {
|
|||||||
serverkey []byte
|
serverkey []byte
|
||||||
clientkey []byte
|
clientkey []byte
|
||||||
|
|
||||||
|
rateLimit func(bytes int64)
|
||||||
|
|
||||||
conns chan net.Conn
|
conns chan net.Conn
|
||||||
}
|
}
|
||||||
|
|
||||||
func newSession() *session {
|
func newSession(sessionRateLimit, globalRateLimit *ratelimit.Bucket) *session {
|
||||||
serverkey := make([]byte, 32)
|
serverkey := make([]byte, 32)
|
||||||
_, err := rand.Read(serverkey)
|
_, err := rand.Read(serverkey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -44,6 +47,7 @@ func newSession() *session {
|
|||||||
ses := &session{
|
ses := &session{
|
||||||
serverkey: serverkey,
|
serverkey: serverkey,
|
||||||
clientkey: clientkey,
|
clientkey: clientkey,
|
||||||
|
rateLimit: makeRateLimitFunc(sessionRateLimit, globalRateLimit),
|
||||||
conns: make(chan net.Conn),
|
conns: make(chan net.Conn),
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -112,12 +116,12 @@ func (s *session) Serve() {
|
|||||||
errors := make(chan error, 2)
|
errors := make(chan error, 2)
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
errors <- proxy(conns[0], conns[1])
|
errors <- s.proxy(conns[0], conns[1])
|
||||||
wg.Done()
|
wg.Done()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
errors <- proxy(conns[1], conns[0])
|
errors <- s.proxy(conns[1], conns[0])
|
||||||
wg.Done()
|
wg.Done()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
@ -169,14 +173,15 @@ func (s *session) GetServerInvitationMessage(from syncthingprotocol.DeviceID) pr
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func proxy(c1, c2 net.Conn) error {
|
func (s *session) proxy(c1, c2 net.Conn) error {
|
||||||
if debug {
|
if debug {
|
||||||
log.Println("Proxy", c1.RemoteAddr(), "->", c2.RemoteAddr())
|
log.Println("Proxy", c1.RemoteAddr(), "->", c2.RemoteAddr())
|
||||||
}
|
}
|
||||||
buf := make([]byte, 1024)
|
|
||||||
|
buf := make([]byte, 65536)
|
||||||
for {
|
for {
|
||||||
c1.SetReadDeadline(time.Now().Add(networkTimeout))
|
c1.SetReadDeadline(time.Now().Add(networkTimeout))
|
||||||
n, err := c1.Read(buf[0:])
|
n, err := c1.Read(buf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -185,6 +190,10 @@ func proxy(c1, c2 net.Conn) error {
|
|||||||
log.Printf("%d bytes from %s to %s", n, c1.RemoteAddr(), c2.RemoteAddr())
|
log.Printf("%d bytes from %s to %s", n, c1.RemoteAddr(), c2.RemoteAddr())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if s.rateLimit != nil {
|
||||||
|
s.rateLimit(int64(n))
|
||||||
|
}
|
||||||
|
|
||||||
c2.SetWriteDeadline(time.Now().Add(networkTimeout))
|
c2.SetWriteDeadline(time.Now().Add(networkTimeout))
|
||||||
_, err = c2.Write(buf[:n])
|
_, err = c2.Write(buf[:n])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -196,3 +205,45 @@ func proxy(c1, c2 net.Conn) error {
|
|||||||
func (s *session) String() string {
|
func (s *session) String() string {
|
||||||
return fmt.Sprintf("<%s/%s>", hex.EncodeToString(s.clientkey)[:5], hex.EncodeToString(s.serverkey)[:5])
|
return fmt.Sprintf("<%s/%s>", hex.EncodeToString(s.clientkey)[:5], hex.EncodeToString(s.serverkey)[:5])
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func makeRateLimitFunc(sessionRateLimit, globalRateLimit *ratelimit.Bucket) func(int64) {
|
||||||
|
// This may be a case of super duper premature optimization... We build an
|
||||||
|
// optimized function to do the rate limiting here based on what we need
|
||||||
|
// to do and then use it in the loop.
|
||||||
|
|
||||||
|
if sessionRateLimit == nil && globalRateLimit == nil {
|
||||||
|
// No limiting needed. We could equally well return a func(int64){} and
|
||||||
|
// not do a nil check were we use it, but I think the nil check there
|
||||||
|
// makes it clear that there will be no limiting if none is
|
||||||
|
// configured...
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if sessionRateLimit == nil {
|
||||||
|
// We only have a global limiter
|
||||||
|
return func(bytes int64) {
|
||||||
|
globalRateLimit.Wait(bytes)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if globalRateLimit == nil {
|
||||||
|
// We only have a session limiter
|
||||||
|
return func(bytes int64) {
|
||||||
|
sessionRateLimit.Wait(bytes)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// We have both. Queue the bytes on both the global and session specific
|
||||||
|
// rate limiters. Wait for both in parallell, so that the actual send
|
||||||
|
// happens when both conditions are satisfied. In practice this just means
|
||||||
|
// wait the longer of the two times.
|
||||||
|
return func(bytes int64) {
|
||||||
|
t0 := sessionRateLimit.Take(bytes)
|
||||||
|
t1 := globalRateLimit.Take(bytes)
|
||||||
|
if t0 > t1 {
|
||||||
|
time.Sleep(t0)
|
||||||
|
} else {
|
||||||
|
time.Sleep(t1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user