diff --git a/cmd/stdiscosrv/apisrv.go b/cmd/stdiscosrv/apisrv.go index fced38e09..96bfce75c 100644 --- a/cmd/stdiscosrv/apisrv.go +++ b/cmd/stdiscosrv/apisrv.go @@ -307,16 +307,17 @@ func (s *apiSrv) handleAnnounce(deviceID protocol.DeviceID, addresses []string) now := time.Now() expire := now.Add(addressExpiryTime).UnixNano() + // The address slice must always be sorted for database merges to work + // properly. + slices.Sort(addresses) + addresses = slices.Compact(addresses) + dbAddrs := make([]DatabaseAddress, len(addresses)) for i := range addresses { dbAddrs[i].Address = addresses[i] dbAddrs[i].Expires = expire } - // The address slice must always be sorted for database merges to work - // properly. - slices.SortFunc(dbAddrs, DatabaseAddress.Cmp) - seen := now.UnixNano() if s.repl != nil { s.repl.send(&deviceID, dbAddrs, seen) diff --git a/cmd/stdiscosrv/database.go b/cmd/stdiscosrv/database.go index 95de82eab..2a796848a 100644 --- a/cmd/stdiscosrv/database.go +++ b/cmd/stdiscosrv/database.go @@ -78,7 +78,7 @@ func newInMemoryStore(dir string, flushInterval time.Duration, s3 *s3Copier) *in log.Println("Error reading database:", err) } log.Printf("Read %d records from database", nr) - s.calculateStatistics() + s.expireAndCalculateStatistics() return s } @@ -99,7 +99,7 @@ func (s *inMemoryStore) merge(key *protocol.DeviceID, addrs []DatabaseAddress, s } oldRec, _ := s.m.Load(*key) - newRec = merge(newRec, oldRec) + newRec = merge(oldRec, newRec) s.m.Store(*key, newRec) databaseOperations.WithLabelValues(dbOpMerge, dbResSuccess).Inc() @@ -126,19 +126,20 @@ func (s *inMemoryStore) get(key *protocol.DeviceID) (DatabaseRecord, error) { } func (s *inMemoryStore) Serve(ctx context.Context) error { + if s.flushInterval <= 0 { + <-ctx.Done() + return nil + } + t := time.NewTimer(s.flushInterval) defer t.Stop() - if s.flushInterval <= 0 { - t.Stop() - } - loop: for { select { case <-t.C: log.Println("Calculating statistics") - s.calculateStatistics() + s.expireAndCalculateStatistics() log.Println("Flushing database") if err := s.write(); err != nil { log.Println("Error writing database:", err) @@ -155,11 +156,11 @@ loop: return s.write() } -func (s *inMemoryStore) calculateStatistics() { +func (s *inMemoryStore) expireAndCalculateStatistics() { now := s.clock.Now() cutoff24h := now.Add(-24 * time.Hour).UnixNano() cutoff1w := now.Add(-7 * 24 * time.Hour).UnixNano() - current, currentIPv4, currentIPv6, last24h, last1w := 0, 0, 0, 0, 0 + current, currentIPv4, currentIPv6, currentIPv6GUA, last24h, last1w := 0, 0, 0, 0, 0, 0 n := 0 s.m.Range(func(key protocol.DeviceID, rec DatabaseRecord) bool { @@ -169,17 +170,31 @@ func (s *inMemoryStore) calculateStatistics() { n++ addresses := expire(rec.Addresses, now) + if len(addresses) == 0 { + rec.Addresses = nil + s.m.Store(key, rec) + } else if len(addresses) != len(rec.Addresses) { + rec.Addresses = addresses + s.m.Store(key, rec) + } + switch { - case len(addresses) > 0: + case len(rec.Addresses) > 0: current++ - seenIPv4, seenIPv6 := false, false + seenIPv4, seenIPv6, seenIPv6GUA := false, false, false for _, addr := range rec.Addresses { + // We do fast and loose matching on strings here instead of + // parsing the address and the IP and doing "proper" checks, + // to keep things fast and generate less garbage. if strings.Contains(addr.Address, "[") { seenIPv6 = true + if strings.Contains(addr.Address, "[2") { + seenIPv6GUA = true + } } else { seenIPv4 = true } - if seenIPv4 && seenIPv6 { + if seenIPv4 && seenIPv6 && seenIPv6GUA { break } } @@ -189,6 +204,9 @@ func (s *inMemoryStore) calculateStatistics() { if seenIPv6 { currentIPv6++ } + if seenIPv6GUA { + currentIPv6GUA++ + } case rec.Seen > cutoff24h: last24h++ case rec.Seen > cutoff1w: @@ -203,6 +221,7 @@ func (s *inMemoryStore) calculateStatistics() { databaseKeys.WithLabelValues("current").Set(float64(current)) databaseKeys.WithLabelValues("currentIPv4").Set(float64(currentIPv4)) databaseKeys.WithLabelValues("currentIPv6").Set(float64(currentIPv6)) + databaseKeys.WithLabelValues("currentIPv6GUA").Set(float64(currentIPv6GUA)) databaseKeys.WithLabelValues("last24h").Set(float64(last24h)) databaseKeys.WithLabelValues("last1w").Set(float64(last1w)) databaseStatisticsSeconds.Set(time.Since(now).Seconds()) @@ -331,6 +350,7 @@ func (s *inMemoryStore) read() (int, error) { } slices.SortFunc(rec.Addresses, DatabaseAddress.Cmp) + rec.Addresses = slices.CompactFunc(rec.Addresses, DatabaseAddress.Equal) s.m.Store(key, DatabaseRecord{ Addresses: expire(rec.Addresses, s.clock.Now()), Seen: rec.Seen, @@ -342,69 +362,36 @@ func (s *inMemoryStore) read() (int, error) { // merge returns the merged result of the two database records a and b. The // result is the union of the two address sets, with the newer expiry time -// chosen for any duplicates. +// chosen for any duplicates. The address list in a is overwritten and +// reused for the result. func merge(a, b DatabaseRecord) DatabaseRecord { // Both lists must be sorted for this to work. - res := DatabaseRecord{ - Addresses: make([]DatabaseAddress, 0, max(len(a.Addresses), len(b.Addresses))), - Seen: a.Seen, - } - if b.Seen > a.Seen { - res.Seen = b.Seen - } + a.Seen = max(a.Seen, b.Seen) aIdx := 0 bIdx := 0 - aAddrs := a.Addresses - bAddrs := b.Addresses -loop: - for { - switch { - case aIdx == len(aAddrs) && bIdx == len(bAddrs): - // both lists are exhausted, we are done - break loop - - case aIdx == len(aAddrs): - // a is exhausted, pick from b and continue - res.Addresses = append(res.Addresses, bAddrs[bIdx]) - bIdx++ - continue - - case bIdx == len(bAddrs): - // b is exhausted, pick from a and continue - res.Addresses = append(res.Addresses, aAddrs[aIdx]) - aIdx++ - continue - } - - // We have values left on both sides. - aVal := aAddrs[aIdx] - bVal := bAddrs[bIdx] - - switch { - case aVal.Address == bVal.Address: - // update for same address, pick newer - if aVal.Expires > bVal.Expires { - res.Addresses = append(res.Addresses, aVal) - } else { - res.Addresses = append(res.Addresses, bVal) - } + for aIdx < len(a.Addresses) && bIdx < len(b.Addresses) { + switch cmp.Compare(a.Addresses[aIdx].Address, b.Addresses[bIdx].Address) { + case 0: + // a == b, choose the newer expiry time + a.Addresses[aIdx].Expires = max(a.Addresses[aIdx].Expires, b.Addresses[bIdx].Expires) aIdx++ bIdx++ - - case aVal.Address < bVal.Address: - // a is smallest, pick it and continue - res.Addresses = append(res.Addresses, aVal) + case -1: + // a < b, keep a and move on aIdx++ - - default: - // b is smallest, pick it and continue - res.Addresses = append(res.Addresses, bVal) + case 1: + // a > b, insert b before a + a.Addresses = append(a.Addresses[:aIdx], append([]DatabaseAddress{b.Addresses[bIdx]}, a.Addresses[aIdx:]...)...) bIdx++ } } - return res + if bIdx < len(b.Addresses) { + a.Addresses = append(a.Addresses, b.Addresses[bIdx:]...) + } + + return a } // expire returns the list of addresses after removing expired entries. @@ -414,10 +401,17 @@ func expire(addrs []DatabaseAddress, now time.Time) []DatabaseAddress { cutoff := now.UnixNano() naddrs := addrs[:0] for i := range addrs { + if i > 0 && addrs[i].Address == addrs[i-1].Address { + // Skip duplicates + continue + } if addrs[i].Expires >= cutoff { naddrs = append(naddrs, addrs[i]) } } + if len(naddrs) == 0 { + return nil + } return naddrs } @@ -427,3 +421,7 @@ func (d DatabaseAddress) Cmp(other DatabaseAddress) (n int) { } return cmp.Compare(d.Expires, other.Expires) } + +func (d DatabaseAddress) Equal(other DatabaseAddress) bool { + return d.Address == other.Address +} diff --git a/cmd/stdiscosrv/database_test.go b/cmd/stdiscosrv/database_test.go index 89a557175..b9bd9c6b5 100644 --- a/cmd/stdiscosrv/database_test.go +++ b/cmd/stdiscosrv/database_test.go @@ -167,6 +167,88 @@ func TestFilter(t *testing.T) { } } +func TestMerge(t *testing.T) { + cases := []struct { + a, b, res []DatabaseAddress + }{ + {nil, nil, nil}, + { + nil, + []DatabaseAddress{{Address: "a", Expires: 10}}, + []DatabaseAddress{{Address: "a", Expires: 10}}, + }, + { + nil, + []DatabaseAddress{{Address: "a", Expires: 10}, {Address: "b", Expires: 10}, {Address: "c", Expires: 10}}, + []DatabaseAddress{{Address: "a", Expires: 10}, {Address: "b", Expires: 10}, {Address: "c", Expires: 10}}, + }, + { + []DatabaseAddress{{Address: "a", Expires: 10}}, + []DatabaseAddress{{Address: "a", Expires: 15}}, + []DatabaseAddress{{Address: "a", Expires: 15}}, + }, + { + []DatabaseAddress{{Address: "a", Expires: 10}}, + []DatabaseAddress{{Address: "b", Expires: 15}}, + []DatabaseAddress{{Address: "a", Expires: 10}, {Address: "b", Expires: 15}}, + }, + { + []DatabaseAddress{{Address: "a", Expires: 10}, {Address: "b", Expires: 15}}, + []DatabaseAddress{{Address: "a", Expires: 15}, {Address: "b", Expires: 15}}, + []DatabaseAddress{{Address: "a", Expires: 15}, {Address: "b", Expires: 15}}, + }, + { + []DatabaseAddress{{Address: "a", Expires: 10}, {Address: "b", Expires: 15}}, + []DatabaseAddress{{Address: "b", Expires: 15}, {Address: "c", Expires: 20}}, + []DatabaseAddress{{Address: "a", Expires: 10}, {Address: "b", Expires: 15}, {Address: "c", Expires: 20}}, + }, + { + []DatabaseAddress{{Address: "a", Expires: 10}, {Address: "b", Expires: 15}}, + []DatabaseAddress{{Address: "b", Expires: 5}, {Address: "c", Expires: 20}}, + []DatabaseAddress{{Address: "a", Expires: 10}, {Address: "b", Expires: 15}, {Address: "c", Expires: 20}}, + }, + { + []DatabaseAddress{{Address: "y", Expires: 10}, {Address: "z", Expires: 10}}, + []DatabaseAddress{{Address: "a", Expires: 5}, {Address: "b", Expires: 15}}, + []DatabaseAddress{{Address: "a", Expires: 5}, {Address: "b", Expires: 15}, {Address: "y", Expires: 10}, {Address: "z", Expires: 10}}, + }, + { + []DatabaseAddress{{Address: "a", Expires: 10}, {Address: "b", Expires: 15}, {Address: "d", Expires: 10}}, + []DatabaseAddress{{Address: "b", Expires: 5}, {Address: "c", Expires: 20}}, + []DatabaseAddress{{Address: "a", Expires: 10}, {Address: "b", Expires: 15}, {Address: "c", Expires: 20}, {Address: "d", Expires: 10}}, + }, + } + + for _, tc := range cases { + rec := merge(DatabaseRecord{Addresses: tc.a}, DatabaseRecord{Addresses: tc.b}) + if fmt.Sprint(rec.Addresses) != fmt.Sprint(tc.res) { + t.Errorf("Incorrect result %v, expected %v", rec.Addresses, tc.res) + } + rec = merge(DatabaseRecord{Addresses: tc.b}, DatabaseRecord{Addresses: tc.a}) + if fmt.Sprint(rec.Addresses) != fmt.Sprint(tc.res) { + t.Errorf("Incorrect result %v, expected %v", rec.Addresses, tc.res) + } + } +} + +func BenchmarkMergeEqual(b *testing.B) { + for i := 0; i < b.N; i++ { + ar := []DatabaseAddress{{Address: "a", Expires: 10}, {Address: "b", Expires: 15}} + br := []DatabaseAddress{{Address: "a", Expires: 15}, {Address: "b", Expires: 10}} + res := merge(DatabaseRecord{Addresses: ar}, DatabaseRecord{Addresses: br}) + if len(res.Addresses) != 2 { + b.Fatal("wrong length") + } + if res.Addresses[0].Address != "a" || res.Addresses[1].Address != "b" { + b.Fatal("wrong address") + } + if res.Addresses[0].Expires != 15 || res.Addresses[1].Expires != 15 { + b.Fatal("wrong expiry") + } + } + b.ReportAllocs() // should be zero per operation +} + type testClock struct { now time.Time }