mirror of
https://github.com/octoleo/syncthing.git
synced 2024-11-09 23:00:58 +00:00
lib/connections: Fix and optimize registry (#7996)
Registry.Get used a full sort to get the minimum of a list, and the sort was broken because util.AddressUnspecifiedLess assumed it could find out whether an address is IPv4 or IPv6 from its Network method. However, net.(TCP|UDP)Addr.Network always returns "tcp"/"udp".
This commit is contained in:
parent
c94b797f00
commit
7c292cc812
@ -58,7 +58,8 @@ func (d *quicDialer) Dial(ctx context.Context, _ protocol.DeviceID, uri *url.URL
|
|||||||
// Given we always pass the connection to quic, it assumes it's a remote connection it never closes it,
|
// Given we always pass the connection to quic, it assumes it's a remote connection it never closes it,
|
||||||
// So our wrapper around it needs to close it, but it only needs to close it if it's not the listening connection.
|
// So our wrapper around it needs to close it, but it only needs to close it if it's not the listening connection.
|
||||||
var createdConn net.PacketConn
|
var createdConn net.PacketConn
|
||||||
if listenConn := registry.Get(uri.Scheme, packetConnLess); listenConn != nil {
|
listenConn := registry.Get(uri.Scheme, packetConnUnspecified)
|
||||||
|
if listenConn != nil {
|
||||||
conn = listenConn.(net.PacketConn)
|
conn = listenConn.(net.PacketConn)
|
||||||
} else {
|
} else {
|
||||||
if packetConn, err := net.ListenPacket("udp", ":0"); err != nil {
|
if packetConn, err := net.ListenPacket("udp", ":0"); err != nil {
|
||||||
|
@ -15,7 +15,6 @@ import (
|
|||||||
"net/url"
|
"net/url"
|
||||||
|
|
||||||
"github.com/lucas-clemente/quic-go"
|
"github.com/lucas-clemente/quic-go"
|
||||||
"github.com/syncthing/syncthing/lib/util"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@ -63,7 +62,10 @@ func (q *quicTlsConn) ConnectionState() tls.ConnectionState {
|
|||||||
return q.Session.ConnectionState().TLS.ConnectionState
|
return q.Session.ConnectionState().TLS.ConnectionState
|
||||||
}
|
}
|
||||||
|
|
||||||
// Sort available packet connections by ip address, preferring unspecified local address.
|
func packetConnUnspecified(conn interface{}) bool {
|
||||||
func packetConnLess(i interface{}, j interface{}) bool {
|
// Since QUIC connections are wrapped, we can't do a simple typecheck
|
||||||
return util.AddressUnspecifiedLess(i.(net.PacketConn).LocalAddr(), j.(net.PacketConn).LocalAddr())
|
// on *net.UDPAddr here.
|
||||||
|
addr := conn.(net.PacketConn).LocalAddr()
|
||||||
|
host, _, err := net.SplitHostPort(addr.String())
|
||||||
|
return err == nil && net.ParseIP(host).IsUnspecified()
|
||||||
}
|
}
|
||||||
|
@ -10,7 +10,6 @@
|
|||||||
package registry
|
package registry
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"sort"
|
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/syncthing/syncthing/lib/sync"
|
"github.com/syncthing/syncthing/lib/sync"
|
||||||
@ -46,7 +45,7 @@ func (r *Registry) Unregister(scheme string, item interface{}) {
|
|||||||
candidates := r.available[scheme]
|
candidates := r.available[scheme]
|
||||||
for i, existingItem := range candidates {
|
for i, existingItem := range candidates {
|
||||||
if existingItem == item {
|
if existingItem == item {
|
||||||
copy(candidates[i:], candidates[i+1:])
|
candidates[i] = candidates[len(candidates)-1]
|
||||||
candidates[len(candidates)-1] = nil
|
candidates[len(candidates)-1] = nil
|
||||||
r.available[scheme] = candidates[:len(candidates)-1]
|
r.available[scheme] = candidates[:len(candidates)-1]
|
||||||
break
|
break
|
||||||
@ -54,26 +53,37 @@ func (r *Registry) Unregister(scheme string, item interface{}) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *Registry) Get(scheme string, less func(i, j interface{}) bool) interface{} {
|
// Get returns an item for a schema compatible with the given scheme.
|
||||||
|
// If any item satisfies preferred, that has precedence over other items.
|
||||||
|
func (r *Registry) Get(scheme string, preferred func(interface{}) bool) interface{} {
|
||||||
r.mut.Lock()
|
r.mut.Lock()
|
||||||
defer r.mut.Unlock()
|
defer r.mut.Unlock()
|
||||||
|
|
||||||
candidates := make([]interface{}, 0)
|
var (
|
||||||
|
best interface{}
|
||||||
|
bestPref bool
|
||||||
|
bestScheme string
|
||||||
|
)
|
||||||
for availableScheme, items := range r.available {
|
for availableScheme, items := range r.available {
|
||||||
// quic:// should be considered ok for both quic4:// and quic6://
|
// quic:// should be considered ok for both quic4:// and quic6://
|
||||||
if strings.HasPrefix(scheme, availableScheme) {
|
if !strings.HasPrefix(scheme, availableScheme) {
|
||||||
candidates = append(candidates, items...)
|
continue
|
||||||
|
}
|
||||||
|
for _, item := range items {
|
||||||
|
better := best == nil
|
||||||
|
pref := preferred(item)
|
||||||
|
if !better {
|
||||||
|
// In case of a tie, prefer "quic" to "quic[46]" etc.
|
||||||
|
better = pref &&
|
||||||
|
(!bestPref || len(availableScheme) < len(bestScheme))
|
||||||
|
}
|
||||||
|
if !better {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
best, bestPref, bestScheme = item, pref, availableScheme
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
return best
|
||||||
if len(candidates) == 0 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
sort.Slice(candidates, func(i, j int) bool {
|
|
||||||
return less(candidates[i], candidates[j])
|
|
||||||
})
|
|
||||||
return candidates[0]
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func Register(scheme string, item interface{}) {
|
func Register(scheme string, item interface{}) {
|
||||||
@ -84,6 +94,6 @@ func Unregister(scheme string, item interface{}) {
|
|||||||
Default.Unregister(scheme, item)
|
Default.Unregister(scheme, item)
|
||||||
}
|
}
|
||||||
|
|
||||||
func Get(scheme string, less func(i, j interface{}) bool) interface{} {
|
func Get(scheme string, preferred func(interface{}) bool) interface{} {
|
||||||
return Default.Get(scheme, less)
|
return Default.Get(scheme, preferred)
|
||||||
}
|
}
|
||||||
|
@ -7,13 +7,18 @@
|
|||||||
package registry
|
package registry
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"net"
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestRegistry(t *testing.T) {
|
func TestRegistry(t *testing.T) {
|
||||||
r := New()
|
r := New()
|
||||||
|
|
||||||
if res := r.Get("int", intLess); res != nil {
|
want := func(i int) func(interface{}) bool {
|
||||||
|
return func(x interface{}) bool { return x.(int) == i }
|
||||||
|
}
|
||||||
|
|
||||||
|
if res := r.Get("int", want(1)); res != nil {
|
||||||
t.Error("unexpected")
|
t.Error("unexpected")
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -24,30 +29,28 @@ func TestRegistry(t *testing.T) {
|
|||||||
r.Register("int6", 6)
|
r.Register("int6", 6)
|
||||||
r.Register("int6", 66)
|
r.Register("int6", 66)
|
||||||
|
|
||||||
if res := r.Get("int", intLess).(int); res != 1 {
|
if res := r.Get("int", want(1)).(int); res != 1 {
|
||||||
t.Error("unexpected", res)
|
t.Error("unexpected", res)
|
||||||
}
|
}
|
||||||
|
|
||||||
// int is prefix of int4, so returns 1
|
// int is prefix of int4, so returns 1
|
||||||
if res := r.Get("int4", intLess).(int); res != 1 {
|
if res := r.Get("int4", want(1)).(int); res != 1 {
|
||||||
t.Error("unexpected", res)
|
t.Error("unexpected", res)
|
||||||
}
|
}
|
||||||
|
|
||||||
r.Unregister("int", 1)
|
r.Unregister("int", 1)
|
||||||
|
|
||||||
// Check that falls through to 11
|
if res := r.Get("int", want(1)).(int); res == 1 {
|
||||||
if res := r.Get("int", intLess).(int); res != 11 {
|
|
||||||
t.Error("unexpected", res)
|
t.Error("unexpected", res)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 6 is smaller than 11 available in int.
|
if res := r.Get("int6", want(6)).(int); res != 6 {
|
||||||
if res := r.Get("int6", intLess).(int); res != 6 {
|
|
||||||
t.Error("unexpected", res)
|
t.Error("unexpected", res)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Unregister 11, int should be impossible to find
|
// Unregister 11, int should be impossible to find
|
||||||
r.Unregister("int", 11)
|
r.Unregister("int", 11)
|
||||||
if res := r.Get("int", intLess); res != nil {
|
if res := r.Get("int", want(11)); res != nil {
|
||||||
t.Error("unexpected")
|
t.Error("unexpected")
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -59,13 +62,35 @@ func TestRegistry(t *testing.T) {
|
|||||||
r.Register("int", 1)
|
r.Register("int", 1)
|
||||||
r.Unregister("int", 1)
|
r.Unregister("int", 1)
|
||||||
|
|
||||||
if res := r.Get("int4", intLess).(int); res != 1 {
|
if res := r.Get("int4", want(1)).(int); res != 1 {
|
||||||
t.Error("unexpected", res)
|
t.Error("unexpected", res)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func intLess(i, j interface{}) bool {
|
func TestShortSchemeFirst(t *testing.T) {
|
||||||
iInt := i.(int)
|
r := New()
|
||||||
jInt := j.(int)
|
r.Register("foo", 0)
|
||||||
return iInt < jInt
|
r.Register("foobar", 1)
|
||||||
|
|
||||||
|
// If we don't care about the value, we should get the one with "foo".
|
||||||
|
res := r.Get("foo", func(interface{}) bool { return false })
|
||||||
|
if res != 0 {
|
||||||
|
t.Error("unexpected", res)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkGet(b *testing.B) {
|
||||||
|
r := New()
|
||||||
|
for _, addr := range []string{"192.168.1.1", "172.1.1.1", "10.1.1.1"} {
|
||||||
|
r.Register("tcp", &net.TCPAddr{IP: net.ParseIP(addr)})
|
||||||
|
}
|
||||||
|
|
||||||
|
b.ReportAllocs()
|
||||||
|
b.ResetTimer()
|
||||||
|
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
r.Get("tcp", func(x interface{}) bool {
|
||||||
|
return x.(*net.TCPAddr).IP.IsUnspecified()
|
||||||
|
})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -13,7 +13,6 @@ import (
|
|||||||
"os"
|
"os"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/syncthing/syncthing/lib/util"
|
|
||||||
"golang.org/x/net/proxy"
|
"golang.org/x/net/proxy"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -61,11 +60,6 @@ func socksDialerFunction(u *url.URL, forward proxy.Dialer) (proxy.Dialer, error)
|
|||||||
return proxy.SOCKS5("tcp", u.Host, auth, forward)
|
return proxy.SOCKS5("tcp", u.Host, auth, forward)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Sort available addresses, preferring unspecified address.
|
|
||||||
func tcpAddrLess(i interface{}, j interface{}) bool {
|
|
||||||
return util.AddressUnspecifiedLess(i.(*net.TCPAddr), j.(*net.TCPAddr))
|
|
||||||
}
|
|
||||||
|
|
||||||
// dialerConn is needed because proxy dialed connections have RemoteAddr() pointing at the proxy,
|
// dialerConn is needed because proxy dialed connections have RemoteAddr() pointing at the proxy,
|
||||||
// which then screws up various things such as IsLAN checks, and "let's populate the relay invitation address from
|
// which then screws up various things such as IsLAN checks, and "let's populate the relay invitation address from
|
||||||
// existing connection" shenanigans.
|
// existing connection" shenanigans.
|
||||||
|
@ -110,7 +110,9 @@ func DialContextReusePort(ctx context.Context, network, addr string) (net.Conn,
|
|||||||
return DialContext(ctx, network, addr)
|
return DialContext(ctx, network, addr)
|
||||||
}
|
}
|
||||||
|
|
||||||
localAddrInterface := registry.Get(network, tcpAddrLess)
|
localAddrInterface := registry.Get(network, func(addr interface{}) bool {
|
||||||
|
return addr.(*net.TCPAddr).IP.IsUnspecified()
|
||||||
|
})
|
||||||
if localAddrInterface == nil {
|
if localAddrInterface == nil {
|
||||||
// Nothing listening, nothing to reuse.
|
// Nothing listening, nothing to reuse.
|
||||||
return DialContext(ctx, network, addr)
|
return DialContext(ctx, network, addr)
|
||||||
|
@ -9,7 +9,6 @@ package util
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
|
||||||
"net/url"
|
"net/url"
|
||||||
"reflect"
|
"reflect"
|
||||||
"strconv"
|
"strconv"
|
||||||
@ -231,25 +230,6 @@ func Address(network, host string) string {
|
|||||||
return u.String()
|
return u.String()
|
||||||
}
|
}
|
||||||
|
|
||||||
// AddressUnspecifiedLess is a comparator function preferring least specific network address (most widely listening,
|
|
||||||
// namely preferring 0.0.0.0 over some IP), if both IPs are equal, it prefers the less restrictive network (prefers tcp
|
|
||||||
// over tcp4)
|
|
||||||
func AddressUnspecifiedLess(a, b net.Addr) bool {
|
|
||||||
aIsUnspecified := false
|
|
||||||
bIsUnspecified := false
|
|
||||||
if host, _, err := net.SplitHostPort(a.String()); err == nil {
|
|
||||||
aIsUnspecified = host == "" || net.ParseIP(host).IsUnspecified()
|
|
||||||
}
|
|
||||||
if host, _, err := net.SplitHostPort(b.String()); err == nil {
|
|
||||||
bIsUnspecified = host == "" || net.ParseIP(host).IsUnspecified()
|
|
||||||
}
|
|
||||||
|
|
||||||
if aIsUnspecified == bIsUnspecified {
|
|
||||||
return len(a.Network()) < len(b.Network())
|
|
||||||
}
|
|
||||||
return aIsUnspecified
|
|
||||||
}
|
|
||||||
|
|
||||||
func CallWithContext(ctx context.Context, fn func() error) error {
|
func CallWithContext(ctx context.Context, fn func() error) error {
|
||||||
var err error
|
var err error
|
||||||
done := make(chan struct{})
|
done := make(chan struct{})
|
||||||
|
@ -225,50 +225,6 @@ func TestCopyMatching(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
type mockedAddr struct {
|
|
||||||
network string
|
|
||||||
addr string
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a mockedAddr) Network() string {
|
|
||||||
return a.network
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a mockedAddr) String() string {
|
|
||||||
return a.addr
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestInspecifiedAddressLess(t *testing.T) {
|
|
||||||
cases := []struct {
|
|
||||||
netA string
|
|
||||||
addrA string
|
|
||||||
netB string
|
|
||||||
addrB string
|
|
||||||
}{
|
|
||||||
// B is assumed the winner.
|
|
||||||
{"tcp", "127.0.0.1:1234", "tcp", ":1235"},
|
|
||||||
{"tcp", "127.0.0.1:1234", "tcp", "0.0.0.0:1235"},
|
|
||||||
{"tcp4", "0.0.0.0:1234", "tcp", "0.0.0.0:1235"}, // tcp4 on the first one
|
|
||||||
}
|
|
||||||
|
|
||||||
for i, testCase := range cases {
|
|
||||||
addrs := []mockedAddr{
|
|
||||||
{testCase.netA, testCase.addrA},
|
|
||||||
{testCase.netB, testCase.addrB},
|
|
||||||
}
|
|
||||||
|
|
||||||
if AddressUnspecifiedLess(addrs[0], addrs[1]) {
|
|
||||||
t.Error(i, "unexpected")
|
|
||||||
}
|
|
||||||
if !AddressUnspecifiedLess(addrs[1], addrs[0]) {
|
|
||||||
t.Error(i, "unexpected")
|
|
||||||
}
|
|
||||||
if AddressUnspecifiedLess(addrs[0], addrs[0]) || AddressUnspecifiedLess(addrs[1], addrs[1]) {
|
|
||||||
t.Error(i, "unexpected")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestFillNil(t *testing.T) {
|
func TestFillNil(t *testing.T) {
|
||||||
type A struct {
|
type A struct {
|
||||||
Slice []int
|
Slice []int
|
||||||
|
Loading…
Reference in New Issue
Block a user