Merge branch 'infrastructure'

* infrastructure:
  chore(stdiscosrv): ensure incoming addresses are sorted and unique
  chore(stdiscosrv): use zero-allocation merge in the common case
  chore(stdiscosrv): properly clean out old addresses from memory
  chore(stdiscosrv): calculate IPv6 GUA
This commit is contained in:
Jakob Borg 2024-09-16 09:33:15 +02:00
commit 0343bca257
No known key found for this signature in database
3 changed files with 148 additions and 67 deletions

View File

@ -307,16 +307,17 @@ func (s *apiSrv) handleAnnounce(deviceID protocol.DeviceID, addresses []string)
now := time.Now() now := time.Now()
expire := now.Add(addressExpiryTime).UnixNano() expire := now.Add(addressExpiryTime).UnixNano()
// The address slice must always be sorted for database merges to work
// properly.
slices.Sort(addresses)
addresses = slices.Compact(addresses)
dbAddrs := make([]DatabaseAddress, len(addresses)) dbAddrs := make([]DatabaseAddress, len(addresses))
for i := range addresses { for i := range addresses {
dbAddrs[i].Address = addresses[i] dbAddrs[i].Address = addresses[i]
dbAddrs[i].Expires = expire dbAddrs[i].Expires = expire
} }
// The address slice must always be sorted for database merges to work
// properly.
slices.SortFunc(dbAddrs, DatabaseAddress.Cmp)
seen := now.UnixNano() seen := now.UnixNano()
if s.repl != nil { if s.repl != nil {
s.repl.send(&deviceID, dbAddrs, seen) s.repl.send(&deviceID, dbAddrs, seen)

View File

@ -78,7 +78,7 @@ func newInMemoryStore(dir string, flushInterval time.Duration, s3 *s3Copier) *in
log.Println("Error reading database:", err) log.Println("Error reading database:", err)
} }
log.Printf("Read %d records from database", nr) log.Printf("Read %d records from database", nr)
s.calculateStatistics() s.expireAndCalculateStatistics()
return s return s
} }
@ -99,7 +99,7 @@ func (s *inMemoryStore) merge(key *protocol.DeviceID, addrs []DatabaseAddress, s
} }
oldRec, _ := s.m.Load(*key) oldRec, _ := s.m.Load(*key)
newRec = merge(newRec, oldRec) newRec = merge(oldRec, newRec)
s.m.Store(*key, newRec) s.m.Store(*key, newRec)
databaseOperations.WithLabelValues(dbOpMerge, dbResSuccess).Inc() databaseOperations.WithLabelValues(dbOpMerge, dbResSuccess).Inc()
@ -126,19 +126,20 @@ func (s *inMemoryStore) get(key *protocol.DeviceID) (DatabaseRecord, error) {
} }
func (s *inMemoryStore) Serve(ctx context.Context) error { func (s *inMemoryStore) Serve(ctx context.Context) error {
if s.flushInterval <= 0 {
<-ctx.Done()
return nil
}
t := time.NewTimer(s.flushInterval) t := time.NewTimer(s.flushInterval)
defer t.Stop() defer t.Stop()
if s.flushInterval <= 0 {
t.Stop()
}
loop: loop:
for { for {
select { select {
case <-t.C: case <-t.C:
log.Println("Calculating statistics") log.Println("Calculating statistics")
s.calculateStatistics() s.expireAndCalculateStatistics()
log.Println("Flushing database") log.Println("Flushing database")
if err := s.write(); err != nil { if err := s.write(); err != nil {
log.Println("Error writing database:", err) log.Println("Error writing database:", err)
@ -155,11 +156,11 @@ loop:
return s.write() return s.write()
} }
func (s *inMemoryStore) calculateStatistics() { func (s *inMemoryStore) expireAndCalculateStatistics() {
now := s.clock.Now() now := s.clock.Now()
cutoff24h := now.Add(-24 * time.Hour).UnixNano() cutoff24h := now.Add(-24 * time.Hour).UnixNano()
cutoff1w := now.Add(-7 * 24 * time.Hour).UnixNano() cutoff1w := now.Add(-7 * 24 * time.Hour).UnixNano()
current, currentIPv4, currentIPv6, last24h, last1w := 0, 0, 0, 0, 0 current, currentIPv4, currentIPv6, currentIPv6GUA, last24h, last1w := 0, 0, 0, 0, 0, 0
n := 0 n := 0
s.m.Range(func(key protocol.DeviceID, rec DatabaseRecord) bool { s.m.Range(func(key protocol.DeviceID, rec DatabaseRecord) bool {
@ -169,17 +170,31 @@ func (s *inMemoryStore) calculateStatistics() {
n++ n++
addresses := expire(rec.Addresses, now) addresses := expire(rec.Addresses, now)
if len(addresses) == 0 {
rec.Addresses = nil
s.m.Store(key, rec)
} else if len(addresses) != len(rec.Addresses) {
rec.Addresses = addresses
s.m.Store(key, rec)
}
switch { switch {
case len(addresses) > 0: case len(rec.Addresses) > 0:
current++ current++
seenIPv4, seenIPv6 := false, false seenIPv4, seenIPv6, seenIPv6GUA := false, false, false
for _, addr := range rec.Addresses { for _, addr := range rec.Addresses {
// We do fast and loose matching on strings here instead of
// parsing the address and the IP and doing "proper" checks,
// to keep things fast and generate less garbage.
if strings.Contains(addr.Address, "[") { if strings.Contains(addr.Address, "[") {
seenIPv6 = true seenIPv6 = true
if strings.Contains(addr.Address, "[2") {
seenIPv6GUA = true
}
} else { } else {
seenIPv4 = true seenIPv4 = true
} }
if seenIPv4 && seenIPv6 { if seenIPv4 && seenIPv6 && seenIPv6GUA {
break break
} }
} }
@ -189,6 +204,9 @@ func (s *inMemoryStore) calculateStatistics() {
if seenIPv6 { if seenIPv6 {
currentIPv6++ currentIPv6++
} }
if seenIPv6GUA {
currentIPv6GUA++
}
case rec.Seen > cutoff24h: case rec.Seen > cutoff24h:
last24h++ last24h++
case rec.Seen > cutoff1w: case rec.Seen > cutoff1w:
@ -203,6 +221,7 @@ func (s *inMemoryStore) calculateStatistics() {
databaseKeys.WithLabelValues("current").Set(float64(current)) databaseKeys.WithLabelValues("current").Set(float64(current))
databaseKeys.WithLabelValues("currentIPv4").Set(float64(currentIPv4)) databaseKeys.WithLabelValues("currentIPv4").Set(float64(currentIPv4))
databaseKeys.WithLabelValues("currentIPv6").Set(float64(currentIPv6)) databaseKeys.WithLabelValues("currentIPv6").Set(float64(currentIPv6))
databaseKeys.WithLabelValues("currentIPv6GUA").Set(float64(currentIPv6GUA))
databaseKeys.WithLabelValues("last24h").Set(float64(last24h)) databaseKeys.WithLabelValues("last24h").Set(float64(last24h))
databaseKeys.WithLabelValues("last1w").Set(float64(last1w)) databaseKeys.WithLabelValues("last1w").Set(float64(last1w))
databaseStatisticsSeconds.Set(time.Since(now).Seconds()) databaseStatisticsSeconds.Set(time.Since(now).Seconds())
@ -331,6 +350,7 @@ func (s *inMemoryStore) read() (int, error) {
} }
slices.SortFunc(rec.Addresses, DatabaseAddress.Cmp) slices.SortFunc(rec.Addresses, DatabaseAddress.Cmp)
rec.Addresses = slices.CompactFunc(rec.Addresses, DatabaseAddress.Equal)
s.m.Store(key, DatabaseRecord{ s.m.Store(key, DatabaseRecord{
Addresses: expire(rec.Addresses, s.clock.Now()), Addresses: expire(rec.Addresses, s.clock.Now()),
Seen: rec.Seen, Seen: rec.Seen,
@ -342,69 +362,36 @@ func (s *inMemoryStore) read() (int, error) {
// merge returns the merged result of the two database records a and b. The // merge returns the merged result of the two database records a and b. The
// result is the union of the two address sets, with the newer expiry time // result is the union of the two address sets, with the newer expiry time
// chosen for any duplicates. // chosen for any duplicates. The address list in a is overwritten and
// reused for the result.
func merge(a, b DatabaseRecord) DatabaseRecord { func merge(a, b DatabaseRecord) DatabaseRecord {
// Both lists must be sorted for this to work. // Both lists must be sorted for this to work.
res := DatabaseRecord{ a.Seen = max(a.Seen, b.Seen)
Addresses: make([]DatabaseAddress, 0, max(len(a.Addresses), len(b.Addresses))),
Seen: a.Seen,
}
if b.Seen > a.Seen {
res.Seen = b.Seen
}
aIdx := 0 aIdx := 0
bIdx := 0 bIdx := 0
aAddrs := a.Addresses for aIdx < len(a.Addresses) && bIdx < len(b.Addresses) {
bAddrs := b.Addresses switch cmp.Compare(a.Addresses[aIdx].Address, b.Addresses[bIdx].Address) {
loop: case 0:
for { // a == b, choose the newer expiry time
switch { a.Addresses[aIdx].Expires = max(a.Addresses[aIdx].Expires, b.Addresses[bIdx].Expires)
case aIdx == len(aAddrs) && bIdx == len(bAddrs):
// both lists are exhausted, we are done
break loop
case aIdx == len(aAddrs):
// a is exhausted, pick from b and continue
res.Addresses = append(res.Addresses, bAddrs[bIdx])
bIdx++
continue
case bIdx == len(bAddrs):
// b is exhausted, pick from a and continue
res.Addresses = append(res.Addresses, aAddrs[aIdx])
aIdx++
continue
}
// We have values left on both sides.
aVal := aAddrs[aIdx]
bVal := bAddrs[bIdx]
switch {
case aVal.Address == bVal.Address:
// update for same address, pick newer
if aVal.Expires > bVal.Expires {
res.Addresses = append(res.Addresses, aVal)
} else {
res.Addresses = append(res.Addresses, bVal)
}
aIdx++ aIdx++
bIdx++ bIdx++
case -1:
case aVal.Address < bVal.Address: // a < b, keep a and move on
// a is smallest, pick it and continue
res.Addresses = append(res.Addresses, aVal)
aIdx++ aIdx++
case 1:
default: // a > b, insert b before a
// b is smallest, pick it and continue a.Addresses = append(a.Addresses[:aIdx], append([]DatabaseAddress{b.Addresses[bIdx]}, a.Addresses[aIdx:]...)...)
res.Addresses = append(res.Addresses, bVal)
bIdx++ bIdx++
} }
} }
return res if bIdx < len(b.Addresses) {
a.Addresses = append(a.Addresses, b.Addresses[bIdx:]...)
}
return a
} }
// expire returns the list of addresses after removing expired entries. // expire returns the list of addresses after removing expired entries.
@ -414,10 +401,17 @@ func expire(addrs []DatabaseAddress, now time.Time) []DatabaseAddress {
cutoff := now.UnixNano() cutoff := now.UnixNano()
naddrs := addrs[:0] naddrs := addrs[:0]
for i := range addrs { for i := range addrs {
if i > 0 && addrs[i].Address == addrs[i-1].Address {
// Skip duplicates
continue
}
if addrs[i].Expires >= cutoff { if addrs[i].Expires >= cutoff {
naddrs = append(naddrs, addrs[i]) naddrs = append(naddrs, addrs[i])
} }
} }
if len(naddrs) == 0 {
return nil
}
return naddrs return naddrs
} }
@ -427,3 +421,7 @@ func (d DatabaseAddress) Cmp(other DatabaseAddress) (n int) {
} }
return cmp.Compare(d.Expires, other.Expires) return cmp.Compare(d.Expires, other.Expires)
} }
func (d DatabaseAddress) Equal(other DatabaseAddress) bool {
return d.Address == other.Address
}

View File

@ -167,6 +167,88 @@ func TestFilter(t *testing.T) {
} }
} }
func TestMerge(t *testing.T) {
cases := []struct {
a, b, res []DatabaseAddress
}{
{nil, nil, nil},
{
nil,
[]DatabaseAddress{{Address: "a", Expires: 10}},
[]DatabaseAddress{{Address: "a", Expires: 10}},
},
{
nil,
[]DatabaseAddress{{Address: "a", Expires: 10}, {Address: "b", Expires: 10}, {Address: "c", Expires: 10}},
[]DatabaseAddress{{Address: "a", Expires: 10}, {Address: "b", Expires: 10}, {Address: "c", Expires: 10}},
},
{
[]DatabaseAddress{{Address: "a", Expires: 10}},
[]DatabaseAddress{{Address: "a", Expires: 15}},
[]DatabaseAddress{{Address: "a", Expires: 15}},
},
{
[]DatabaseAddress{{Address: "a", Expires: 10}},
[]DatabaseAddress{{Address: "b", Expires: 15}},
[]DatabaseAddress{{Address: "a", Expires: 10}, {Address: "b", Expires: 15}},
},
{
[]DatabaseAddress{{Address: "a", Expires: 10}, {Address: "b", Expires: 15}},
[]DatabaseAddress{{Address: "a", Expires: 15}, {Address: "b", Expires: 15}},
[]DatabaseAddress{{Address: "a", Expires: 15}, {Address: "b", Expires: 15}},
},
{
[]DatabaseAddress{{Address: "a", Expires: 10}, {Address: "b", Expires: 15}},
[]DatabaseAddress{{Address: "b", Expires: 15}, {Address: "c", Expires: 20}},
[]DatabaseAddress{{Address: "a", Expires: 10}, {Address: "b", Expires: 15}, {Address: "c", Expires: 20}},
},
{
[]DatabaseAddress{{Address: "a", Expires: 10}, {Address: "b", Expires: 15}},
[]DatabaseAddress{{Address: "b", Expires: 5}, {Address: "c", Expires: 20}},
[]DatabaseAddress{{Address: "a", Expires: 10}, {Address: "b", Expires: 15}, {Address: "c", Expires: 20}},
},
{
[]DatabaseAddress{{Address: "y", Expires: 10}, {Address: "z", Expires: 10}},
[]DatabaseAddress{{Address: "a", Expires: 5}, {Address: "b", Expires: 15}},
[]DatabaseAddress{{Address: "a", Expires: 5}, {Address: "b", Expires: 15}, {Address: "y", Expires: 10}, {Address: "z", Expires: 10}},
},
{
[]DatabaseAddress{{Address: "a", Expires: 10}, {Address: "b", Expires: 15}, {Address: "d", Expires: 10}},
[]DatabaseAddress{{Address: "b", Expires: 5}, {Address: "c", Expires: 20}},
[]DatabaseAddress{{Address: "a", Expires: 10}, {Address: "b", Expires: 15}, {Address: "c", Expires: 20}, {Address: "d", Expires: 10}},
},
}
for _, tc := range cases {
rec := merge(DatabaseRecord{Addresses: tc.a}, DatabaseRecord{Addresses: tc.b})
if fmt.Sprint(rec.Addresses) != fmt.Sprint(tc.res) {
t.Errorf("Incorrect result %v, expected %v", rec.Addresses, tc.res)
}
rec = merge(DatabaseRecord{Addresses: tc.b}, DatabaseRecord{Addresses: tc.a})
if fmt.Sprint(rec.Addresses) != fmt.Sprint(tc.res) {
t.Errorf("Incorrect result %v, expected %v", rec.Addresses, tc.res)
}
}
}
func BenchmarkMergeEqual(b *testing.B) {
for i := 0; i < b.N; i++ {
ar := []DatabaseAddress{{Address: "a", Expires: 10}, {Address: "b", Expires: 15}}
br := []DatabaseAddress{{Address: "a", Expires: 15}, {Address: "b", Expires: 10}}
res := merge(DatabaseRecord{Addresses: ar}, DatabaseRecord{Addresses: br})
if len(res.Addresses) != 2 {
b.Fatal("wrong length")
}
if res.Addresses[0].Address != "a" || res.Addresses[1].Address != "b" {
b.Fatal("wrong address")
}
if res.Addresses[0].Expires != 15 || res.Addresses[1].Expires != 15 {
b.Fatal("wrong expiry")
}
}
b.ReportAllocs() // should be zero per operation
}
type testClock struct { type testClock struct {
now time.Time now time.Time
} }