From 3cacb48f3cdf8c8e3d4c331ad737df69a1aaea65 Mon Sep 17 00:00:00 2001 From: AudriusButkevicius Date: Mon, 7 Sep 2015 18:13:50 +0100 Subject: [PATCH] Add IP based rate limiting, check if client IP matches advertised relay, reorder stuff --- cmd/strelaypoolsrv/main.go | 206 ++++++++++++++++++++++++------------- 1 file changed, 137 insertions(+), 69 deletions(-) diff --git a/cmd/strelaypoolsrv/main.go b/cmd/strelaypoolsrv/main.go index 5bcdd3279..375bdaea7 100644 --- a/cmd/strelaypoolsrv/main.go +++ b/cmd/strelaypoolsrv/main.go @@ -17,7 +17,10 @@ import ( "strings" "time" + "github.com/golang/groupcache/lru" + "github.com/juju/ratelimit" "github.com/kardianos/osext" + "github.com/syncthing/relaysrv/client" "github.com/syncthing/syncthing/lib/sync" "github.com/syncthing/syncthing/lib/tlsutil" @@ -44,12 +47,24 @@ type result struct { } var ( - binDir string - testCert []tls.Certificate - listen string - dir string - evictionTime time.Duration - debug bool + binDir string + testCert []tls.Certificate + listen string = ":80" + dir string = "" + evictionTime time.Duration = time.Hour + debug bool = false + getLRUSize int = 10240 + getLimit time.Duration + getLimitBurst int64 = 10 + postLRUSize int = 128 + postLimit time.Duration + postLimitBurst int64 = 2 + + getMut sync.RWMutex = sync.NewRWMutex() + getLRUCache *lru.Cache + + postMut sync.RWMutex = sync.NewRWMutex() + postLRUCache *lru.Cache requests = make(chan request, 10) @@ -60,13 +75,27 @@ var ( ) func main() { - flag.StringVar(&listen, "listen", ":80", "Listen address") - flag.StringVar(&dir, "keys", "", "Directory where http-cert.pem and http-key.pem is stored for TLS listening") - flag.BoolVar(&debug, "debug", false, "Enable debug output") - flag.DurationVar(&evictionTime, "eviction", time.Hour, "After how long the relay is evicted") + var getLimitAvg, postLimitAvg int + + flag.StringVar(&listen, "listen", listen, "Listen address") + flag.StringVar(&dir, "keys", dir, "Directory where http-cert.pem and http-key.pem is stored for TLS listening") + flag.BoolVar(&debug, "debug", debug, "Enable debug output") + flag.DurationVar(&evictionTime, "eviction", evictionTime, "After how long the relay is evicted") + flag.IntVar(&getLRUSize, "get-limit-cache", getLRUSize, "Get request limiter cache size") + flag.IntVar(&getLimitAvg, "get-limit-avg", getLimitAvg, "Allowed average get request rate, per 10 s") + flag.Int64Var(&getLimitBurst, "get-limit-burst", getLimitBurst, "Allowed burst get requests") + flag.IntVar(&postLRUSize, "post-limit-cache", postLRUSize, "Post request limiter cache size") + flag.IntVar(&postLimitAvg, "post-limit-avg", postLimitAvg, "Allowed average post request rate, per minute") + flag.Int64Var(&postLimitBurst, "post-limit-burst", postLimitBurst, "Allowed burst post requests") flag.Parse() + getLimit = 10 * time.Second / time.Duration(getLimitAvg) + postLimit = time.Minute / time.Duration(postLimitAvg) + + getLRUCache = lru.New(getLRUSize) + postLRUCache = lru.New(postLRUSize) + var listener net.Listener var err error @@ -136,8 +165,16 @@ func main() { func handleRequest(w http.ResponseWriter, r *http.Request) { switch r.Method { case "GET": + if limit(r.RemoteAddr, getLRUCache, getMut, getLimit, int64(getLimitBurst)) { + w.WriteHeader(429) + return + } handleGetRequest(w, r) case "POST": + if limit(r.RemoteAddr, postLRUCache, postMut, postLimit, int64(postLimitBurst)) { + w.WriteHeader(429) + return + } handlePostRequest(w, r) default: if debug { @@ -195,18 +232,26 @@ func handlePostRequest(w http.ResponseWriter, r *http.Request) { return } - // The client did not provide an IP address, work it out. - if host == "" { - rhost, _, err := net.SplitHostPort(r.RemoteAddr) - if err != nil { - if debug { - log.Println("Failed to split remote address", r.RemoteAddr) - } - http.Error(w, err.Error(), 500) - return + // Get the IP address of the client + rhost, _, err := net.SplitHostPort(r.RemoteAddr) + if err != nil { + if debug { + log.Println("Failed to split remote address", r.RemoteAddr) } + http.Error(w, err.Error(), 500) + return + } + + // The client did not provide an IP address, use the IP address of the client. + if host == "" { uri.Host = net.JoinHostPort(rhost, port) newRelay.URL = uri.String() + } else if host != rhost { + if debug { + log.Println("IP address advertised does not match client IP address", r.RemoteAddr, uri) + } + http.Error(w, "IP address does not match client IP", http.StatusUnauthorized) + return } newRelay.uri = uri @@ -242,56 +287,6 @@ func handlePostRequest(w http.ResponseWriter, r *http.Request) { } } -func loadPermanentRelays() { - path, err := osext.ExecutableFolder() - if err != nil { - log.Println("Failed to locate executable directory") - return - } - - content, err := ioutil.ReadFile(filepath.Join(path, "relays")) - if err != nil { - return - } - - for _, line := range strings.Split(string(content), "\n") { - if len(line) == 0 { - continue - } - - uri, err := url.Parse(line) - if err != nil { - if debug { - log.Println("Skipping permanent relay", line, "due to parse error", err) - } - continue - } - - permanentRelays = append(permanentRelays, relay{ - URL: line, - uri: uri, - }) - if debug { - log.Println("Adding permanent relay", line) - } - } -} - -func loadOrCreateTestCertificate() { - certFile, keyFile := filepath.Join(binDir, "cert.pem"), filepath.Join(binDir, "key.pem") - cert, err := tls.LoadX509KeyPair(certFile, keyFile) - if err == nil { - testCert = []tls.Certificate{cert} - return - } - - cert, err = tlsutil.NewCertificate(certFile, keyFile, "relaypoolsrv", 3072) - if err != nil { - log.Fatalln("Failed to create test X509 key pair:", err) - } - testCert = []tls.Certificate{cert} -} - func requestProcessor() { for request := range requests { if debug { @@ -356,3 +351,76 @@ func evict(relay relay) func() { delete(evictionTimers, relay.uri.Host) } } + +func limit(addr string, cache *lru.Cache, lock sync.RWMutex, rate time.Duration, burst int64) bool { + host, _, err := net.SplitHostPort(addr) + if err != nil { + return false + } + + lock.RLock() + bkt, ok := cache.Get(host) + lock.RUnlock() + if ok { + bkt := bkt.(*ratelimit.Bucket) + if bkt.TakeAvailable(1) != 1 { + // Rate limit + return true + } + } else { + lock.Lock() + cache.Add(host, ratelimit.NewBucket(rate, burst)) + lock.Unlock() + } + return false +} + +func loadPermanentRelays() { + path, err := osext.ExecutableFolder() + if err != nil { + log.Println("Failed to locate executable directory") + return + } + + content, err := ioutil.ReadFile(filepath.Join(path, "relays")) + if err != nil { + return + } + + for _, line := range strings.Split(string(content), "\n") { + if len(line) == 0 { + continue + } + + uri, err := url.Parse(line) + if err != nil { + if debug { + log.Println("Skipping permanent relay", line, "due to parse error", err) + } + continue + } + + permanentRelays = append(permanentRelays, relay{ + URL: line, + uri: uri, + }) + if debug { + log.Println("Adding permanent relay", line) + } + } +} + +func loadOrCreateTestCertificate() { + certFile, keyFile := filepath.Join(binDir, "cert.pem"), filepath.Join(binDir, "key.pem") + cert, err := tls.LoadX509KeyPair(certFile, keyFile) + if err == nil { + testCert = []tls.Certificate{cert} + return + } + + cert, err = tlsutil.NewCertificate(certFile, keyFile, "relaypoolsrv", 3072) + if err != nil { + log.Fatalln("Failed to create test X509 key pair:", err) + } + testCert = []tls.Certificate{cert} +}