cmd/strelaypoolsrv: More compact response, improved metrics

This commit is contained in:
Jakob Borg 2024-06-03 07:14:45 +02:00
parent f283215fce
commit 18a58a2ddc
4 changed files with 106 additions and 58 deletions

View File

@ -259,7 +259,7 @@
return a.value > b.value ? 1 : -1;
}
$http.get("/endpoint").then(function(response) {
$http.get("/endpoint/full").then(function(response) {
$scope.relays = response.data.relays;
angular.forEach($scope.relays, function(relay) {

View File

@ -27,9 +27,7 @@ import (
"github.com/syncthing/syncthing/lib/assets"
_ "github.com/syncthing/syncthing/lib/automaxprocs"
"github.com/syncthing/syncthing/lib/geoip"
"github.com/syncthing/syncthing/lib/httpcache"
"github.com/syncthing/syncthing/lib/protocol"
"github.com/syncthing/syncthing/lib/rand"
"github.com/syncthing/syncthing/lib/relay/client"
"github.com/syncthing/syncthing/lib/sync"
"github.com/syncthing/syncthing/lib/tlsutil"
@ -51,6 +49,10 @@ type relay struct {
StatsRetrieved time.Time `json:"statsRetrieved"`
}
type relayShort struct {
URL string `json:"url"`
}
type stats struct {
StartTime time.Time `json:"startTime"`
UptimeSeconds int `json:"uptimeSeconds"`
@ -95,6 +97,7 @@ var (
testCert tls.Certificate
knownRelaysFile = filepath.Join(os.TempDir(), "strelaypoolsrv_known_relays")
listen = ":80"
metricsListen = ":8081"
dir string
evictionTime = time.Hour
debug bool
@ -125,6 +128,7 @@ func main() {
log.SetFlags(log.Lshortfile)
flag.StringVar(&listen, "listen", listen, "Listen address")
flag.StringVar(&metricsListen, "metrics-listen", metricsListen, "Metrics listen address")
flag.StringVar(&dir, "keys", dir, "Directory where http-cert.pem and http-key.pem is stored for TLS listening")
flag.BoolVar(&debug, "debug", debug, "Enable debug output")
flag.DurationVar(&evictionTime, "eviction", evictionTime, "After how long the relay is evicted")
@ -218,15 +222,40 @@ func main() {
log.Fatalln("listen:", err)
}
handler := http.NewServeMux()
handler.HandleFunc("/", handleAssets)
handler.Handle("/endpoint", httpcache.SinglePath(http.HandlerFunc(handleRequest), 15*time.Second))
handler.HandleFunc("/metrics", handleMetrics)
if metricsListen != "" {
mmux := http.NewServeMux()
mmux.HandleFunc("/metrics", handleMetrics)
go func() {
if err := http.ListenAndServe(metricsListen, mmux); err != nil {
log.Fatalln("HTTP serve metrics:", err)
}
}()
}
getMux := http.NewServeMux()
getMux.HandleFunc("/", handleAssets)
getMux.HandleFunc("/endpoint", withAPIMetrics(handleEndpointShort))
getMux.HandleFunc("/endpoint/full", withAPIMetrics(handleEndpointFull))
postMux := http.NewServeMux()
postMux.HandleFunc("/endpoint", withAPIMetrics(handleRegister))
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.Method {
case http.MethodGet, http.MethodHead, http.MethodOptions:
getMux.ServeHTTP(w, r)
case http.MethodPost:
postMux.ServeHTTP(w, r)
default:
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
}
})
srv := http.Server{
Handler: handler,
ReadTimeout: 10 * time.Second,
}
srv.SetKeepAlivesEnabled(false)
err = srv.Serve(listener)
if err != nil {
@ -260,39 +289,24 @@ func handleAssets(w http.ResponseWriter, r *http.Request) {
assets.Serve(w, r, as)
}
func handleRequest(w http.ResponseWriter, r *http.Request) {
func withAPIMetrics(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
timer := prometheus.NewTimer(apiRequestsSeconds.WithLabelValues(r.Method))
w = NewLoggingResponseWriter(w)
defer func() {
timer.ObserveDuration()
lw := w.(*loggingResponseWriter)
apiRequestsTotal.WithLabelValues(r.Method, strconv.Itoa(lw.statusCode)).Inc()
}()
if ipHeader != "" {
hdr := r.Header.Get(ipHeader)
fields := strings.Split(hdr, ",")
if len(fields) > 0 {
r.RemoteAddr = strings.TrimSpace(fields[len(fields)-1])
}
}
w.Header().Set("Access-Control-Allow-Origin", "*")
switch r.Method {
case "GET":
handleGetRequest(w, r)
case "POST":
handlePostRequest(w, r)
default:
if debug {
log.Println("Unhandled HTTP method", r.Method)
}
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
next(w, r)
}
}
func handleGetRequest(rw http.ResponseWriter, r *http.Request) {
// handleEndpointFull returns the relay list with full metadata and
// statistics. Large, and expensive.
func handleEndpointFull(rw http.ResponseWriter, r *http.Request) {
rw.Header().Set("Content-Type", "application/json; charset=utf-8")
rw.Header().Set("Access-Control-Allow-Origin", "*")
mut.RLock()
relays := make([]*relay, len(permanentRelays)+len(knownRelays))
@ -300,17 +314,38 @@ func handleGetRequest(rw http.ResponseWriter, r *http.Request) {
copy(relays[n:], knownRelays)
mut.RUnlock()
// Shuffle
rand.Shuffle(relays)
_ = json.NewEncoder(rw).Encode(map[string][]*relay{
"relays": relays,
})
}
func handlePostRequest(w http.ResponseWriter, r *http.Request) {
// handleEndpointShort returns the relay list with only the URL.
func handleEndpointShort(rw http.ResponseWriter, r *http.Request) {
rw.Header().Set("Content-Type", "application/json; charset=utf-8")
rw.Header().Set("Access-Control-Allow-Origin", "*")
mut.RLock()
relays := make([]relayShort, 0, len(permanentRelays)+len(knownRelays))
for _, r := range append(permanentRelays, knownRelays...) {
relays = append(relays, relayShort{URL: slimURL(r.URL)})
}
mut.RUnlock()
_ = json.NewEncoder(rw).Encode(map[string][]relayShort{
"relays": relays,
})
}
func handleRegister(w http.ResponseWriter, r *http.Request) {
// Get the IP address of the client
rhost := r.RemoteAddr
if ipHeader != "" {
hdr := r.Header.Get(ipHeader)
fields := strings.Split(hdr, ",")
if len(fields) > 0 {
rhost = strings.TrimSpace(fields[len(fields)-1])
}
}
if host, _, err := net.SplitHostPort(rhost); err == nil {
rhost = host
}
@ -660,3 +695,16 @@ func (b *errorTracker) IsBlocked(host string) bool {
}
return false
}
func slimURL(u string) string {
p, err := url.Parse(u)
if err != nil {
return u
}
newQuery := url.Values{}
if id := p.Query().Get("id"); id != "" {
newQuery.Set("id", id)
}
p.RawQuery = newQuery.Encode()
return p.String()
}

View File

@ -42,7 +42,7 @@ func TestHandleGetRequest(t *testing.T) {
w := httptest.NewRecorder()
w.Body = new(bytes.Buffer)
handleGetRequest(w, httptest.NewRequest("GET", "/", nil))
handleEndpointFull(w, httptest.NewRequest("GET", "/", nil))
result := make(map[string][]*relay)
err := json.NewDecoder(w.Body).Decode(&result)
@ -92,3 +92,18 @@ func TestCanonicalizeQueryValues(t *testing.T) {
t.Errorf("expected %q, got %q", exp, str)
}
}
func TestSlimURL(t *testing.T) {
cases := []struct {
in, out string
}{
{"http://example.com/", "http://example.com/"},
{"relay://192.0.2.42:22067/?globalLimitBps=0&id=EIC6B3M-EIC6B3M-EIC6B3M-EIC6B3M-EIC6B3M-EIC6B3M-EIC6B3M-EIC6B3M&networkTimeout=2m0s&pingInterval=1m0s&providedBy=Test&sessionLimitBps=0&statusAddr=%3A22070", "relay://192.0.2.42:22067/?id=EIC6B3M-EIC6B3M-EIC6B3M-EIC6B3M-EIC6B3M-EIC6B3M-EIC6B3M-EIC6B3M"},
}
for _, c := range cases {
if got := slimURL(c.in); got != c.out {
t.Errorf("expected %q, got %q", c.out, got)
}
}
}

View File

@ -6,27 +6,12 @@ import (
"encoding/json"
"net"
"net/http"
"os"
"time"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/collectors"
"github.com/syncthing/syncthing/lib/sync"
)
func init() {
processCollectorOpts := collectors.ProcessCollectorOpts{
Namespace: "syncthing_relaypoolsrv",
PidFn: func() (int, error) {
return os.Getpid(), nil
},
}
prometheus.MustRegister(
collectors.NewProcessCollector(processCollectorOpts),
)
}
var (
statusClient = http.Client{
Timeout: 5 * time.Second,