mirror of
https://github.com/octoleo/syncthing.git
synced 2024-11-09 23:00:58 +00:00
lib/connections: Fix and optimize registry (#7996)
Registry.Get used a full sort to get the minimum of a list, and the sort was broken because util.AddressUnspecifiedLess assumed it could find out whether an address is IPv4 or IPv6 from its Network method. However, net.(TCP|UDP)Addr.Network always returns "tcp"/"udp".
This commit is contained in:
parent
c94b797f00
commit
7c292cc812
@ -58,7 +58,8 @@ func (d *quicDialer) Dial(ctx context.Context, _ protocol.DeviceID, uri *url.URL
|
||||
// Given we always pass the connection to quic, it assumes it's a remote connection it never closes it,
|
||||
// So our wrapper around it needs to close it, but it only needs to close it if it's not the listening connection.
|
||||
var createdConn net.PacketConn
|
||||
if listenConn := registry.Get(uri.Scheme, packetConnLess); listenConn != nil {
|
||||
listenConn := registry.Get(uri.Scheme, packetConnUnspecified)
|
||||
if listenConn != nil {
|
||||
conn = listenConn.(net.PacketConn)
|
||||
} else {
|
||||
if packetConn, err := net.ListenPacket("udp", ":0"); err != nil {
|
||||
|
@ -15,7 +15,6 @@ import (
|
||||
"net/url"
|
||||
|
||||
"github.com/lucas-clemente/quic-go"
|
||||
"github.com/syncthing/syncthing/lib/util"
|
||||
)
|
||||
|
||||
var (
|
||||
@ -63,7 +62,10 @@ func (q *quicTlsConn) ConnectionState() tls.ConnectionState {
|
||||
return q.Session.ConnectionState().TLS.ConnectionState
|
||||
}
|
||||
|
||||
// Sort available packet connections by ip address, preferring unspecified local address.
|
||||
func packetConnLess(i interface{}, j interface{}) bool {
|
||||
return util.AddressUnspecifiedLess(i.(net.PacketConn).LocalAddr(), j.(net.PacketConn).LocalAddr())
|
||||
func packetConnUnspecified(conn interface{}) bool {
|
||||
// Since QUIC connections are wrapped, we can't do a simple typecheck
|
||||
// on *net.UDPAddr here.
|
||||
addr := conn.(net.PacketConn).LocalAddr()
|
||||
host, _, err := net.SplitHostPort(addr.String())
|
||||
return err == nil && net.ParseIP(host).IsUnspecified()
|
||||
}
|
||||
|
@ -10,7 +10,6 @@
|
||||
package registry
|
||||
|
||||
import (
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/syncthing/syncthing/lib/sync"
|
||||
@ -46,7 +45,7 @@ func (r *Registry) Unregister(scheme string, item interface{}) {
|
||||
candidates := r.available[scheme]
|
||||
for i, existingItem := range candidates {
|
||||
if existingItem == item {
|
||||
copy(candidates[i:], candidates[i+1:])
|
||||
candidates[i] = candidates[len(candidates)-1]
|
||||
candidates[len(candidates)-1] = nil
|
||||
r.available[scheme] = candidates[:len(candidates)-1]
|
||||
break
|
||||
@ -54,26 +53,37 @@ func (r *Registry) Unregister(scheme string, item interface{}) {
|
||||
}
|
||||
}
|
||||
|
||||
func (r *Registry) Get(scheme string, less func(i, j interface{}) bool) interface{} {
|
||||
// Get returns an item for a schema compatible with the given scheme.
|
||||
// If any item satisfies preferred, that has precedence over other items.
|
||||
func (r *Registry) Get(scheme string, preferred func(interface{}) bool) interface{} {
|
||||
r.mut.Lock()
|
||||
defer r.mut.Unlock()
|
||||
|
||||
candidates := make([]interface{}, 0)
|
||||
var (
|
||||
best interface{}
|
||||
bestPref bool
|
||||
bestScheme string
|
||||
)
|
||||
for availableScheme, items := range r.available {
|
||||
// quic:// should be considered ok for both quic4:// and quic6://
|
||||
if strings.HasPrefix(scheme, availableScheme) {
|
||||
candidates = append(candidates, items...)
|
||||
if !strings.HasPrefix(scheme, availableScheme) {
|
||||
continue
|
||||
}
|
||||
for _, item := range items {
|
||||
better := best == nil
|
||||
pref := preferred(item)
|
||||
if !better {
|
||||
// In case of a tie, prefer "quic" to "quic[46]" etc.
|
||||
better = pref &&
|
||||
(!bestPref || len(availableScheme) < len(bestScheme))
|
||||
}
|
||||
if !better {
|
||||
continue
|
||||
}
|
||||
best, bestPref, bestScheme = item, pref, availableScheme
|
||||
}
|
||||
}
|
||||
|
||||
if len(candidates) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
sort.Slice(candidates, func(i, j int) bool {
|
||||
return less(candidates[i], candidates[j])
|
||||
})
|
||||
return candidates[0]
|
||||
return best
|
||||
}
|
||||
|
||||
func Register(scheme string, item interface{}) {
|
||||
@ -84,6 +94,6 @@ func Unregister(scheme string, item interface{}) {
|
||||
Default.Unregister(scheme, item)
|
||||
}
|
||||
|
||||
func Get(scheme string, less func(i, j interface{}) bool) interface{} {
|
||||
return Default.Get(scheme, less)
|
||||
func Get(scheme string, preferred func(interface{}) bool) interface{} {
|
||||
return Default.Get(scheme, preferred)
|
||||
}
|
||||
|
@ -7,13 +7,18 @@
|
||||
package registry
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestRegistry(t *testing.T) {
|
||||
r := New()
|
||||
|
||||
if res := r.Get("int", intLess); res != nil {
|
||||
want := func(i int) func(interface{}) bool {
|
||||
return func(x interface{}) bool { return x.(int) == i }
|
||||
}
|
||||
|
||||
if res := r.Get("int", want(1)); res != nil {
|
||||
t.Error("unexpected")
|
||||
}
|
||||
|
||||
@ -24,30 +29,28 @@ func TestRegistry(t *testing.T) {
|
||||
r.Register("int6", 6)
|
||||
r.Register("int6", 66)
|
||||
|
||||
if res := r.Get("int", intLess).(int); res != 1 {
|
||||
if res := r.Get("int", want(1)).(int); res != 1 {
|
||||
t.Error("unexpected", res)
|
||||
}
|
||||
|
||||
// int is prefix of int4, so returns 1
|
||||
if res := r.Get("int4", intLess).(int); res != 1 {
|
||||
if res := r.Get("int4", want(1)).(int); res != 1 {
|
||||
t.Error("unexpected", res)
|
||||
}
|
||||
|
||||
r.Unregister("int", 1)
|
||||
|
||||
// Check that falls through to 11
|
||||
if res := r.Get("int", intLess).(int); res != 11 {
|
||||
if res := r.Get("int", want(1)).(int); res == 1 {
|
||||
t.Error("unexpected", res)
|
||||
}
|
||||
|
||||
// 6 is smaller than 11 available in int.
|
||||
if res := r.Get("int6", intLess).(int); res != 6 {
|
||||
if res := r.Get("int6", want(6)).(int); res != 6 {
|
||||
t.Error("unexpected", res)
|
||||
}
|
||||
|
||||
// Unregister 11, int should be impossible to find
|
||||
r.Unregister("int", 11)
|
||||
if res := r.Get("int", intLess); res != nil {
|
||||
if res := r.Get("int", want(11)); res != nil {
|
||||
t.Error("unexpected")
|
||||
}
|
||||
|
||||
@ -59,13 +62,35 @@ func TestRegistry(t *testing.T) {
|
||||
r.Register("int", 1)
|
||||
r.Unregister("int", 1)
|
||||
|
||||
if res := r.Get("int4", intLess).(int); res != 1 {
|
||||
if res := r.Get("int4", want(1)).(int); res != 1 {
|
||||
t.Error("unexpected", res)
|
||||
}
|
||||
}
|
||||
|
||||
func intLess(i, j interface{}) bool {
|
||||
iInt := i.(int)
|
||||
jInt := j.(int)
|
||||
return iInt < jInt
|
||||
func TestShortSchemeFirst(t *testing.T) {
|
||||
r := New()
|
||||
r.Register("foo", 0)
|
||||
r.Register("foobar", 1)
|
||||
|
||||
// If we don't care about the value, we should get the one with "foo".
|
||||
res := r.Get("foo", func(interface{}) bool { return false })
|
||||
if res != 0 {
|
||||
t.Error("unexpected", res)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkGet(b *testing.B) {
|
||||
r := New()
|
||||
for _, addr := range []string{"192.168.1.1", "172.1.1.1", "10.1.1.1"} {
|
||||
r.Register("tcp", &net.TCPAddr{IP: net.ParseIP(addr)})
|
||||
}
|
||||
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
r.Get("tcp", func(x interface{}) bool {
|
||||
return x.(*net.TCPAddr).IP.IsUnspecified()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@ -13,7 +13,6 @@ import (
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/syncthing/syncthing/lib/util"
|
||||
"golang.org/x/net/proxy"
|
||||
)
|
||||
|
||||
@ -61,11 +60,6 @@ func socksDialerFunction(u *url.URL, forward proxy.Dialer) (proxy.Dialer, error)
|
||||
return proxy.SOCKS5("tcp", u.Host, auth, forward)
|
||||
}
|
||||
|
||||
// Sort available addresses, preferring unspecified address.
|
||||
func tcpAddrLess(i interface{}, j interface{}) bool {
|
||||
return util.AddressUnspecifiedLess(i.(*net.TCPAddr), j.(*net.TCPAddr))
|
||||
}
|
||||
|
||||
// dialerConn is needed because proxy dialed connections have RemoteAddr() pointing at the proxy,
|
||||
// which then screws up various things such as IsLAN checks, and "let's populate the relay invitation address from
|
||||
// existing connection" shenanigans.
|
||||
|
@ -110,7 +110,9 @@ func DialContextReusePort(ctx context.Context, network, addr string) (net.Conn,
|
||||
return DialContext(ctx, network, addr)
|
||||
}
|
||||
|
||||
localAddrInterface := registry.Get(network, tcpAddrLess)
|
||||
localAddrInterface := registry.Get(network, func(addr interface{}) bool {
|
||||
return addr.(*net.TCPAddr).IP.IsUnspecified()
|
||||
})
|
||||
if localAddrInterface == nil {
|
||||
// Nothing listening, nothing to reuse.
|
||||
return DialContext(ctx, network, addr)
|
||||
|
@ -9,7 +9,6 @@ package util
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/url"
|
||||
"reflect"
|
||||
"strconv"
|
||||
@ -231,25 +230,6 @@ func Address(network, host string) string {
|
||||
return u.String()
|
||||
}
|
||||
|
||||
// AddressUnspecifiedLess is a comparator function preferring least specific network address (most widely listening,
|
||||
// namely preferring 0.0.0.0 over some IP), if both IPs are equal, it prefers the less restrictive network (prefers tcp
|
||||
// over tcp4)
|
||||
func AddressUnspecifiedLess(a, b net.Addr) bool {
|
||||
aIsUnspecified := false
|
||||
bIsUnspecified := false
|
||||
if host, _, err := net.SplitHostPort(a.String()); err == nil {
|
||||
aIsUnspecified = host == "" || net.ParseIP(host).IsUnspecified()
|
||||
}
|
||||
if host, _, err := net.SplitHostPort(b.String()); err == nil {
|
||||
bIsUnspecified = host == "" || net.ParseIP(host).IsUnspecified()
|
||||
}
|
||||
|
||||
if aIsUnspecified == bIsUnspecified {
|
||||
return len(a.Network()) < len(b.Network())
|
||||
}
|
||||
return aIsUnspecified
|
||||
}
|
||||
|
||||
func CallWithContext(ctx context.Context, fn func() error) error {
|
||||
var err error
|
||||
done := make(chan struct{})
|
||||
|
@ -225,50 +225,6 @@ func TestCopyMatching(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
type mockedAddr struct {
|
||||
network string
|
||||
addr string
|
||||
}
|
||||
|
||||
func (a mockedAddr) Network() string {
|
||||
return a.network
|
||||
}
|
||||
|
||||
func (a mockedAddr) String() string {
|
||||
return a.addr
|
||||
}
|
||||
|
||||
func TestInspecifiedAddressLess(t *testing.T) {
|
||||
cases := []struct {
|
||||
netA string
|
||||
addrA string
|
||||
netB string
|
||||
addrB string
|
||||
}{
|
||||
// B is assumed the winner.
|
||||
{"tcp", "127.0.0.1:1234", "tcp", ":1235"},
|
||||
{"tcp", "127.0.0.1:1234", "tcp", "0.0.0.0:1235"},
|
||||
{"tcp4", "0.0.0.0:1234", "tcp", "0.0.0.0:1235"}, // tcp4 on the first one
|
||||
}
|
||||
|
||||
for i, testCase := range cases {
|
||||
addrs := []mockedAddr{
|
||||
{testCase.netA, testCase.addrA},
|
||||
{testCase.netB, testCase.addrB},
|
||||
}
|
||||
|
||||
if AddressUnspecifiedLess(addrs[0], addrs[1]) {
|
||||
t.Error(i, "unexpected")
|
||||
}
|
||||
if !AddressUnspecifiedLess(addrs[1], addrs[0]) {
|
||||
t.Error(i, "unexpected")
|
||||
}
|
||||
if AddressUnspecifiedLess(addrs[0], addrs[0]) || AddressUnspecifiedLess(addrs[1], addrs[1]) {
|
||||
t.Error(i, "unexpected")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestFillNil(t *testing.T) {
|
||||
type A struct {
|
||||
Slice []int
|
||||
|
Loading…
Reference in New Issue
Block a user