diff --git a/cmd/stdiscosrv/apisrv.go b/cmd/stdiscosrv/apisrv.go index 1d202d7cf..2c7c874c9 100644 --- a/cmd/stdiscosrv/apisrv.go +++ b/cmd/stdiscosrv/apisrv.go @@ -22,7 +22,7 @@ import ( "net" "net/http" "net/url" - "sort" + "slices" "strconv" "strings" "sync" @@ -311,7 +311,7 @@ func (s *apiSrv) handleAnnounce(deviceID protocol.DeviceID, addresses []string) // The address slice must always be sorted for database merges to work // properly. - sort.Sort(databaseAddressOrder(dbAddrs)) + slices.SortFunc(dbAddrs, DatabaseAddress.Cmp) seen := now.UnixNano() if s.repl != nil { diff --git a/cmd/stdiscosrv/database.go b/cmd/stdiscosrv/database.go index a6cbd424f..407168260 100644 --- a/cmd/stdiscosrv/database.go +++ b/cmd/stdiscosrv/database.go @@ -11,6 +11,7 @@ package main import ( "bufio" + "cmp" "context" "encoding/binary" "io" @@ -19,7 +20,7 @@ import ( "net/url" "os" "path" - "sort" + "slices" "time" "github.com/aws/aws-sdk-go/aws" @@ -326,6 +327,7 @@ func (s *inMemoryStore) read() error { continue } + slices.SortFunc(rec.Addresses, DatabaseAddress.Cmp) s.m.Store(key, DatabaseRecord{ Addresses: rec.Addresses, Seen: rec.Seen, @@ -339,15 +341,6 @@ func (s *inMemoryStore) read() error { // chosen for any duplicates. func merge(a, b DatabaseRecord) DatabaseRecord { // Both lists must be sorted for this to work. - if !sort.IsSorted(databaseAddressOrder(a.Addresses)) { - log.Println("Warning: bug: addresses not correctly sorted in merge") - a.Addresses = sortedAddressCopy(a.Addresses) - } - if !sort.IsSorted(databaseAddressOrder(b.Addresses)) { - // no warning because this is the side we read from disk and it may - // legitimately predate correct sorting. - b.Addresses = sortedAddressCopy(b.Addresses) - } res := DatabaseRecord{ Addresses: make([]DatabaseAddress, 0, max(len(a.Addresses), len(b.Addresses))), @@ -425,27 +418,6 @@ func expire(addrs []DatabaseAddress, now int64) []DatabaseAddress { return addrs } -func sortedAddressCopy(addrs []DatabaseAddress) []DatabaseAddress { - sorted := make([]DatabaseAddress, len(addrs)) - copy(sorted, addrs) - sort.Sort(databaseAddressOrder(sorted)) - return sorted -} - -type databaseAddressOrder []DatabaseAddress - -func (s databaseAddressOrder) Less(a, b int) bool { - return s[a].Address < s[b].Address -} - -func (s databaseAddressOrder) Swap(a, b int) { - s[a], s[b] = s[b], s[a] -} - -func (s databaseAddressOrder) Len() int { - return len(s) -} - func s3Upload(r io.Reader) error { sess, err := session.NewSession(&aws.Config{ Region: aws.String("fr-par"), @@ -478,3 +450,10 @@ func s3Download(w io.WriterAt) error { }) return err } + +func (d DatabaseAddress) Cmp(other DatabaseAddress) (n int) { + if c := cmp.Compare(d.Address, other.Address); c != 0 { + return c + } + return cmp.Compare(d.Expires, other.Expires) +}