lib/connections: Refactor address listing into connection service

This commit is contained in:
Audrius Butkevicius 2016-03-25 07:35:18 +00:00 committed by Jakob Borg
parent 690837dbe5
commit 29913dd1e4
9 changed files with 173 additions and 202 deletions

View File

@ -1,127 +0,0 @@
// Copyright (C) 2015 The Syncthing Authors.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.
package main
import (
"fmt"
"net"
"net/url"
"github.com/syncthing/syncthing/lib/config"
)
type addressLister struct {
upnpService *upnpService
cfg *config.Wrapper
}
func newAddressLister(upnpService *upnpService, cfg *config.Wrapper) *addressLister {
return &addressLister{
upnpService: upnpService,
cfg: cfg,
}
}
// ExternalAddresses returns a list of addresses that are our best guess for
// where we are reachable from the outside. As a special case, we may return
// one or more addresses with an empty IP address (0.0.0.0 or ::) and just
// port number - this means that the outside address of a NAT gateway should
// be substituted.
func (e *addressLister) ExternalAddresses() []string {
return e.addresses(false)
}
// AllAddresses returns a list of addresses that are our best guess for where
// we are reachable from the local network. Same conditions as
// ExternalAddresses, but private IPv4 addresses are included.
func (e *addressLister) AllAddresses() []string {
return e.addresses(true)
}
func (e *addressLister) addresses(includePrivateIPV4 bool) []string {
var addrs []string
// Grab our listen addresses from the config. Unspecified ones are passed
// on verbatim (to be interpreted by a global discovery server or local
// discovery peer). Public addresses are passed on verbatim. Private
// addresses are filtered.
for _, addrStr := range e.cfg.Options().ListenAddress {
addrURL, err := url.Parse(addrStr)
if err != nil {
l.Infoln("Listen address", addrStr, "is invalid:", err)
continue
}
addr, err := net.ResolveTCPAddr("tcp", addrURL.Host)
if err != nil {
l.Infoln("Listen address", addrStr, "is invalid:", err)
continue
}
if addr.IP == nil || addr.IP.IsUnspecified() {
// Address like 0.0.0.0:22000 or [::]:22000 or :22000; include as is.
addrs = append(addrs, tcpAddr(addr.String()))
} else if isPublicIPv4(addr.IP) || isPublicIPv6(addr.IP) {
// A public address; include as is.
addrs = append(addrs, tcpAddr(addr.String()))
} else if includePrivateIPV4 && addr.IP.To4().IsGlobalUnicast() {
// A private IPv4 address.
addrs = append(addrs, tcpAddr(addr.String()))
}
}
// Get an external port mapping from the upnpService, if it has one. If so,
// add it as another unspecified address.
if e.upnpService != nil {
if port := e.upnpService.ExternalPort(); port != 0 {
addrs = append(addrs, fmt.Sprintf("tcp://:%d", port))
}
}
return addrs
}
func isPublicIPv4(ip net.IP) bool {
ip = ip.To4()
if ip == nil {
// Not an IPv4 address (IPv6)
return false
}
// IsGlobalUnicast below only checks that it's not link local or
// multicast, and we want to exclude private (NAT:ed) addresses as well.
rfc1918 := []net.IPNet{
{IP: net.IP{10, 0, 0, 0}, Mask: net.IPMask{255, 0, 0, 0}},
{IP: net.IP{172, 16, 0, 0}, Mask: net.IPMask{255, 240, 0, 0}},
{IP: net.IP{192, 168, 0, 0}, Mask: net.IPMask{255, 255, 0, 0}},
}
for _, n := range rfc1918 {
if n.Contains(ip) {
return false
}
}
return ip.IsGlobalUnicast()
}
func isPublicIPv6(ip net.IP) bool {
if ip.To4() != nil {
// Not an IPv6 address (IPv4)
// (To16() returns a v6 mapped v4 address so can't be used to check
// that it's an actual v6 address)
return false
}
return ip.IsGlobalUnicast()
}
func tcpAddr(host string) string {
u := url.URL{
Scheme: "tcp",
Host: host,
}
return u.String()
}

View File

@ -39,6 +39,7 @@ import (
"github.com/syncthing/syncthing/lib/sync" "github.com/syncthing/syncthing/lib/sync"
"github.com/syncthing/syncthing/lib/tlsutil" "github.com/syncthing/syncthing/lib/tlsutil"
"github.com/syncthing/syncthing/lib/upgrade" "github.com/syncthing/syncthing/lib/upgrade"
"github.com/syncthing/syncthing/lib/util"
"github.com/vitrun/qart/qr" "github.com/vitrun/qart/qr"
"golang.org/x/crypto/bcrypt" "golang.org/x/crypto/bcrypt"
) )
@ -733,7 +734,7 @@ func (s *apiService) postSystemConfig(w http.ResponseWriter, r *http.Request) {
if curAcc := s.cfg.Options().URAccepted; to.Options.URAccepted > curAcc { if curAcc := s.cfg.Options().URAccepted; to.Options.URAccepted > curAcc {
// UR was enabled // UR was enabled
to.Options.URAccepted = usageReportVersion to.Options.URAccepted = usageReportVersion
to.Options.URUniqueID = randomString(8) to.Options.URUniqueID = util.RandomString(8)
} else if to.Options.URAccepted < curAcc { } else if to.Options.URAccepted < curAcc {
// UR was disabled // UR was disabled
to.Options.URAccepted = -1 to.Options.URAccepted = -1

View File

@ -17,6 +17,7 @@ import (
"github.com/syncthing/syncthing/lib/config" "github.com/syncthing/syncthing/lib/config"
"github.com/syncthing/syncthing/lib/events" "github.com/syncthing/syncthing/lib/events"
"github.com/syncthing/syncthing/lib/sync" "github.com/syncthing/syncthing/lib/sync"
"github.com/syncthing/syncthing/lib/util"
"golang.org/x/crypto/bcrypt" "golang.org/x/crypto/bcrypt"
) )
@ -90,7 +91,7 @@ func basicAuthAndSessionMiddleware(cookieName string, cfg config.GUIConfiguratio
return return
} }
sessionid := randomString(32) sessionid := util.RandomString(32)
sessionsMut.Lock() sessionsMut.Lock()
sessions[sessionid] = true sessions[sessionid] = true
sessionsMut.Unlock() sessionsMut.Unlock()

View File

@ -16,6 +16,7 @@ import (
"github.com/syncthing/syncthing/lib/config" "github.com/syncthing/syncthing/lib/config"
"github.com/syncthing/syncthing/lib/osutil" "github.com/syncthing/syncthing/lib/osutil"
"github.com/syncthing/syncthing/lib/sync" "github.com/syncthing/syncthing/lib/sync"
"github.com/syncthing/syncthing/lib/util"
) )
// csrfTokens is a list of valid tokens. It is sorted so that the most // csrfTokens is a list of valid tokens. It is sorted so that the most
@ -97,7 +98,7 @@ func validCsrfToken(token string) bool {
} }
func newCsrfToken() string { func newCsrfToken() string {
token := randomString(32) token := util.RandomString(32)
csrfMut.Lock() csrfMut.Lock()
csrfTokens = append([]string{token}, csrfTokens...) csrfTokens = append([]string{token}, csrfTokens...)

View File

@ -44,6 +44,8 @@ import (
"github.com/syncthing/syncthing/lib/symlinks" "github.com/syncthing/syncthing/lib/symlinks"
"github.com/syncthing/syncthing/lib/tlsutil" "github.com/syncthing/syncthing/lib/tlsutil"
"github.com/syncthing/syncthing/lib/upgrade" "github.com/syncthing/syncthing/lib/upgrade"
"github.com/syncthing/syncthing/lib/upnp"
"github.com/syncthing/syncthing/lib/util"
"github.com/thejerf/suture" "github.com/thejerf/suture"
) )
@ -558,7 +560,7 @@ func syncthingMain(runtimeOptions RuntimeOptions) {
// We reinitialize the predictable RNG with our device ID, to get a // We reinitialize the predictable RNG with our device ID, to get a
// sequence that is always the same but unique to this syncthing instance. // sequence that is always the same but unique to this syncthing instance.
predictableRandom.Seed(seedFromBytes(cert.Certificate[0])) util.PredictableRandom.Seed(util.SeedFromBytes(cert.Certificate[0]))
myID = protocol.NewDeviceID(cert.Certificate[0]) myID = protocol.NewDeviceID(cert.Certificate[0])
l.SetPrefix(fmt.Sprintf("[%s] ", myID.String()[:5])) l.SetPrefix(fmt.Sprintf("[%s] ", myID.String()[:5]))
@ -720,21 +722,11 @@ func syncthingMain(runtimeOptions RuntimeOptions) {
l.Fatalln("Bad listen address:", err) l.Fatalln("Bad listen address:", err)
} }
// The externalAddr tracks our external addresses for discovery purposes.
var addrList *addressLister
// Start UPnP // Start UPnP
var upnpService *upnp.Service
if opts.UPnPEnabled { if opts.UPnPEnabled {
upnpService := newUPnPService(cfg, addr.Port) upnpService = upnp.NewUPnPService(cfg, addr.Port)
mainService.Add(upnpService) mainService.Add(upnpService)
// The external address tracker needs to know about the UPnP service
// so it can check for an external mapped port.
addrList = newAddressLister(upnpService, cfg)
} else {
addrList = newAddressLister(nil, cfg)
} }
// Start relay management // Start relay management
@ -750,10 +742,15 @@ func syncthingMain(runtimeOptions RuntimeOptions) {
cachedDiscovery := discover.NewCachingMux() cachedDiscovery := discover.NewCachingMux()
mainService.Add(cachedDiscovery) mainService.Add(cachedDiscovery)
// Start connection management
connectionService := connections.NewConnectionService(cfg, myID, m, tlsCfg, cachedDiscovery, upnpService, relayService, bepProtocolName, tlsDefaultCommonName, lans)
mainService.Add(connectionService)
if cfg.Options().GlobalAnnEnabled { if cfg.Options().GlobalAnnEnabled {
for _, srv := range cfg.GlobalDiscoveryServers() { for _, srv := range cfg.GlobalDiscoveryServers() {
l.Infoln("Using discovery server", srv) l.Infoln("Using discovery server", srv)
gd, err := discover.NewGlobal(srv, cert, addrList, relayService) gd, err := discover.NewGlobal(srv, cert, connectionService, relayService)
if err != nil { if err != nil {
l.Warnln("Global discovery:", err) l.Warnln("Global discovery:", err)
continue continue
@ -768,14 +765,14 @@ func syncthingMain(runtimeOptions RuntimeOptions) {
if cfg.Options().LocalAnnEnabled { if cfg.Options().LocalAnnEnabled {
// v4 broadcasts // v4 broadcasts
bcd, err := discover.NewLocal(myID, fmt.Sprintf(":%d", cfg.Options().LocalAnnPort), addrList, relayService) bcd, err := discover.NewLocal(myID, fmt.Sprintf(":%d", cfg.Options().LocalAnnPort), connectionService, relayService)
if err != nil { if err != nil {
l.Warnln("IPv4 local discovery:", err) l.Warnln("IPv4 local discovery:", err)
} else { } else {
cachedDiscovery.Add(bcd, 0, 0, ipv4LocalDiscoveryPriority) cachedDiscovery.Add(bcd, 0, 0, ipv4LocalDiscoveryPriority)
} }
// v6 multicasts // v6 multicasts
mcd, err := discover.NewLocal(myID, cfg.Options().LocalAnnMCAddr, addrList, relayService) mcd, err := discover.NewLocal(myID, cfg.Options().LocalAnnMCAddr, connectionService, relayService)
if err != nil { if err != nil {
l.Warnln("IPv6 local discovery:", err) l.Warnln("IPv6 local discovery:", err)
} else { } else {
@ -787,11 +784,6 @@ func syncthingMain(runtimeOptions RuntimeOptions) {
setupGUI(mainService, cfg, m, apiSub, cachedDiscovery, relayService, errors, systemLog, runtimeOptions) setupGUI(mainService, cfg, m, apiSub, cachedDiscovery, relayService, errors, systemLog, runtimeOptions)
// Start connection management
connectionService := connections.NewConnectionService(cfg, myID, m, tlsCfg, cachedDiscovery, relayService, bepProtocolName, tlsDefaultCommonName, lans)
mainService.Add(connectionService)
if runtimeOptions.cpuProfile { if runtimeOptions.cpuProfile {
f, err := os.Create(fmt.Sprintf("cpu-%d.pprof", os.Getpid())) f, err := os.Create(fmt.Sprintf("cpu-%d.pprof", os.Getpid()))
if err != nil { if err != nil {
@ -816,7 +808,7 @@ func syncthingMain(runtimeOptions RuntimeOptions) {
if opts.URUniqueID == "" { if opts.URUniqueID == "" {
// Previously the ID was generated from the node ID. We now need // Previously the ID was generated from the node ID. We now need
// to generate a new one. // to generate a new one.
opts.URUniqueID = randomString(8) opts.URUniqueID = util.RandomString(8)
cfg.SetOptions(opts) cfg.SetOptions(opts)
cfg.Save() cfg.Save()
} }

View File

@ -23,6 +23,7 @@ import (
"github.com/syncthing/syncthing/lib/protocol" "github.com/syncthing/syncthing/lib/protocol"
"github.com/syncthing/syncthing/lib/relay" "github.com/syncthing/syncthing/lib/relay"
"github.com/syncthing/syncthing/lib/relay/client" "github.com/syncthing/syncthing/lib/relay/client"
"github.com/syncthing/syncthing/lib/upnp"
"github.com/thejerf/suture" "github.com/thejerf/suture"
) )
@ -42,9 +43,9 @@ type Model interface {
IsPaused(remoteID protocol.DeviceID) bool IsPaused(remoteID protocol.DeviceID) bool
} }
// The connection connectionService listens on TLS and dials configured unconnected // Service listens on TLS and dials configured unconnected devices. Successful
// devices. Successful connections are handed to the model. // connections are handed to the model.
type connectionService struct { type Service struct {
*suture.Supervisor *suture.Supervisor
cfg *config.Wrapper cfg *config.Wrapper
myID protocol.DeviceID myID protocol.DeviceID
@ -52,6 +53,7 @@ type connectionService struct {
tlsCfg *tls.Config tlsCfg *tls.Config
discoverer discover.Finder discoverer discover.Finder
conns chan model.IntermediateConnection conns chan model.IntermediateConnection
upnpService *upnp.Service
relayService relay.Service relayService relay.Service
bepProtocolName string bepProtocolName string
tlsDefaultCommonName string tlsDefaultCommonName string
@ -66,15 +68,16 @@ type connectionService struct {
relaysEnabled bool relaysEnabled bool
} }
func NewConnectionService(cfg *config.Wrapper, myID protocol.DeviceID, mdl Model, tlsCfg *tls.Config, discoverer discover.Finder, relayService relay.Service, func NewConnectionService(cfg *config.Wrapper, myID protocol.DeviceID, mdl Model, tlsCfg *tls.Config, discoverer discover.Finder, upnpService *upnp.Service,
bepProtocolName string, tlsDefaultCommonName string, lans []*net.IPNet) suture.Service { relayService relay.Service, bepProtocolName string, tlsDefaultCommonName string, lans []*net.IPNet) *Service {
service := &connectionService{ service := &Service{
Supervisor: suture.NewSimple("connectionService"), Supervisor: suture.NewSimple("connections.Service"),
cfg: cfg, cfg: cfg,
myID: myID, myID: myID,
model: mdl, model: mdl,
tlsCfg: tlsCfg, tlsCfg: tlsCfg,
discoverer: discoverer, discoverer: discoverer,
upnpService: upnpService,
relayService: relayService, relayService: relayService,
conns: make(chan model.IntermediateConnection), conns: make(chan model.IntermediateConnection),
bepProtocolName: bepProtocolName, bepProtocolName: bepProtocolName,
@ -100,7 +103,7 @@ func NewConnectionService(cfg *config.Wrapper, myID protocol.DeviceID, mdl Model
// to handle incoming connections, one routine to periodically attempt // to handle incoming connections, one routine to periodically attempt
// outgoing connections, one routine to the the common handling // outgoing connections, one routine to the the common handling
// regardless of whether the connection was incoming or outgoing. // regardless of whether the connection was incoming or outgoing.
// Furthermore, a relay connectionService which handles incoming requests to connect // Furthermore, a relay service which handles incoming requests to connect
// via the relays. // via the relays.
// //
// TODO: Clean shutdown, and/or handling config changes on the fly. We // TODO: Clean shutdown, and/or handling config changes on the fly. We
@ -137,7 +140,7 @@ func NewConnectionService(cfg *config.Wrapper, myID protocol.DeviceID, mdl Model
return service return service
} }
func (s *connectionService) handle() { func (s *Service) handle() {
next: next:
for c := range s.conns { for c := range s.conns {
cs := c.Conn.ConnectionState() cs := c.Conn.ConnectionState()
@ -259,7 +262,7 @@ next:
} }
} }
func (s *connectionService) connect() { func (s *Service) connect() {
delay := time.Second delay := time.Second
for { for {
l.Debugln("Reconnect loop") l.Debugln("Reconnect loop")
@ -342,7 +345,7 @@ func (s *connectionService) connect() {
} }
} }
func (s *connectionService) resolveAddresses(deviceID protocol.DeviceID, inAddrs []string) (addrs []string, relays []discover.Relay) { func (s *Service) resolveAddresses(deviceID protocol.DeviceID, inAddrs []string) (addrs []string, relays []discover.Relay) {
for _, addr := range inAddrs { for _, addr := range inAddrs {
if addr == "dynamic" { if addr == "dynamic" {
if s.discoverer != nil { if s.discoverer != nil {
@ -358,7 +361,7 @@ func (s *connectionService) resolveAddresses(deviceID protocol.DeviceID, inAddrs
return return
} }
func (s *connectionService) connectDirect(deviceID protocol.DeviceID, addr string) *tls.Conn { func (s *Service) connectDirect(deviceID protocol.DeviceID, addr string) *tls.Conn {
uri, err := url.Parse(addr) uri, err := url.Parse(addr)
if err != nil { if err != nil {
l.Infoln("Failed to parse connection url:", addr, err) l.Infoln("Failed to parse connection url:", addr, err)
@ -381,7 +384,7 @@ func (s *connectionService) connectDirect(deviceID protocol.DeviceID, addr strin
return conn return conn
} }
func (s *connectionService) connectViaRelay(deviceID protocol.DeviceID, addr discover.Relay) *tls.Conn { func (s *Service) connectViaRelay(deviceID protocol.DeviceID, addr discover.Relay) *tls.Conn {
uri, err := url.Parse(addr.URL) uri, err := url.Parse(addr.URL)
if err != nil { if err != nil {
l.Infoln("Failed to parse relay connection url:", addr, err) l.Infoln("Failed to parse relay connection url:", addr, err)
@ -420,7 +423,7 @@ func (s *connectionService) connectViaRelay(deviceID protocol.DeviceID, addr dis
return tc return tc
} }
func (s *connectionService) acceptRelayConns() { func (s *Service) acceptRelayConns() {
for { for {
conn := s.relayService.Accept() conn := s.relayService.Accept()
s.conns <- model.IntermediateConnection{ s.conns <- model.IntermediateConnection{
@ -430,7 +433,7 @@ func (s *connectionService) acceptRelayConns() {
} }
} }
func (s *connectionService) shouldLimit(addr net.Addr) bool { func (s *Service) shouldLimit(addr net.Addr) bool {
if s.cfg.Options().LimitBandwidthInLan { if s.cfg.Options().LimitBandwidthInLan {
return true return true
} }
@ -447,11 +450,11 @@ func (s *connectionService) shouldLimit(addr net.Addr) bool {
return !tcpaddr.IP.IsLoopback() return !tcpaddr.IP.IsLoopback()
} }
func (s *connectionService) VerifyConfiguration(from, to config.Configuration) error { func (s *Service) VerifyConfiguration(from, to config.Configuration) error {
return nil return nil
} }
func (s *connectionService) CommitConfiguration(from, to config.Configuration) bool { func (s *Service) CommitConfiguration(from, to config.Configuration) bool {
s.mut.Lock() s.mut.Lock()
s.relaysEnabled = to.Options.RelaysEnabled s.relaysEnabled = to.Options.RelaysEnabled
s.mut.Unlock() s.mut.Unlock()
@ -472,6 +475,106 @@ func (s *connectionService) CommitConfiguration(from, to config.Configuration) b
return true return true
} }
// ExternalAddresses returns a list of addresses that are our best guess for
// where we are reachable from the outside. As a special case, we may return
// one or more addresses with an empty IP address (0.0.0.0 or ::) and just
// port number - this means that the outside address of a NAT gateway should
// be substituted.
func (s *Service) ExternalAddresses() []string {
return s.addresses(false)
}
// AllAddresses returns a list of addresses that are our best guess for where
// we are reachable from the local network. Same conditions as
// ExternalAddresses, but private IPv4 addresses are included.
func (s *Service) AllAddresses() []string {
return s.addresses(true)
}
func (s *Service) addresses(includePrivateIPV4 bool) []string {
var addrs []string
// Grab our listen addresses from the config. Unspecified ones are passed
// on verbatim (to be interpreted by a global discovery server or local
// discovery peer). Public addresses are passed on verbatim. Private
// addresses are filtered.
for _, addrStr := range s.cfg.Options().ListenAddress {
addrURL, err := url.Parse(addrStr)
if err != nil {
l.Infoln("Listen address", addrStr, "is invalid:", err)
continue
}
addr, err := net.ResolveTCPAddr("tcp", addrURL.Host)
if err != nil {
l.Infoln("Listen address", addrStr, "is invalid:", err)
continue
}
if addr.IP == nil || addr.IP.IsUnspecified() {
// Address like 0.0.0.0:22000 or [::]:22000 or :22000; include as is.
addrs = append(addrs, tcpAddr(addr.String()))
} else if isPublicIPv4(addr.IP) || isPublicIPv6(addr.IP) {
// A public address; include as is.
addrs = append(addrs, tcpAddr(addr.String()))
} else if includePrivateIPV4 && addr.IP.To4().IsGlobalUnicast() {
// A private IPv4 address.
addrs = append(addrs, tcpAddr(addr.String()))
}
}
// Get an external port mapping from the upnpService, if it has one. If so,
// add it as another unspecified address.
if s.upnpService != nil {
if port := s.upnpService.ExternalPort(); port != 0 {
addrs = append(addrs, fmt.Sprintf("tcp://:%d", port))
}
}
return addrs
}
func isPublicIPv4(ip net.IP) bool {
ip = ip.To4()
if ip == nil {
// Not an IPv4 address (IPv6)
return false
}
// IsGlobalUnicast below only checks that it's not link local or
// multicast, and we want to exclude private (NAT:ed) addresses as well.
rfc1918 := []net.IPNet{
{IP: net.IP{10, 0, 0, 0}, Mask: net.IPMask{255, 0, 0, 0}},
{IP: net.IP{172, 16, 0, 0}, Mask: net.IPMask{255, 240, 0, 0}},
{IP: net.IP{192, 168, 0, 0}, Mask: net.IPMask{255, 255, 0, 0}},
}
for _, n := range rfc1918 {
if n.Contains(ip) {
return false
}
}
return ip.IsGlobalUnicast()
}
func isPublicIPv6(ip net.IP) bool {
if ip.To4() != nil {
// Not an IPv6 address (IPv4)
// (To16() returns a v6 mapped v4 address so can't be used to check
// that it's an actual v6 address)
return false
}
return ip.IsGlobalUnicast()
}
func tcpAddr(host string) string {
u := url.URL{
Scheme: "tcp",
Host: host,
}
return u.String()
}
// serviceFunc wraps a function to create a suture.Service without stop // serviceFunc wraps a function to create a suture.Service without stop
// functionality. // functionality.
type serviceFunc func() type serviceFunc func()

View File

@ -4,7 +4,7 @@
// License, v. 2.0. If a copy of the MPL was not distributed with this file, // License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/. // You can obtain one at http://mozilla.org/MPL/2.0/.
package main package upnp
import ( import (
"fmt" "fmt"
@ -13,12 +13,12 @@ import (
"github.com/syncthing/syncthing/lib/config" "github.com/syncthing/syncthing/lib/config"
"github.com/syncthing/syncthing/lib/events" "github.com/syncthing/syncthing/lib/events"
"github.com/syncthing/syncthing/lib/sync" "github.com/syncthing/syncthing/lib/sync"
"github.com/syncthing/syncthing/lib/upnp" "github.com/syncthing/syncthing/lib/util"
) )
// The UPnP service runs a loop for discovery of IGDs (Internet Gateway // Service runs a loop for discovery of IGDs (Internet Gateway Devices) and
// Devices) and setup/renewal of a port mapping. // setup/renewal of a port mapping.
type upnpService struct { type Service struct {
cfg *config.Wrapper cfg *config.Wrapper
localPort int localPort int
extPort int extPort int
@ -26,20 +26,20 @@ type upnpService struct {
stop chan struct{} stop chan struct{}
} }
func newUPnPService(cfg *config.Wrapper, localPort int) *upnpService { func NewUPnPService(cfg *config.Wrapper, localPort int) *Service {
return &upnpService{ return &Service{
cfg: cfg, cfg: cfg,
localPort: localPort, localPort: localPort,
extPortMut: sync.NewMutex(), extPortMut: sync.NewMutex(),
} }
} }
func (s *upnpService) Serve() { func (s *Service) Serve() {
foundIGD := true foundIGD := true
s.stop = make(chan struct{}) s.stop = make(chan struct{})
for { for {
igds := upnp.Discover(time.Duration(s.cfg.Options().UPnPTimeoutS) * time.Second) igds := Discover(time.Duration(s.cfg.Options().UPnPTimeoutS) * time.Second)
if len(igds) > 0 { if len(igds) > 0 {
foundIGD = true foundIGD = true
s.extPortMut.Lock() s.extPortMut.Lock()
@ -72,18 +72,18 @@ func (s *upnpService) Serve() {
} }
} }
func (s *upnpService) Stop() { func (s *Service) Stop() {
close(s.stop) close(s.stop)
} }
func (s *upnpService) ExternalPort() int { func (s *Service) ExternalPort() int {
s.extPortMut.Lock() s.extPortMut.Lock()
port := s.extPort port := s.extPort
s.extPortMut.Unlock() s.extPortMut.Unlock()
return port return port
} }
func (s *upnpService) tryIGDs(igds []upnp.IGD, prevExtPort int) int { func (s *Service) tryIGDs(igds []IGD, prevExtPort int) int {
// Lets try all the IGDs we found and use the first one that works. // Lets try all the IGDs we found and use the first one that works.
// TODO: Use all of them, and sort out the resulting mess to the // TODO: Use all of them, and sort out the resulting mess to the
// discovery announcement code... // discovery announcement code...
@ -105,14 +105,14 @@ func (s *upnpService) tryIGDs(igds []upnp.IGD, prevExtPort int) int {
return 0 return 0
} }
func (s *upnpService) tryIGD(igd upnp.IGD, suggestedPort int) (int, error) { func (s *Service) tryIGD(igd IGD, suggestedPort int) (int, error) {
var err error var err error
leaseTime := s.cfg.Options().UPnPLeaseM * 60 leaseTime := s.cfg.Options().UPnPLeaseM * 60
if suggestedPort != 0 { if suggestedPort != 0 {
// First try renewing our existing mapping. // First try renewing our existing mapping.
name := fmt.Sprintf("syncthing-%d", suggestedPort) name := fmt.Sprintf("syncthing-%d", suggestedPort)
err = igd.AddPortMapping(upnp.TCP, suggestedPort, s.localPort, name, leaseTime) err = igd.AddPortMapping(TCP, suggestedPort, s.localPort, name, leaseTime)
if err == nil { if err == nil {
return suggestedPort, nil return suggestedPort, nil
} }
@ -120,9 +120,9 @@ func (s *upnpService) tryIGD(igd upnp.IGD, suggestedPort int) (int, error) {
for i := 0; i < 10; i++ { for i := 0; i < 10; i++ {
// Then try up to ten random ports. // Then try up to ten random ports.
extPort := 1024 + predictableRandom.Intn(65535-1024) extPort := 1024 + util.PredictableRandom.Intn(65535-1024)
name := fmt.Sprintf("syncthing-%d", extPort) name := fmt.Sprintf("syncthing-%d", extPort)
err = igd.AddPortMapping(upnp.TCP, extPort, s.localPort, name, leaseTime) err = igd.AddPortMapping(TCP, extPort, s.localPort, name, leaseTime)
if err == nil { if err == nil {
return extPort, nil return extPort, nil
} }

View File

@ -4,7 +4,7 @@
// License, v. 2.0. If a copy of the MPL was not distributed with this file, // License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/. // You can obtain one at http://mozilla.org/MPL/2.0/.
package main package util
import ( import (
"crypto/md5" "crypto/md5"
@ -17,19 +17,19 @@ import (
// randomCharset contains the characters that can make up a randomString(). // randomCharset contains the characters that can make up a randomString().
const randomCharset = "01234567890abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ-" const randomCharset = "01234567890abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ-"
// predictableRandom is an RNG that will always have the same sequence. It // PredictableRandom is an RNG that will always have the same sequence. It
// will be seeded with the device ID during startup, so that the sequence is // will be seeded with the device ID during startup, so that the sequence is
// predictable but varies between instances. // predictable but varies between instances.
var predictableRandom = mathRand.New(mathRand.NewSource(42)) var PredictableRandom = mathRand.New(mathRand.NewSource(42))
func init() { func init() {
// The default RNG should be seeded with something good. // The default RNG should be seeded with something good.
mathRand.Seed(randomInt64()) mathRand.Seed(RandomInt64())
} }
// randomString returns a string of random characters (taken from // RandomString returns a string of random characters (taken from
// randomCharset) of the specified length. // randomCharset) of the specified length.
func randomString(l int) string { func RandomString(l int) string {
bs := make([]byte, l) bs := make([]byte, l)
for i := range bs { for i := range bs {
bs[i] = randomCharset[mathRand.Intn(len(randomCharset))] bs[i] = randomCharset[mathRand.Intn(len(randomCharset))]
@ -37,19 +37,19 @@ func randomString(l int) string {
return string(bs) return string(bs)
} }
// randomInt64 returns a strongly random int64, slowly // RandomInt64 returns a strongly random int64, slowly
func randomInt64() int64 { func RandomInt64() int64 {
var bs [8]byte var bs [8]byte
_, err := io.ReadFull(cryptoRand.Reader, bs[:]) _, err := io.ReadFull(cryptoRand.Reader, bs[:])
if err != nil { if err != nil {
panic("randomness failure: " + err.Error()) panic("randomness failure: " + err.Error())
} }
return seedFromBytes(bs[:]) return SeedFromBytes(bs[:])
} }
// seedFromBytes calculates a weak 64 bit hash from the given byte slice, // SeedFromBytes calculates a weak 64 bit hash from the given byte slice,
// suitable for use a predictable random seed. // suitable for use a predictable random seed.
func seedFromBytes(bs []byte) int64 { func SeedFromBytes(bs []byte) int64 {
h := md5.New() h := md5.New()
h.Write(bs) h.Write(bs)
s := h.Sum(nil) s := h.Sum(nil)

View File

@ -4,7 +4,7 @@
// License, v. 2.0. If a copy of the MPL was not distributed with this file, // License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/. // You can obtain one at http://mozilla.org/MPL/2.0/.
package main package util
import ( import (
"runtime" "runtime"
@ -21,7 +21,7 @@ func TestPredictableRandom(t *testing.T) {
predictableRandomTest.Do(func() { predictableRandomTest.Do(func() {
// predictable random sequence is predictable // predictable random sequence is predictable
e := int64(3440579354231278675) e := int64(3440579354231278675)
if v := int64(predictableRandom.Int()); v != e { if v := int64(PredictableRandom.Int()); v != e {
t.Errorf("Unexpected random value %d != %d", v, e) t.Errorf("Unexpected random value %d != %d", v, e)
} }
}) })
@ -38,7 +38,7 @@ func TestSeedFromBytes(t *testing.T) {
} }
for _, tc := range tcs { for _, tc := range tcs {
if v := seedFromBytes(tc.bs); v != tc.v { if v := SeedFromBytes(tc.bs); v != tc.v {
t.Errorf("Unexpected seed value %d != %d", v, tc.v) t.Errorf("Unexpected seed value %d != %d", v, tc.v)
} }
} }
@ -46,7 +46,7 @@ func TestSeedFromBytes(t *testing.T) {
func TestRandomString(t *testing.T) { func TestRandomString(t *testing.T) {
for _, l := range []int{0, 1, 2, 3, 4, 8, 42} { for _, l := range []int{0, 1, 2, 3, 4, 8, 42} {
s := randomString(l) s := RandomString(l)
if len(s) != l { if len(s) != l {
t.Errorf("Incorrect length %d != %d", len(s), l) t.Errorf("Incorrect length %d != %d", len(s), l)
} }
@ -54,7 +54,7 @@ func TestRandomString(t *testing.T) {
strings := make([]string, 1000) strings := make([]string, 1000)
for i := range strings { for i := range strings {
strings[i] = randomString(8) strings[i] = RandomString(8)
for j := range strings { for j := range strings {
if i == j { if i == j {
continue continue
@ -69,7 +69,7 @@ func TestRandomString(t *testing.T) {
func TestRandomInt64(t *testing.T) { func TestRandomInt64(t *testing.T) {
ints := make([]int64, 1000) ints := make([]int64, 1000)
for i := range ints { for i := range ints {
ints[i] = randomInt64() ints[i] = RandomInt64()
for j := range ints { for j := range ints {
if i == j { if i == j {
continue continue