cmd/strelaysrv: Add uPNP support, ability to set listen protocol (fixes #3503, fixes #3505, fixes #3506)

This commit is contained in:
Audrius Butkevicius 2016-08-23 08:43:27 +02:00 committed by Jakob Borg
parent 1de787fab8
commit be38c2111f
3 changed files with 84 additions and 13 deletions

View File

@ -23,7 +23,7 @@ var (
numConnections int64 numConnections int64
) )
func listener(addr string, config *tls.Config) { func listener(proto, addr string, config *tls.Config) {
tcpListener, err := net.Listen("tcp", addr) tcpListener, err := net.Listen("tcp", addr)
if err != nil { if err != nil {
log.Fatalln(err) log.Fatalln(err)
@ -167,11 +167,17 @@ func protocolConnectionHandler(tcpConn net.Conn, config *tls.Config) {
continue continue
} }
peerOutbox <- serverInvitation select {
case peerOutbox <- serverInvitation:
if debug { if debug {
log.Println("Sent invitation from", id, "to", requestedPeer) log.Println("Sent invitation from", id, "to", requestedPeer)
} }
default:
if debug {
log.Println("Could not send invitation from", id, "to", requestedPeer, "as peer disconnected")
}
}
conn.Close() conn.Close()
case protocol.Ping: case protocol.Ping:

View File

@ -24,6 +24,11 @@ import (
"github.com/syncthing/syncthing/lib/relay/protocol" "github.com/syncthing/syncthing/lib/relay/protocol"
"github.com/syncthing/syncthing/lib/tlsutil" "github.com/syncthing/syncthing/lib/tlsutil"
"github.com/syncthing/syncthing/lib/config"
"github.com/syncthing/syncthing/lib/nat"
_ "github.com/syncthing/syncthing/lib/pmp"
_ "github.com/syncthing/syncthing/lib/upnp"
syncthingprotocol "github.com/syncthing/syncthing/lib/protocol" syncthingprotocol "github.com/syncthing/syncthing/lib/protocol"
) )
@ -48,6 +53,7 @@ func init() {
var ( var (
listen string listen string
debug bool debug bool
proto string
sessionAddress []byte sessionAddress []byte
sessionPort uint16 sessionPort uint16
@ -70,12 +76,17 @@ var (
pools []string pools []string
providedBy string providedBy string
defaultPoolAddrs = "https://relays.syncthing.net/endpoint" defaultPoolAddrs = "https://relays.syncthing.net/endpoint"
natEnabled bool
natLease int
natRenewal int
natTimeout int
) )
func main() { func main() {
log.SetFlags(log.Lshortfile | log.LstdFlags) log.SetFlags(log.Lshortfile | log.LstdFlags)
var dir, extAddress string var dir, extAddress, proto string
flag.StringVar(&listen, "listen", ":22067", "Protocol listen address") flag.StringVar(&listen, "listen", ":22067", "Protocol listen address")
flag.StringVar(&dir, "keys", ".", "Directory where cert.pem and key.pem is stored") flag.StringVar(&dir, "keys", ".", "Directory where cert.pem and key.pem is stored")
@ -89,14 +100,22 @@ func main() {
flag.StringVar(&poolAddrs, "pools", defaultPoolAddrs, "Comma separated list of relay pool addresses to join") flag.StringVar(&poolAddrs, "pools", defaultPoolAddrs, "Comma separated list of relay pool addresses to join")
flag.StringVar(&providedBy, "provided-by", "", "An optional description about who provides the relay") flag.StringVar(&providedBy, "provided-by", "", "An optional description about who provides the relay")
flag.StringVar(&extAddress, "ext-address", "", "An optional address to advertise as being available on.\n\tAllows listening on an unprivileged port with port forwarding from e.g. 443, and be connected to on port 443.") flag.StringVar(&extAddress, "ext-address", "", "An optional address to advertise as being available on.\n\tAllows listening on an unprivileged port with port forwarding from e.g. 443, and be connected to on port 443.")
flag.StringVar(&proto, "protocol", "tcp", "Protocol used for listening. 'tcp' for IPv4 and IPv6, 'tcp4' for IPv4, 'tcp6' for IPv6")
flag.BoolVar(&natEnabled, "nat", false, "Use UPnP/NAT-PMP to acquire external port mapping")
flag.IntVar(&natLease, "nat-lease", 60, "NAT lease length in minutes")
flag.IntVar(&natRenewal, "nat-renewal", 30, "NAT renewal frequency in minutes")
flag.IntVar(&natTimeout, "nat-timeout", 10, "NAT discovery timeout in seconds")
flag.Parse() flag.Parse()
if extAddress == "" { if extAddress == "" {
extAddress = listen extAddress = listen
} }
addr, err := net.ResolveTCPAddr("tcp", extAddress) if len(providedBy) > 30 {
log.Fatal("Provided-by cannot be longer than 30 characters")
}
addr, err := net.ResolveTCPAddr(proto, extAddress)
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }
@ -149,6 +168,37 @@ func main() {
log.Println("ID:", id) log.Println("ID:", id)
} }
wrapper := config.Wrap("config", config.New(id))
wrapper.SetOptions(config.OptionsConfiguration{
NATLeaseM: natLease,
NATRenewalM: natRenewal,
NATTimeoutS: natTimeout,
})
natSvc := nat.NewService(id, wrapper)
mapping := mapping{natSvc.NewMapping(nat.TCP, addr.IP, addr.Port)}
if natEnabled {
go natSvc.Serve()
found := make(chan struct{})
mapping.OnChanged(func(_ *nat.Mapping, _, _ []nat.Address) {
select {
case found <- struct{}{}:
default:
}
})
// Need to wait a few extra seconds, since NAT library waits exactly natTimeout seconds on all interfaces.
timeout := time.Duration(natTimeout+2) * time.Second
log.Printf("Waiting %s to acquire NAT mapping", timeout)
select {
case <-found:
log.Printf("Found NAT mapping: %s", mapping.ExternalAddresses())
case <-time.After(timeout):
log.Println("Timeout out waiting for NAT mapping.")
}
}
if sessionLimitBps > 0 { if sessionLimitBps > 0 {
sessionLimiter = ratelimit.NewBucketWithRate(float64(sessionLimitBps), int64(2*sessionLimitBps)) sessionLimiter = ratelimit.NewBucketWithRate(float64(sessionLimitBps), int64(2*sessionLimitBps))
} }
@ -160,7 +210,7 @@ func main() {
go statusService(statusAddr) go statusService(statusAddr)
} }
uri, err := url.Parse(fmt.Sprintf("relay://%s/?id=%s&pingInterval=%s&networkTimeout=%s&sessionLimitBps=%d&globalLimitBps=%d&statusAddr=%s&providedBy=%s", extAddress, id, pingInterval, networkTimeout, sessionLimitBps, globalLimitBps, statusAddr, providedBy)) uri, err := url.Parse(fmt.Sprintf("relay://%s/?id=%s&pingInterval=%s&networkTimeout=%s&sessionLimitBps=%d&globalLimitBps=%d&statusAddr=%s&providedBy=%s", mapping.Address(), id, pingInterval, networkTimeout, sessionLimitBps, globalLimitBps, statusAddr, providedBy))
if err != nil { if err != nil {
log.Fatalln("Failed to construct URI", err) log.Fatalln("Failed to construct URI", err)
} }
@ -178,11 +228,11 @@ func main() {
for _, pool := range pools { for _, pool := range pools {
pool = strings.TrimSpace(pool) pool = strings.TrimSpace(pool)
if len(pool) > 0 { if len(pool) > 0 {
go poolHandler(pool, uri) go poolHandler(pool, uri, mapping)
} }
} }
go listener(listen, tlsCfg) go listener(proto, listen, tlsCfg)
sigs := make(chan os.Signal, 1) sigs := make(chan os.Signal, 1)
signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM) signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM)
@ -222,3 +272,15 @@ func monitorLimits() {
limitCheckTimer.Reset(time.Minute) limitCheckTimer.Reset(time.Minute)
} }
} }
type mapping struct {
*nat.Mapping
}
func (m *mapping) Address() nat.Address {
ext := m.ExternalAddresses()
if len(ext) > 0 {
return ext[0]
}
return m.Mapping.Address()
}

View File

@ -12,16 +12,19 @@ import (
"time" "time"
) )
func poolHandler(pool string, uri *url.URL) { func poolHandler(pool string, uri *url.URL, mapping mapping) {
if debug { if debug {
log.Println("Joining", pool) log.Println("Joining", pool)
} }
for { for {
uriCopy := *uri
uriCopy.Host = mapping.Address().String()
var b bytes.Buffer var b bytes.Buffer
json.NewEncoder(&b).Encode(struct { json.NewEncoder(&b).Encode(struct {
URL string `json:"url"` URL string `json:"url"`
}{ }{
uri.String(), uriCopy.String(),
}) })
resp, err := http.Post(pool, "application/json", &b) resp, err := http.Post(pool, "application/json", &b)
@ -39,7 +42,7 @@ func poolHandler(pool string, uri *url.URL) {
log.Println(pool, "under load, will retry in a minute") log.Println(pool, "under load, will retry in a minute")
time.Sleep(time.Minute) time.Sleep(time.Minute)
continue continue
} else if resp.StatusCode == 403 { } else if resp.StatusCode == 401 {
log.Println(pool, "failed to join due to IP address not matching external address. Aborting") log.Println(pool, "failed to join due to IP address not matching external address. Aborting")
return return
} else if resp.StatusCode == 200 { } else if resp.StatusCode == 200 {