diff --git a/internal/repository/repository_internal_test.go b/internal/repository/repository_internal_test.go index e5ab6e5b7..2a9976ace 100644 --- a/internal/repository/repository_internal_test.go +++ b/internal/repository/repository_internal_test.go @@ -3,8 +3,10 @@ package repository import ( "math/rand" "sort" + "strings" "testing" + "github.com/restic/restic/internal/crypto" "github.com/restic/restic/internal/restic" rtest "github.com/restic/restic/internal/test" ) @@ -72,3 +74,101 @@ func BenchmarkSortCachedPacksFirst(b *testing.B) { sortCachedPacksFirst(cache, cpy[:]) } } + +func TestBlobVerification(t *testing.T) { + repo := TestRepository(t).(*Repository) + + type DamageType string + const ( + damageData DamageType = "data" + damageCompressed DamageType = "compressed" + damageCiphertext DamageType = "ciphertext" + ) + + for _, test := range []struct { + damage DamageType + msg string + }{ + {"", ""}, + {damageData, "hash mismatch"}, + {damageCompressed, "decompression failed"}, + {damageCiphertext, "ciphertext verification failed"}, + } { + plaintext := rtest.Random(800, 1234) + id := restic.Hash(plaintext) + if test.damage == damageData { + plaintext[42] ^= 0x42 + } + + uncompressedLength := uint(len(plaintext)) + plaintext = repo.getZstdEncoder().EncodeAll(plaintext, nil) + + if test.damage == damageCompressed { + plaintext = plaintext[:len(plaintext)-8] + } + + nonce := crypto.NewRandomNonce() + ciphertext := append([]byte{}, nonce...) + ciphertext = repo.Key().Seal(ciphertext, nonce, plaintext, nil) + + if test.damage == damageCiphertext { + ciphertext[42] ^= 0x42 + } + + err := repo.verifyCiphertext(ciphertext, int(uncompressedLength), id) + if test.msg == "" { + rtest.Assert(t, err == nil, "expected no error, got %v", err) + } else { + rtest.Assert(t, strings.Contains(err.Error(), test.msg), "expected error to contain %q, got %q", test.msg, err) + } + } +} + +func TestUnpackedVerification(t *testing.T) { + repo := TestRepository(t).(*Repository) + + type DamageType string + const ( + damageData DamageType = "data" + damageCompressed DamageType = "compressed" + damageCiphertext DamageType = "ciphertext" + ) + + for _, test := range []struct { + damage DamageType + msg string + }{ + {"", ""}, + {damageData, "data mismatch"}, + {damageCompressed, "decompression failed"}, + {damageCiphertext, "ciphertext verification failed"}, + } { + plaintext := rtest.Random(800, 1234) + orig := append([]byte{}, plaintext...) + if test.damage == damageData { + plaintext[42] ^= 0x42 + } + + compressed := []byte{2} + compressed = repo.getZstdEncoder().EncodeAll(plaintext, compressed) + + if test.damage == damageCompressed { + compressed = compressed[:len(compressed)-8] + } + + nonce := crypto.NewRandomNonce() + ciphertext := append([]byte{}, nonce...) + ciphertext = repo.Key().Seal(ciphertext, nonce, compressed, nil) + + if test.damage == damageCiphertext { + ciphertext[42] ^= 0x42 + } + + err := repo.verifyUnpacked(ciphertext, restic.IndexFile, orig) + if test.msg == "" { + rtest.Assert(t, err == nil, "expected no error, got %v", err) + } else { + rtest.Assert(t, strings.Contains(err.Error(), test.msg), "expected error to contain %q, got %q", test.msg, err) + } + } +}