From d03460010ff47f1beeaf376d40c5423ee238ab58 Mon Sep 17 00:00:00 2001 From: greatroar <61184462+greatroar@users.noreply.github.com> Date: Sat, 15 Oct 2022 14:12:45 +0200 Subject: [PATCH] internal/restic: Fix ID.UnmarshalJSON, ParseID ID.UnmarshalJSON accepted non-JSON input with ' as the string delimiter. Also, the error message for non-hex input was less informative than it could be and it performed too many checks. Changed ParseID to keep the error messages consistent. --- internal/restic/id.go | 40 ++++++++++++---------------------------- 1 file changed, 12 insertions(+), 28 deletions(-) diff --git a/internal/restic/id.go b/internal/restic/id.go index 5a25e0ebe..e71c6d71b 100644 --- a/internal/restic/id.go +++ b/internal/restic/id.go @@ -6,8 +6,6 @@ import ( "fmt" "io" - "github.com/restic/restic/internal/errors" - "github.com/minio/sha256-simd" ) @@ -24,14 +22,13 @@ type ID [idSize]byte // ParseID converts the given string to an ID. func ParseID(s string) (ID, error) { - b, err := hex.DecodeString(s) - - if err != nil { - return ID{}, errors.Wrap(err, "hex.DecodeString") + if len(s) != hex.EncodedLen(idSize) { + return ID{}, fmt.Errorf("invalid length for ID: %q", s) } - if len(b) != idSize { - return ID{}, errors.New("invalid length for hash") + b, err := hex.DecodeString(s) + if err != nil { + return ID{}, fmt.Errorf("invalid ID: %s", err) } id := ID{} @@ -96,34 +93,21 @@ func (id ID) MarshalJSON() ([]byte, error) { // UnmarshalJSON parses the JSON-encoded data and stores the result in id. func (id *ID) UnmarshalJSON(b []byte) error { // check string length - if len(b) < 2 { - return fmt.Errorf("invalid ID: %q", b) + if len(b) != len(`""`)+hex.EncodedLen(idSize) { + return fmt.Errorf("invalid length for ID: %q", b) } - if len(b)%2 != 0 { - return fmt.Errorf("invalid ID length: %q", b) - } - - // check string delimiters - if b[0] != '"' && b[0] != '\'' { + if b[0] != '"' { return fmt.Errorf("invalid start of string: %q", b[0]) } - last := len(b) - 1 - if b[0] != b[last] { - return fmt.Errorf("starting string delimiter (%q) does not match end (%q)", b[0], b[last]) - } - - // strip JSON string delimiters - b = b[1:last] - - if len(b) != 2*len(id) { - return fmt.Errorf("invalid length for ID") - } + // Strip JSON string delimiters. The json.Unmarshaler contract says we get + // a valid JSON value, so we don't need to check that b[len(b)-1] == '"'. + b = b[1 : len(b)-1] _, err := hex.Decode(id[:], b) if err != nil { - return errors.Wrap(err, "hex.Decode") + return fmt.Errorf("invalid ID: %s", err) } return nil