cmd/strelaypoolsrv: Expose check error to client, fix incorrect response code handling

This commit is contained in:
Jakob Borg 2020-04-04 09:21:52 +02:00
parent 66262392c3
commit 362da59396
4 changed files with 25 additions and 16 deletions

View File

@ -10,7 +10,6 @@ import (
"context" "context"
"crypto/tls" "crypto/tls"
"encoding/json" "encoding/json"
"errors"
"flag" "flag"
"fmt" "fmt"
"io" "io"
@ -491,11 +490,11 @@ func handleRelayTest(request request) {
if debug { if debug {
log.Println("Request for", request.relay) log.Println("Request for", request.relay)
} }
if !client.TestRelay(context.TODO(), request.relay.uri, []tls.Certificate{testCert}, time.Second, 2*time.Second, 3) { if err := client.TestRelay(context.TODO(), request.relay.uri, []tls.Certificate{testCert}, time.Second, 2*time.Second, 3); err != nil {
if debug { if debug {
log.Println("Test for relay", request.relay, "failed") log.Println("Test for relay", request.relay, "failed:", err)
} }
request.result <- result{errors.New("connection test failed"), 0} request.result <- result{err, 0}
return return
} }

View File

@ -107,10 +107,10 @@ func main() {
connectToStdio(stdin, conn) connectToStdio(stdin, conn)
log.Println("Finished", conn.RemoteAddr(), conn.LocalAddr()) log.Println("Finished", conn.RemoteAddr(), conn.LocalAddr())
} else if test { } else if test {
if client.TestRelay(ctx, uri, []tls.Certificate{cert}, time.Second, 2*time.Second, 4) { if err := client.TestRelay(ctx, uri, []tls.Certificate{cert}, time.Second, 2*time.Second, 4); err == nil {
log.Println("OK") log.Println("OK")
} else { } else {
log.Println("FAIL") log.Println("FAIL:", err)
} }
} else { } else {
log.Fatal("Requires either join or connect") log.Fatal("Requires either join or connect")

View File

@ -5,11 +5,11 @@ package client
import ( import (
"context" "context"
"crypto/tls" "crypto/tls"
"errors"
"fmt" "fmt"
"net" "net"
"net/url" "net/url"
"strconv" "strconv"
"strings"
"time" "time"
"github.com/syncthing/syncthing/lib/dialer" "github.com/syncthing/syncthing/lib/dialer"
@ -17,6 +17,15 @@ import (
"github.com/syncthing/syncthing/lib/relay/protocol" "github.com/syncthing/syncthing/lib/relay/protocol"
) )
type incorrectResponseCodeErr struct {
code int32
msg string
}
func (e incorrectResponseCodeErr) Error() string {
return fmt.Sprintf("incorrect response code %d: %s", e.code, e.msg)
}
func GetInvitationFromRelay(ctx context.Context, uri *url.URL, id syncthingprotocol.DeviceID, certs []tls.Certificate, timeout time.Duration) (protocol.SessionInvitation, error) { func GetInvitationFromRelay(ctx context.Context, uri *url.URL, id syncthingprotocol.DeviceID, certs []tls.Certificate, timeout time.Duration) (protocol.SessionInvitation, error) {
if uri.Scheme != "relay" { if uri.Scheme != "relay" {
return protocol.SessionInvitation{}, fmt.Errorf("unsupported relay scheme: %v", uri.Scheme) return protocol.SessionInvitation{}, fmt.Errorf("unsupported relay scheme: %v", uri.Scheme)
@ -53,7 +62,7 @@ func GetInvitationFromRelay(ctx context.Context, uri *url.URL, id syncthingproto
switch msg := message.(type) { switch msg := message.(type) {
case protocol.Response: case protocol.Response:
return protocol.SessionInvitation{}, fmt.Errorf("incorrect response code %d: %s", msg.Code, msg.Message) return protocol.SessionInvitation{}, incorrectResponseCodeErr{msg.Code, msg.Message}
case protocol.SessionInvitation: case protocol.SessionInvitation:
l.Debugln("Received invitation", msg, "via", conn.LocalAddr()) l.Debugln("Received invitation", msg, "via", conn.LocalAddr())
ip := net.IP(msg.Address) ip := net.IP(msg.Address)
@ -104,13 +113,13 @@ func JoinSession(ctx context.Context, invitation protocol.SessionInvitation) (ne
} }
} }
func TestRelay(ctx context.Context, uri *url.URL, certs []tls.Certificate, sleep, timeout time.Duration, times int) bool { func TestRelay(ctx context.Context, uri *url.URL, certs []tls.Certificate, sleep, timeout time.Duration, times int) error {
id := syncthingprotocol.NewDeviceID(certs[0].Certificate[0]) id := syncthingprotocol.NewDeviceID(certs[0].Certificate[0])
invs := make(chan protocol.SessionInvitation, 1) invs := make(chan protocol.SessionInvitation, 1)
c, err := NewClient(uri, certs, invs, timeout) c, err := NewClient(uri, certs, invs, timeout)
if err != nil { if err != nil {
close(invs) close(invs)
return false return fmt.Errorf("creating client: %w", err)
} }
go c.Serve() go c.Serve()
defer func() { defer func() {
@ -119,16 +128,17 @@ func TestRelay(ctx context.Context, uri *url.URL, certs []tls.Certificate, sleep
}() }()
for i := 0; i < times; i++ { for i := 0; i < times; i++ {
_, err := GetInvitationFromRelay(ctx, uri, id, certs, timeout) _, err = GetInvitationFromRelay(ctx, uri, id, certs, timeout)
if err == nil { if err == nil {
return true return nil
} }
if !strings.Contains(err.Error(), "Incorrect response code") { if !errors.As(err, &incorrectResponseCodeErr{}) {
return false return fmt.Errorf("getting invitation: %w", err)
} }
time.Sleep(sleep) time.Sleep(sleep)
} }
return false
return fmt.Errorf("getting invitation: %w", err) // last of the above errors
} }
func configForCerts(certs []tls.Certificate) *tls.Config { func configForCerts(certs []tls.Certificate) *tls.Config {

View File

@ -201,7 +201,7 @@ func (c *staticClient) join() error {
switch msg := message.(type) { switch msg := message.(type) {
case protocol.Response: case protocol.Response:
if msg.Code != 0 { if msg.Code != 0 {
return fmt.Errorf("incorrect response code %d: %s", msg.Code, msg.Message) return incorrectResponseCodeErr{msg.Code, msg.Message}
} }
case protocol.RelayFull: case protocol.RelayFull: