From bf61e485a6e3b51b87b97c183325302696508241 Mon Sep 17 00:00:00 2001 From: Jakob Borg Date: Mon, 10 Jul 2023 08:27:12 +0200 Subject: [PATCH] cmd/ursrv: Refactor to use CLI options, fewer global vars --- cmd/ursrv/main.go | 151 +++++++++++++++++++++++++--------------------- 1 file changed, 83 insertions(+), 68 deletions(-) diff --git a/cmd/ursrv/main.go b/cmd/ursrv/main.go index c69559464..76014d268 100644 --- a/cmd/ursrv/main.go +++ b/cmd/ursrv/main.go @@ -25,6 +25,7 @@ import ( "time" "unicode" + "github.com/alecthomas/kong" _ "github.com/lib/pq" // PostgreSQL driver "github.com/oschwald/geoip2-golang" "golang.org/x/text/cases" @@ -34,14 +35,17 @@ import ( "github.com/syncthing/syncthing/lib/ur/contract" ) +type CLI struct { + UseHTTP bool `env:"UR_USE_HTTP"` + Debug bool `env:"UR_DEBUG"` + KeyFile string `env:"UR_KEY_FILE" default:"key.pem"` + CertFile string `env:"UR_CRT_FILE" default:"crt.pem"` + DBConn string `env:"UR_DB_URL" default:"postgres://user:password@localhost/ur?sslmode=disable"` + Listen string `env:"UR_LISTEN" default:"0.0.0.0:8443"` + GeoIPPath string `env:"UR_GEOIP" default:"GeoLite2-City.mmdb"` +} + var ( - useHTTP = os.Getenv("UR_USE_HTTP") != "" - debug = os.Getenv("UR_DEBUG") != "" - keyFile = getEnvDefault("UR_KEY_FILE", "key.pem") - certFile = getEnvDefault("UR_CRT_FILE", "crt.pem") - dbConn = getEnvDefault("UR_DB_URL", "postgres://user:password@localhost/ur?sslmode=disable") - listenAddr = getEnvDefault("UR_LISTEN", "0.0.0.0:8443") - geoIPPath = getEnvDefault("UR_GEOIP", "GeoLite2-City.mmdb") tpl *template.Template compilerRe = regexp.MustCompile(`\(([A-Za-z0-9()., -]+) \w+-\w+(?:| android| default)\) ([\w@.-]+)`) progressBarClass = []string{"", "progress-bar-success", "progress-bar-info", "progress-bar-warning", "progress-bar-danger"} @@ -159,6 +163,9 @@ func main() { log.SetFlags(log.Ltime | log.Ldate | log.Lshortfile) log.SetOutput(os.Stdout) + var cli CLI + kong.Parse(&cli) + // Template fd, err := os.Open("static/index.html") @@ -174,7 +181,7 @@ func main() { // DB - db, err := sql.Open("postgres", dbConn) + db, err := sql.Open("postgres", cli.DBConn) if err != nil { log.Fatalln("database:", err) } @@ -186,11 +193,11 @@ func main() { // TLS & Listening var listener net.Listener - if useHTTP { - listener, err = net.Listen("tcp", listenAddr) + if cli.UseHTTP { + listener, err = net.Listen("tcp", cli.Listen) } else { var cert tls.Certificate - cert, err = tls.LoadX509KeyPair(certFile, keyFile) + cert, err = tls.LoadX509KeyPair(cli.CertFile, cli.KeyFile) if err != nil { log.Fatalln("tls:", err) } @@ -199,81 +206,89 @@ func main() { Certificates: []tls.Certificate{cert}, SessionTicketsDisabled: true, } - listener, err = tls.Listen("tcp", listenAddr, cfg) + listener, err = tls.Listen("tcp", cli.Listen, cfg) } if err != nil { log.Fatalln("listen:", err) } - srv := http.Server{ + srv := &server{ + db: db, + debug: cli.Debug, + geoIPPath: cli.GeoIPPath, + } + http.HandleFunc("/", srv.rootHandler) + http.HandleFunc("/newdata", srv.newDataHandler) + http.HandleFunc("/summary.json", srv.summaryHandler) + http.HandleFunc("/movement.json", srv.movementHandler) + http.HandleFunc("/performance.json", srv.performanceHandler) + http.HandleFunc("/blockstats.json", srv.blockStatsHandler) + http.HandleFunc("/locations.json", srv.locationsHandler) + http.Handle("/static/", http.StripPrefix("/static/", http.FileServer(http.Dir("static")))) + + go srv.cacheRefresher() + + httpSrv := http.Server{ ReadTimeout: 5 * time.Second, WriteTimeout: 15 * time.Second, } - - http.HandleFunc("/", withDB(db, rootHandler)) - http.HandleFunc("/newdata", withDB(db, newDataHandler)) - http.HandleFunc("/summary.json", withDB(db, summaryHandler)) - http.HandleFunc("/movement.json", withDB(db, movementHandler)) - http.HandleFunc("/performance.json", withDB(db, performanceHandler)) - http.HandleFunc("/blockstats.json", withDB(db, blockStatsHandler)) - http.HandleFunc("/locations.json", withDB(db, locationsHandler)) - http.Handle("/static/", http.StripPrefix("/static/", http.FileServer(http.Dir("static")))) - - go cacheRefresher(db) - - err = srv.Serve(listener) + err = httpSrv.Serve(listener) if err != nil { log.Fatalln("https:", err) } } -var ( +type server struct { + debug bool + db *sql.DB + geoIPPath string + + cacheMut sync.Mutex cachedIndex []byte cachedLocations []byte cacheTime time.Time - cacheMut sync.Mutex -) +} const maxCacheTime = 15 * time.Minute -func cacheRefresher(db *sql.DB) { +func (s *server) cacheRefresher() { ticker := time.NewTicker(maxCacheTime - time.Minute) defer ticker.Stop() for ; true; <-ticker.C { - cacheMut.Lock() - if err := refreshCacheLocked(db); err != nil { + s.cacheMut.Lock() + if err := s.refreshCacheLocked(); err != nil { log.Println(err) } - cacheMut.Unlock() + s.cacheMut.Unlock() } } -func refreshCacheLocked(db *sql.DB) error { - rep := getReport(db) +func (s *server) refreshCacheLocked() error { + rep := getReport(s.db, s.geoIPPath) buf := new(bytes.Buffer) err := tpl.Execute(buf, rep) if err != nil { return err } - cachedIndex = buf.Bytes() - cacheTime = time.Now() + s.cachedIndex = buf.Bytes() + s.cacheTime = time.Now() locs := rep["locations"].(map[location]int) wlocs := make([]weightedLocation, 0, len(locs)) for loc, w := range locs { wlocs = append(wlocs, weightedLocation{loc, w}) } - cachedLocations, _ = json.Marshal(wlocs) + s.cachedLocations, _ = json.Marshal(wlocs) return nil } -func rootHandler(db *sql.DB, w http.ResponseWriter, r *http.Request) { +func (s *server) rootHandler(w http.ResponseWriter, r *http.Request) { if r.URL.Path == "/" || r.URL.Path == "/index.html" { - cacheMut.Lock() - defer cacheMut.Unlock() + s.cacheMut.Lock() + defer s.cacheMut.Unlock() - if time.Since(cacheTime) > maxCacheTime { - if err := refreshCacheLocked(db); err != nil { + if time.Since(s.cacheTime) > maxCacheTime { + if err := s.refreshCacheLocked(); err != nil { log.Println(err) http.Error(w, "Template Error", http.StatusInternalServerError) return @@ -281,19 +296,19 @@ func rootHandler(db *sql.DB, w http.ResponseWriter, r *http.Request) { } w.Header().Set("Content-Type", "text/html; charset=utf-8") - w.Write(cachedIndex) + w.Write(s.cachedIndex) } else { http.Error(w, "Not found", 404) return } } -func locationsHandler(db *sql.DB, w http.ResponseWriter, _ *http.Request) { - cacheMut.Lock() - defer cacheMut.Unlock() +func (s *server) locationsHandler(w http.ResponseWriter, _ *http.Request) { + s.cacheMut.Lock() + defer s.cacheMut.Unlock() - if time.Since(cacheTime) > maxCacheTime { - if err := refreshCacheLocked(db); err != nil { + if time.Since(s.cacheTime) > maxCacheTime { + if err := s.refreshCacheLocked(); err != nil { log.Println(err) http.Error(w, "Template Error", http.StatusInternalServerError) return @@ -301,10 +316,10 @@ func locationsHandler(db *sql.DB, w http.ResponseWriter, _ *http.Request) { } w.Header().Set("Content-Type", "application/json; charset=utf-8") - w.Write(cachedLocations) + w.Write(s.cachedLocations) } -func newDataHandler(db *sql.DB, w http.ResponseWriter, r *http.Request) { +func (s *server) newDataHandler(w http.ResponseWriter, r *http.Request) { defer r.Body.Close() addr := r.Header.Get("X-Forwarded-For") @@ -330,7 +345,7 @@ func newDataHandler(db *sql.DB, w http.ResponseWriter, r *http.Request) { bs, _ := io.ReadAll(lr) if err := json.Unmarshal(bs, &rep); err != nil { log.Println("decode:", err) - if debug { + if s.debug { log.Printf("%s", bs) } http.Error(w, "JSON Decode Error", http.StatusInternalServerError) @@ -339,21 +354,21 @@ func newDataHandler(db *sql.DB, w http.ResponseWriter, r *http.Request) { if err := rep.Validate(); err != nil { log.Println("validate:", err) - if debug { + if s.debug { log.Printf("%#v", rep) } http.Error(w, "Validation Error", http.StatusInternalServerError) return } - if err := insertReport(db, rep); err != nil { + if err := insertReport(s.db, rep); err != nil { if err.Error() == `pq: duplicate key value violates unique constraint "uniqueidjsonindex"` { // We already have a report today for the same unique ID; drop // this one without complaining. return } log.Println("insert:", err) - if debug { + if s.debug { log.Printf("%#v", rep) } http.Error(w, "Database Error", http.StatusInternalServerError) @@ -361,16 +376,16 @@ func newDataHandler(db *sql.DB, w http.ResponseWriter, r *http.Request) { } } -func summaryHandler(db *sql.DB, w http.ResponseWriter, r *http.Request) { +func (s *server) summaryHandler(w http.ResponseWriter, r *http.Request) { min, _ := strconv.Atoi(r.URL.Query().Get("min")) - s, err := getSummary(db, min) + sum, err := getSummary(s.db, min) if err != nil { log.Println("summaryHandler:", err) http.Error(w, "Database Error", http.StatusInternalServerError) return } - bs, err := s.MarshalJSON() + bs, err := sum.MarshalJSON() if err != nil { log.Println("summaryHandler:", err) http.Error(w, "JSON Encode Error", http.StatusInternalServerError) @@ -381,15 +396,15 @@ func summaryHandler(db *sql.DB, w http.ResponseWriter, r *http.Request) { w.Write(bs) } -func movementHandler(db *sql.DB, w http.ResponseWriter, _ *http.Request) { - s, err := getMovement(db) +func (s *server) movementHandler(w http.ResponseWriter, _ *http.Request) { + mov, err := getMovement(s.db) if err != nil { log.Println("movementHandler:", err) http.Error(w, "Database Error", http.StatusInternalServerError) return } - bs, err := json.Marshal(s) + bs, err := json.Marshal(mov) if err != nil { log.Println("movementHandler:", err) http.Error(w, "JSON Encode Error", http.StatusInternalServerError) @@ -400,15 +415,15 @@ func movementHandler(db *sql.DB, w http.ResponseWriter, _ *http.Request) { w.Write(bs) } -func performanceHandler(db *sql.DB, w http.ResponseWriter, _ *http.Request) { - s, err := getPerformance(db) +func (s *server) performanceHandler(w http.ResponseWriter, _ *http.Request) { + perf, err := getPerformance(s.db) if err != nil { log.Println("performanceHandler:", err) http.Error(w, "Database Error", http.StatusInternalServerError) return } - bs, err := json.Marshal(s) + bs, err := json.Marshal(perf) if err != nil { log.Println("performanceHandler:", err) http.Error(w, "JSON Encode Error", http.StatusInternalServerError) @@ -419,15 +434,15 @@ func performanceHandler(db *sql.DB, w http.ResponseWriter, _ *http.Request) { w.Write(bs) } -func blockStatsHandler(db *sql.DB, w http.ResponseWriter, _ *http.Request) { - s, err := getBlockStats(db) +func (s *server) blockStatsHandler(w http.ResponseWriter, _ *http.Request) { + blocks, err := getBlockStats(s.db) if err != nil { log.Println("blockStatsHandler:", err) http.Error(w, "Database Error", http.StatusInternalServerError) return } - bs, err := json.Marshal(s) + bs, err := json.Marshal(blocks) if err != nil { log.Println("blockStatsHandler:", err) http.Error(w, "JSON Encode Error", http.StatusInternalServerError) @@ -513,7 +528,7 @@ type weightedLocation struct { Weight int `json:"weight"` } -func getReport(db *sql.DB) map[string]interface{} { +func getReport(db *sql.DB, geoIPPath string) map[string]interface{} { geoip, err := geoip2.Open(geoIPPath) if err != nil { log.Println("opening geoip db", err)