diff --git a/go.mod b/go.mod index 26943e6d5..db383a235 100644 --- a/go.mod +++ b/go.mod @@ -46,7 +46,6 @@ require ( gopkg.in/asn1-ber.v1 v1.0.0-20170511165959-379148ca0225 // indirect gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 // indirect gopkg.in/ldap.v2 v2.5.1 - gopkg.in/yaml.v2 v2.2.2 // indirect ) go 1.12 diff --git a/go.sum b/go.sum index 2f223a324..dba1af70c 100644 --- a/go.sum +++ b/go.sum @@ -5,6 +5,7 @@ github.com/AudriusButkevicius/pfilter v0.0.0-20190627213056-c55ef6137fc6/go.mod github.com/AudriusButkevicius/recli v0.0.5 h1:xUa55PvWTHBm17T6RvjElRO3y5tALpdceH86vhzQ5wg= github.com/AudriusButkevicius/recli v0.0.5/go.mod h1:Q2E26yc6RvWWEz/TJ/goUp6yXvipYdJI096hpoaqsNs= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= +github.com/StackExchange/wmi v0.0.0-20180116203802-5d049714c4a6 h1:fLjPD/aNc3UIOA6tDi6QXUemppXK3P9BI7mr2hd6gx8= github.com/StackExchange/wmi v0.0.0-20180116203802-5d049714c4a6/go.mod h1:3eOhrUMpNV+6aFIbp5/iudMxNCF27Vw2OZgy4xEx0Fg= github.com/alecthomas/template v0.0.0-20160405071501-a0175ee3bccc/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc= github.com/alecthomas/units v0.0.0-20151022065526-2efee857e7cf/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0= @@ -37,6 +38,7 @@ github.com/getsentry/raven-go v0.2.0 h1:no+xWJRb5ZI7eE8TWgIq1jLulQiIoLG0IfYxv5JY github.com/getsentry/raven-go v0.2.0/go.mod h1:KungGk8q33+aIAZUIVWZDr2OfAEBsO49PX4NzFV5kcQ= github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as= github.com/go-logfmt/logfmt v0.3.0/go.mod h1:Qt1PoO58o5twSAckw1HlFXLmHsOX5/0LbT9GBnD5lWE= +github.com/go-ole/go-ole v1.2.1 h1:2lOsA72HgjxAuMlKpFiCbHTvu44PIVkZ5hqm3RSdI/E= github.com/go-ole/go-ole v1.2.1/go.mod h1:7FAglXiTm7HKlQRDeOQ6ZNUHidzCWXuZWq/1dTyBNF8= github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= github.com/gobwas/glob v0.0.0-20170212200151-51eb1ee00b6d h1:IngNQgbqr5ZOU0exk395Szrvkzes9Ilk1fmJfkw7d+M= diff --git a/lib/connections/quic_dial.go b/lib/connections/quic_dial.go index baf84069c..62aba5c0f 100644 --- a/lib/connections/quic_dial.go +++ b/lib/connections/quic_dial.go @@ -36,7 +36,7 @@ type quicDialer struct { tlsCfg *tls.Config } -func (d *quicDialer) Dial(id protocol.DeviceID, uri *url.URL) (internalConn, error) { +func (d *quicDialer) Dial(_ protocol.DeviceID, uri *url.URL) (internalConn, error) { uri = fixupPort(uri, config.DefaultQUICPort) addr, err := net.ResolveUDPAddr("udp", uri.Host) diff --git a/lib/connections/service.go b/lib/connections/service.go index e861562b9..167b7da01 100644 --- a/lib/connections/service.go +++ b/lib/connections/service.go @@ -231,7 +231,7 @@ func (s *service) handle(stop chan struct{}) { continue } - c.SetDeadline(time.Now().Add(20 * time.Second)) + _ = c.SetDeadline(time.Now().Add(20 * time.Second)) hello, err := protocol.ExchangeHello(c, s.model.GetHello(remoteID)) if err != nil { if protocol.IsVersionMismatch(err) { @@ -255,7 +255,7 @@ func (s *service) handle(stop chan struct{}) { c.Close() continue } - c.SetDeadline(time.Time{}) + _ = c.SetDeadline(time.Time{}) // The Model will return an error for devices that we don't want to // have a connection with for whatever reason, for example unknown devices. @@ -850,8 +850,15 @@ func (s *service) dialParallel(deviceID protocol.DeviceID, dialTargets []dialTar wg.Add(1) go func(tgt dialTarget) { conn, err := tgt.Dial() - s.setConnectionStatus(tgt.addr, err) if err == nil { + // Closes the connection on error + err = s.validateIdentity(conn, deviceID) + } + s.setConnectionStatus(tgt.addr, err) + if err != nil { + l.Debugln("dialing", deviceID, tgt.uri, "error:", err) + } else { + l.Debugln("dialing", deviceID, tgt.uri, "success:", conn) res <- conn } wg.Done() @@ -884,3 +891,36 @@ func (s *service) dialParallel(deviceID protocol.DeviceID, dialTargets []dialTar } return internalConn{}, false } + +func (s *service) validateIdentity(c internalConn, expectedID protocol.DeviceID) error { + cs := c.ConnectionState() + + // We should have received exactly one certificate from the other + // side. If we didn't, they don't have a device ID and we drop the + // connection. + certs := cs.PeerCertificates + if cl := len(certs); cl != 1 { + l.Infof("Got peer certificate list of length %d != 1 from peer at %s; protocol error", cl, c) + c.Close() + return fmt.Errorf("expected 1 certificate, got %d", cl) + } + remoteCert := certs[0] + remoteID := protocol.NewDeviceID(remoteCert.Raw) + + // The device ID should not be that of ourselves. It can happen + // though, especially in the presence of NAT hairpinning, multiple + // clients between the same NAT gateway, and global discovery. + if remoteID == s.myID { + l.Infof("Connected to myself (%s) at %s - should not happen", remoteID, c) + c.Close() + return fmt.Errorf("connected to self") + } + + // We should see the expected device ID + if !remoteID.Equals(expectedID) { + c.Close() + return fmt.Errorf("unexpected device id, expected %s got %s", expectedID, remoteID) + } + + return nil +} diff --git a/lib/connections/structs.go b/lib/connections/structs.go index 00d5eaff3..a780b3602 100644 --- a/lib/connections/structs.go +++ b/lib/connections/structs.go @@ -215,11 +215,5 @@ type dialTarget struct { func (t dialTarget) Dial() (internalConn, error) { l.Debugln("dialing", t.deviceID, t.uri, "prio", t.priority) - conn, err := t.dialer.Dial(t.deviceID, t.uri) - if err != nil { - l.Debugln("dialing", t.deviceID, t.uri, "error:", err) - } else { - l.Debugln("dialing", t.deviceID, t.uri, "success:", conn) - } - return conn, err + return t.dialer.Dial(t.deviceID, t.uri) } diff --git a/lib/connections/tcp_dial.go b/lib/connections/tcp_dial.go index db083e45f..d1db852af 100644 --- a/lib/connections/tcp_dial.go +++ b/lib/connections/tcp_dial.go @@ -30,7 +30,7 @@ type tcpDialer struct { tlsCfg *tls.Config } -func (d *tcpDialer) Dial(id protocol.DeviceID, uri *url.URL) (internalConn, error) { +func (d *tcpDialer) Dial(_ protocol.DeviceID, uri *url.URL) (internalConn, error) { uri = fixupPort(uri, config.DefaultTCPPort) conn, err := dialer.DialTimeout(uri.Scheme, uri.Host, 10*time.Second)