From 4b6e7e7867047eb41e63b6798c401c10675d9cec Mon Sep 17 00:00:00 2001 From: Jakob Borg Date: Thu, 31 Aug 2017 07:34:48 +0000 Subject: [PATCH] lib/tlsutil: Remove undesired bufio from UnionedConnection (ref #4245) GitHub-Pull-Request: https://github.com/syncthing/syncthing/pull/4335 --- lib/tlsutil/empty_test.go | 10 ---- lib/tlsutil/tlsutil.go | 23 +++++--- lib/tlsutil/tlsutil_test.go | 110 ++++++++++++++++++++++++++++++++++++ 3 files changed, 125 insertions(+), 18 deletions(-) delete mode 100644 lib/tlsutil/empty_test.go create mode 100644 lib/tlsutil/tlsutil_test.go diff --git a/lib/tlsutil/empty_test.go b/lib/tlsutil/empty_test.go deleted file mode 100644 index 8ba9da697..000000000 --- a/lib/tlsutil/empty_test.go +++ /dev/null @@ -1,10 +0,0 @@ -// Copyright (C) 2016 The Syncthing Authors. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this file, -// You can obtain one at https://mozilla.org/MPL/2.0/. - -// The existence of this file means we get 0% test coverage rather than no -// test coverage at all. Remove when implementing an actual test. - -package tlsutil diff --git a/lib/tlsutil/tlsutil.go b/lib/tlsutil/tlsutil.go index 677de308c..3598b8773 100644 --- a/lib/tlsutil/tlsutil.go +++ b/lib/tlsutil/tlsutil.go @@ -7,7 +7,6 @@ package tlsutil import ( - "bufio" "crypto/ecdsa" "crypto/elliptic" "crypto/rsa" @@ -16,7 +15,6 @@ import ( "crypto/x509/pkix" "encoding/pem" "fmt" - "io" "math/big" "net" "os" @@ -130,27 +128,36 @@ func (l *DowngradingListener) AcceptNoWrapTLS() (net.Conn, bool, error) { return nil, false, err } - br := bufio.NewReader(conn) + var first [1]byte conn.SetReadDeadline(time.Now().Add(1 * time.Second)) - bs, err := br.Peek(1) + n, err := conn.Read(first[:]) conn.SetReadDeadline(time.Time{}) - if err != nil { + if err != nil || n == 0 { // We hit a read error here, but the Accept() call succeeded so we must not return an error. // We return the connection as is with a special error which handles this // special case in Accept(). return conn, false, ErrIdentificationFailed } - return &UnionedConnection{br, conn}, bs[0] == 0x16, nil + return &UnionedConnection{&first, conn}, first[0] == 0x16, nil } type UnionedConnection struct { - io.Reader + first *[1]byte net.Conn } func (c *UnionedConnection) Read(b []byte) (n int, err error) { - return c.Reader.Read(b) + if c.first != nil { + if len(b) == 0 { + // this probably doesn't happen, but handle it anyway + return 0, nil + } + b[0] = c.first[0] + c.first = nil + return 1, nil + } + return c.Conn.Read(b) } func publicKey(priv interface{}) interface{} { diff --git a/lib/tlsutil/tlsutil_test.go b/lib/tlsutil/tlsutil_test.go new file mode 100644 index 000000000..e5630344f --- /dev/null +++ b/lib/tlsutil/tlsutil_test.go @@ -0,0 +1,110 @@ +// Copyright (C) 2016 The Syncthing Authors. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at https://mozilla.org/MPL/2.0/. + +// The existence of this file means we get 0% test coverage rather than no +// test coverage at all. Remove when implementing an actual test. + +package tlsutil + +import ( + "bytes" + "io" + "net" + "testing" + "time" +) + +func TestUnionedConnection(t *testing.T) { + cases := []struct { + data []byte + isTLS bool + }{ + {[]byte{0}, false}, + {[]byte{0x16}, true}, + {[]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0}, false}, + {[]byte{0x16, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0}, true}, + } + + for i, tc := range cases { + fc := &fakeAccepter{tc.data} + dl := DowngradingListener{fc, nil} + + conn, isTLS, err := dl.AcceptNoWrapTLS() + if err != nil { + t.Fatalf("%d: %v", i, err) + } + if conn == nil { + t.Fatalf("%d: unexpected nil conn", i) + } + if isTLS != tc.isTLS { + t.Errorf("%d: isTLS=%v, expected %v", i, isTLS, tc.isTLS) + } + + // Read all the data, check it's the same + var bs []byte + buf := make([]byte, 128) + for { + n, err := conn.Read(buf) + if err == io.EOF { + break + } + if err != nil { + t.Fatalf("%d: read error: %v", i, err) + } + if len(bs) == 0 { + // first read; should return just one byte + if n != 1 { + t.Errorf("%d: first read returned %d bytes, not 1", i, n) + } + // Check that we've nilled out the "first" thing + if conn.(*UnionedConnection).first != nil { + t.Errorf("%d: expected first read to clear out the `first` attribute", i) + } + } + bs = append(bs, buf[:n]...) + } + if !bytes.Equal(bs, tc.data) { + t.Errorf("%d: got wrong data", i) + } + + t.Logf("%d: %v, %x", i, isTLS, bs) + } +} + +type fakeAccepter struct { + data []byte +} + +func (f *fakeAccepter) Accept() (net.Conn, error) { + return &fakeConn{f.data}, nil +} + +func (f *fakeAccepter) Addr() net.Addr { return nil } +func (f *fakeAccepter) Close() error { return nil } + +type fakeConn struct { + data []byte +} + +func (f *fakeConn) Read(b []byte) (int, error) { + if len(f.data) == 0 { + return 0, io.EOF + } + n := copy(b, f.data) + f.data = f.data[n:] + return n, nil +} + +func (f *fakeConn) Write(b []byte) (int, error) { + return len(b), nil +} + +func (f *fakeConn) Close() error { return nil } +func (f *fakeConn) LocalAddr() net.Addr { return nil } +func (f *fakeConn) RemoteAddr() net.Addr { return nil } +func (f *fakeConn) SetDeadline(time.Time) error { return nil } +func (f *fakeConn) SetReadDeadline(time.Time) error { return nil } +func (f *fakeConn) SetWriteDeadline(time.Time) error { return nil }