diff --git a/lib/discover/global.go b/lib/discover/global.go index d16acc147..f3b180edd 100644 --- a/lib/discover/global.go +++ b/lib/discover/global.go @@ -42,6 +42,7 @@ type httpClient interface { const ( defaultReannounceInterval = 30 * time.Minute announceErrorRetryInterval = 5 * time.Minute + requestTimeout = 5 * time.Second ) type announcement struct { @@ -73,6 +74,7 @@ func NewGlobal(server string, cert tls.Certificate, addrList AddressLister, rela // certificate to prove our identity, and may or may not verify the server // certificate depending on the insecure setting. var announceClient httpClient = &http.Client{ + Timeout: requestTimeout, Transport: &http.Transport{ Dial: dialer.Dial, Proxy: http.ProxyFromEnvironment, @@ -89,6 +91,7 @@ func NewGlobal(server string, cert tls.Certificate, addrList AddressLister, rela // The http.Client used for queries. We don't need to present our // certificate here, so lets not include it. May be insecure if requested. var queryClient httpClient = &http.Client{ + Timeout: requestTimeout, Transport: &http.Transport{ Dial: dialer.Dial, Proxy: http.ProxyFromEnvironment, diff --git a/lib/discover/global_test.go b/lib/discover/global_test.go index 000824241..360747236 100644 --- a/lib/discover/global_test.go +++ b/lib/discover/global_test.go @@ -79,11 +79,26 @@ func TestGlobalOverHTTP(t *testing.T) { mux.HandleFunc("/", s.handler) go http.Serve(list, mux) + // This should succeed direct, relays, err := testLookup("http://" + list.Addr().String() + "?insecure&noannounce") if err != nil { t.Fatalf("unexpected error: %v", err) } + if !testing.Short() { + // This should time out + _, _, err = testLookup("http://" + list.Addr().String() + "/block?insecure&noannounce") + if err == nil { + t.Fatalf("unexpected nil error, should have been a timeout") + } + } + + // This should work again + _, _, err = testLookup("http://" + list.Addr().String() + "?insecure&noannounce") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(direct) != 1 || direct[0] != "tcp://192.0.2.42::22000" { t.Errorf("incorrect direct list: %+v", direct) } @@ -231,6 +246,11 @@ type fakeDiscoveryServer struct { } func (s *fakeDiscoveryServer) handler(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/block" { + // Never return for requests here + select {} + } + if r.Method == "POST" { s.announce, _ = ioutil.ReadAll(r.Body) w.WriteHeader(204)