// Copyright (C) 2015 Audrius Butkevicius and Contributors (see the CONTRIBUTORS file). package client import ( "context" "crypto/tls" "errors" "fmt" "net" "net/url" "strconv" "time" "github.com/syncthing/syncthing/lib/dialer" syncthingprotocol "github.com/syncthing/syncthing/lib/protocol" "github.com/syncthing/syncthing/lib/relay/protocol" ) type incorrectResponseCodeErr struct { code int32 msg string } func (e incorrectResponseCodeErr) Error() string { return fmt.Sprintf("incorrect response code %d: %s", e.code, e.msg) } func GetInvitationFromRelay(ctx context.Context, uri *url.URL, id syncthingprotocol.DeviceID, certs []tls.Certificate, timeout time.Duration) (protocol.SessionInvitation, error) { if uri.Scheme != "relay" { return protocol.SessionInvitation{}, fmt.Errorf("unsupported relay scheme: %v", uri.Scheme) } ctx, cancel := context.WithTimeout(ctx, timeout) defer cancel() rconn, err := dialer.DialContext(ctx, "tcp", uri.Host) if err != nil { return protocol.SessionInvitation{}, err } conn := tls.Client(rconn, configForCerts(certs)) conn.SetDeadline(time.Now().Add(timeout)) if err := performHandshakeAndValidation(conn, uri); err != nil { return protocol.SessionInvitation{}, err } defer conn.Close() request := protocol.ConnectRequest{ ID: id[:], } if err := protocol.WriteMessage(conn, request); err != nil { return protocol.SessionInvitation{}, err } message, err := protocol.ReadMessage(conn) if err != nil { return protocol.SessionInvitation{}, err } switch msg := message.(type) { case protocol.Response: return protocol.SessionInvitation{}, incorrectResponseCodeErr{msg.Code, msg.Message} case protocol.SessionInvitation: l.Debugln("Received invitation", msg, "via", conn.LocalAddr()) ip := net.IP(msg.Address) if len(ip) == 0 || ip.IsUnspecified() { msg.Address = remoteIPBytes(conn) } return msg, nil default: return protocol.SessionInvitation{}, fmt.Errorf("protocol error: unexpected message %v", msg) } } func JoinSession(ctx context.Context, invitation protocol.SessionInvitation) (net.Conn, error) { addr := net.JoinHostPort(net.IP(invitation.Address).String(), strconv.Itoa(int(invitation.Port))) ctx, cancel := context.WithTimeout(ctx, 10*time.Second) defer cancel() conn, err := dialer.DialContext(ctx, "tcp", addr) if err != nil { return nil, err } request := protocol.JoinSessionRequest{ Key: invitation.Key, } conn.SetDeadline(time.Now().Add(10 * time.Second)) err = protocol.WriteMessage(conn, request) if err != nil { return nil, err } message, err := protocol.ReadMessage(conn) if err != nil { return nil, err } conn.SetDeadline(time.Time{}) switch msg := message.(type) { case protocol.Response: if msg.Code != 0 { return nil, fmt.Errorf("incorrect response code %d: %s", msg.Code, msg.Message) } return conn, nil default: return nil, fmt.Errorf("protocol error: expecting response got %v", msg) } } func TestRelay(ctx context.Context, uri *url.URL, certs []tls.Certificate, sleep, timeout time.Duration, times int) error { id := syncthingprotocol.NewDeviceID(certs[0].Certificate[0]) invs := make(chan protocol.SessionInvitation, 1) c, err := NewClient(uri, certs, invs, timeout) if err != nil { close(invs) return fmt.Errorf("creating client: %w", err) } go c.Serve() defer func() { c.Stop() close(invs) }() for i := 0; i < times; i++ { _, err = GetInvitationFromRelay(ctx, uri, id, certs, timeout) if err == nil { return nil } if !errors.As(err, &incorrectResponseCodeErr{}) { return fmt.Errorf("getting invitation: %w", err) } time.Sleep(sleep) } return fmt.Errorf("getting invitation: %w", err) // last of the above errors } func configForCerts(certs []tls.Certificate) *tls.Config { return &tls.Config{ Certificates: certs, NextProtos: []string{protocol.ProtocolName}, ClientAuth: tls.RequestClientCert, SessionTicketsDisabled: true, InsecureSkipVerify: true, MinVersion: tls.VersionTLS12, CipherSuites: []uint16{ tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA, tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA, tls.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA, tls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, }, } } func remoteIPBytes(conn net.Conn) []byte { addr := conn.RemoteAddr().String() if host, _, err := net.SplitHostPort(addr); err == nil { addr = host } return net.ParseIP(addr)[:] }