diff --git a/lib/api/api_csrf.go b/lib/api/api_csrf.go index 951e85dc0..97e0f3357 100644 --- a/lib/api/api_csrf.go +++ b/lib/api/api_csrf.go @@ -46,6 +46,7 @@ type apiKeyValidator interface { func newCsrfManager(unique string, prefix string, apiKeyValidator apiKeyValidator, next http.Handler, saveLocation string) *csrfManager { m := &csrfManager{ tokensMut: sync.NewMutex(), + tokens: make([]string, 0, maxCsrfTokens), unique: unique, prefix: prefix, apiKeyValidator: apiKeyValidator, @@ -108,7 +109,7 @@ func (m *csrfManager) validToken(token string) bool { // Move this token to the head of the list. Copy the tokens at // the front one step to the right and then replace the token // at the head. - copy(m.tokens[1:], m.tokens[:i+1]) + copy(m.tokens[1:], m.tokens[:i]) m.tokens[0] = token } return true @@ -121,12 +122,14 @@ func (m *csrfManager) newToken() string { token := rand.String(32) m.tokensMut.Lock() - m.tokens = append([]string{token}, m.tokens...) - if len(m.tokens) > maxCsrfTokens { - m.tokens = m.tokens[:maxCsrfTokens] - } defer m.tokensMut.Unlock() + if len(m.tokens) < maxCsrfTokens { + m.tokens = append(m.tokens, "") + } + copy(m.tokens[1:], m.tokens) + m.tokens[0] = token + m.save() return token diff --git a/lib/api/api_test.go b/lib/api/api_test.go index e149824b0..7dd3a1955 100644 --- a/lib/api/api_test.go +++ b/lib/api/api_test.go @@ -18,6 +18,7 @@ import ( "net/http/httptest" "os" "path/filepath" + "reflect" "runtime" "strconv" "strings" @@ -73,10 +74,10 @@ func TestMain(m *testing.M) { func TestCSRFToken(t *testing.T) { t.Parallel() - max := 250 + max := 10 * maxCsrfTokens int := 5 if testing.Short() { - max = 20 + max = 1 + maxCsrfTokens int = 2 } @@ -90,6 +91,11 @@ func TestCSRFToken(t *testing.T) { t.Fatal("t3 should be valid") } + valid := make(map[string]struct{}, maxCsrfTokens) + for _, token := range m.tokens { + valid[token] = struct{}{} + } + for i := 0; i < max; i++ { if i%int == 0 { // t1 and t2 should remain valid by virtue of us checking them now @@ -102,11 +108,27 @@ func TestCSRFToken(t *testing.T) { } } + if len(m.tokens) == maxCsrfTokens { + // We're about to add a token, which will remove the last token + // from m.tokens. + delete(valid, m.tokens[len(m.tokens)-1]) + } + // The newly generated token is always valid t4 := m.newToken() if !m.validToken(t4) { t.Fatal("t4 should be valid at iteration", i) } + valid[t4] = struct{}{} + + v := make(map[string]struct{}, maxCsrfTokens) + for _, token := range m.tokens { + v[token] = struct{}{} + } + + if !reflect.DeepEqual(v, valid) { + t.Fatalf("want valid tokens %v, got %v", valid, v) + } } if m.validToken(t3) {