diff --git a/lib/connections/connections_test.go b/lib/connections/connections_test.go index b9a10d5ca..b54d52725 100644 --- a/lib/connections/connections_test.go +++ b/lib/connections/connections_test.go @@ -7,10 +7,12 @@ package connections import ( + "bytes" "context" "crypto/tls" "errors" "fmt" + "io" "io/ioutil" "math/rand" "net" @@ -296,7 +298,7 @@ func TestNextDialRegistryCleanup(t *testing.T) { } } -func BenchmarkConnections(pb *testing.B) { +func BenchmarkConnections(b *testing.B) { addrs := []string{ "tcp://127.0.0.1:0", "quic://127.0.0.1:0", @@ -317,9 +319,13 @@ func BenchmarkConnections(pb *testing.B) { } for _, addr := range addrs { for _, sz := range sizes { + data := make([]byte, sz) + if _, err := rand.Read(data); err != nil { + b.Fatal(err) + } for _, direction := range []string{"cs", "sc"} { proto := strings.SplitN(addr, ":", 2)[0] - pb.Run(fmt.Sprintf("%s_%d_%s", proto, sz, direction), func(b *testing.B) { + b.Run(fmt.Sprintf("%s_%d_%s", proto, sz, direction), func(b *testing.B) { if proto == "relay" && !haveRelay { b.Skip("could not connect to relay") } @@ -327,61 +333,81 @@ func BenchmarkConnections(pb *testing.B) { if direction == "sc" { server, client = client, server } - data := make([]byte, sz) - if _, err := rand.Read(data); err != nil { - b.Fatal(err) - } total := 0 - wg := sync.NewWaitGroup() b.ResetTimer() for i := 0; i < b.N; i++ { + wg := sync.NewWaitGroup() wg.Add(2) + errC := make(chan error, 2) go func() { - if err := sendMsg(client, data); err != nil { - b.Fatal(err) + wg.Wait() + close(errC) + }() + go func() { + if _, err := client.Write(data); err != nil { + errC <- err + return } wg.Done() }() go func() { - if err := recvMsg(server, data); err != nil { - b.Fatal(err) + if _, err := io.ReadFull(server, data); err != nil { + errC <- err + return } total += sz wg.Done() }() - wg.Wait() + err := <-errC + if err != nil { + b.Fatal(err) + } } b.ReportAllocs() b.SetBytes(int64(total / b.N)) }) }) } - } } } -func sendMsg(c internalConn, buf []byte) error { - n, err := c.Write(buf) - if n != len(buf) || err != nil { - return err +func TestConnectionEstablishment(t *testing.T) { + addrs := []string{ + "tcp://127.0.0.1:0", + "quic://127.0.0.1:0", + } + + send := make([]byte, 128<<10) + if _, err := rand.Read(send); err != nil { + t.Fatal(err) + } + + for _, addr := range addrs { + proto := strings.SplitN(addr, ":", 2)[0] + + t.Run(proto, func(t *testing.T) { + withConnectionPair(t, addr, func(client, server internalConn) { + if _, err := client.Write(send); err != nil { + t.Fatal(err) + } + + recv := make([]byte, len(send)) + if _, err := io.ReadFull(server, recv); err != nil { + t.Fatal(err) + } + + if !bytes.Equal(recv, send) { + t.Fatal("data mismatch") + } + }) + }) + } - return nil } -func recvMsg(c internalConn, buf []byte) error { - for read := 0; read != len(buf); { - n, err := c.Read(buf) - read += n - if err != nil { - return err - } - } - return nil -} - -func withConnectionPair(b *testing.B, connUri string, h func(client, server internalConn)) { +func withConnectionPair(b interface{ Fatal(...interface{}) }, connUri string, h func(client, server internalConn)) { // Root of the service tree. supervisor := suture.New("main", suture.Spec{ PassThroughPanics: true, @@ -449,19 +475,22 @@ func withConnectionPair(b *testing.B, connUri string, h func(client, server inte } } - data := []byte("hello") - // Quic does not start a stream until some data is sent through, so send something for the AcceptStream // to fire on the other side. - if err := sendMsg(clientConn, data); err != nil { + send := []byte("hello") + if _, err := clientConn.Write(send); err != nil { b.Fatal(err) } serverConn := <-conns - if err := recvMsg(serverConn, data); err != nil { + recv := make([]byte, len(send)) + if _, err := io.ReadFull(serverConn, recv); err != nil { b.Fatal(err) } + if !bytes.Equal(recv, send) { + b.Fatal("data mismatch") + } h(clientConn, serverConn) @@ -469,7 +498,7 @@ func withConnectionPair(b *testing.B, connUri string, h func(client, server inte _ = serverConn.Close() } -func mustGetCert(b *testing.B) tls.Certificate { +func mustGetCert(b interface{ Fatal(...interface{}) }) tls.Certificate { f1, err := ioutil.TempFile("", "") if err != nil { b.Fatal(err)