diff --git a/lib/connections/quic_dial.go b/lib/connections/quic_dial.go index 88ff5fa1d..baf84069c 100644 --- a/lib/connections/quic_dial.go +++ b/lib/connections/quic_dial.go @@ -45,23 +45,26 @@ func (d *quicDialer) Dial(id protocol.DeviceID, uri *url.URL) (internalConn, err } var conn net.PacketConn - closeConn := false + // We need to track who created the conn. + // Given we always pass the connection to quic, it assumes it's a remote connection it never closes it, + // So our wrapper around it needs to close it, but it only needs to close it if it's not the listening connection. + var createdConn net.PacketConn if listenConn := registry.Get(uri.Scheme, packetConnLess); listenConn != nil { conn = listenConn.(net.PacketConn) } else { if packetConn, err := net.ListenPacket("udp", ":0"); err != nil { return internalConn{}, err } else { - closeConn = true conn = packetConn + createdConn = packetConn } } ctx, _ := context.WithTimeout(context.Background(), 10*time.Second) session, err := quic.DialContext(ctx, conn, addr, uri.Host, d.tlsCfg, quicConfig) if err != nil { - if closeConn { - _ = conn.Close() + if createdConn != nil { + _ = createdConn.Close() } return internalConn{}, err } @@ -85,13 +88,13 @@ func (d *quicDialer) Dial(id protocol.DeviceID, uri *url.URL) (internalConn, err if err != nil { // It's ok to close these, this does not close the underlying packetConn. _ = session.Close() - if closeConn { - _ = conn.Close() + if createdConn != nil { + _ = createdConn.Close() } return internalConn{}, err } - return internalConn{&quicTlsConn{session, stream}, connTypeQUICClient, quicPriority}, nil + return internalConn{&quicTlsConn{session, stream, createdConn}, connTypeQUICClient, quicPriority}, nil } func (d *quicDialer) RedialFrequency() time.Duration { diff --git a/lib/connections/quic_listen.go b/lib/connections/quic_listen.go index a192008ca..d81aabde9 100644 --- a/lib/connections/quic_listen.go +++ b/lib/connections/quic_listen.go @@ -59,7 +59,8 @@ func (t *quicListener) OnNATTypeChanged(natType stun.NATType) { func (t *quicListener) OnExternalAddressChanged(address *stun.Host, via string) { var uri *url.URL if address != nil { - uri = &(*t.uri) + copy := *t.uri + uri = © uri.Host = address.TransportAddr() } @@ -165,7 +166,7 @@ func (t *quicListener) Serve() { continue } - t.conns <- internalConn{&quicTlsConn{session, stream}, connTypeQUICServer, quicPriority} + t.conns <- internalConn{&quicTlsConn{session, stream, nil}, connTypeQUICServer, quicPriority} } } diff --git a/lib/connections/quic_misc.go b/lib/connections/quic_misc.go index 93c086008..37758f2e0 100644 --- a/lib/connections/quic_misc.go +++ b/lib/connections/quic_misc.go @@ -24,15 +24,24 @@ var ( type quicTlsConn struct { quic.Session quic.Stream + // If we created this connection, we should be the ones closing it. + createdConn net.PacketConn } func (q *quicTlsConn) Close() error { sterr := q.Stream.Close() seerr := q.Session.Close() + var pcerr error + if q.createdConn != nil { + pcerr = q.createdConn.Close() + } if sterr != nil { return sterr } - return seerr + if seerr != nil { + return seerr + } + return pcerr } // Sort available packet connections by ip address, preferring unspecified local address.