From 203d775190b09a85a40bac19864e9f2ac58ab631 Mon Sep 17 00:00:00 2001 From: Alexander Neumann Date: Fri, 22 Mar 2019 20:30:29 +0100 Subject: [PATCH] restic: Make JSON unmarshal for ID more efficient This commit reduces several allocations in UnmarshalJSON() by decoding the hex string directly in a single step. --- internal/restic/id.go | 31 +++++++++++++++++++++++++----- internal/restic/id_test.go | 39 +++++++++++++++++++++++++++++++++++++- 2 files changed, 64 insertions(+), 6 deletions(-) diff --git a/internal/restic/id.go b/internal/restic/id.go index ffe818a83..bc9749e77 100644 --- a/internal/restic/id.go +++ b/internal/restic/id.go @@ -5,6 +5,7 @@ import ( "crypto/sha256" "encoding/hex" "encoding/json" + "fmt" "io" "github.com/restic/restic/internal/errors" @@ -101,13 +102,33 @@ func (id ID) MarshalJSON() ([]byte, error) { // UnmarshalJSON parses the JSON-encoded data and stores the result in id. func (id *ID) UnmarshalJSON(b []byte) error { - var s string - err := json.Unmarshal(b, &s) - if err != nil { - return errors.Wrap(err, "Unmarshal") + // check string length + if len(b) < 2 { + return fmt.Errorf("invalid ID: %q", b) } - _, err = hex.Decode(id[:], []byte(s)) + if len(b)%2 != 0 { + return fmt.Errorf("invalid ID length: %q", b) + } + + // check string delimiters + if b[0] != '"' && b[0] != '\'' { + return fmt.Errorf("invalid start of string: %q", b[0]) + } + + last := len(b) - 1 + if b[0] != b[last] { + return fmt.Errorf("starting string delimiter (%q) does not match end (%q)", b[0], b[last]) + } + + // strip JSON string delimiters + b = b[1:last] + + if len(b) != 2*len(id) { + return fmt.Errorf("invalid length for ID") + } + + _, err := hex.Decode(id[:], b) if err != nil { return errors.Wrap(err, "hex.Decode") } diff --git a/internal/restic/id_test.go b/internal/restic/id_test.go index 2e9634a19..ff1dc54e0 100644 --- a/internal/restic/id_test.go +++ b/internal/restic/id_test.go @@ -51,10 +51,47 @@ func TestID(t *testing.T) { var id3 ID err = id3.UnmarshalJSON(buf) if err != nil { - t.Fatal(err) + t.Fatalf("error for %q: %v", buf, err) } if !reflect.DeepEqual(id, id3) { t.Error("ids are not equal") } } } + +func TestIDUnmarshal(t *testing.T) { + var tests = []struct { + s string + valid bool + }{ + {`"`, false}, + {`""`, false}, + {`'`, false}, + {`"`, false}, + {`"c3ab8ff13720e8ad9047dd39466b3c8974e592c2fa383d4a3960714caef0c4"`, false}, + {`"c3ab8ff13720e8ad9047dd39466b3c8974e592c2fa383d4a3960714caef0c4f"`, false}, + {`"c3ab8ff13720e8ad9047dd39466b3c8974e592c2fa383d4a3960714caef0c4f2"`, true}, + } + + wantID, err := ParseID("c3ab8ff13720e8ad9047dd39466b3c8974e592c2fa383d4a3960714caef0c4f2") + if err != nil { + t.Fatal(err) + } + + for _, test := range tests { + t.Run("", func(t *testing.T) { + id := &ID{} + err := id.UnmarshalJSON([]byte(test.s)) + if test.valid && err != nil { + t.Fatal(err) + } + if !test.valid && err == nil { + t.Fatalf("want error for invalid value, got nil") + } + + if test.valid && !id.Equal(wantID) { + t.Fatalf("wrong ID returned, want %s, got %s", wantID, id) + } + }) + } +}