diff --git a/crypto/crypto.go b/crypto/crypto.go index 77dadb7a5..bfb00d953 100644 --- a/crypto/crypto.go +++ b/crypto/crypto.go @@ -204,13 +204,23 @@ func (k *EncryptionKey) UnmarshalJSON(data []byte) error { return nil } +// ErrInvalidCiphertext is returned when trying to encrypt into the slice that +// holds the plaintext. +var ErrInvalidCiphertext = errors.New("invalid ciphertext, same slice used for plaintext") + // Encrypt encrypts and signs data. Stored in ciphertext is IV || Ciphertext || // MAC. Encrypt returns the new ciphertext slice, which is extended when // necessary. ciphertext and plaintext may not point to (exactly) the same // slice or non-intersecting slices. func Encrypt(ks *Key, ciphertext, plaintext []byte) ([]byte, error) { - // extend ciphertext slice if necessary ciphertext = ciphertext[:cap(ciphertext)] + + // test for same slice, if possible + if len(plaintext) > 0 && len(ciphertext) > 0 && &plaintext[0] == &ciphertext[0] { + return nil, ErrInvalidCiphertext + } + + // extend ciphertext slice if necessary if len(ciphertext) < len(plaintext)+Extension { ext := len(plaintext) + Extension - cap(ciphertext) ciphertext = append(ciphertext, make([]byte, ext)...) diff --git a/crypto/crypto_test.go b/crypto/crypto_test.go index bf1863abf..2a8375035 100644 --- a/crypto/crypto_test.go +++ b/crypto/crypto_test.go @@ -95,6 +95,28 @@ func TestSameBuffer(t *testing.T) { "wrong plaintext returned") } +func TestCornerCases(t *testing.T) { + k := crypto.NewKey() + + // nil plaintext should encrypt to the empty string + // nil ciphertext should allocate a new slice for the ciphertext + c, err := crypto.Encrypt(k, nil, nil) + OK(t, err) + + Assert(t, len(c) == crypto.Extension, + "wrong length returned for ciphertext, expected 0, got %d", + len(c)) + + // this should decrypt to an empty slice + p, err := crypto.Decrypt(k, nil, c) + OK(t, err) + Equals(t, []byte{}, p) + + // test encryption for same slice, this should return an error + _, err = crypto.Encrypt(k, c, c) + Equals(t, crypto.ErrInvalidCiphertext, err) +} + func TestLargeEncrypt(t *testing.T) { if !*testLargeCrypto { t.SkipNow()