lib/api: Fix and optimize csrfManager (#8329)

An off-by-one error could cause tokens to be forgotten. Suppose

	tokens := []string{"foo", "bar", "baz", "quux"}
	i := 2
	token := tokens[i] // token == "baz"

Then, after

	copy(tokens[1:], tokens[:i+1])
	tokens[0] = token

we have

	tokens == []string{"baz", "foo", "bar", "baz"}

The short test actually relied on this bug.
This commit is contained in:
greatroar 2022-05-07 12:30:13 +02:00 committed by GitHub
parent 520ca4bcb0
commit 97291c9184
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 32 additions and 7 deletions

View File

@ -46,6 +46,7 @@ type apiKeyValidator interface {
func newCsrfManager(unique string, prefix string, apiKeyValidator apiKeyValidator, next http.Handler, saveLocation string) *csrfManager { func newCsrfManager(unique string, prefix string, apiKeyValidator apiKeyValidator, next http.Handler, saveLocation string) *csrfManager {
m := &csrfManager{ m := &csrfManager{
tokensMut: sync.NewMutex(), tokensMut: sync.NewMutex(),
tokens: make([]string, 0, maxCsrfTokens),
unique: unique, unique: unique,
prefix: prefix, prefix: prefix,
apiKeyValidator: apiKeyValidator, apiKeyValidator: apiKeyValidator,
@ -108,7 +109,7 @@ func (m *csrfManager) validToken(token string) bool {
// Move this token to the head of the list. Copy the tokens at // Move this token to the head of the list. Copy the tokens at
// the front one step to the right and then replace the token // the front one step to the right and then replace the token
// at the head. // at the head.
copy(m.tokens[1:], m.tokens[:i+1]) copy(m.tokens[1:], m.tokens[:i])
m.tokens[0] = token m.tokens[0] = token
} }
return true return true
@ -121,12 +122,14 @@ func (m *csrfManager) newToken() string {
token := rand.String(32) token := rand.String(32)
m.tokensMut.Lock() m.tokensMut.Lock()
m.tokens = append([]string{token}, m.tokens...)
if len(m.tokens) > maxCsrfTokens {
m.tokens = m.tokens[:maxCsrfTokens]
}
defer m.tokensMut.Unlock() defer m.tokensMut.Unlock()
if len(m.tokens) < maxCsrfTokens {
m.tokens = append(m.tokens, "")
}
copy(m.tokens[1:], m.tokens)
m.tokens[0] = token
m.save() m.save()
return token return token

View File

@ -18,6 +18,7 @@ import (
"net/http/httptest" "net/http/httptest"
"os" "os"
"path/filepath" "path/filepath"
"reflect"
"runtime" "runtime"
"strconv" "strconv"
"strings" "strings"
@ -73,10 +74,10 @@ func TestMain(m *testing.M) {
func TestCSRFToken(t *testing.T) { func TestCSRFToken(t *testing.T) {
t.Parallel() t.Parallel()
max := 250 max := 10 * maxCsrfTokens
int := 5 int := 5
if testing.Short() { if testing.Short() {
max = 20 max = 1 + maxCsrfTokens
int = 2 int = 2
} }
@ -90,6 +91,11 @@ func TestCSRFToken(t *testing.T) {
t.Fatal("t3 should be valid") t.Fatal("t3 should be valid")
} }
valid := make(map[string]struct{}, maxCsrfTokens)
for _, token := range m.tokens {
valid[token] = struct{}{}
}
for i := 0; i < max; i++ { for i := 0; i < max; i++ {
if i%int == 0 { if i%int == 0 {
// t1 and t2 should remain valid by virtue of us checking them now // t1 and t2 should remain valid by virtue of us checking them now
@ -102,11 +108,27 @@ func TestCSRFToken(t *testing.T) {
} }
} }
if len(m.tokens) == maxCsrfTokens {
// We're about to add a token, which will remove the last token
// from m.tokens.
delete(valid, m.tokens[len(m.tokens)-1])
}
// The newly generated token is always valid // The newly generated token is always valid
t4 := m.newToken() t4 := m.newToken()
if !m.validToken(t4) { if !m.validToken(t4) {
t.Fatal("t4 should be valid at iteration", i) t.Fatal("t4 should be valid at iteration", i)
} }
valid[t4] = struct{}{}
v := make(map[string]struct{}, maxCsrfTokens)
for _, token := range m.tokens {
v[token] = struct{}{}
}
if !reflect.DeepEqual(v, valid) {
t.Fatalf("want valid tokens %v, got %v", valid, v)
}
} }
if m.validToken(t3) { if m.validToken(t3) {