chore(stdiscosrv): clean up s3 handling

This commit is contained in:
Jakob Borg 2024-09-11 11:31:09 +02:00
parent 63e4659282
commit 6505e123bb
No known key found for this signature in database
5 changed files with 164 additions and 92 deletions

View File

@ -107,7 +107,7 @@ func addr(host string, port int) *net.TCPAddr {
} }
func BenchmarkAPIRequests(b *testing.B) { func BenchmarkAPIRequests(b *testing.B) {
db := newInMemoryStore(b.TempDir(), 0) db := newInMemoryStore(b.TempDir(), 0, nil)
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
go db.Serve(ctx) go db.Serve(ctx)

View File

@ -24,10 +24,6 @@ import (
"strings" "strings"
"time" "time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/s3"
"github.com/aws/aws-sdk-go/service/s3/s3manager"
"github.com/puzpuzpuz/xsync/v3" "github.com/puzpuzpuz/xsync/v3"
"github.com/syncthing/syncthing/lib/protocol" "github.com/syncthing/syncthing/lib/protocol"
) )
@ -52,25 +48,27 @@ type inMemoryStore struct {
m *xsync.MapOf[protocol.DeviceID, DatabaseRecord] m *xsync.MapOf[protocol.DeviceID, DatabaseRecord]
dir string dir string
flushInterval time.Duration flushInterval time.Duration
s3 *s3Copier
clock clock clock clock
} }
func newInMemoryStore(dir string, flushInterval time.Duration) *inMemoryStore { func newInMemoryStore(dir string, flushInterval time.Duration, s3 *s3Copier) *inMemoryStore {
s := &inMemoryStore{ s := &inMemoryStore{
m: xsync.NewMapOf[protocol.DeviceID, DatabaseRecord](), m: xsync.NewMapOf[protocol.DeviceID, DatabaseRecord](),
dir: dir, dir: dir,
flushInterval: flushInterval, flushInterval: flushInterval,
s3: s3,
clock: defaultClock{}, clock: defaultClock{},
} }
nr, err := s.read() nr, err := s.read()
if os.IsNotExist(err) { if os.IsNotExist(err) && s3 != nil {
// Try to read from AWS // Try to read from AWS
fd, cerr := os.Create(path.Join(s.dir, "records.db")) fd, cerr := os.Create(path.Join(s.dir, "records.db"))
if cerr != nil { if cerr != nil {
log.Println("Error creating database file:", err) log.Println("Error creating database file:", err)
return s return s
} }
if err := s3Download(fd); err != nil { if err := s3.downloadLatest(fd); err != nil {
log.Printf("Error reading database from S3: %v", err) log.Printf("Error reading database from S3: %v", err)
} }
_ = fd.Close() _ = fd.Close()
@ -278,16 +276,15 @@ func (s *inMemoryStore) write() (err error) {
return err return err
} }
if os.Getenv("PODINDEX") == "0" { // Upload to S3
// Upload to S3 if s.s3 != nil {
log.Println("Uploading database")
fd, err = os.Open(dbf) fd, err = os.Open(dbf)
if err != nil { if err != nil {
log.Printf("Error uploading database to S3: %v", err) log.Printf("Error uploading database to S3: %v", err)
return nil return nil
} }
defer fd.Close() defer fd.Close()
if err := s3Upload(fd); err != nil { if err := s.s3.upload(fd); err != nil {
log.Printf("Error uploading database to S3: %v", err) log.Printf("Error uploading database to S3: %v", err)
} }
log.Println("Finished uploading database") log.Println("Finished uploading database")
@ -424,39 +421,6 @@ func expire(addrs []DatabaseAddress, now time.Time) []DatabaseAddress {
return naddrs return naddrs
} }
func s3Upload(r io.Reader) error {
sess, err := session.NewSession(&aws.Config{
Region: aws.String("fr-par"),
Endpoint: aws.String("s3.fr-par.scw.cloud"),
})
if err != nil {
return err
}
uploader := s3manager.NewUploader(sess)
_, err = uploader.Upload(&s3manager.UploadInput{
Bucket: aws.String("syncthing-discovery"),
Key: aws.String("discovery.db"),
Body: r,
})
return err
}
func s3Download(w io.WriterAt) error {
sess, err := session.NewSession(&aws.Config{
Region: aws.String("fr-par"),
Endpoint: aws.String("s3.fr-par.scw.cloud"),
})
if err != nil {
return err
}
downloader := s3manager.NewDownloader(sess)
_, err = downloader.Download(w, &s3.GetObjectInput{
Bucket: aws.String("syncthing-discovery"),
Key: aws.String("discovery.db"),
})
return err
}
func (d DatabaseAddress) Cmp(other DatabaseAddress) (n int) { func (d DatabaseAddress) Cmp(other DatabaseAddress) (n int) {
if c := cmp.Compare(d.Address, other.Address); c != 0 { if c := cmp.Compare(d.Address, other.Address); c != 0 {
return c return c

View File

@ -16,7 +16,7 @@ import (
) )
func TestDatabaseGetSet(t *testing.T) { func TestDatabaseGetSet(t *testing.T) {
db := newInMemoryStore(t.TempDir(), 0) db := newInMemoryStore(t.TempDir(), 0, nil)
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
go db.Serve(ctx) go db.Serve(ctx)
defer cancel() defer cancel()

View File

@ -9,7 +9,6 @@ package main
import ( import (
"context" "context"
"crypto/tls" "crypto/tls"
"flag"
"log" "log"
"net" "net"
"net/http" "net/http"
@ -21,6 +20,7 @@ import (
_ "net/http/pprof" _ "net/http/pprof"
"github.com/alecthomas/kong"
"github.com/prometheus/client_golang/prometheus/promhttp" "github.com/prometheus/client_golang/prometheus/promhttp"
_ "github.com/syncthing/syncthing/lib/automaxprocs" _ "github.com/syncthing/syncthing/lib/automaxprocs"
"github.com/syncthing/syncthing/lib/build" "github.com/syncthing/syncthing/lib/build"
@ -58,52 +58,52 @@ const (
var debug = false var debug = false
type CLI struct {
Cert string `group:"Listen" help:"Certificate file" default:"./cert.pem" env:"DISCOVERY_CERT_FILE"`
Key string `group:"Listen" help:"Key file" default:"./key.pem" env:"DISCOVERY_KEY_FILE"`
HTTP bool `group:"Listen" help:"Listen on HTTP (behind an HTTPS proxy)" env:"DISCOVERY_HTTP"`
Compression bool `group:"Listen" help:"Enable GZIP compression of responses" env:"DISCOVERY_COMPRESSION"`
Listen string `group:"Listen" help:"Listen address" default:":8443" env:"DISCOVERY_LISTEN"`
MetricsListen string `group:"Listen" help:"Metrics listen address" env:"DISCOVERY_METRICS_LISTEN"`
Replicate []string `group:"Legacy replication" help:"Replication peers, id@address, comma separated" env:"DISCOVERY_REPLICATE"`
ReplicationListen string `group:"Legacy replication" help:"Replication listen address" default:":19200" env:"DISCOVERY_REPLICATION_LISTEN"`
ReplicationCert string `group:"Legacy replication" help:"Certificate file for replication" env:"DISCOVERY_REPLICATION_CERT_FILE"`
ReplicationKey string `group:"Legacy replication" help:"Key file for replication" env:"DISCOVERY_REPLICATION_KEY_FILE"`
AMQPAddress string `group:"AMQP replication" help:"Address to AMQP broker" env:"DISCOVERY_AMQP_ADDRESS"`
DBDir string `group:"Database" help:"Database directory" default:"." env:"DISCOVERY_DB_DIR"`
DBFlushInterval time.Duration `group:"Database" help:"Interval between database flushes" default:"5m" env:"DISCOVERY_DB_FLUSH_INTERVAL"`
DBS3Endpoint string `name:"db-s3-endpoint" group:"Database (S3 backup)" help:"S3 endpoint for database" env:"DISCOVERY_DB_S3_ENDPOINT"`
DBS3Region string `name:"db-s3-region" group:"Database (S3 backup)" help:"S3 region for database" env:"DISCOVERY_DB_S3_REGION"`
DBS3Bucket string `name:"db-s3-bucket" group:"Database (S3 backup)" help:"S3 bucket for database" env:"DISCOVERY_DB_S3_BUCKET"`
DBS3AccessKeyID string `name:"db-s3-access-key-id" group:"Database (S3 backup)" help:"S3 access key ID for database" env:"DISCOVERY_DB_S3_ACCESS_KEY_ID"`
DBS3SecretKey string `name:"db-s3-secret-key" group:"Database (S3 backup)" help:"S3 secret key for database" env:"DISCOVERY_DB_S3_SECRET_KEY"`
Debug bool `short:"d" help:"Print debug output" env:"DISCOVERY_DEBUG"`
Version bool `short:"v" help:"Print version and exit"`
}
func main() { func main() {
var listen string
var dir string
var metricsListen string
var replicationListen string
var replicationPeers string
var certFile string
var keyFile string
var replCertFile string
var replKeyFile string
var useHTTP bool
var compression bool
var amqpAddress string
var flushInterval time.Duration
log.SetOutput(os.Stdout) log.SetOutput(os.Stdout)
// log.SetFlags(0)
flag.StringVar(&certFile, "cert", "./cert.pem", "Certificate file") var cli CLI
flag.StringVar(&keyFile, "key", "./key.pem", "Key file") kong.Parse(&cli)
flag.StringVar(&dir, "db-dir", ".", "Database directory") debug = cli.Debug
flag.BoolVar(&debug, "debug", false, "Print debug output")
flag.BoolVar(&useHTTP, "http", false, "Listen on HTTP (behind an HTTPS proxy)")
flag.BoolVar(&compression, "compression", true, "Enable GZIP compression of responses")
flag.StringVar(&listen, "listen", ":8443", "Listen address")
flag.StringVar(&metricsListen, "metrics-listen", "", "Metrics listen address")
flag.StringVar(&replicationPeers, "replicate", "", "Replication peers, id@address, comma separated")
flag.StringVar(&replicationListen, "replication-listen", ":19200", "Replication listen address")
flag.StringVar(&replCertFile, "replication-cert", "", "Certificate file for replication")
flag.StringVar(&replKeyFile, "replication-key", "", "Key file for replication")
flag.StringVar(&amqpAddress, "amqp-address", "", "Address to AMQP broker")
flag.DurationVar(&flushInterval, "flush-interval", 5*time.Minute, "Interval between database flushes")
showVersion := flag.Bool("version", false, "Show version")
flag.Parse()
log.Println(build.LongVersionFor("stdiscosrv")) log.Println(build.LongVersionFor("stdiscosrv"))
if *showVersion { if cli.Version {
return return
} }
buildInfo.WithLabelValues(build.Version, runtime.Version(), build.User, build.Date.UTC().Format("2006-01-02T15:04:05Z")).Set(1) buildInfo.WithLabelValues(build.Version, runtime.Version(), build.User, build.Date.UTC().Format("2006-01-02T15:04:05Z")).Set(1)
cert, err := tls.LoadX509KeyPair(certFile, keyFile) cert, err := tls.LoadX509KeyPair(cli.Cert, cli.Key)
if os.IsNotExist(err) { if os.IsNotExist(err) {
log.Println("Failed to load keypair. Generating one, this might take a while...") log.Println("Failed to load keypair. Generating one, this might take a while...")
cert, err = tlsutil.NewCertificate(certFile, keyFile, "stdiscosrv", 20*365) cert, err = tlsutil.NewCertificate(cli.Cert, cli.Key, "stdiscosrv", 20*365)
if err != nil { if err != nil {
log.Fatalln("Failed to generate X509 key pair:", err) log.Fatalln("Failed to generate X509 key pair:", err)
} }
@ -114,8 +114,8 @@ func main() {
log.Println("Server device ID is", devID) log.Println("Server device ID is", devID)
replCert := cert replCert := cert
if replCertFile != "" && replKeyFile != "" { if cli.ReplicationCert != "" && cli.ReplicationKey != "" {
replCert, err = tls.LoadX509KeyPair(replCertFile, replKeyFile) replCert, err = tls.LoadX509KeyPair(cli.ReplicationCert, cli.ReplicationKey)
if err != nil { if err != nil {
log.Fatalln("Failed to load replication keypair:", err) log.Fatalln("Failed to load replication keypair:", err)
} }
@ -126,8 +126,7 @@ func main() {
// Parse the replication specs, if any. // Parse the replication specs, if any.
var allowedReplicationPeers []protocol.DeviceID var allowedReplicationPeers []protocol.DeviceID
var replicationDestinations []string var replicationDestinations []string
parts := strings.Split(replicationPeers, ",") for _, part := range cli.Replicate {
for _, part := range parts {
if part == "" { if part == "" {
continue continue
} }
@ -165,10 +164,22 @@ func main() {
// Root of the service tree. // Root of the service tree.
main := suture.New("main", suture.Spec{ main := suture.New("main", suture.Spec{
PassThroughPanics: true, PassThroughPanics: true,
Timeout: 2 * time.Minute,
}) })
// If configured, use S3 for database backups.
var s3c *s3Copier
if cli.DBS3Endpoint != "" {
hostname, err := os.Hostname()
if err != nil {
log.Fatalf("Failed to get hostname: %v", err)
}
key := hostname + ".db"
s3c = newS3Copier(cli.DBS3Endpoint, cli.DBS3Region, cli.DBS3Bucket, key, cli.DBS3AccessKeyID, cli.DBS3SecretKey)
}
// Start the database. // Start the database.
db := newInMemoryStore(dir, flushInterval) db := newInMemoryStore(cli.DBDir, cli.DBFlushInterval, s3c)
main.Add(db) main.Add(db)
// Start any replication senders. // Start any replication senders.
@ -181,28 +192,28 @@ func main() {
// If we have replication configured, start the replication listener. // If we have replication configured, start the replication listener.
if len(allowedReplicationPeers) > 0 { if len(allowedReplicationPeers) > 0 {
rl := newReplicationListener(replicationListen, replCert, allowedReplicationPeers, db) rl := newReplicationListener(cli.ReplicationListen, replCert, allowedReplicationPeers, db)
main.Add(rl) main.Add(rl)
} }
// If we have an AMQP broker, start that // If we have an AMQP broker, start that
if amqpAddress != "" { if cli.AMQPAddress != "" {
clientID := rand.String(10) clientID := rand.String(10)
kr := newAMQPReplicator(amqpAddress, clientID, db) kr := newAMQPReplicator(cli.AMQPAddress, clientID, db)
repl = append(repl, kr) repl = append(repl, kr)
main.Add(kr) main.Add(kr)
} }
// Start the main API server. // Start the main API server.
qs := newAPISrv(listen, cert, db, repl, useHTTP, compression) qs := newAPISrv(cli.Listen, cert, db, repl, cli.HTTP, cli.Compression)
main.Add(qs) main.Add(qs)
// If we have a metrics port configured, start a metrics handler. // If we have a metrics port configured, start a metrics handler.
if metricsListen != "" { if cli.MetricsListen != "" {
go func() { go func() {
mux := http.NewServeMux() mux := http.NewServeMux()
mux.Handle("/metrics", promhttp.Handler()) mux.Handle("/metrics", promhttp.Handler())
log.Fatal(http.ListenAndServe(metricsListen, mux)) log.Fatal(http.ListenAndServe(cli.MetricsListen, mux))
}() }()
} }

97
cmd/stdiscosrv/s3.go Normal file
View File

@ -0,0 +1,97 @@
// Copyright (C) 2024 The Syncthing Authors.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at https://mozilla.org/MPL/2.0/.
package main
import (
"io"
"log"
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/s3"
"github.com/aws/aws-sdk-go/service/s3/s3manager"
)
type s3Copier struct {
endpoint string
region string
bucket string
key string
accessKeyID string
secretKey string
}
func newS3Copier(endpoint, region, bucket, key, accessKeyID, secretKey string) *s3Copier {
return &s3Copier{
endpoint: endpoint,
region: region,
bucket: bucket,
key: key,
accessKeyID: accessKeyID,
secretKey: secretKey,
}
}
func (s *s3Copier) upload(r io.Reader) error {
sess, err := session.NewSession(&aws.Config{
Region: aws.String(s.region),
Endpoint: aws.String(s.endpoint),
Credentials: credentials.NewStaticCredentials(s.accessKeyID, s.secretKey, ""),
})
if err != nil {
return err
}
uploader := s3manager.NewUploader(sess)
_, err = uploader.Upload(&s3manager.UploadInput{
Bucket: aws.String(s.bucket),
Key: aws.String(s.key),
Body: r,
})
return err
}
func (s *s3Copier) downloadLatest(w io.WriterAt) error {
sess, err := session.NewSession(&aws.Config{
Region: aws.String(s.region),
Endpoint: aws.String(s.endpoint),
Credentials: credentials.NewStaticCredentials(s.accessKeyID, s.secretKey, ""),
})
if err != nil {
return err
}
svc := s3.New(sess)
resp, err := svc.ListObjectsV2(&s3.ListObjectsV2Input{Bucket: aws.String(s.bucket)})
if err != nil {
return err
}
var lastKey string
var lastModified time.Time
var lastSize int64
for _, item := range resp.Contents {
if item.LastModified.After(lastModified) && *item.Size > lastSize {
lastKey = *item.Key
lastModified = *item.LastModified
lastSize = *item.Size
} else if lastModified.Sub(*item.LastModified) < 5*time.Minute && *item.Size > lastSize {
lastKey = *item.Key
lastSize = *item.Size
}
}
log.Println("Downloading database from", lastKey)
downloader := s3manager.NewDownloader(sess)
_, err = downloader.Download(w, &s3.GetObjectInput{
Bucket: aws.String(s.bucket),
Key: aws.String(lastKey),
})
return err
}