diff --git a/internal/restic/snapshot_find.go b/internal/restic/snapshot_find.go index 9b0aa8d82..5eae464a7 100644 --- a/internal/restic/snapshot_find.go +++ b/internal/restic/snapshot_find.go @@ -123,6 +123,8 @@ func (f *SnapshotFilter) FindLatest(ctx context.Context, be Lister, loader Loade type SnapshotFindCb func(string, *Snapshot, error) error +var ErrInvalidSnapshotSyntax = errors.New("snapshot:path syntax not allowed") + // FindAll yields Snapshots, either given explicitly by `snapshotIDs` or filtered from the list of all snapshots. func (f *SnapshotFilter) FindAll(ctx context.Context, be Lister, loader LoaderUnpacked, snapshotIDs []string, fn SnapshotFindCb) error { if len(snapshotIDs) != 0 { @@ -148,11 +150,13 @@ func (f *SnapshotFilter) FindAll(ctx context.Context, be Lister, loader LoaderUn if sn != nil { ids.Insert(*sn.ID()) } + } else if strings.HasPrefix(s, "latest:") { + err = ErrInvalidSnapshotSyntax } else { var subpath string sn, subpath, err = FindSnapshot(ctx, be, loader, s) if err == nil && subpath != "" { - err = errors.New("snapshot:path syntax not allowed") + err = ErrInvalidSnapshotSyntax } else if err == nil { if ids.Has(*sn.ID()) { continue diff --git a/internal/restic/snapshot_find_test.go b/internal/restic/snapshot_find_test.go index d330a5b01..30d9eaff4 100644 --- a/internal/restic/snapshot_find_test.go +++ b/internal/restic/snapshot_find_test.go @@ -6,6 +6,7 @@ import ( "github.com/restic/restic/internal/repository" "github.com/restic/restic/internal/restic" + "github.com/restic/restic/internal/test" ) func TestFindLatestSnapshot(t *testing.T) { @@ -43,3 +44,48 @@ func TestFindLatestSnapshotWithMaxTimestamp(t *testing.T) { t.Errorf("FindLatest returned wrong snapshot ID: %v", *sn.ID()) } } + +func TestFindLatestWithSubpath(t *testing.T) { + repo := repository.TestRepository(t) + restic.TestCreateSnapshot(t, repo, parseTimeUTC("2015-05-05 05:05:05"), 1, 0) + desiredSnapshot := restic.TestCreateSnapshot(t, repo, parseTimeUTC("2017-07-07 07:07:07"), 1, 0) + + for _, exp := range []struct { + query string + subpath string + }{ + {"latest", ""}, + {"latest:subpath", "subpath"}, + {desiredSnapshot.ID().Str(), ""}, + {desiredSnapshot.ID().Str() + ":subpath", "subpath"}, + {desiredSnapshot.ID().String(), ""}, + {desiredSnapshot.ID().String() + ":subpath", "subpath"}, + } { + t.Run("", func(t *testing.T) { + sn, subpath, err := (&restic.SnapshotFilter{}).FindLatest(context.TODO(), repo.Backend(), repo, exp.query) + if err != nil { + t.Fatalf("FindLatest returned error: %v", err) + } + + test.Assert(t, *sn.ID() == *desiredSnapshot.ID(), "FindLatest returned wrong snapshot ID: %v", *sn.ID()) + test.Assert(t, subpath == exp.subpath, "FindLatest returned wrong path in snapshot: %v", subpath) + }) + } +} + +func TestFindAllSubpathError(t *testing.T) { + repo := repository.TestRepository(t) + desiredSnapshot := restic.TestCreateSnapshot(t, repo, parseTimeUTC("2017-07-07 07:07:07"), 1, 0) + + count := 0 + test.OK(t, (&restic.SnapshotFilter{}).FindAll(context.TODO(), repo.Backend(), repo, + []string{"latest:subpath", desiredSnapshot.ID().Str() + ":subpath"}, + func(id string, sn *restic.Snapshot, err error) error { + if err == restic.ErrInvalidSnapshotSyntax { + count++ + return nil + } + return err + })) + test.Assert(t, count == 2, "unexpected number of subpath errors: %v, wanted %v", count, 2) +}