lib/connections: Refactor connection loop (#7177)

This breaks out some methods from the connection loop to make it simpler
to manage and understand.

Some slight simplifications to remove the `seen` variable (we can filter
`nextDial` based on times are in the future or not, so we don't need to
track `seen`) and adding a minimum loop interval (5s) in case some
dialer goes haywire and requests a 0s redial interval or such.

Otherwise no significant behavioral changes.
This commit is contained in:
Jakob Borg 2020-12-21 16:40:13 +01:00 committed by GitHub
parent a744dee94c
commit 05f25e600e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -10,6 +10,7 @@ import (
"context"
"crypto/tls"
"fmt"
"math"
"net"
"net/url"
"sort"
@ -56,6 +57,9 @@ const (
perDeviceWarningIntv = 15 * time.Minute
tlsHandshakeTimeout = 10 * time.Second
minConnectionReplaceAge = 10 * time.Second
minConnectionLoopSleep = 5 * time.Second
stdConnectionLoopSleep = time.Minute
worstDialerPriority = math.MaxInt32
)
// From go/src/crypto/tls/cipher_suites.go
@ -342,159 +346,52 @@ func (s *service) handle(ctx context.Context) error {
}
func (s *service) connect(ctx context.Context) error {
nextDial := make(map[string]time.Time)
// Map of when to earliest dial each given device + address again
nextDialAt := make(map[string]time.Time)
// Used as delay for the first few connection attempts, increases
// exponentially
// Used as delay for the first few connection attempts (adjusted up to
// minConnectionLoopSleep), increased exponentially until it reaches
// stdConnectionLoopSleep, at which time the normal sleep mechanism
// kicks in.
initialRampup := time.Second
// Calculated from actual dialers reconnectInterval
var sleep time.Duration
for {
cfg := s.cfg.RawCopy()
bestDialerPriority := s.bestDialerPriority(cfg)
isInitialRampup := initialRampup < stdConnectionLoopSleep
bestDialerPrio := 1<<31 - 1 // worse prio won't build on 32 bit
for _, df := range dialers {
if df.Valid(cfg) != nil {
continue
}
if prio := df.Priority(); prio < bestDialerPrio {
bestDialerPrio = prio
}
l.Debugln("Connection loop")
if isInitialRampup {
l.Debugln("Connection loop in initial rampup")
}
l.Debugln("Reconnect loop")
// Used for consistency throughout this loop run, as time passes
// while we try connections etc.
now := time.Now()
var seen []string
for _, deviceCfg := range cfg.Devices {
select {
case <-ctx.Done():
return ctx.Err()
default:
}
// Attempt to dial all devices that are unconnected or can be connection-upgraded
s.dialDevices(ctx, now, cfg, bestDialerPriority, nextDialAt, isInitialRampup)
deviceID := deviceCfg.DeviceID
if deviceID == s.myID {
continue
}
if deviceCfg.Paused {
continue
}
ct, connected := s.model.Connection(deviceID)
if connected && ct.Priority() == bestDialerPrio {
// Things are already as good as they can get.
continue
}
var addrs []string
for _, addr := range deviceCfg.Addresses {
if addr == "dynamic" {
if s.discoverer != nil {
if t, err := s.discoverer.Lookup(ctx, deviceID); err == nil {
addrs = append(addrs, t...)
}
}
} else {
addrs = append(addrs, addr)
}
}
addrs = util.UniqueTrimmedStrings(addrs)
l.Debugln("Reconnect loop for", deviceID, addrs)
dialTargets := make([]dialTarget, 0)
for _, addr := range addrs {
// Use a special key that is more than just the address, as you might have two devices connected to the same relay
nextDialKey := deviceID.String() + "/" + addr
seen = append(seen, nextDialKey)
nextDialAt, ok := nextDial[nextDialKey]
if ok && initialRampup >= sleep && nextDialAt.After(now) {
l.Debugf("Not dialing %s via %v as sleep is %v, next dial is at %s and current time is %s", deviceID, addr, sleep, nextDialAt, now)
continue
}
// If we fail at any step before actually getting the dialer
// retry in a minute
nextDial[nextDialKey] = now.Add(time.Minute)
uri, err := url.Parse(addr)
if err != nil {
s.setConnectionStatus(addr, err)
l.Infof("Parsing dialer address %s: %v", addr, err)
continue
}
if len(deviceCfg.AllowedNetworks) > 0 {
if !IsAllowedNetwork(uri.Host, deviceCfg.AllowedNetworks) {
s.setConnectionStatus(addr, errors.New("network disallowed"))
l.Debugln("Network for", uri, "is disallowed")
continue
}
}
dialerFactory, err := getDialerFactory(cfg, uri)
if err != nil {
s.setConnectionStatus(addr, err)
}
if errors.Is(err, errUnsupported) {
l.Debugf("Dialer for %v: %v", uri, err)
continue
} else if err != nil {
l.Infof("Dialer for %v: %v", uri, err)
continue
}
priority := dialerFactory.Priority()
if connected && priority >= ct.Priority() {
l.Debugf("Not dialing using %s as priority is less than current connection (%d >= %d)", dialerFactory, dialerFactory.Priority(), ct.Priority())
continue
}
dialer := dialerFactory.New(s.cfg.Options(), s.tlsCfg)
nextDial[nextDialKey] = now.Add(dialer.RedialFrequency())
// For LAN addresses, increase the priority so that we
// try these first.
switch {
case dialerFactory.AlwaysWAN():
// Do nothing.
case s.isLANHost(uri.Host):
priority -= 1
}
dialTargets = append(dialTargets, dialTarget{
addr: addr,
dialer: dialer,
priority: priority,
deviceID: deviceID,
uri: uri,
})
}
conn, ok := s.dialParallel(ctx, deviceCfg.DeviceID, dialTargets)
if ok {
s.conns <- conn
}
}
nextDial, sleep = filterAndFindSleepDuration(nextDial, seen, now)
if initialRampup < sleep {
l.Debugln("initial rampup; sleep", initialRampup, "and update to", initialRampup*2)
var sleep time.Duration
if isInitialRampup {
// We are in the initial rampup time, so we slowly, statically
// increase the sleep time.
sleep = initialRampup
initialRampup *= 2
} else {
l.Debugln("sleep until next dial", sleep)
// The sleep time is until the next dial scheduled in nextDialAt,
// clamped by stdConnectionLoopSleep as we don't want to sleep too
// long (config changes might happen).
sleep = filterAndFindSleepDuration(nextDialAt, now)
}
// ... while making sure not to loop too quickly either.
if sleep < minConnectionLoopSleep {
sleep = minConnectionLoopSleep
}
l.Debugln("Next connection loop in", sleep)
select {
case <-time.After(sleep):
case <-ctx.Done():
@ -503,6 +400,145 @@ func (s *service) connect(ctx context.Context) error {
}
}
func (s *service) bestDialerPriority(cfg config.Configuration) int {
bestDialerPriority := worstDialerPriority
for _, df := range dialers {
if df.Valid(cfg) != nil {
continue
}
if prio := df.Priority(); prio < bestDialerPriority {
bestDialerPriority = prio
}
}
return bestDialerPriority
}
func (s *service) dialDevices(ctx context.Context, now time.Time, cfg config.Configuration, bestDialerPriority int, nextDialAt map[string]time.Time, initial bool) {
for _, deviceCfg := range cfg.Devices {
// Don't attempt to connect to ourselves...
if deviceCfg.DeviceID == s.myID {
continue
}
// Don't attempt to connect to paused devices...
if deviceCfg.Paused {
continue
}
// See if we are already connected and, if so, what our cutoff is
// for dialer priority.
priorityCutoff := worstDialerPriority
connection, connected := s.model.Connection(deviceCfg.DeviceID)
if connected {
priorityCutoff = connection.Priority()
if bestDialerPriority >= priorityCutoff {
// Our best dialer is not any better than what we already
// have, so nothing to do here.
continue
}
}
dialTargets := s.resolveDialTargets(ctx, now, cfg, deviceCfg, nextDialAt, initial, priorityCutoff)
if conn, ok := s.dialParallel(ctx, deviceCfg.DeviceID, dialTargets); ok {
s.conns <- conn
}
}
}
func (s *service) resolveDialTargets(ctx context.Context, now time.Time, cfg config.Configuration, deviceCfg config.DeviceConfiguration, nextDialAt map[string]time.Time, initial bool, priorityCutoff int) []dialTarget {
deviceID := deviceCfg.DeviceID
addrs := s.resolveDeviceAddrs(ctx, deviceCfg)
l.Debugln("Resolved device", deviceID, "addresses:", addrs)
dialTargets := make([]dialTarget, 0, len(addrs))
for _, addr := range addrs {
// Use a special key that is more than just the address, as you
// might have two devices connected to the same relay
nextDialKey := deviceID.String() + "/" + addr
when, ok := nextDialAt[nextDialKey]
if ok && !initial && when.After(now) {
l.Debugf("Not dialing %s via %v as it's not time yet", deviceID, addr)
continue
}
// If we fail at any step before actually getting the dialer
// retry in a minute
nextDialAt[nextDialKey] = now.Add(time.Minute)
uri, err := url.Parse(addr)
if err != nil {
s.setConnectionStatus(addr, err)
l.Infof("Parsing dialer address %s: %v", addr, err)
continue
}
if len(deviceCfg.AllowedNetworks) > 0 {
if !IsAllowedNetwork(uri.Host, deviceCfg.AllowedNetworks) {
s.setConnectionStatus(addr, errors.New("network disallowed"))
l.Debugln("Network for", uri, "is disallowed")
continue
}
}
dialerFactory, err := getDialerFactory(cfg, uri)
if err != nil {
s.setConnectionStatus(addr, err)
}
if errors.Is(err, errUnsupported) {
l.Debugf("Dialer for %v: %v", uri, err)
continue
} else if err != nil {
l.Infof("Dialer for %v: %v", uri, err)
continue
}
priority := dialerFactory.Priority()
if priority >= priorityCutoff {
l.Debugf("Not dialing using %s as priority is not better than current connection (%d >= %d)", dialerFactory, dialerFactory.Priority(), priorityCutoff)
continue
}
dialer := dialerFactory.New(s.cfg.Options(), s.tlsCfg)
nextDialAt[nextDialKey] = now.Add(dialer.RedialFrequency())
// For LAN addresses, increase the priority so that we
// try these first.
switch {
case dialerFactory.AlwaysWAN():
// Do nothing.
case s.isLANHost(uri.Host):
priority--
}
dialTargets = append(dialTargets, dialTarget{
addr: addr,
dialer: dialer,
priority: priority,
deviceID: deviceID,
uri: uri,
})
}
return dialTargets
}
func (s *service) resolveDeviceAddrs(ctx context.Context, cfg config.DeviceConfiguration) []string {
var addrs []string
for _, addr := range cfg.Addresses {
if addr == "dynamic" {
if s.discoverer != nil {
if t, err := s.discoverer.Lookup(ctx, cfg.DeviceID); err == nil {
addrs = append(addrs, t...)
}
}
} else {
addrs = append(addrs, addr)
}
}
return util.UniqueTrimmedStrings(addrs)
}
func (s *service) isLANHost(host string) bool {
// Probably we are called with an ip:port combo which we can resolve as
// a TCP address.
@ -778,24 +814,19 @@ func getListenerFactory(cfg config.Configuration, uri *url.URL) (listenerFactory
return listenerFactory, nil
}
func filterAndFindSleepDuration(nextDial map[string]time.Time, seen []string, now time.Time) (map[string]time.Time, time.Duration) {
newNextDial := make(map[string]time.Time)
for _, addr := range seen {
nextDialAt, ok := nextDial[addr]
if ok {
newNextDial[addr] = nextDialAt
func filterAndFindSleepDuration(nextDialAt map[string]time.Time, now time.Time) time.Duration {
sleep := stdConnectionLoopSleep
for key, next := range nextDialAt {
if next.Before(now) {
// Expired entry, address was not seen in last pass(es)
delete(nextDialAt, key)
continue
}
if cur := next.Sub(now); cur < sleep {
sleep = cur
}
}
min := time.Minute
for _, next := range newNextDial {
cur := next.Sub(now)
if cur < min {
min = cur
}
}
return newNextDial, min
return sleep
}
func urlsToStrings(urls []*url.URL) []string {