From 680b0b14db1b722f44010cba0d05e00cbf912e13 Mon Sep 17 00:00:00 2001 From: Simon Frei Date: Tue, 25 Feb 2020 21:18:31 +0100 Subject: [PATCH] lib/connections: Refactor status for testing (ref #6361) (#6362) --- lib/connections/connections_test.go | 38 +++++++++++++++++++++++++++++ lib/connections/service.go | 26 ++++++++++++++------ 2 files changed, 56 insertions(+), 8 deletions(-) diff --git a/lib/connections/connections_test.go b/lib/connections/connections_test.go index f42f18390..8aa8cfcae 100644 --- a/lib/connections/connections_test.go +++ b/lib/connections/connections_test.go @@ -7,6 +7,8 @@ package connections import ( + "context" + "errors" "net/url" "testing" @@ -167,3 +169,39 @@ func TestGetDialer(t *testing.T) { } } } + +func TestConnectionStatus(t *testing.T) { + s := newConnectionStatusHandler() + + addr := "testAddr" + testErr := errors.New("testErr") + + if stats := s.ConnectionStatus(); len(stats) != 0 { + t.Fatal("newly created connectionStatusHandler isn't empty:", len(stats)) + } + + check := func(in, out error) { + t.Helper() + s.setConnectionStatus(addr, in) + switch stat, ok := s.ConnectionStatus()[addr]; { + case !ok: + t.Fatal("entry missing") + case out == nil: + if stat.Error != nil { + t.Fatal("expected nil error, got", stat.Error) + } + case *stat.Error != out.Error(): + t.Fatalf("expected %v error, got %v", out.Error(), *stat.Error) + } + } + + check(nil, nil) + + check(context.Canceled, nil) + + check(testErr, testErr) + + check(context.Canceled, testErr) + + check(nil, nil) +} diff --git a/lib/connections/service.go b/lib/connections/service.go index 68d03ca5d..2c9831d29 100644 --- a/lib/connections/service.go +++ b/lib/connections/service.go @@ -110,6 +110,8 @@ type ConnectionStatusEntry struct { type service struct { *suture.Supervisor + connectionStatusHandler + cfg config.Wrapper myID protocol.DeviceID model Model @@ -127,9 +129,6 @@ type service struct { listeners map[string]genericListener listenerTokens map[string]suture.ServiceToken listenerSupervisor *suture.Supervisor - - connectionStatusMut sync.RWMutex - connectionStatus map[string]ConnectionStatusEntry // address -> latest error/status } func NewService(cfg config.Wrapper, myID protocol.DeviceID, mdl Model, tlsCfg *tls.Config, discoverer discover.Finder, bepProtocolName string, tlsDefaultCommonName string, evLogger events.Logger) Service { @@ -140,6 +139,8 @@ func NewService(cfg config.Wrapper, myID protocol.DeviceID, mdl Model, tlsCfg *t }, PassThroughPanics: true, }), + connectionStatusHandler: newConnectionStatusHandler(), + cfg: cfg, myID: myID, model: mdl, @@ -168,9 +169,6 @@ func NewService(cfg config.Wrapper, myID protocol.DeviceID, mdl Model, tlsCfg *t FailureBackoff: 600 * time.Second, PassThroughPanics: true, }), - - connectionStatusMut: sync.NewRWMutex(), - connectionStatus: make(map[string]ConnectionStatusEntry), } cfg.Subscribe(service) @@ -702,7 +700,19 @@ func (s *service) ListenerStatus() map[string]ListenerStatusEntry { return result } -func (s *service) ConnectionStatus() map[string]ConnectionStatusEntry { +type connectionStatusHandler struct { + connectionStatusMut sync.RWMutex + connectionStatus map[string]ConnectionStatusEntry // address -> latest error/status +} + +func newConnectionStatusHandler() connectionStatusHandler { + return connectionStatusHandler{ + connectionStatusMut: sync.NewRWMutex(), + connectionStatus: make(map[string]ConnectionStatusEntry), + } +} + +func (s *connectionStatusHandler) ConnectionStatus() map[string]ConnectionStatusEntry { result := make(map[string]ConnectionStatusEntry) s.connectionStatusMut.RLock() for k, v := range s.connectionStatus { @@ -712,7 +722,7 @@ func (s *service) ConnectionStatus() map[string]ConnectionStatusEntry { return result } -func (s *service) setConnectionStatus(address string, err error) { +func (s *connectionStatusHandler) setConnectionStatus(address string, err error) { if errors.Cause(err) == context.Canceled { return }