diff --git a/lib/connections/connections_test.go b/lib/connections/connections_test.go index 8ddec00ca..72ff3cc52 100644 --- a/lib/connections/connections_test.go +++ b/lib/connections/connections_test.go @@ -8,12 +8,26 @@ package connections import ( "context" + "crypto/tls" "errors" + "fmt" + "io/ioutil" + "math/rand" + "net" "net/url" + "os" + "strings" "testing" + "time" + + "github.com/thejerf/suture/v4" "github.com/syncthing/syncthing/lib/config" + "github.com/syncthing/syncthing/lib/events" + "github.com/syncthing/syncthing/lib/nat" "github.com/syncthing/syncthing/lib/protocol" + "github.com/syncthing/syncthing/lib/sync" + "github.com/syncthing/syncthing/lib/tlsutil" ) func TestFixupPort(t *testing.T) { @@ -216,3 +230,196 @@ func TestConnectionStatus(t *testing.T) { check(nil, nil) } + +func BenchmarkConnections(pb *testing.B) { + addrs := []string{ + "tcp://127.0.0.1:0", + "quic://127.0.0.1:0", + "relay://127.0.0.1:22067", + } + sizes := []int{ + 1 << 10, + 1 << 15, + 1 << 20, + 1 << 22, + } + haveRelay := false + // Check if we have a relay running locally + conn, err := net.DialTimeout("tcp", "127.0.0.1:22067", 100*time.Millisecond) + if err == nil { + haveRelay = true + _ = conn.Close() + } + for _, addr := range addrs { + for _, sz := range sizes { + for _, direction := range []string{"cs", "sc"} { + proto := strings.SplitN(addr, ":", 2)[0] + pb.Run(fmt.Sprintf("%s_%d_%s", proto, sz, direction), func(b *testing.B) { + if proto == "relay" && !haveRelay { + b.Skip("could not connect to relay") + } + withConnectionPair(b, addr, func(client, server internalConn) { + if direction == "sc" { + server, client = client, server + } + data := make([]byte, sz) + if _, err := rand.Read(data); err != nil { + b.Fatal(err) + } + + total := 0 + wg := sync.NewWaitGroup() + b.ResetTimer() + for i := 0; i < b.N; i++ { + wg.Add(2) + go func() { + if err := sendMsg(client, data); err != nil { + b.Fatal(err) + } + wg.Done() + }() + go func() { + if err := recvMsg(server, data); err != nil { + b.Fatal(err) + } + total += sz + wg.Done() + }() + wg.Wait() + } + b.ReportAllocs() + b.SetBytes(int64(total / b.N)) + }) + }) + } + + } + } +} + +func sendMsg(c internalConn, buf []byte) error { + n, err := c.Write(buf) + if n != len(buf) || err != nil { + return err + } + return nil +} + +func recvMsg(c internalConn, buf []byte) error { + for read := 0; read != len(buf); { + n, err := c.Read(buf) + read += n + if err != nil { + return err + } + } + return nil +} + +func withConnectionPair(b *testing.B, connUri string, h func(client, server internalConn)) { + // Root of the service tree. + supervisor := suture.New("main", suture.Spec{ + PassThroughPanics: true, + }) + + cert := mustGetCert(b) + deviceId := protocol.NewDeviceID(cert.Certificate[0]) + tlsCfg := tlsutil.SecureDefaultTLS13() + tlsCfg.Certificates = []tls.Certificate{cert} + tlsCfg.NextProtos = []string{"bench"} + tlsCfg.ClientAuth = tls.RequestClientCert + tlsCfg.SessionTicketsDisabled = true + tlsCfg.InsecureSkipVerify = true + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + supervisor.ServeBackground(ctx) + + cfg := config.Configuration{ + Options: config.OptionsConfiguration{ + RelaysEnabled: true, + }, + } + wcfg := config.Wrap("", cfg, deviceId, events.NoopLogger) + uri, err := url.Parse(connUri) + if err != nil { + b.Fatal(err) + } + lf, err := getListenerFactory(cfg, uri) + if err != nil { + b.Fatal(err) + } + natSvc := nat.NewService(deviceId, wcfg) + conns := make(chan internalConn, 1) + listenSvc := lf.New(uri, wcfg, tlsCfg, conns, natSvc) + supervisor.Add(listenSvc) + + var addr *url.URL + for { + addrs := listenSvc.LANAddresses() + if len(addrs) > 0 { + if !strings.HasSuffix(addrs[0].Host, ":0") { + addr = addrs[0] + break + } + } + time.Sleep(time.Millisecond) + } + + df, err := getDialerFactory(cfg, addr) + if err != nil { + b.Fatal(err) + } + dialer := df.New(cfg.Options, tlsCfg) + + // Relays might take some time to register the device, so dial multiple times + clientConn, err := dialer.Dial(ctx, deviceId, addr) + if err != nil { + for i := 0; i < 10 && err != nil; i++ { + clientConn, err = dialer.Dial(ctx, deviceId, addr) + time.Sleep(100 * time.Millisecond) + } + if err != nil { + b.Fatal(err) + } + } + + data := []byte("hello") + + // Quic does not start a stream until some data is sent through, so send something for the AcceptStream + // to fire on the other side. + if err := sendMsg(clientConn, data); err != nil { + b.Fatal(err) + } + + serverConn := <-conns + + if err := recvMsg(serverConn, data); err != nil { + b.Fatal(err) + } + + h(clientConn, serverConn) + + _ = clientConn.Close() + _ = serverConn.Close() +} + +func mustGetCert(b *testing.B) tls.Certificate { + f1, err := ioutil.TempFile("", "") + if err != nil { + b.Fatal(err) + } + f1.Close() + f2, err := ioutil.TempFile("", "") + if err != nil { + b.Fatal(err) + } + f2.Close() + cert, err := tlsutil.NewCertificate(f1.Name(), f2.Name(), "bench", 10) + if err != nil { + b.Fatal(err) + } + _ = os.Remove(f1.Name()) + _ = os.Remove(f2.Name()) + return cert +} diff --git a/lib/connections/quic_listen.go b/lib/connections/quic_listen.go index 411368782..61c43e31d 100644 --- a/lib/connections/quic_listen.go +++ b/lib/connections/quic_listen.go @@ -48,6 +48,7 @@ type quicListener struct { factory listenerFactory address *url.URL + laddr net.Addr mut sync.Mutex } @@ -87,10 +88,8 @@ func (t *quicListener) serve(ctx context.Context) error { l.Infoln("Listen (BEP/quic):", err) return err } - defer func() { _ = packetConn.Close() }() svc, conn := stun.New(t.cfg, t, packetConn) - defer func() { _ = conn.Close() }() wrapped := &stunConnQUICWrapper{ PacketConn: conn, underlying: packetConn.(*net.UDPConn), @@ -99,7 +98,6 @@ func (t *quicListener) serve(ctx context.Context) error { go svc.Serve(ctx) registry.Register(t.uri.Scheme, wrapped) - defer registry.Unregister(t.uri.Scheme, wrapped) listener, err := quic.Listen(wrapped, t.tlsCfg, quicConfig) if err != nil { @@ -107,11 +105,23 @@ func (t *quicListener) serve(ctx context.Context) error { return err } t.notifyAddressesChanged(t) - defer listener.Close() - defer t.clearAddresses(t) l.Infof("QUIC listener (%v) starting", packetConn.LocalAddr()) - defer l.Infof("QUIC listener (%v) shutting down", packetConn.LocalAddr()) + t.mut.Lock() + t.laddr = packetConn.LocalAddr() + t.mut.Unlock() + + defer func() { + l.Infof("QUIC listener (%v) shutting down", packetConn.LocalAddr()) + t.mut.Lock() + t.laddr = nil + t.mut.Unlock() + registry.Unregister(t.uri.Scheme, wrapped) + t.clearAddresses(t) + _ = listener.Close() + _ = conn.Close() + _ = packetConn.Close() + }() acceptFailures := 0 const maxAcceptFailures = 10 @@ -164,8 +174,8 @@ func (t *quicListener) URI() *url.URL { } func (t *quicListener) WANAddresses() []*url.URL { - uris := []*url.URL{t.uri} t.mut.Lock() + uris := []*url.URL{maybeReplacePort(t.uri, t.laddr)} if t.address != nil { uris = append(uris, t.address) } @@ -174,9 +184,12 @@ func (t *quicListener) WANAddresses() []*url.URL { } func (t *quicListener) LANAddresses() []*url.URL { - addrs := []*url.URL{t.uri} - network := strings.ReplaceAll(t.uri.Scheme, "quic", "udp") - addrs = append(addrs, getURLsForAllAdaptersIfUnspecified(network, t.uri)...) + t.mut.Lock() + uri := maybeReplacePort(t.uri, t.laddr) + t.mut.Unlock() + addrs := []*url.URL{uri} + network := strings.ReplaceAll(uri.Scheme, "quic", "udp") + addrs = append(addrs, getURLsForAllAdaptersIfUnspecified(network, uri)...) return addrs } diff --git a/lib/connections/tcp_listen.go b/lib/connections/tcp_listen.go index 51dfd39bc..6fc3235e9 100644 --- a/lib/connections/tcp_listen.go +++ b/lib/connections/tcp_listen.go @@ -40,6 +40,7 @@ type tcpListener struct { natService *nat.Service mapping *nat.Mapping + laddr net.Addr mut sync.RWMutex } @@ -60,26 +61,36 @@ func (t *tcpListener) serve(ctx context.Context) error { l.Infoln("Listen (BEP/tcp):", err) return err } + + // We might bind to :0, so use the port we've been given. + tcaddr = listener.Addr().(*net.TCPAddr) + t.notifyAddressesChanged(t) registry.Register(t.uri.Scheme, tcaddr) - defer listener.Close() - defer t.clearAddresses(t) - defer registry.Unregister(t.uri.Scheme, tcaddr) - - l.Infof("TCP listener (%v) starting", listener.Addr()) - defer l.Infof("TCP listener (%v) shutting down", listener.Addr()) + l.Infof("TCP listener (%v) starting", tcaddr) mapping := t.natService.NewMapping(nat.TCP, tcaddr.IP, tcaddr.Port) mapping.OnChanged(func(_ *nat.Mapping, _, _ []nat.Address) { t.notifyAddressesChanged(t) }) - defer t.natService.RemoveMapping(mapping) t.mut.Lock() t.mapping = mapping + t.laddr = tcaddr t.mut.Unlock() + defer func() { + l.Infof("TCP listener (%v) shutting down", tcaddr) + t.natService.RemoveMapping(mapping) + t.mut.Lock() + t.laddr = nil + t.mut.Unlock() + registry.Unregister(t.uri.Scheme, tcaddr) + t.clearAddresses(t) + _ = listener.Close() + }() + acceptFailures := 0 const maxAcceptFailures = 10 @@ -146,8 +157,10 @@ func (t *tcpListener) URI() *url.URL { } func (t *tcpListener) WANAddresses() []*url.URL { - uris := []*url.URL{t.uri} t.mut.RLock() + uris := []*url.URL{ + maybeReplacePort(t.uri, t.laddr), + } if t.mapping != nil { addrs := t.mapping.ExternalAddresses() for _, addr := range addrs { @@ -179,8 +192,11 @@ func (t *tcpListener) WANAddresses() []*url.URL { } func (t *tcpListener) LANAddresses() []*url.URL { - addrs := []*url.URL{t.uri} - addrs = append(addrs, getURLsForAllAdaptersIfUnspecified(t.uri.Scheme, t.uri)...) + t.mut.RLock() + uri := maybeReplacePort(t.uri, t.laddr) + t.mut.RUnlock() + addrs := []*url.URL{uri} + addrs = append(addrs, getURLsForAllAdaptersIfUnspecified(uri.Scheme, uri)...) return addrs } diff --git a/lib/connections/util.go b/lib/connections/util.go index 1738c8c85..3dd5dc92c 100644 --- a/lib/connections/util.go +++ b/lib/connections/util.go @@ -117,3 +117,30 @@ func isV4Local(ip net.IP) bool { } return false } + +func maybeReplacePort(uri *url.URL, laddr net.Addr) *url.URL { + if laddr == nil { + return uri + } + + host, portStr, err := net.SplitHostPort(uri.Host) + if err != nil { + return uri + } + port, err := strconv.Atoi(portStr) + if err != nil { + return uri + } + if port != 0 { + return uri + } + + _, lportStr, err := net.SplitHostPort(laddr.String()) + if err != nil { + return uri + } + + uriCopy := *uri + uriCopy.Host = net.JoinHostPort(host, lportStr) + return &uriCopy +}