diff --git a/lib/api/api.go b/lib/api/api.go index af7675a38..bbce4489c 100644 --- a/lib/api/api.go +++ b/lib/api/api.go @@ -310,7 +310,7 @@ func (s *service) serve(stop chan struct{}) { // Wrap everything in CSRF protection. The /rest prefix should be // protected, other requests will grant cookies. - handler := csrfMiddleware(s.id.String()[:5], "/rest", guiCfg, mux) + var handler http.Handler = newCsrfManager(s.id.String()[:5], "/rest", guiCfg, mux, locations.Get(locations.CsrfTokens)) // Add our version and ID as a header to responses handler = withDetailsMiddleware(s.id, handler) diff --git a/lib/api/api_auth_test.go b/lib/api/api_auth_test.go index dc41f8804..4735378a5 100644 --- a/lib/api/api_auth_test.go +++ b/lib/api/api_auth_test.go @@ -19,6 +19,8 @@ func init() { } func TestStaticAuthOK(t *testing.T) { + t.Parallel() + ok := authStatic("user", "pass", "user", string(passwordHashBytes)) if !ok { t.Fatalf("should pass auth") @@ -26,6 +28,8 @@ func TestStaticAuthOK(t *testing.T) { } func TestSimpleAuthUsernameFail(t *testing.T) { + t.Parallel() + ok := authStatic("userWRONG", "pass", "user", string(passwordHashBytes)) if ok { t.Fatalf("should fail auth") @@ -33,6 +37,8 @@ func TestSimpleAuthUsernameFail(t *testing.T) { } func TestStaticAuthPasswordFail(t *testing.T) { + t.Parallel() + ok := authStatic("user", "passWRONG", "user", string(passwordHashBytes)) if ok { t.Fatalf("should fail auth") diff --git a/lib/api/api_csrf.go b/lib/api/api_csrf.go index b0f1d4032..951e85dc0 100644 --- a/lib/api/api_csrf.go +++ b/lib/api/api_csrf.go @@ -13,83 +13,103 @@ import ( "os" "strings" - "github.com/syncthing/syncthing/lib/config" - "github.com/syncthing/syncthing/lib/locations" "github.com/syncthing/syncthing/lib/osutil" "github.com/syncthing/syncthing/lib/rand" "github.com/syncthing/syncthing/lib/sync" ) -// csrfTokens is a list of valid tokens. It is sorted so that the most -// recently used token is first in the list. New tokens are added to the front -// of the list (as it is the most recently used at that time). The list is -// pruned to a maximum of maxCsrfTokens, throwing away the least recently used -// tokens. -var csrfTokens []string -var csrfMut = sync.NewMutex() - const maxCsrfTokens = 25 +type csrfManager struct { + // tokens is a list of valid tokens. It is sorted so that the most + // recently used token is first in the list. New tokens are added to the front + // of the list (as it is the most recently used at that time). The list is + // pruned to a maximum of maxCsrfTokens, throwing away the least recently used + // tokens. + tokens []string + tokensMut sync.Mutex + + unique string + prefix string + apiKeyValidator apiKeyValidator + next http.Handler + saveLocation string +} + +type apiKeyValidator interface { + IsValidAPIKey(key string) bool +} + // Check for CSRF token on /rest/ URLs. If a correct one is not given, reject // the request with 403. For / and /index.html, set a new CSRF cookie if none // is currently set. -func csrfMiddleware(unique string, prefix string, cfg config.GUIConfiguration, next http.Handler) http.Handler { - loadCsrfTokens() - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // Allow requests carrying a valid API key - if cfg.IsValidAPIKey(r.Header.Get("X-API-Key")) { - // Set the access-control-allow-origin header for CORS requests - // since a valid API key has been provided - w.Header().Add("Access-Control-Allow-Origin", "*") - next.ServeHTTP(w, r) - return - } - - if strings.HasPrefix(r.URL.Path, "/rest/debug") { - // Debugging functions are only available when explicitly - // enabled, and can be accessed without a CSRF token - next.ServeHTTP(w, r) - return - } - - // Allow requests for anything not under the protected path prefix, - // and set a CSRF cookie if there isn't already a valid one. - if !strings.HasPrefix(r.URL.Path, prefix) { - cookie, err := r.Cookie("CSRF-Token-" + unique) - if err != nil || !validCsrfToken(cookie.Value) { - l.Debugln("new CSRF cookie in response to request for", r.URL) - cookie = &http.Cookie{ - Name: "CSRF-Token-" + unique, - Value: newCsrfToken(), - } - http.SetCookie(w, cookie) - } - next.ServeHTTP(w, r) - return - } - - // Verify the CSRF token - token := r.Header.Get("X-CSRF-Token-" + unique) - if !validCsrfToken(token) { - http.Error(w, "CSRF Error", 403) - return - } - - next.ServeHTTP(w, r) - }) +func newCsrfManager(unique string, prefix string, apiKeyValidator apiKeyValidator, next http.Handler, saveLocation string) *csrfManager { + m := &csrfManager{ + tokensMut: sync.NewMutex(), + unique: unique, + prefix: prefix, + apiKeyValidator: apiKeyValidator, + next: next, + saveLocation: saveLocation, + } + m.load() + return m } -func validCsrfToken(token string) bool { - csrfMut.Lock() - defer csrfMut.Unlock() - for i, t := range csrfTokens { +func (m *csrfManager) ServeHTTP(w http.ResponseWriter, r *http.Request) { + // Allow requests carrying a valid API key + if m.apiKeyValidator.IsValidAPIKey(r.Header.Get("X-API-Key")) { + // Set the access-control-allow-origin header for CORS requests + // since a valid API key has been provided + w.Header().Add("Access-Control-Allow-Origin", "*") + m.next.ServeHTTP(w, r) + return + } + + if strings.HasPrefix(r.URL.Path, "/rest/debug") { + // Debugging functions are only available when explicitly + // enabled, and can be accessed without a CSRF token + m.next.ServeHTTP(w, r) + return + } + + // Allow requests for anything not under the protected path prefix, + // and set a CSRF cookie if there isn't already a valid one. + if !strings.HasPrefix(r.URL.Path, m.prefix) { + cookie, err := r.Cookie("CSRF-Token-" + m.unique) + if err != nil || !m.validToken(cookie.Value) { + l.Debugln("new CSRF cookie in response to request for", r.URL) + cookie = &http.Cookie{ + Name: "CSRF-Token-" + m.unique, + Value: m.newToken(), + } + http.SetCookie(w, cookie) + } + m.next.ServeHTTP(w, r) + return + } + + // Verify the CSRF token + token := r.Header.Get("X-CSRF-Token-" + m.unique) + if !m.validToken(token) { + http.Error(w, "CSRF Error", http.StatusForbidden) + return + } + + m.next.ServeHTTP(w, r) +} + +func (m *csrfManager) validToken(token string) bool { + m.tokensMut.Lock() + defer m.tokensMut.Unlock() + for i, t := range m.tokens { if t == token { if i > 0 { // Move this token to the head of the list. Copy the tokens at // the front one step to the right and then replace the token // at the head. - copy(csrfTokens[1:], csrfTokens[:i+1]) - csrfTokens[0] = token + copy(m.tokens[1:], m.tokens[:i+1]) + m.tokens[0] = token } return true } @@ -97,40 +117,47 @@ func validCsrfToken(token string) bool { return false } -func newCsrfToken() string { +func (m *csrfManager) newToken() string { token := rand.String(32) - csrfMut.Lock() - csrfTokens = append([]string{token}, csrfTokens...) - if len(csrfTokens) > maxCsrfTokens { - csrfTokens = csrfTokens[:maxCsrfTokens] + m.tokensMut.Lock() + m.tokens = append([]string{token}, m.tokens...) + if len(m.tokens) > maxCsrfTokens { + m.tokens = m.tokens[:maxCsrfTokens] } - defer csrfMut.Unlock() + defer m.tokensMut.Unlock() - saveCsrfTokens() + m.save() return token } -func saveCsrfTokens() { +func (m *csrfManager) save() { // We're ignoring errors in here. It's not super critical and there's // nothing relevant we can do about them anyway... - name := locations.Get(locations.CsrfTokens) - f, err := osutil.CreateAtomic(name) + if m.saveLocation == "" { + return + } + + f, err := osutil.CreateAtomic(m.saveLocation) if err != nil { return } - for _, t := range csrfTokens { + for _, t := range m.tokens { fmt.Fprintln(f, t) } f.Close() } -func loadCsrfTokens() { - f, err := os.Open(locations.Get(locations.CsrfTokens)) +func (m *csrfManager) load() { + if m.saveLocation == "" { + return + } + + f, err := os.Open(m.saveLocation) if err != nil { return } @@ -138,6 +165,6 @@ func loadCsrfTokens() { s := bufio.NewScanner(f) for s.Scan() { - csrfTokens = append(csrfTokens, s.Text()) + m.tokens = append(m.tokens, s.Text()) } } diff --git a/lib/api/api_test.go b/lib/api/api_test.go index 7b5bab602..13d1b738b 100644 --- a/lib/api/api_test.go +++ b/lib/api/api_test.go @@ -53,7 +53,7 @@ func TestMain(m *testing.M) { } func TestCSRFToken(t *testing.T) { - defer os.Remove(token) + t.Parallel() max := 250 int := 5 @@ -62,11 +62,13 @@ func TestCSRFToken(t *testing.T) { int = 2 } - t1 := newCsrfToken() - t2 := newCsrfToken() + m := newCsrfManager("unique", "prefix", config.GUIConfiguration{}, nil, "") - t3 := newCsrfToken() - if !validCsrfToken(t3) { + t1 := m.newToken() + t2 := m.newToken() + + t3 := m.newToken() + if !m.validToken(t3) { t.Fatal("t3 should be valid") } @@ -74,27 +76,29 @@ func TestCSRFToken(t *testing.T) { if i%int == 0 { // t1 and t2 should remain valid by virtue of us checking them now // and then. - if !validCsrfToken(t1) { + if !m.validToken(t1) { t.Fatal("t1 should be valid at iteration", i) } - if !validCsrfToken(t2) { + if !m.validToken(t2) { t.Fatal("t2 should be valid at iteration", i) } } // The newly generated token is always valid - t4 := newCsrfToken() - if !validCsrfToken(t4) { + t4 := m.newToken() + if !m.validToken(t4) { t.Fatal("t4 should be valid at iteration", i) } } - if validCsrfToken(t3) { + if m.validToken(t3) { t.Fatal("t3 should have expired by now") } } func TestStopAfterBrokenConfig(t *testing.T) { + t.Parallel() + cfg := config.Configuration{ GUI: config.GUIConfiguration{ RawAddress: "127.0.0.1:0", @@ -135,6 +139,8 @@ func TestStopAfterBrokenConfig(t *testing.T) { } func TestAssetsDir(t *testing.T) { + t.Parallel() + // For any given request to $FILE, we should return the first found of // - assetsdir/$THEME/$FILE // - compiled in asset $THEME/$FILE @@ -209,6 +215,8 @@ func expectURLToContain(t *testing.T, url, exp string) { } func TestDirNames(t *testing.T) { + t.Parallel() + names := dirNames("testdata") expected := []string{"config", "default", "foo", "testfolder"} if diff, equal := messagediff.PrettyDiff(expected, names); !equal { @@ -225,6 +233,8 @@ type httpTestCase struct { } func TestAPIServiceRequests(t *testing.T) { + t.Parallel() + const testAPIKey = "foobarbaz" cfg := new(mockedConfig) cfg.gui.APIKey = testAPIKey @@ -435,6 +445,8 @@ func testHTTPRequest(t *testing.T, baseURL string, tc httpTestCase, apikey strin } func TestHTTPLogin(t *testing.T) { + t.Parallel() + cfg := new(mockedConfig) cfg.gui.User = "üser" cfg.gui.Password = "$2a$10$IdIZTxTg/dCNuNEGlmLynOjqg4B1FvDKuIV5e0BB3pnWVHNb8.GSq" // bcrypt of "räksmörgås" in UTF-8 @@ -542,6 +554,8 @@ func startHTTP(cfg *mockedConfig) (string, error) { } func TestCSRFRequired(t *testing.T) { + t.Parallel() + const testAPIKey = "foobarbaz" cfg := new(mockedConfig) cfg.gui.APIKey = testAPIKey @@ -615,6 +629,8 @@ func TestCSRFRequired(t *testing.T) { } func TestRandomString(t *testing.T) { + t.Parallel() + const testAPIKey = "foobarbaz" cfg := new(mockedConfig) cfg.gui.APIKey = testAPIKey @@ -664,6 +680,8 @@ func TestRandomString(t *testing.T) { } func TestConfigPostOK(t *testing.T) { + t.Parallel() + cfg := bytes.NewBuffer([]byte(`{ "version": 15, "folders": [ @@ -685,6 +703,8 @@ func TestConfigPostOK(t *testing.T) { } func TestConfigPostDupFolder(t *testing.T) { + t.Parallel() + cfg := bytes.NewBuffer([]byte(`{ "version": 15, "folders": [ @@ -720,6 +740,8 @@ func testConfigPost(data io.Reader) (*http.Response, error) { } func TestHostCheck(t *testing.T) { + t.Parallel() + // An API service bound to localhost should reject non-localhost host Headers cfg := new(mockedConfig) @@ -878,6 +900,8 @@ func TestHostCheck(t *testing.T) { } func TestAddressIsLocalhost(t *testing.T) { + t.Parallel() + testcases := []struct { address string result bool @@ -921,6 +945,8 @@ func TestAddressIsLocalhost(t *testing.T) { } func TestAccessControlAllowOriginHeader(t *testing.T) { + t.Parallel() + const testAPIKey = "foobarbaz" cfg := new(mockedConfig) cfg.gui.APIKey = testAPIKey @@ -949,6 +975,8 @@ func TestAccessControlAllowOriginHeader(t *testing.T) { } func TestOptionsRequest(t *testing.T) { + t.Parallel() + const testAPIKey = "foobarbaz" cfg := new(mockedConfig) cfg.gui.APIKey = testAPIKey @@ -982,6 +1010,8 @@ func TestOptionsRequest(t *testing.T) { } func TestEventMasks(t *testing.T) { + t.Parallel() + cfg := new(mockedConfig) defSub := new(mockedEventSub) diskSub := new(mockedEventSub) @@ -1014,6 +1044,8 @@ func TestEventMasks(t *testing.T) { } func TestBrowse(t *testing.T) { + t.Parallel() + pathSep := string(os.PathSeparator) tmpDir, err := ioutil.TempDir("", "syncthing") @@ -1066,6 +1098,8 @@ func TestBrowse(t *testing.T) { } func TestPrefixMatch(t *testing.T) { + t.Parallel() + cases := []struct { s string prefix string