cmd/ursrv: Refactor to use CLI options, fewer global vars

This commit is contained in:
Jakob Borg 2023-07-10 08:27:12 +02:00
parent b2886f11b1
commit bf61e485a6

View File

@ -25,6 +25,7 @@ import (
"time"
"unicode"
"github.com/alecthomas/kong"
_ "github.com/lib/pq" // PostgreSQL driver
"github.com/oschwald/geoip2-golang"
"golang.org/x/text/cases"
@ -34,14 +35,17 @@ import (
"github.com/syncthing/syncthing/lib/ur/contract"
)
type CLI struct {
UseHTTP bool `env:"UR_USE_HTTP"`
Debug bool `env:"UR_DEBUG"`
KeyFile string `env:"UR_KEY_FILE" default:"key.pem"`
CertFile string `env:"UR_CRT_FILE" default:"crt.pem"`
DBConn string `env:"UR_DB_URL" default:"postgres://user:password@localhost/ur?sslmode=disable"`
Listen string `env:"UR_LISTEN" default:"0.0.0.0:8443"`
GeoIPPath string `env:"UR_GEOIP" default:"GeoLite2-City.mmdb"`
}
var (
useHTTP = os.Getenv("UR_USE_HTTP") != ""
debug = os.Getenv("UR_DEBUG") != ""
keyFile = getEnvDefault("UR_KEY_FILE", "key.pem")
certFile = getEnvDefault("UR_CRT_FILE", "crt.pem")
dbConn = getEnvDefault("UR_DB_URL", "postgres://user:password@localhost/ur?sslmode=disable")
listenAddr = getEnvDefault("UR_LISTEN", "0.0.0.0:8443")
geoIPPath = getEnvDefault("UR_GEOIP", "GeoLite2-City.mmdb")
tpl *template.Template
compilerRe = regexp.MustCompile(`\(([A-Za-z0-9()., -]+) \w+-\w+(?:| android| default)\) ([\w@.-]+)`)
progressBarClass = []string{"", "progress-bar-success", "progress-bar-info", "progress-bar-warning", "progress-bar-danger"}
@ -159,6 +163,9 @@ func main() {
log.SetFlags(log.Ltime | log.Ldate | log.Lshortfile)
log.SetOutput(os.Stdout)
var cli CLI
kong.Parse(&cli)
// Template
fd, err := os.Open("static/index.html")
@ -174,7 +181,7 @@ func main() {
// DB
db, err := sql.Open("postgres", dbConn)
db, err := sql.Open("postgres", cli.DBConn)
if err != nil {
log.Fatalln("database:", err)
}
@ -186,11 +193,11 @@ func main() {
// TLS & Listening
var listener net.Listener
if useHTTP {
listener, err = net.Listen("tcp", listenAddr)
if cli.UseHTTP {
listener, err = net.Listen("tcp", cli.Listen)
} else {
var cert tls.Certificate
cert, err = tls.LoadX509KeyPair(certFile, keyFile)
cert, err = tls.LoadX509KeyPair(cli.CertFile, cli.KeyFile)
if err != nil {
log.Fatalln("tls:", err)
}
@ -199,81 +206,89 @@ func main() {
Certificates: []tls.Certificate{cert},
SessionTicketsDisabled: true,
}
listener, err = tls.Listen("tcp", listenAddr, cfg)
listener, err = tls.Listen("tcp", cli.Listen, cfg)
}
if err != nil {
log.Fatalln("listen:", err)
}
srv := http.Server{
srv := &server{
db: db,
debug: cli.Debug,
geoIPPath: cli.GeoIPPath,
}
http.HandleFunc("/", srv.rootHandler)
http.HandleFunc("/newdata", srv.newDataHandler)
http.HandleFunc("/summary.json", srv.summaryHandler)
http.HandleFunc("/movement.json", srv.movementHandler)
http.HandleFunc("/performance.json", srv.performanceHandler)
http.HandleFunc("/blockstats.json", srv.blockStatsHandler)
http.HandleFunc("/locations.json", srv.locationsHandler)
http.Handle("/static/", http.StripPrefix("/static/", http.FileServer(http.Dir("static"))))
go srv.cacheRefresher()
httpSrv := http.Server{
ReadTimeout: 5 * time.Second,
WriteTimeout: 15 * time.Second,
}
http.HandleFunc("/", withDB(db, rootHandler))
http.HandleFunc("/newdata", withDB(db, newDataHandler))
http.HandleFunc("/summary.json", withDB(db, summaryHandler))
http.HandleFunc("/movement.json", withDB(db, movementHandler))
http.HandleFunc("/performance.json", withDB(db, performanceHandler))
http.HandleFunc("/blockstats.json", withDB(db, blockStatsHandler))
http.HandleFunc("/locations.json", withDB(db, locationsHandler))
http.Handle("/static/", http.StripPrefix("/static/", http.FileServer(http.Dir("static"))))
go cacheRefresher(db)
err = srv.Serve(listener)
err = httpSrv.Serve(listener)
if err != nil {
log.Fatalln("https:", err)
}
}
var (
type server struct {
debug bool
db *sql.DB
geoIPPath string
cacheMut sync.Mutex
cachedIndex []byte
cachedLocations []byte
cacheTime time.Time
cacheMut sync.Mutex
)
}
const maxCacheTime = 15 * time.Minute
func cacheRefresher(db *sql.DB) {
func (s *server) cacheRefresher() {
ticker := time.NewTicker(maxCacheTime - time.Minute)
defer ticker.Stop()
for ; true; <-ticker.C {
cacheMut.Lock()
if err := refreshCacheLocked(db); err != nil {
s.cacheMut.Lock()
if err := s.refreshCacheLocked(); err != nil {
log.Println(err)
}
cacheMut.Unlock()
s.cacheMut.Unlock()
}
}
func refreshCacheLocked(db *sql.DB) error {
rep := getReport(db)
func (s *server) refreshCacheLocked() error {
rep := getReport(s.db, s.geoIPPath)
buf := new(bytes.Buffer)
err := tpl.Execute(buf, rep)
if err != nil {
return err
}
cachedIndex = buf.Bytes()
cacheTime = time.Now()
s.cachedIndex = buf.Bytes()
s.cacheTime = time.Now()
locs := rep["locations"].(map[location]int)
wlocs := make([]weightedLocation, 0, len(locs))
for loc, w := range locs {
wlocs = append(wlocs, weightedLocation{loc, w})
}
cachedLocations, _ = json.Marshal(wlocs)
s.cachedLocations, _ = json.Marshal(wlocs)
return nil
}
func rootHandler(db *sql.DB, w http.ResponseWriter, r *http.Request) {
func (s *server) rootHandler(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/" || r.URL.Path == "/index.html" {
cacheMut.Lock()
defer cacheMut.Unlock()
s.cacheMut.Lock()
defer s.cacheMut.Unlock()
if time.Since(cacheTime) > maxCacheTime {
if err := refreshCacheLocked(db); err != nil {
if time.Since(s.cacheTime) > maxCacheTime {
if err := s.refreshCacheLocked(); err != nil {
log.Println(err)
http.Error(w, "Template Error", http.StatusInternalServerError)
return
@ -281,19 +296,19 @@ func rootHandler(db *sql.DB, w http.ResponseWriter, r *http.Request) {
}
w.Header().Set("Content-Type", "text/html; charset=utf-8")
w.Write(cachedIndex)
w.Write(s.cachedIndex)
} else {
http.Error(w, "Not found", 404)
return
}
}
func locationsHandler(db *sql.DB, w http.ResponseWriter, _ *http.Request) {
cacheMut.Lock()
defer cacheMut.Unlock()
func (s *server) locationsHandler(w http.ResponseWriter, _ *http.Request) {
s.cacheMut.Lock()
defer s.cacheMut.Unlock()
if time.Since(cacheTime) > maxCacheTime {
if err := refreshCacheLocked(db); err != nil {
if time.Since(s.cacheTime) > maxCacheTime {
if err := s.refreshCacheLocked(); err != nil {
log.Println(err)
http.Error(w, "Template Error", http.StatusInternalServerError)
return
@ -301,10 +316,10 @@ func locationsHandler(db *sql.DB, w http.ResponseWriter, _ *http.Request) {
}
w.Header().Set("Content-Type", "application/json; charset=utf-8")
w.Write(cachedLocations)
w.Write(s.cachedLocations)
}
func newDataHandler(db *sql.DB, w http.ResponseWriter, r *http.Request) {
func (s *server) newDataHandler(w http.ResponseWriter, r *http.Request) {
defer r.Body.Close()
addr := r.Header.Get("X-Forwarded-For")
@ -330,7 +345,7 @@ func newDataHandler(db *sql.DB, w http.ResponseWriter, r *http.Request) {
bs, _ := io.ReadAll(lr)
if err := json.Unmarshal(bs, &rep); err != nil {
log.Println("decode:", err)
if debug {
if s.debug {
log.Printf("%s", bs)
}
http.Error(w, "JSON Decode Error", http.StatusInternalServerError)
@ -339,21 +354,21 @@ func newDataHandler(db *sql.DB, w http.ResponseWriter, r *http.Request) {
if err := rep.Validate(); err != nil {
log.Println("validate:", err)
if debug {
if s.debug {
log.Printf("%#v", rep)
}
http.Error(w, "Validation Error", http.StatusInternalServerError)
return
}
if err := insertReport(db, rep); err != nil {
if err := insertReport(s.db, rep); err != nil {
if err.Error() == `pq: duplicate key value violates unique constraint "uniqueidjsonindex"` {
// We already have a report today for the same unique ID; drop
// this one without complaining.
return
}
log.Println("insert:", err)
if debug {
if s.debug {
log.Printf("%#v", rep)
}
http.Error(w, "Database Error", http.StatusInternalServerError)
@ -361,16 +376,16 @@ func newDataHandler(db *sql.DB, w http.ResponseWriter, r *http.Request) {
}
}
func summaryHandler(db *sql.DB, w http.ResponseWriter, r *http.Request) {
func (s *server) summaryHandler(w http.ResponseWriter, r *http.Request) {
min, _ := strconv.Atoi(r.URL.Query().Get("min"))
s, err := getSummary(db, min)
sum, err := getSummary(s.db, min)
if err != nil {
log.Println("summaryHandler:", err)
http.Error(w, "Database Error", http.StatusInternalServerError)
return
}
bs, err := s.MarshalJSON()
bs, err := sum.MarshalJSON()
if err != nil {
log.Println("summaryHandler:", err)
http.Error(w, "JSON Encode Error", http.StatusInternalServerError)
@ -381,15 +396,15 @@ func summaryHandler(db *sql.DB, w http.ResponseWriter, r *http.Request) {
w.Write(bs)
}
func movementHandler(db *sql.DB, w http.ResponseWriter, _ *http.Request) {
s, err := getMovement(db)
func (s *server) movementHandler(w http.ResponseWriter, _ *http.Request) {
mov, err := getMovement(s.db)
if err != nil {
log.Println("movementHandler:", err)
http.Error(w, "Database Error", http.StatusInternalServerError)
return
}
bs, err := json.Marshal(s)
bs, err := json.Marshal(mov)
if err != nil {
log.Println("movementHandler:", err)
http.Error(w, "JSON Encode Error", http.StatusInternalServerError)
@ -400,15 +415,15 @@ func movementHandler(db *sql.DB, w http.ResponseWriter, _ *http.Request) {
w.Write(bs)
}
func performanceHandler(db *sql.DB, w http.ResponseWriter, _ *http.Request) {
s, err := getPerformance(db)
func (s *server) performanceHandler(w http.ResponseWriter, _ *http.Request) {
perf, err := getPerformance(s.db)
if err != nil {
log.Println("performanceHandler:", err)
http.Error(w, "Database Error", http.StatusInternalServerError)
return
}
bs, err := json.Marshal(s)
bs, err := json.Marshal(perf)
if err != nil {
log.Println("performanceHandler:", err)
http.Error(w, "JSON Encode Error", http.StatusInternalServerError)
@ -419,15 +434,15 @@ func performanceHandler(db *sql.DB, w http.ResponseWriter, _ *http.Request) {
w.Write(bs)
}
func blockStatsHandler(db *sql.DB, w http.ResponseWriter, _ *http.Request) {
s, err := getBlockStats(db)
func (s *server) blockStatsHandler(w http.ResponseWriter, _ *http.Request) {
blocks, err := getBlockStats(s.db)
if err != nil {
log.Println("blockStatsHandler:", err)
http.Error(w, "Database Error", http.StatusInternalServerError)
return
}
bs, err := json.Marshal(s)
bs, err := json.Marshal(blocks)
if err != nil {
log.Println("blockStatsHandler:", err)
http.Error(w, "JSON Encode Error", http.StatusInternalServerError)
@ -513,7 +528,7 @@ type weightedLocation struct {
Weight int `json:"weight"`
}
func getReport(db *sql.DB) map[string]interface{} {
func getReport(db *sql.DB, geoIPPath string) map[string]interface{} {
geoip, err := geoip2.Open(geoIPPath)
if err != nil {
log.Println("opening geoip db", err)