syncthing/cmd/stdiscosrv/database.go
2024-09-30 14:16:27 -05:00

441 lines
10 KiB
Go

// Copyright (C) 2018 The Syncthing Authors.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at https://mozilla.org/MPL/2.0/.
//go:generate go run ../../proto/scripts/protofmt.go database.proto
//go:generate protoc -I ../../ -I . --gogofast_out=. database.proto
package main
import (
"bufio"
"cmp"
"context"
"encoding/binary"
"errors"
"io"
"log"
"os"
"path"
"runtime"
"slices"
"strings"
"time"
"github.com/puzpuzpuz/xsync/v3"
"github.com/syncthing/syncthing/lib/protocol"
"github.com/syncthing/syncthing/lib/rand"
"github.com/syncthing/syncthing/lib/s3"
)
type clock interface {
Now() time.Time
}
type defaultClock struct{}
func (defaultClock) Now() time.Time {
return time.Now()
}
type database interface {
put(key *protocol.DeviceID, rec DatabaseRecord) error
merge(key *protocol.DeviceID, addrs []DatabaseAddress, seen int64) error
get(key *protocol.DeviceID) (DatabaseRecord, error)
}
type inMemoryStore struct {
m *xsync.MapOf[protocol.DeviceID, DatabaseRecord]
dir string
flushInterval time.Duration
s3 *s3.Session
objKey string
clock clock
}
func newInMemoryStore(dir string, flushInterval time.Duration, s3sess *s3.Session) *inMemoryStore {
hn, err := os.Hostname()
if err != nil {
hn = rand.String(8)
}
s := &inMemoryStore{
m: xsync.NewMapOf[protocol.DeviceID, DatabaseRecord](),
dir: dir,
flushInterval: flushInterval,
s3: s3sess,
objKey: hn + ".db",
clock: defaultClock{},
}
nr, err := s.read()
if os.IsNotExist(err) && s3sess != nil {
// Try to read from AWS
latestKey, cerr := s3sess.LatestKey()
if cerr != nil {
log.Println("Error reading database from S3:", err)
return s
}
fd, cerr := os.Create(path.Join(s.dir, "records.db"))
if cerr != nil {
log.Println("Error creating database file:", err)
return s
}
if cerr := s3sess.Download(fd, latestKey); cerr != nil {
log.Printf("Error reading database from S3: %v", err)
}
_ = fd.Close()
nr, err = s.read()
}
if err != nil {
log.Println("Error reading database:", err)
}
log.Printf("Read %d records from database", nr)
s.expireAndCalculateStatistics()
return s
}
func (s *inMemoryStore) put(key *protocol.DeviceID, rec DatabaseRecord) error {
t0 := time.Now()
s.m.Store(*key, rec)
databaseOperations.WithLabelValues(dbOpPut, dbResSuccess).Inc()
databaseOperationSeconds.WithLabelValues(dbOpPut).Observe(time.Since(t0).Seconds())
return nil
}
func (s *inMemoryStore) merge(key *protocol.DeviceID, addrs []DatabaseAddress, seen int64) error {
t0 := time.Now()
newRec := DatabaseRecord{
Addresses: addrs,
Seen: seen,
}
oldRec, _ := s.m.Load(*key)
newRec = merge(oldRec, newRec)
s.m.Store(*key, newRec)
databaseOperations.WithLabelValues(dbOpMerge, dbResSuccess).Inc()
databaseOperationSeconds.WithLabelValues(dbOpMerge).Observe(time.Since(t0).Seconds())
return nil
}
func (s *inMemoryStore) get(key *protocol.DeviceID) (DatabaseRecord, error) {
t0 := time.Now()
defer func() {
databaseOperationSeconds.WithLabelValues(dbOpGet).Observe(time.Since(t0).Seconds())
}()
rec, ok := s.m.Load(*key)
if !ok {
databaseOperations.WithLabelValues(dbOpGet, dbResNotFound).Inc()
return DatabaseRecord{}, nil
}
rec.Addresses = expire(rec.Addresses, s.clock.Now())
databaseOperations.WithLabelValues(dbOpGet, dbResSuccess).Inc()
return rec, nil
}
func (s *inMemoryStore) Serve(ctx context.Context) error {
if s.flushInterval <= 0 {
<-ctx.Done()
return nil
}
t := time.NewTimer(s.flushInterval)
defer t.Stop()
loop:
for {
select {
case <-t.C:
log.Println("Calculating statistics")
s.expireAndCalculateStatistics()
log.Println("Flushing database")
if err := s.write(); err != nil {
log.Println("Error writing database:", err)
}
log.Println("Finished flushing database")
t.Reset(s.flushInterval)
case <-ctx.Done():
// We're done.
break loop
}
}
return s.write()
}
func (s *inMemoryStore) expireAndCalculateStatistics() {
now := s.clock.Now()
cutoff24h := now.Add(-24 * time.Hour).UnixNano()
cutoff1w := now.Add(-7 * 24 * time.Hour).UnixNano()
current, currentIPv4, currentIPv6, currentIPv6GUA, last24h, last1w := 0, 0, 0, 0, 0, 0
n := 0
s.m.Range(func(key protocol.DeviceID, rec DatabaseRecord) bool {
if n%1000 == 0 {
runtime.Gosched()
}
n++
addresses := expire(rec.Addresses, now)
if len(addresses) == 0 {
rec.Addresses = nil
s.m.Store(key, rec)
} else if len(addresses) != len(rec.Addresses) {
rec.Addresses = addresses
s.m.Store(key, rec)
}
switch {
case len(rec.Addresses) > 0:
current++
seenIPv4, seenIPv6, seenIPv6GUA := false, false, false
for _, addr := range rec.Addresses {
// We do fast and loose matching on strings here instead of
// parsing the address and the IP and doing "proper" checks,
// to keep things fast and generate less garbage.
if strings.Contains(addr.Address, "[") {
seenIPv6 = true
if strings.Contains(addr.Address, "[2") {
seenIPv6GUA = true
}
} else {
seenIPv4 = true
}
if seenIPv4 && seenIPv6 && seenIPv6GUA {
break
}
}
if seenIPv4 {
currentIPv4++
}
if seenIPv6 {
currentIPv6++
}
if seenIPv6GUA {
currentIPv6GUA++
}
case rec.Seen > cutoff24h:
last24h++
case rec.Seen > cutoff1w:
last1w++
default:
// drop the record if it's older than a week
s.m.Delete(key)
}
return true
})
databaseKeys.WithLabelValues("current").Set(float64(current))
databaseKeys.WithLabelValues("currentIPv4").Set(float64(currentIPv4))
databaseKeys.WithLabelValues("currentIPv6").Set(float64(currentIPv6))
databaseKeys.WithLabelValues("currentIPv6GUA").Set(float64(currentIPv6GUA))
databaseKeys.WithLabelValues("last24h").Set(float64(last24h))
databaseKeys.WithLabelValues("last1w").Set(float64(last1w))
databaseStatisticsSeconds.Set(time.Since(now).Seconds())
}
func (s *inMemoryStore) write() (err error) {
t0 := time.Now()
defer func() {
if err == nil {
databaseWriteSeconds.Set(time.Since(t0).Seconds())
databaseLastWritten.Set(float64(t0.Unix()))
}
}()
dbf := path.Join(s.dir, "records.db")
fd, err := os.Create(dbf + ".tmp")
if err != nil {
return err
}
bw := bufio.NewWriter(fd)
var buf []byte
var rangeErr error
now := s.clock.Now()
cutoff1w := now.Add(-7 * 24 * time.Hour).UnixNano()
n := 0
s.m.Range(func(key protocol.DeviceID, value DatabaseRecord) bool {
if n%1000 == 0 {
runtime.Gosched()
}
n++
if value.Seen < cutoff1w {
// drop the record if it's older than a week
return true
}
rec := ReplicationRecord{
Key: key[:],
Addresses: value.Addresses,
Seen: value.Seen,
}
s := rec.Size()
if s+4 > len(buf) {
buf = make([]byte, s+4)
}
n, err := rec.MarshalTo(buf[4:])
if err != nil {
rangeErr = err
return false
}
binary.BigEndian.PutUint32(buf, uint32(n))
if _, err := bw.Write(buf[:n+4]); err != nil {
rangeErr = err
return false
}
return true
})
if rangeErr != nil {
_ = fd.Close()
return rangeErr
}
if err := bw.Flush(); err != nil {
_ = fd.Close
return err
}
if err := fd.Close(); err != nil {
return err
}
if err := os.Rename(dbf+".tmp", dbf); err != nil {
return err
}
// Upload to S3
if s.s3 != nil {
fd, err = os.Open(dbf)
if err != nil {
log.Printf("Error uploading database to S3: %v", err)
return nil
}
defer fd.Close()
if err := s.s3.Upload(fd, s.objKey); err != nil {
log.Printf("Error uploading database to S3: %v", err)
}
log.Println("Finished uploading database")
}
return nil
}
func (s *inMemoryStore) read() (int, error) {
fd, err := os.Open(path.Join(s.dir, "records.db"))
if err != nil {
return 0, err
}
defer fd.Close()
br := bufio.NewReader(fd)
var buf []byte
nr := 0
for {
var n uint32
if err := binary.Read(br, binary.BigEndian, &n); err != nil {
if errors.Is(err, io.EOF) {
break
}
return nr, err
}
if int(n) > len(buf) {
buf = make([]byte, n)
}
if _, err := io.ReadFull(br, buf[:n]); err != nil {
return nr, err
}
rec := ReplicationRecord{}
if err := rec.Unmarshal(buf[:n]); err != nil {
return nr, err
}
key, err := protocol.DeviceIDFromBytes(rec.Key)
if err != nil {
key, err = protocol.DeviceIDFromString(string(rec.Key))
}
if err != nil {
log.Println("Bad device ID:", err)
continue
}
slices.SortFunc(rec.Addresses, DatabaseAddress.Cmp)
rec.Addresses = slices.CompactFunc(rec.Addresses, DatabaseAddress.Equal)
s.m.Store(key, DatabaseRecord{
Addresses: expire(rec.Addresses, s.clock.Now()),
Seen: rec.Seen,
})
nr++
}
return nr, nil
}
// merge returns the merged result of the two database records a and b. The
// result is the union of the two address sets, with the newer expiry time
// chosen for any duplicates. The address list in a is overwritten and
// reused for the result.
func merge(a, b DatabaseRecord) DatabaseRecord {
// Both lists must be sorted for this to work.
a.Seen = max(a.Seen, b.Seen)
aIdx := 0
bIdx := 0
for aIdx < len(a.Addresses) && bIdx < len(b.Addresses) {
switch cmp.Compare(a.Addresses[aIdx].Address, b.Addresses[bIdx].Address) {
case 0:
// a == b, choose the newer expiry time
a.Addresses[aIdx].Expires = max(a.Addresses[aIdx].Expires, b.Addresses[bIdx].Expires)
aIdx++
bIdx++
case -1:
// a < b, keep a and move on
aIdx++
case 1:
// a > b, insert b before a
a.Addresses = append(a.Addresses[:aIdx], append([]DatabaseAddress{b.Addresses[bIdx]}, a.Addresses[aIdx:]...)...)
bIdx++
}
}
if bIdx < len(b.Addresses) {
a.Addresses = append(a.Addresses, b.Addresses[bIdx:]...)
}
return a
}
// expire returns the list of addresses after removing expired entries.
// Expiration happen in place, so the slice given as the parameter is
// destroyed. Internal order is preserved.
func expire(addrs []DatabaseAddress, now time.Time) []DatabaseAddress {
cutoff := now.UnixNano()
naddrs := addrs[:0]
for i := range addrs {
if i > 0 && addrs[i].Address == addrs[i-1].Address {
// Skip duplicates
continue
}
if addrs[i].Expires >= cutoff {
naddrs = append(naddrs, addrs[i])
}
}
if len(naddrs) == 0 {
return nil
}
return naddrs
}
func (d DatabaseAddress) Cmp(other DatabaseAddress) (n int) {
if c := cmp.Compare(d.Address, other.Address); c != 0 {
return c
}
return cmp.Compare(d.Expires, other.Expires)
}
func (d DatabaseAddress) Equal(other DatabaseAddress) bool {
return d.Address == other.Address
}