Use a single socket for relaying

This commit is contained in:
AudriusButkevicius 2015-09-02 21:35:52 +01:00
parent f407ff8861
commit 9b85a6fb7c
4 changed files with 79 additions and 111 deletions

View File

@ -11,6 +11,7 @@ import (
"time"
syncthingprotocol "github.com/syncthing/protocol"
"github.com/syncthing/syncthing/lib/tlsutil"
"github.com/syncthing/relaysrv/protocol"
)
@ -21,14 +22,16 @@ var (
numConnections int64
)
func protocolListener(addr string, config *tls.Config) {
listener, err := net.Listen("tcp", addr)
func listener(addr string, config *tls.Config) {
tcpListener, err := net.Listen("tcp", addr)
if err != nil {
log.Fatalln(err)
}
listener := tlsutil.DowngradingListener{tcpListener, nil}
for {
conn, err := listener.Accept()
conn, isTLS, err := listener.AcceptNoWrap()
if err != nil {
if debug {
log.Println(err)
@ -39,10 +42,15 @@ func protocolListener(addr string, config *tls.Config) {
setTCPOptions(conn)
if debug {
log.Println("Protocol listener accepted connection from", conn.RemoteAddr())
log.Println("Listener accepted connection from", conn.RemoteAddr(), "tls", isTLS)
}
if isTLS {
go protocolConnectionHandler(conn, config)
} else {
go sessionConnectionHandler(conn)
}
go protocolConnectionHandler(conn, config)
}
}
@ -220,6 +228,65 @@ func protocolConnectionHandler(tcpConn net.Conn, config *tls.Config) {
}
}
func sessionConnectionHandler(conn net.Conn) {
if err := conn.SetDeadline(time.Now().Add(messageTimeout)); err != nil {
if debug {
log.Println("Weird error setting deadline:", err, "on", conn.RemoteAddr())
}
return
}
message, err := protocol.ReadMessage(conn)
if err != nil {
return
}
switch msg := message.(type) {
case protocol.JoinSessionRequest:
ses := findSession(string(msg.Key))
if debug {
log.Println(conn.RemoteAddr(), "session lookup", ses)
}
if ses == nil {
protocol.WriteMessage(conn, protocol.ResponseNotFound)
conn.Close()
return
}
if !ses.AddConnection(conn) {
if debug {
log.Println("Failed to add", conn.RemoteAddr(), "to session", ses)
}
protocol.WriteMessage(conn, protocol.ResponseAlreadyConnected)
conn.Close()
return
}
if err := protocol.WriteMessage(conn, protocol.ResponseSuccess); err != nil {
if debug {
log.Println("Failed to send session join response to ", conn.RemoteAddr(), "for", ses)
}
return
}
if err := conn.SetDeadline(time.Time{}); err != nil {
if debug {
log.Println("Weird error setting deadline:", err, "on", conn.RemoteAddr())
}
conn.Close()
return
}
default:
if debug {
log.Println("Unexpected message from", conn.RemoteAddr(), message)
}
protocol.WriteMessage(conn, protocol.ResponseUnexpectedMessage)
conn.Close()
}
}
func messageReader(conn net.Conn, messages chan<- interface{}, errors chan<- error) {
atomic.AddInt64(&numConnections, 1)
defer atomic.AddInt64(&numConnections, -1)

View File

@ -17,9 +17,8 @@ import (
)
var (
listenProtocol string
listenSession string
debug bool
listen string
debug bool
sessionAddress []byte
sessionPort uint16
@ -39,9 +38,7 @@ var (
func main() {
var dir, extAddress string
flag.StringVar(&listenProtocol, "protocol-listen", ":22067", "Protocol listen address")
flag.StringVar(&listenSession, "session-listen", ":22068", "Session listen address")
flag.StringVar(&extAddress, "external-address", "", "External address to advertise, defaults no IP and session-listen port, causing clients to use the remote IP from the protocol connection")
flag.StringVar(&listen, "listen", ":22067", "Protocol listen address")
flag.StringVar(&dir, "keys", ".", "Directory where cert.pem and key.pem is stored")
flag.DurationVar(&networkTimeout, "network-timeout", 2*time.Minute, "Timeout for network operations")
flag.DurationVar(&pingInterval, "ping-interval", time.Minute, "How often pings are sent")
@ -54,7 +51,7 @@ func main() {
flag.Parse()
if extAddress == "" {
extAddress = listenSession
extAddress = listen
}
addr, err := net.ResolveTCPAddr("tcp", extAddress)
@ -100,11 +97,9 @@ func main() {
globalLimiter = ratelimit.NewBucketWithRate(float64(globalLimitBps), int64(2*globalLimitBps))
}
go sessionListener(listenSession)
if statusAddr != "" {
go statusService(statusAddr)
}
protocolListener(listenProtocol, tlsCfg)
listener(listen, tlsCfg)
}

View File

@ -7,8 +7,9 @@ package protocol
import (
"fmt"
syncthingprotocol "github.com/syncthing/protocol"
"net"
syncthingprotocol "github.com/syncthing/protocol"
)
const (

View File

@ -1,95 +0,0 @@
// Copyright (C) 2015 Audrius Butkevicius and Contributors (see the CONTRIBUTORS file).
package main
import (
"log"
"net"
"time"
"github.com/syncthing/relaysrv/protocol"
)
func sessionListener(addr string) {
listener, err := net.Listen("tcp", addr)
if err != nil {
log.Fatalln(err)
}
for {
conn, err := listener.Accept()
if err != nil {
if debug {
log.Println(err)
}
continue
}
setTCPOptions(conn)
if debug {
log.Println("Session listener accepted connection from", conn.RemoteAddr())
}
go sessionConnectionHandler(conn)
}
}
func sessionConnectionHandler(conn net.Conn) {
if err := conn.SetDeadline(time.Now().Add(messageTimeout)); err != nil {
if debug {
log.Println("Weird error setting deadline:", err, "on", conn.RemoteAddr())
}
return
}
message, err := protocol.ReadMessage(conn)
if err != nil {
return
}
switch msg := message.(type) {
case protocol.JoinSessionRequest:
ses := findSession(string(msg.Key))
if debug {
log.Println(conn.RemoteAddr(), "session lookup", ses)
}
if ses == nil {
protocol.WriteMessage(conn, protocol.ResponseNotFound)
conn.Close()
return
}
if !ses.AddConnection(conn) {
if debug {
log.Println("Failed to add", conn.RemoteAddr(), "to session", ses)
}
protocol.WriteMessage(conn, protocol.ResponseAlreadyConnected)
conn.Close()
return
}
if err := protocol.WriteMessage(conn, protocol.ResponseSuccess); err != nil {
if debug {
log.Println("Failed to send session join response to ", conn.RemoteAddr(), "for", ses)
}
return
}
if err := conn.SetDeadline(time.Time{}); err != nil {
if debug {
log.Println("Weird error setting deadline:", err, "on", conn.RemoteAddr())
}
conn.Close()
return
}
default:
if debug {
log.Println("Unexpected message from", conn.RemoteAddr(), message)
}
protocol.WriteMessage(conn, protocol.ResponseUnexpectedMessage)
conn.Close()
}
}