2
2
mirror of https://github.com/octoleo/restic.git synced 2024-11-26 14:56:29 +00:00

Merge pull request #1571 from restic/rework-backend-list

Rework backend list
This commit is contained in:
Alexander Neumann 2018-01-24 19:43:07 +01:00
commit 44550a88a0
67 changed files with 3324 additions and 678 deletions

8
Gopkg.lock generated
View File

@ -187,6 +187,12 @@
packages = [".","google","internal","jws","jwt"]
revision = "f95fa95eaa936d9d87489b15d1d18b97c1ba9c28"
[[projects]]
branch = "master"
name = "golang.org/x/sync"
packages = ["errgroup"]
revision = "fd80eb99c8f653c847d294a001bdf2a3a6f768f5"
[[projects]]
branch = "master"
name = "golang.org/x/sys"
@ -214,6 +220,6 @@
[solve-meta]
analyzer-name = "dep"
analyzer-version = 1
inputs-digest = "f0a207197cb502238ac87ca8e07b2640c02ec380a50b036e09ef87e40e31ca2d"
inputs-digest = "a7d099b3ce195ffc37adedb05a4386be38e6158925a1c0fe579efdc20fa11f6a"
solver-name = "gps-cdcl"
solver-version = 1

View File

@ -15,8 +15,6 @@ import (
"github.com/restic/restic/internal/pack"
"github.com/restic/restic/internal/repository"
"github.com/restic/restic/internal/restic"
"github.com/restic/restic/internal/worker"
)
var cmdDebug = &cobra.Command{
@ -52,26 +50,18 @@ func prettyPrintJSON(wr io.Writer, item interface{}) error {
}
func debugPrintSnapshots(repo *repository.Repository, wr io.Writer) error {
for id := range repo.List(context.TODO(), restic.SnapshotFile) {
return repo.List(context.TODO(), restic.SnapshotFile, func(id restic.ID, size int64) error {
snapshot, err := restic.LoadSnapshot(context.TODO(), repo, id)
if err != nil {
fmt.Fprintf(os.Stderr, "LoadSnapshot(%v): %v", id.Str(), err)
continue
return err
}
fmt.Fprintf(wr, "snapshot_id: %v\n", id)
err = prettyPrintJSON(wr, snapshot)
if err != nil {
return err
}
}
return nil
return prettyPrintJSON(wr, snapshot)
})
}
const dumpPackWorkers = 10
// Pack is the struct used in printPacks.
type Pack struct {
Name string `json:"name"`
@ -88,49 +78,21 @@ type Blob struct {
}
func printPacks(repo *repository.Repository, wr io.Writer) error {
f := func(ctx context.Context, job worker.Job) (interface{}, error) {
name := job.Data.(string)
h := restic.Handle{Type: restic.DataFile, Name: name}
return repo.List(context.TODO(), restic.DataFile, func(id restic.ID, size int64) error {
h := restic.Handle{Type: restic.DataFile, Name: id.String()}
blobInfo, err := repo.Backend().Stat(ctx, h)
blobs, err := pack.List(repo.Key(), restic.ReaderAt(repo.Backend(), h), size)
if err != nil {
return nil, err
fmt.Fprintf(os.Stderr, "error for pack %v: %v\n", id.Str(), err)
return nil
}
blobs, err := pack.List(repo.Key(), restic.ReaderAt(repo.Backend(), h), blobInfo.Size)
if err != nil {
return nil, err
}
return blobs, nil
}
jobCh := make(chan worker.Job)
resCh := make(chan worker.Job)
wp := worker.New(context.TODO(), dumpPackWorkers, f, jobCh, resCh)
go func() {
for name := range repo.Backend().List(context.TODO(), restic.DataFile) {
jobCh <- worker.Job{Data: name}
}
close(jobCh)
}()
for job := range resCh {
name := job.Data.(string)
if job.Error != nil {
fmt.Fprintf(os.Stderr, "error for pack %v: %v\n", name, job.Error)
continue
}
entries := job.Result.([]restic.Blob)
p := Pack{
Name: name,
Blobs: make([]Blob, len(entries)),
Name: id.String(),
Blobs: make([]Blob, len(blobs)),
}
for i, blob := range entries {
for i, blob := range blobs {
p.Blobs[i] = Blob{
Type: blob.Type,
Length: blob.Length,
@ -139,16 +101,14 @@ func printPacks(repo *repository.Repository, wr io.Writer) error {
}
}
prettyPrintJSON(os.Stdout, p)
}
wp.Wait()
return prettyPrintJSON(os.Stdout, p)
})
return nil
}
func dumpIndexes(repo restic.Repository) error {
for id := range repo.List(context.TODO(), restic.IndexFile) {
return repo.List(context.TODO(), restic.IndexFile, func(id restic.ID, size int64) error {
fmt.Printf("index_id: %v\n", id)
idx, err := repository.LoadIndex(context.TODO(), repo, id)
@ -156,13 +116,8 @@ func dumpIndexes(repo restic.Repository) error {
return err
}
err = idx.Dump(os.Stdout)
if err != nil {
return err
}
}
return nil
return idx.Dump(os.Stdout)
})
}
func runDebugDump(gopts GlobalOptions, args []string) error {

View File

@ -32,11 +32,11 @@ func listKeys(ctx context.Context, s *repository.Repository) error {
tab.Header = fmt.Sprintf(" %-10s %-10s %-10s %s", "ID", "User", "Host", "Created")
tab.RowFormat = "%s%-10s %-10s %-10s %s"
for id := range s.List(ctx, restic.KeyFile) {
err := s.List(ctx, restic.KeyFile, func(id restic.ID, size int64) error {
k, err := repository.LoadKey(ctx, s, id.String())
if err != nil {
Warnf("LoadKey() failed: %v\n", err)
continue
return nil
}
var current string
@ -47,6 +47,10 @@ func listKeys(ctx context.Context, s *repository.Repository) error {
}
tab.Rows = append(tab.Rows, []interface{}{current, id.Str(),
k.Username, k.Hostname, k.Created.Format(TimeFormat)})
return nil
})
if err != nil {
return err
}
return tab.Write(globalOptions.stdout)

View File

@ -73,9 +73,8 @@ func runList(opts GlobalOptions, args []string) error {
return errors.Fatal("invalid type")
}
for id := range repo.List(opts.ctx, t) {
return repo.List(opts.ctx, t, func(id restic.ID, size int64) error {
Printf("%s\n", id)
}
return nil
return nil
})
}

View File

@ -120,8 +120,12 @@ func pruneRepository(gopts GlobalOptions, repo restic.Repository) error {
}
Verbosef("counting files in repo\n")
for range repo.List(ctx, restic.DataFile) {
err = repo.List(ctx, restic.DataFile, func(restic.ID, int64) error {
stats.packs++
return nil
})
if err != nil {
return err
}
Verbosef("building new index for repo\n")

View File

@ -48,8 +48,12 @@ func rebuildIndex(ctx context.Context, repo restic.Repository, ignorePacks resti
Verbosef("counting files in repo\n")
var packs uint64
for range repo.List(ctx, restic.DataFile) {
err := repo.List(ctx, restic.DataFile, func(restic.ID, int64) error {
packs++
return nil
})
if err != nil {
return err
}
bar := newProgressMax(!globalOptions.Quiet, packs-uint64(len(ignorePacks)), "packs")
@ -61,8 +65,12 @@ func rebuildIndex(ctx context.Context, repo restic.Repository, ignorePacks resti
Verbosef("finding old index files\n")
var supersedes restic.IDs
for id := range repo.List(ctx, restic.IndexFile) {
err = repo.List(ctx, restic.IndexFile, func(id restic.ID, size int64) error {
supersedes = append(supersedes, id)
return nil
})
if err != nil {
return err
}
id, err := idx.Save(ctx, repo, supersedes)

View File

@ -58,7 +58,13 @@ func FindFilteredSnapshots(ctx context.Context, repo *repository.Repository, hos
return
}
for _, sn := range restic.FindFilteredSnapshots(ctx, repo, host, tags, paths) {
snapshots, err := restic.FindFilteredSnapshots(ctx, repo, host, tags, paths)
if err != nil {
Warnf("could not load snapshots: %v\n", err)
return
}
for _, sn := range snapshots {
select {
case <-ctx.Done():
return

View File

@ -658,10 +658,30 @@ REST Backend
************
Restic can interact with HTTP Backend that respects the following REST
API. The following values are valid for ``{type}``: ``data``, ``keys``,
``locks``, ``snapshots``, ``index``, ``config``. ``{path}`` is a path to
the repository, so that multiple different repositories can be accessed.
The default path is ``/``.
API.
The following values are valid for ``{type}``:
* ``data``
* ``keys``
* ``locks``
* ``snapshots``
* ``index``
* ``config``
The API version is selected via the ``Accept`` HTTP header in the request. The
following values are defined:
* ``application/vnd.x.restic.rest.v1+json`` or empty: Select API version 1
* ``application/vnd.x.restic.rest.v2+json``: Select API version 2
The server will respond with the value of the highest version it supports in
the ``Content-Type`` HTTP response header for the HTTP requests which should
return JSON. Any different value for this header means API version 1.
The placeholder ``{path}`` in this document is a path to the repository, so
that multiple different repositories can be accessed. The default path is
``/``.
POST {path}?create=true
=======================
@ -701,10 +721,48 @@ saved, an HTTP error otherwise.
GET {path}/{type}/
==================
Returns a JSON array containing the names of all the blobs stored for a
given type.
API version 1
-------------
Response format: JSON
Returns a JSON array containing the names of all the blobs stored for a given
type, example:
.. code:: json
[
"245bc4c430d393f74fbe7b13325e30dbde9fb0745e50caad57c446c93d20096b",
"85b420239efa1132c41cea0065452a40ebc20c6f8e0b132a5b2f5848360973ec",
"8e2006bb5931a520f3c7009fe278d1ebb87eb72c3ff92a50c30e90f1b8cf3e60",
"e75c8c407ea31ba399ab4109f28dd18c4c68303d8d86cc275432820c42ce3649"
]
API version 2
-------------
Returns a JSON array containing an object for each file of the given type. The
objects have two keys: ``name`` for the file name, and ``size`` for the size in
bytes.
.. code:: json
[
{
"name": "245bc4c430d393f74fbe7b13325e30dbde9fb0745e50caad57c446c93d20096b",
"size": 2341058
},
{
"name": "85b420239efa1132c41cea0065452a40ebc20c6f8e0b132a5b2f5848360973ec",
"size": 2908900
},
{
"name": "8e2006bb5931a520f3c7009fe278d1ebb87eb72c3ff92a50c30e90f1b8cf3e60",
"size": 3030712
},
{
"name": "e75c8c407ea31ba399ab4109f28dd18c4c68303d8d86cc275432820c42ce3649",
"size": 2804
}
]
HEAD {path}/{type}/{name}
=========================

View File

@ -135,8 +135,12 @@ func (e errReader) Read([]byte) (int, error) {
func countSnapshots(t testing.TB, repo restic.Repository) int {
snapshots := 0
for range repo.List(context.TODO(), restic.SnapshotFile) {
err := repo.List(context.TODO(), restic.SnapshotFile, func(id restic.ID, size int64) error {
snapshots++
return nil
})
if err != nil {
t.Fatal(err)
}
return snapshots
}

View File

@ -60,10 +60,8 @@ func forgetfulBackend() restic.Backend {
return nil
}
be.ListFn = func(ctx context.Context, t restic.FileType) <-chan string {
ch := make(chan string)
close(ch)
return ch
be.ListFn = func(ctx context.Context, t restic.FileType, fn func(restic.FileInfo) error) error {
return nil
}
be.DeleteFn = func(ctx context.Context) error {

View File

@ -131,9 +131,13 @@ func BenchmarkArchiveDirectory(b *testing.B) {
}
}
func countPacks(repo restic.Repository, t restic.FileType) (n uint) {
for range repo.Backend().List(context.TODO(), t) {
func countPacks(t testing.TB, repo restic.Repository, tpe restic.FileType) (n uint) {
err := repo.Backend().List(context.TODO(), tpe, func(restic.FileInfo) error {
n++
return nil
})
if err != nil {
t.Fatal(err)
}
return n
@ -158,7 +162,7 @@ func archiveWithDedup(t testing.TB) {
t.Logf("archived snapshot %v", sn.ID().Str())
// get archive stats
cnt.before.packs = countPacks(repo, restic.DataFile)
cnt.before.packs = countPacks(t, repo, restic.DataFile)
cnt.before.dataBlobs = repo.Index().Count(restic.DataBlob)
cnt.before.treeBlobs = repo.Index().Count(restic.TreeBlob)
t.Logf("packs %v, data blobs %v, tree blobs %v",
@ -169,7 +173,7 @@ func archiveWithDedup(t testing.TB) {
t.Logf("archived snapshot %v", sn2.ID().Str())
// get archive stats again
cnt.after.packs = countPacks(repo, restic.DataFile)
cnt.after.packs = countPacks(t, repo, restic.DataFile)
cnt.after.dataBlobs = repo.Index().Count(restic.DataBlob)
cnt.after.treeBlobs = repo.Index().Count(restic.TreeBlob)
t.Logf("packs %v, data blobs %v, tree blobs %v",
@ -186,7 +190,7 @@ func archiveWithDedup(t testing.TB) {
t.Logf("archived snapshot %v, parent %v", sn3.ID().Str(), sn2.ID().Str())
// get archive stats again
cnt.after2.packs = countPacks(repo, restic.DataFile)
cnt.after2.packs = countPacks(t, repo, restic.DataFile)
cnt.after2.dataBlobs = repo.Index().Count(restic.DataBlob)
cnt.after2.treeBlobs = repo.Index().Count(restic.TreeBlob)
t.Logf("packs %v, data blobs %v, tree blobs %v",

View File

@ -242,7 +242,11 @@ func (be *Backend) Stat(ctx context.Context, h restic.Handle) (restic.FileInfo,
return restic.FileInfo{}, errors.Wrap(err, "blob.GetProperties")
}
return restic.FileInfo{Size: int64(blob.Properties.ContentLength)}, nil
fi := restic.FileInfo{
Size: int64(blob.Properties.ContentLength),
Name: h.Name,
}
return fi, nil
}
// Test returns true if a blob of the given type and name exists in the backend.
@ -271,17 +275,15 @@ func (be *Backend) Remove(ctx context.Context, h restic.Handle) error {
return errors.Wrap(err, "client.RemoveObject")
}
// List returns a channel that yields all names of blobs of type t. A
// goroutine is started for this. If the channel done is closed, sending
// stops.
func (be *Backend) List(ctx context.Context, t restic.FileType) <-chan string {
// List runs fn for each file in the backend which has the type t. When an
// error occurs (or fn returns an error), List stops and returns it.
func (be *Backend) List(ctx context.Context, t restic.FileType, fn func(restic.FileInfo) error) error {
debug.Log("listing %v", t)
ch := make(chan string)
prefix, _ := be.Basedir(t)
// make sure prefix ends with a slash
if prefix[len(prefix)-1] != '/' {
if !strings.HasSuffix(prefix, "/") {
prefix += "/"
}
@ -290,53 +292,57 @@ func (be *Backend) List(ctx context.Context, t restic.FileType) <-chan string {
Prefix: prefix,
}
go func() {
defer close(ch)
for {
be.sem.GetToken()
obj, err := be.container.ListBlobs(params)
be.sem.ReleaseToken()
for {
be.sem.GetToken()
obj, err := be.container.ListBlobs(params)
be.sem.ReleaseToken()
if err != nil {
return
}
debug.Log("got %v objects", len(obj.Blobs))
for _, item := range obj.Blobs {
m := strings.TrimPrefix(item.Name, prefix)
if m == "" {
continue
}
select {
case ch <- path.Base(m):
case <-ctx.Done():
return
}
}
if obj.NextMarker == "" {
break
}
params.Marker = obj.NextMarker
if err != nil {
return err
}
}()
return ch
debug.Log("got %v objects", len(obj.Blobs))
for _, item := range obj.Blobs {
m := strings.TrimPrefix(item.Name, prefix)
if m == "" {
continue
}
fi := restic.FileInfo{
Name: path.Base(m),
Size: item.Properties.ContentLength,
}
if ctx.Err() != nil {
return ctx.Err()
}
err := fn(fi)
if err != nil {
return err
}
if ctx.Err() != nil {
return ctx.Err()
}
}
if obj.NextMarker == "" {
break
}
params.Marker = obj.NextMarker
}
return ctx.Err()
}
// Remove keys for a specified backend type.
func (be *Backend) removeKeys(ctx context.Context, t restic.FileType) error {
for key := range be.List(ctx, restic.DataFile) {
err := be.Remove(ctx, restic.Handle{Type: restic.DataFile, Name: key})
if err != nil {
return err
}
}
return nil
return be.List(ctx, t, func(fi restic.FileInfo) error {
return be.Remove(ctx, restic.Handle{Type: t, Name: fi.Name})
})
}
// Delete removes all restic keys in the bucket. It will not remove the bucket itself.

View File

@ -228,7 +228,7 @@ func (be *b2Backend) Stat(ctx context.Context, h restic.Handle) (bi restic.FileI
debug.Log("Attrs() err %v", err)
return restic.FileInfo{}, errors.Wrap(err, "Stat")
}
return restic.FileInfo{Size: info.Size}, nil
return restic.FileInfo{Size: info.Size, Name: h.Name}, nil
}
// Test returns true if a blob of the given type and name exists in the backend.
@ -262,66 +262,76 @@ func (be *b2Backend) Remove(ctx context.Context, h restic.Handle) error {
// List returns a channel that yields all names of blobs of type t. A
// goroutine is started for this. If the channel done is closed, sending
// stops.
func (be *b2Backend) List(ctx context.Context, t restic.FileType) <-chan string {
func (be *b2Backend) List(ctx context.Context, t restic.FileType, fn func(restic.FileInfo) error) error {
debug.Log("List %v", t)
ch := make(chan string)
prefix, _ := be.Basedir(t)
cur := &b2.Cursor{Prefix: prefix}
ctx, cancel := context.WithCancel(ctx)
defer cancel()
go func() {
defer close(ch)
defer cancel()
for {
be.sem.GetToken()
objs, c, err := be.bucket.ListCurrentObjects(ctx, be.listMaxItems, cur)
be.sem.ReleaseToken()
prefix, _ := be.Basedir(t)
cur := &b2.Cursor{Prefix: prefix}
for {
be.sem.GetToken()
objs, c, err := be.bucket.ListCurrentObjects(ctx, be.listMaxItems, cur)
be.sem.ReleaseToken()
if err != nil && err != io.EOF {
// TODO: return err to caller once err handling in List() is improved
debug.Log("List: %v", err)
return
}
debug.Log("returned %v items", len(objs))
for _, obj := range objs {
// Skip objects returned that do not have the specified prefix.
if !strings.HasPrefix(obj.Name(), prefix) {
continue
}
m := path.Base(obj.Name())
if m == "" {
continue
}
select {
case ch <- m:
case <-ctx.Done():
return
}
}
if err == io.EOF {
return
}
cur = c
if err != nil && err != io.EOF {
debug.Log("List: %v", err)
return err
}
}()
return ch
debug.Log("returned %v items", len(objs))
for _, obj := range objs {
// Skip objects returned that do not have the specified prefix.
if !strings.HasPrefix(obj.Name(), prefix) {
continue
}
m := path.Base(obj.Name())
if m == "" {
continue
}
if ctx.Err() != nil {
return ctx.Err()
}
attrs, err := obj.Attrs(ctx)
if err != nil {
return err
}
fi := restic.FileInfo{
Name: m,
Size: attrs.Size,
}
err = fn(fi)
if err != nil {
return err
}
if ctx.Err() != nil {
return ctx.Err()
}
}
if err == io.EOF {
return ctx.Err()
}
cur = c
}
return ctx.Err()
}
// Remove keys for a specified backend type.
func (be *b2Backend) removeKeys(ctx context.Context, t restic.FileType) error {
debug.Log("removeKeys %v", t)
for key := range be.List(ctx, t) {
err := be.Remove(ctx, restic.Handle{Type: t, Name: key})
if err != nil {
return err
}
}
return nil
return be.List(ctx, t, func(fi restic.FileInfo) error {
return be.Remove(ctx, restic.Handle{Type: t, Name: fi.Name})
})
}
// Delete removes all restic keys in the bucket. It will not remove the bucket itself.

View File

@ -333,7 +333,7 @@ func (be *Backend) Stat(ctx context.Context, h restic.Handle) (bi restic.FileInf
return restic.FileInfo{}, errors.Wrap(err, "service.Objects.Get")
}
return restic.FileInfo{Size: int64(obj.Size)}, nil
return restic.FileInfo{Size: int64(obj.Size), Name: h.Name}, nil
}
// Test returns true if a blob of the given type and name exists in the backend.
@ -370,69 +370,72 @@ func (be *Backend) Remove(ctx context.Context, h restic.Handle) error {
return errors.Wrap(err, "client.RemoveObject")
}
// List returns a channel that yields all names of blobs of type t. A
// goroutine is started for this. If the channel done is closed, sending
// stops.
func (be *Backend) List(ctx context.Context, t restic.FileType) <-chan string {
// List runs fn for each file in the backend which has the type t. When an
// error occurs (or fn returns an error), List stops and returns it.
func (be *Backend) List(ctx context.Context, t restic.FileType, fn func(restic.FileInfo) error) error {
debug.Log("listing %v", t)
ch := make(chan string)
prefix, _ := be.Basedir(t)
// make sure prefix ends with a slash
if prefix[len(prefix)-1] != '/' {
if !strings.HasSuffix(prefix, "/") {
prefix += "/"
}
go func() {
defer close(ch)
ctx, cancel := context.WithCancel(ctx)
defer cancel()
listReq := be.service.Objects.List(be.bucketName).Prefix(prefix).MaxResults(int64(be.listMaxItems))
for {
be.sem.GetToken()
obj, err := listReq.Do()
be.sem.ReleaseToken()
listReq := be.service.Objects.List(be.bucketName).Context(ctx).Prefix(prefix).MaxResults(int64(be.listMaxItems))
for {
be.sem.GetToken()
obj, err := listReq.Do()
be.sem.ReleaseToken()
if err != nil {
fmt.Fprintf(os.Stderr, "error listing %v: %v\n", prefix, err)
return
}
debug.Log("returned %v items", len(obj.Items))
for _, item := range obj.Items {
m := strings.TrimPrefix(item.Name, prefix)
if m == "" {
continue
}
select {
case ch <- path.Base(m):
case <-ctx.Done():
return
}
}
if obj.NextPageToken == "" {
break
}
listReq.PageToken(obj.NextPageToken)
if err != nil {
return err
}
}()
return ch
debug.Log("returned %v items", len(obj.Items))
for _, item := range obj.Items {
m := strings.TrimPrefix(item.Name, prefix)
if m == "" {
continue
}
if ctx.Err() != nil {
return ctx.Err()
}
fi := restic.FileInfo{
Name: path.Base(m),
Size: int64(item.Size),
}
err := fn(fi)
if err != nil {
return err
}
if ctx.Err() != nil {
return ctx.Err()
}
}
if obj.NextPageToken == "" {
break
}
listReq.PageToken(obj.NextPageToken)
}
return ctx.Err()
}
// Remove keys for a specified backend type.
func (be *Backend) removeKeys(ctx context.Context, t restic.FileType) error {
for key := range be.List(ctx, restic.DataFile) {
err := be.Remove(ctx, restic.Handle{Type: restic.DataFile, Name: key})
if err != nil {
return err
}
}
return nil
return be.List(ctx, t, func(fi restic.FileInfo) error {
return be.Remove(ctx, restic.Handle{Type: t, Name: fi.Name})
})
}
// Delete removes all restic keys in the bucket. It will not remove the bucket itself.

View File

@ -49,8 +49,13 @@ func TestLayout(t *testing.T) {
}
datafiles := make(map[string]bool)
for id := range be.List(context.TODO(), restic.DataFile) {
datafiles[id] = false
err = be.List(context.TODO(), restic.DataFile, func(fi restic.FileInfo) error {
datafiles[fi.Name] = false
return nil
})
if err != nil {
t.Fatalf("List() returned error %v", err)
}
if len(datafiles) == 0 {

View File

@ -191,7 +191,7 @@ func (b *Local) Stat(ctx context.Context, h restic.Handle) (restic.FileInfo, err
return restic.FileInfo{}, errors.Wrap(err, "Stat")
}
return restic.FileInfo{Size: fi.Size()}, nil
return restic.FileInfo{Size: fi.Size(), Name: h.Name}, nil
}
// Test returns true if a blob of the given type and name exists in the backend.
@ -226,52 +226,48 @@ func isFile(fi os.FileInfo) bool {
return fi.Mode()&(os.ModeType|os.ModeCharDevice) == 0
}
// List returns a channel that yields all names of blobs of type t. A
// goroutine is started for this.
func (b *Local) List(ctx context.Context, t restic.FileType) <-chan string {
// List runs fn for each file in the backend which has the type t. When an
// error occurs (or fn returns an error), List stops and returns it.
func (b *Local) List(ctx context.Context, t restic.FileType, fn func(restic.FileInfo) error) error {
debug.Log("List %v", t)
ch := make(chan string)
go func() {
defer close(ch)
basedir, subdirs := b.Basedir(t)
err := fs.Walk(basedir, func(path string, fi os.FileInfo, err error) error {
debug.Log("walk on %v\n", path)
if err != nil {
return err
}
if path == basedir {
return nil
}
if !isFile(fi) {
return nil
}
if fi.IsDir() && !subdirs {
return filepath.SkipDir
}
debug.Log("send %v\n", filepath.Base(path))
select {
case ch <- filepath.Base(path):
case <-ctx.Done():
return nil
}
return nil
})
basedir, subdirs := b.Basedir(t)
return fs.Walk(basedir, func(path string, fi os.FileInfo, err error) error {
debug.Log("walk on %v\n", path)
if err != nil {
debug.Log("Walk %v", err)
return err
}
}()
return ch
if path == basedir {
return nil
}
if !isFile(fi) {
return nil
}
if fi.IsDir() && !subdirs {
return filepath.SkipDir
}
debug.Log("send %v\n", filepath.Base(path))
rfi := restic.FileInfo{
Name: filepath.Base(path),
Size: fi.Size(),
}
if ctx.Err() != nil {
return ctx.Err()
}
err = fn(rfi)
if err != nil {
return err
}
return ctx.Err()
})
}
// Delete removes the repository and all files.

View File

@ -143,7 +143,7 @@ func (be *MemoryBackend) Stat(ctx context.Context, h restic.Handle) (restic.File
return restic.FileInfo{}, errNotFound
}
return restic.FileInfo{Size: int64(len(e))}, nil
return restic.FileInfo{Size: int64(len(e)), Name: h.Name}, nil
}
// Remove deletes a file from the backend.
@ -163,34 +163,40 @@ func (be *MemoryBackend) Remove(ctx context.Context, h restic.Handle) error {
}
// List returns a channel which yields entries from the backend.
func (be *MemoryBackend) List(ctx context.Context, t restic.FileType) <-chan string {
func (be *MemoryBackend) List(ctx context.Context, t restic.FileType, fn func(restic.FileInfo) error) error {
entries := make(map[string]int64)
be.m.Lock()
defer be.m.Unlock()
ch := make(chan string)
var ids []string
for entry := range be.data {
for entry, buf := range be.data {
if entry.Type != t {
continue
}
ids = append(ids, entry.Name)
entries[entry.Name] = int64(len(buf))
}
be.m.Unlock()
for name, size := range entries {
fi := restic.FileInfo{
Name: name,
Size: size,
}
if ctx.Err() != nil {
return ctx.Err()
}
err := fn(fi)
if err != nil {
return err
}
if ctx.Err() != nil {
return ctx.Err()
}
}
debug.Log("list %v: %v", t, ids)
go func() {
defer close(ch)
for _, id := range ids {
select {
case ch <- id:
case <-ctx.Done():
return
}
}
}()
return ch
return ctx.Err()
}
// Location returns the location of the backend (RAM).

View File

@ -30,6 +30,11 @@ type restBackend struct {
backend.Layout
}
const (
contentTypeV1 = "application/vnd.x.restic.rest.v1"
contentTypeV2 = "application/vnd.x.restic.rest.v2"
)
// Open opens the REST backend with the given config.
func Open(cfg Config, rt http.RoundTripper) (*restBackend, error) {
client := &http.Client{Transport: rt}
@ -111,8 +116,15 @@ func (b *restBackend) Save(ctx context.Context, h restic.Handle, rd io.Reader) (
// make sure that client.Post() cannot close the reader by wrapping it
rd = ioutil.NopCloser(rd)
req, err := http.NewRequest(http.MethodPost, b.Filename(h), rd)
if err != nil {
return errors.Wrap(err, "NewRequest")
}
req.Header.Set("Content-Type", "application/octet-stream")
req.Header.Set("Accept", contentTypeV2)
b.sem.GetToken()
resp, err := ctxhttp.Post(ctx, b.client, b.Filename(h), "binary/octet-stream", rd)
resp, err := ctxhttp.Do(ctx, b.client, req)
b.sem.ReleaseToken()
if resp != nil {
@ -180,7 +192,8 @@ func (b *restBackend) Load(ctx context.Context, h restic.Handle, length int, off
if length > 0 {
byteRange = fmt.Sprintf("bytes=%d-%d", offset, offset+int64(length)-1)
}
req.Header.Add("Range", byteRange)
req.Header.Set("Range", byteRange)
req.Header.Set("Accept", contentTypeV2)
debug.Log("Load(%v) send range %v", h, byteRange)
b.sem.GetToken()
@ -214,8 +227,14 @@ func (b *restBackend) Stat(ctx context.Context, h restic.Handle) (restic.FileInf
return restic.FileInfo{}, err
}
req, err := http.NewRequest(http.MethodHead, b.Filename(h), nil)
if err != nil {
return restic.FileInfo{}, errors.Wrap(err, "NewRequest")
}
req.Header.Set("Accept", contentTypeV2)
b.sem.GetToken()
resp, err := ctxhttp.Head(ctx, b.client, b.Filename(h))
resp, err := ctxhttp.Do(ctx, b.client, req)
b.sem.ReleaseToken()
if err != nil {
return restic.FileInfo{}, errors.Wrap(err, "client.Head")
@ -241,6 +260,7 @@ func (b *restBackend) Stat(ctx context.Context, h restic.Handle) (restic.FileInf
bi := restic.FileInfo{
Size: resp.ContentLength,
Name: h.Name,
}
return bi, nil
@ -266,6 +286,8 @@ func (b *restBackend) Remove(ctx context.Context, h restic.Handle) error {
if err != nil {
return errors.Wrap(err, "http.NewRequest")
}
req.Header.Set("Accept", contentTypeV2)
b.sem.GetToken()
resp, err := ctxhttp.Do(ctx, b.client, req)
b.sem.ReleaseToken()
@ -291,56 +313,105 @@ func (b *restBackend) Remove(ctx context.Context, h restic.Handle) error {
return errors.Wrap(resp.Body.Close(), "Close")
}
// List returns a channel that yields all names of blobs of type t. A
// goroutine is started for this. If the channel done is closed, sending
// stops.
func (b *restBackend) List(ctx context.Context, t restic.FileType) <-chan string {
ch := make(chan string)
// List runs fn for each file in the backend which has the type t. When an
// error occurs (or fn returns an error), List stops and returns it.
func (b *restBackend) List(ctx context.Context, t restic.FileType, fn func(restic.FileInfo) error) error {
url := b.Dirname(restic.Handle{Type: t})
if !strings.HasSuffix(url, "/") {
url += "/"
}
req, err := http.NewRequest(http.MethodGet, url, nil)
if err != nil {
return errors.Wrap(err, "NewRequest")
}
req.Header.Set("Accept", contentTypeV2)
b.sem.GetToken()
resp, err := ctxhttp.Get(ctx, b.client, url)
resp, err := ctxhttp.Do(ctx, b.client, req)
b.sem.ReleaseToken()
if resp != nil {
defer func() {
_, _ = io.Copy(ioutil.Discard, resp.Body)
e := resp.Body.Close()
if err == nil {
err = errors.Wrap(e, "Close")
}
}()
}
if err != nil {
close(ch)
return ch
return errors.Wrap(err, "Get")
}
if resp.Header.Get("Content-Type") == contentTypeV2 {
return b.listv2(ctx, t, resp, fn)
}
return b.listv1(ctx, t, resp, fn)
}
// listv1 uses the REST protocol v1, where a list HTTP request (e.g. `GET
// /data/`) only returns the names of the files, so we need to issue an HTTP
// HEAD request for each file.
func (b *restBackend) listv1(ctx context.Context, t restic.FileType, resp *http.Response, fn func(restic.FileInfo) error) error {
debug.Log("parsing API v1 response")
dec := json.NewDecoder(resp.Body)
var list []string
if err = dec.Decode(&list); err != nil {
close(ch)
return ch
if err := dec.Decode(&list); err != nil {
return errors.Wrap(err, "Decode")
}
go func() {
defer close(ch)
for _, m := range list {
select {
case ch <- m:
case <-ctx.Done():
return
}
for _, m := range list {
fi, err := b.Stat(ctx, restic.Handle{Name: m, Type: t})
if err != nil {
return err
}
}()
return ch
if ctx.Err() != nil {
return ctx.Err()
}
fi.Name = m
err = fn(fi)
if err != nil {
return err
}
if ctx.Err() != nil {
return ctx.Err()
}
}
return ctx.Err()
}
// listv2 uses the REST protocol v2, where a list HTTP request (e.g. `GET
// /data/`) returns the names and sizes of all files.
func (b *restBackend) listv2(ctx context.Context, t restic.FileType, resp *http.Response, fn func(restic.FileInfo) error) error {
debug.Log("parsing API v2 response")
dec := json.NewDecoder(resp.Body)
var list []struct {
Name string `json:"name"`
Size int64 `json:"size"`
}
if err := dec.Decode(&list); err != nil {
return errors.Wrap(err, "Decode")
}
for _, item := range list {
if ctx.Err() != nil {
return ctx.Err()
}
fi := restic.FileInfo{
Name: item.Name,
Size: item.Size,
}
err := fn(fi)
if err != nil {
return err
}
if ctx.Err() != nil {
return ctx.Err()
}
}
return ctx.Err()
}
// Close closes all open files.
@ -352,14 +423,9 @@ func (b *restBackend) Close() error {
// Remove keys for a specified backend type.
func (b *restBackend) removeKeys(ctx context.Context, t restic.FileType) error {
for key := range b.List(ctx, restic.DataFile) {
err := b.Remove(ctx, restic.Handle{Type: restic.DataFile, Name: key})
if err != nil {
return err
}
}
return nil
return b.List(ctx, t, func(fi restic.FileInfo) error {
return b.Remove(ctx, restic.Handle{Type: t, Name: fi.Name})
})
}
// Delete removes all data in the backend.

View File

@ -0,0 +1,150 @@
package rest_test
import (
"context"
"fmt"
"net/http"
"net/http/httptest"
"net/url"
"reflect"
"strconv"
"testing"
"github.com/restic/restic/internal/backend/rest"
"github.com/restic/restic/internal/restic"
)
func TestListAPI(t *testing.T) {
var tests = []struct {
Name string
ContentType string // response header
Data string // response data
Requests int
Result []restic.FileInfo
}{
{
Name: "content-type-unknown",
ContentType: "application/octet-stream",
Data: `[
"1122e6749358b057fa1ac6b580a0fbe7a9a5fbc92e82743ee21aaf829624a985",
"3b6ec1af8d4f7099d0445b12fdb75b166ba19f789e5c48350c423dc3b3e68352",
"8271d221a60e0058e6c624f248d0080fc04f4fac07a28584a9b89d0eb69e189b"
]`,
Result: []restic.FileInfo{
{Name: "1122e6749358b057fa1ac6b580a0fbe7a9a5fbc92e82743ee21aaf829624a985", Size: 4386},
{Name: "3b6ec1af8d4f7099d0445b12fdb75b166ba19f789e5c48350c423dc3b3e68352", Size: 15214},
{Name: "8271d221a60e0058e6c624f248d0080fc04f4fac07a28584a9b89d0eb69e189b", Size: 33393},
},
Requests: 4,
},
{
Name: "content-type-v1",
ContentType: "application/vnd.x.restic.rest.v1",
Data: `[
"1122e6749358b057fa1ac6b580a0fbe7a9a5fbc92e82743ee21aaf829624a985",
"3b6ec1af8d4f7099d0445b12fdb75b166ba19f789e5c48350c423dc3b3e68352",
"8271d221a60e0058e6c624f248d0080fc04f4fac07a28584a9b89d0eb69e189b"
]`,
Result: []restic.FileInfo{
{Name: "1122e6749358b057fa1ac6b580a0fbe7a9a5fbc92e82743ee21aaf829624a985", Size: 4386},
{Name: "3b6ec1af8d4f7099d0445b12fdb75b166ba19f789e5c48350c423dc3b3e68352", Size: 15214},
{Name: "8271d221a60e0058e6c624f248d0080fc04f4fac07a28584a9b89d0eb69e189b", Size: 33393},
},
Requests: 4,
},
{
Name: "content-type-v2",
ContentType: "application/vnd.x.restic.rest.v2",
Data: `[
{"name": "1122e6749358b057fa1ac6b580a0fbe7a9a5fbc92e82743ee21aaf829624a985", "size": 1001},
{"name": "3b6ec1af8d4f7099d0445b12fdb75b166ba19f789e5c48350c423dc3b3e68352", "size": 1002},
{"name": "8271d221a60e0058e6c624f248d0080fc04f4fac07a28584a9b89d0eb69e189b", "size": 1003}
]`,
Result: []restic.FileInfo{
{Name: "1122e6749358b057fa1ac6b580a0fbe7a9a5fbc92e82743ee21aaf829624a985", Size: 1001},
{Name: "3b6ec1af8d4f7099d0445b12fdb75b166ba19f789e5c48350c423dc3b3e68352", Size: 1002},
{Name: "8271d221a60e0058e6c624f248d0080fc04f4fac07a28584a9b89d0eb69e189b", Size: 1003},
},
Requests: 1,
},
}
for _, test := range tests {
t.Run(test.Name, func(t *testing.T) {
numRequests := 0
srv := httptest.NewServer(http.HandlerFunc(func(res http.ResponseWriter, req *http.Request) {
numRequests++
t.Logf("req %v %v, accept: %v", req.Method, req.URL.Path, req.Header["Accept"])
var err error
switch {
case req.Method == "GET":
// list files in data/
res.Header().Set("Content-Type", test.ContentType)
_, err = res.Write([]byte(test.Data))
if err != nil {
t.Fatal(err)
}
return
case req.Method == "HEAD":
// stat file in data/, use the first two bytes in the name
// of the file as the size :)
filename := req.URL.Path[6:]
len, err := strconv.ParseInt(filename[:4], 16, 64)
if err != nil {
t.Fatal(err)
}
res.Header().Set("Content-Length", fmt.Sprintf("%d", len))
res.WriteHeader(http.StatusOK)
return
}
t.Errorf("unhandled request %v %v", req.Method, req.URL.Path)
}))
defer srv.Close()
srvURL, err := url.Parse(srv.URL)
if err != nil {
t.Fatal(err)
}
cfg := rest.Config{
Connections: 5,
URL: srvURL,
}
be, err := rest.Open(cfg, http.DefaultTransport)
if err != nil {
t.Fatal(err)
}
var list []restic.FileInfo
err = be.List(context.TODO(), restic.DataFile, func(fi restic.FileInfo) error {
list = append(list, fi)
return nil
})
if err != nil {
t.Fatal(err)
}
if !reflect.DeepEqual(list, test.Result) {
t.Fatalf("wrong response returned, want:\n %v\ngot: %v", test.Result, list)
}
if numRequests != test.Requests {
t.Fatalf("wrong number of HTTP requests executed, want %d, got %d", test.Requests, numRequests)
}
defer func() {
err = be.Close()
if err != nil {
t.Fatal(err)
}
}()
})
}
}

View File

@ -365,7 +365,7 @@ func (be *Backend) Stat(ctx context.Context, h restic.Handle) (bi restic.FileInf
return restic.FileInfo{}, errors.Wrap(err, "Stat")
}
return restic.FileInfo{Size: fi.Size}, nil
return restic.FileInfo{Size: fi.Size, Name: h.Name}, nil
}
// Test returns true if a blob of the given type and name exists in the backend.
@ -402,54 +402,59 @@ func (be *Backend) Remove(ctx context.Context, h restic.Handle) error {
return errors.Wrap(err, "client.RemoveObject")
}
// List returns a channel that yields all names of blobs of type t. A
// goroutine is started for this. If the channel done is closed, sending
// stops.
func (be *Backend) List(ctx context.Context, t restic.FileType) <-chan string {
// List runs fn for each file in the backend which has the type t. When an
// error occurs (or fn returns an error), List stops and returns it.
func (be *Backend) List(ctx context.Context, t restic.FileType, fn func(restic.FileInfo) error) error {
debug.Log("listing %v", t)
ch := make(chan string)
prefix, recursive := be.Basedir(t)
// make sure prefix ends with a slash
if prefix[len(prefix)-1] != '/' {
if !strings.HasSuffix(prefix, "/") {
prefix += "/"
}
ctx, cancel := context.WithCancel(ctx)
defer cancel()
// NB: unfortunately we can't protect this with be.sem.GetToken() here.
// Doing so would enable a deadlock situation (gh-1399), as ListObjects()
// starts its own goroutine and returns results via a channel.
listresp := be.client.ListObjects(be.cfg.Bucket, prefix, recursive, ctx.Done())
go func() {
defer close(ch)
for obj := range listresp {
m := strings.TrimPrefix(obj.Key, prefix)
if m == "" {
continue
}
select {
case ch <- path.Base(m):
case <-ctx.Done():
return
}
for obj := range listresp {
m := strings.TrimPrefix(obj.Key, prefix)
if m == "" {
continue
}
}()
return ch
fi := restic.FileInfo{
Name: path.Base(m),
Size: obj.Size,
}
if ctx.Err() != nil {
return ctx.Err()
}
err := fn(fi)
if err != nil {
return err
}
if ctx.Err() != nil {
return ctx.Err()
}
}
return ctx.Err()
}
// Remove keys for a specified backend type.
func (be *Backend) removeKeys(ctx context.Context, t restic.FileType) error {
for key := range be.List(ctx, restic.DataFile) {
err := be.Remove(ctx, restic.Handle{Type: restic.DataFile, Name: key})
if err != nil {
return err
}
}
return nil
return be.List(ctx, restic.DataFile, func(fi restic.FileInfo) error {
return be.Remove(ctx, restic.Handle{Type: t, Name: fi.Name})
})
}
// Delete removes all restic keys in the bucket. It will not remove the bucket itself.

View File

@ -56,9 +56,10 @@ func TestLayout(t *testing.T) {
}
datafiles := make(map[string]bool)
for id := range be.List(context.TODO(), restic.DataFile) {
datafiles[id] = false
}
err = be.List(context.TODO(), restic.DataFile, func(fi restic.FileInfo) error {
datafiles[fi.Name] = false
return nil
})
if len(datafiles) == 0 {
t.Errorf("List() returned zero data files")

View File

@ -376,7 +376,7 @@ func (r *SFTP) Stat(ctx context.Context, h restic.Handle) (restic.FileInfo, erro
return restic.FileInfo{}, errors.Wrap(err, "Lstat")
}
return restic.FileInfo{Size: fi.Size()}, nil
return restic.FileInfo{Size: fi.Size(), Name: h.Name}, nil
}
// Test returns true if a blob of the given type and name exists in the backend.
@ -408,47 +408,54 @@ func (r *SFTP) Remove(ctx context.Context, h restic.Handle) error {
return r.c.Remove(r.Filename(h))
}
// List returns a channel that yields all names of blobs of type t. A
// goroutine is started for this. If the channel done is closed, sending
// stops.
func (r *SFTP) List(ctx context.Context, t restic.FileType) <-chan string {
// List runs fn for each file in the backend which has the type t. When an
// error occurs (or fn returns an error), List stops and returns it.
func (r *SFTP) List(ctx context.Context, t restic.FileType, fn func(restic.FileInfo) error) error {
debug.Log("List %v", t)
ch := make(chan string)
go func() {
defer close(ch)
basedir, subdirs := r.Basedir(t)
walker := r.c.Walk(basedir)
for walker.Step() {
if walker.Err() != nil {
continue
}
if walker.Path() == basedir {
continue
}
if walker.Stat().IsDir() && !subdirs {
walker.SkipDir()
continue
}
if !walker.Stat().Mode().IsRegular() {
continue
}
select {
case ch <- path.Base(walker.Path()):
case <-ctx.Done():
return
}
basedir, subdirs := r.Basedir(t)
walker := r.c.Walk(basedir)
for walker.Step() {
if walker.Err() != nil {
return walker.Err()
}
}()
return ch
if walker.Path() == basedir {
continue
}
if walker.Stat().IsDir() && !subdirs {
walker.SkipDir()
continue
}
fi := walker.Stat()
if !fi.Mode().IsRegular() {
continue
}
debug.Log("send %v\n", path.Base(walker.Path()))
rfi := restic.FileInfo{
Name: path.Base(walker.Path()),
Size: fi.Size(),
}
if ctx.Err() != nil {
return ctx.Err()
}
err := fn(rfi)
if err != nil {
return err
}
if ctx.Err() != nil {
return ctx.Err()
}
}
return ctx.Err()
}
var closeTimeout = 2 * time.Second

View File

@ -6,7 +6,6 @@ import (
"io"
"net/http"
"path"
"path/filepath"
"strings"
"time"
@ -203,7 +202,7 @@ func (be *beSwift) Stat(ctx context.Context, h restic.Handle) (bi restic.FileInf
return restic.FileInfo{}, errors.Wrap(err, "conn.Object")
}
return restic.FileInfo{Size: obj.Bytes}, nil
return restic.FileInfo{Size: obj.Bytes, Name: h.Name}, nil
}
// Test returns true if a blob of the given type and name exists in the backend.
@ -237,61 +236,62 @@ func (be *beSwift) Remove(ctx context.Context, h restic.Handle) error {
return errors.Wrap(err, "conn.ObjectDelete")
}
// List returns a channel that yields all names of blobs of type t. A
// goroutine is started for this. If the channel done is closed, sending
// stops.
func (be *beSwift) List(ctx context.Context, t restic.FileType) <-chan string {
// List runs fn for each file in the backend which has the type t. When an
// error occurs (or fn returns an error), List stops and returns it.
func (be *beSwift) List(ctx context.Context, t restic.FileType, fn func(restic.FileInfo) error) error {
debug.Log("listing %v", t)
ch := make(chan string)
prefix, _ := be.Basedir(t)
prefix += "/"
go func() {
defer close(ch)
err := be.conn.ObjectsWalk(be.container, &swift.ObjectsOpts{Prefix: prefix},
func(opts *swift.ObjectsOpts) (interface{}, error) {
be.sem.GetToken()
newObjects, err := be.conn.Objects(be.container, opts)
be.sem.ReleaseToken()
err := be.conn.ObjectsWalk(be.container, &swift.ObjectsOpts{Prefix: prefix},
func(opts *swift.ObjectsOpts) (interface{}, error) {
be.sem.GetToken()
newObjects, err := be.conn.ObjectNames(be.container, opts)
be.sem.ReleaseToken()
if err != nil {
return nil, errors.Wrap(err, "conn.ObjectNames")
}
for _, obj := range newObjects {
m := path.Base(strings.TrimPrefix(obj.Name, prefix))
if m == "" {
continue
}
fi := restic.FileInfo{
Name: m,
Size: obj.Bytes,
}
if ctx.Err() != nil {
return nil, ctx.Err()
}
err := fn(fi)
if err != nil {
return nil, errors.Wrap(err, "conn.ObjectNames")
return nil, err
}
for _, obj := range newObjects {
m := filepath.Base(strings.TrimPrefix(obj, prefix))
if m == "" {
continue
}
select {
case ch <- m:
case <-ctx.Done():
return nil, io.EOF
}
if ctx.Err() != nil {
return nil, ctx.Err()
}
return newObjects, nil
})
}
return newObjects, nil
})
if err != nil {
debug.Log("ObjectsWalk returned error: %v", err)
}
}()
if err != nil {
return err
}
return ch
return ctx.Err()
}
// Remove keys for a specified backend type.
func (be *beSwift) removeKeys(ctx context.Context, t restic.FileType) error {
for key := range be.List(ctx, t) {
err := be.Remove(ctx, restic.Handle{Type: t, Name: key})
if err != nil {
return err
}
}
return nil
return be.List(ctx, t, func(fi restic.FileInfo) error {
return be.Remove(ctx, restic.Handle{Type: t, Name: fi.Name})
})
}
// IsNotExist returns true if the error is caused by a not existing file.

View File

@ -249,17 +249,17 @@ func (s *Suite) TestList(t *testing.T) {
b := s.open(t)
defer s.close(t, b)
list1 := restic.NewIDSet()
list1 := make(map[restic.ID]int64)
for i := 0; i < numTestFiles; i++ {
data := []byte(fmt.Sprintf("random test blob %v", i))
data := test.Random(rand.Int(), rand.Intn(100)+55)
id := restic.Hash(data)
h := restic.Handle{Type: restic.DataFile, Name: id.String()}
err := b.Save(context.TODO(), h, bytes.NewReader(data))
if err != nil {
t.Fatal(err)
}
list1.Insert(id)
list1[id] = int64(len(data))
}
t.Logf("wrote %v files", len(list1))
@ -272,7 +272,7 @@ func (s *Suite) TestList(t *testing.T) {
for _, test := range tests {
t.Run(fmt.Sprintf("max-%v", test.maxItems), func(t *testing.T) {
list2 := restic.NewIDSet()
list2 := make(map[restic.ID]int64)
type setter interface {
SetListMaxItems(int)
@ -283,19 +283,37 @@ func (s *Suite) TestList(t *testing.T) {
s.SetListMaxItems(test.maxItems)
}
for name := range b.List(context.TODO(), restic.DataFile) {
id, err := restic.ParseID(name)
err := b.List(context.TODO(), restic.DataFile, func(fi restic.FileInfo) error {
id, err := restic.ParseID(fi.Name)
if err != nil {
t.Fatal(err)
}
list2.Insert(id)
list2[id] = fi.Size
return nil
})
if err != nil {
t.Fatalf("List returned error %v", err)
}
t.Logf("loaded %v IDs from backend", len(list2))
if !list1.Equals(list2) {
t.Errorf("lists are not equal, list1 %d entries, list2 %d entries",
len(list1), len(list2))
for id, size := range list1 {
size2, ok := list2[id]
if !ok {
t.Errorf("id %v not returned by List()", id.Str())
}
if size != size2 {
t.Errorf("wrong size for id %v returned: want %v, got %v", id.Str(), size, size2)
}
}
for id := range list2 {
_, ok := list1[id]
if !ok {
t.Errorf("extra id %v returned by List()", id.Str())
}
}
})
}
@ -312,6 +330,123 @@ func (s *Suite) TestList(t *testing.T) {
}
}
// TestListCancel tests that the context is respected and the error is returned by List.
func (s *Suite) TestListCancel(t *testing.T) {
seedRand(t)
numTestFiles := 5
b := s.open(t)
defer s.close(t, b)
testFiles := make([]restic.Handle, 0, numTestFiles)
for i := 0; i < numTestFiles; i++ {
data := []byte(fmt.Sprintf("random test blob %v", i))
id := restic.Hash(data)
h := restic.Handle{Type: restic.DataFile, Name: id.String()}
err := b.Save(context.TODO(), h, bytes.NewReader(data))
if err != nil {
t.Fatal(err)
}
testFiles = append(testFiles, h)
}
t.Run("Cancelled", func(t *testing.T) {
ctx, cancel := context.WithCancel(context.TODO())
cancel()
// pass in a cancelled context
err := b.List(ctx, restic.DataFile, func(fi restic.FileInfo) error {
t.Errorf("got FileInfo %v for cancelled context", fi)
return nil
})
if errors.Cause(err) != context.Canceled {
t.Fatalf("expected error not found, want %v, got %v", context.Canceled, errors.Cause(err))
}
})
t.Run("First", func(t *testing.T) {
ctx, cancel := context.WithCancel(context.TODO())
defer cancel()
i := 0
err := b.List(ctx, restic.DataFile, func(fi restic.FileInfo) error {
i++
// cancel the context on the first file
if i == 1 {
cancel()
}
return nil
})
if err != context.Canceled {
t.Fatalf("expected error not found, want %v, got %v", context.Canceled, err)
}
if i != 1 {
t.Fatalf("wrong number of files returned by List, want %v, got %v", 1, i)
}
})
t.Run("Last", func(t *testing.T) {
ctx, cancel := context.WithCancel(context.TODO())
defer cancel()
i := 0
err := b.List(ctx, restic.DataFile, func(fi restic.FileInfo) error {
// cancel the context at the last file
i++
if i == numTestFiles {
cancel()
}
return nil
})
if err != context.Canceled {
t.Fatalf("expected error not found, want %v, got %v", context.Canceled, err)
}
if i != numTestFiles {
t.Fatalf("wrong number of files returned by List, want %v, got %v", numTestFiles, i)
}
})
t.Run("Timeout", func(t *testing.T) {
ctx, cancel := context.WithCancel(context.TODO())
defer cancel()
// rather large timeout, let's try to get at least one item
timeout := time.Second
ctxTimeout, _ := context.WithTimeout(ctx, timeout)
i := 0
// pass in a context with a timeout
err := b.List(ctxTimeout, restic.DataFile, func(fi restic.FileInfo) error {
i++
// wait until the context is cancelled
<-ctxTimeout.Done()
return nil
})
if err != context.DeadlineExceeded {
t.Fatalf("expected error not found, want %#v, got %#v", context.DeadlineExceeded, err)
}
if i > 1 {
t.Fatalf("wrong number of files returned by List, want <= 1, got %v", i)
}
})
err := s.delayedRemove(t, b, testFiles...)
if err != nil {
t.Fatal(err)
}
}
type errorCloser struct {
io.Reader
l int
@ -366,8 +501,12 @@ func (s *Suite) TestSave(t *testing.T) {
fi, err := b.Stat(context.TODO(), h)
test.OK(t, err)
if fi.Name != h.Name {
t.Errorf("Stat() returned wrong name, want %q, got %q", h.Name, fi.Name)
}
if fi.Size != int64(len(data)) {
t.Fatalf("Stat() returned different size, want %q, got %d", len(data), fi.Size)
t.Errorf("Stat() returned different size, want %q, got %d", len(data), fi.Size)
}
err = b.Remove(context.TODO(), h)
@ -556,10 +695,16 @@ func delayedList(t testing.TB, b restic.Backend, tpe restic.FileType, max int, m
list := restic.NewIDSet()
start := time.Now()
for i := 0; i < max; i++ {
for s := range b.List(context.TODO(), tpe) {
id := restic.TestParseID(s)
err := b.List(context.TODO(), tpe, func(fi restic.FileInfo) error {
id := restic.TestParseID(fi.Name)
list.Insert(id)
return nil
})
if err != nil {
t.Fatal(err)
}
if len(list) < max && time.Since(start) < maxwait {
time.Sleep(500 * time.Millisecond)
}

View File

@ -12,6 +12,7 @@ import (
"github.com/restic/restic/internal/fs"
"github.com/restic/restic/internal/hashing"
"github.com/restic/restic/internal/restic"
"golang.org/x/sync/errgroup"
"github.com/restic/restic/internal/debug"
"github.com/restic/restic/internal/pack"
@ -192,13 +193,14 @@ func (c *Checker) Packs(ctx context.Context, errChan chan<- error) {
debug.Log("listing repository packs")
repoPacks := restic.NewIDSet()
for id := range c.repo.List(ctx, restic.DataFile) {
select {
case <-ctx.Done():
return
default:
}
err := c.repo.List(ctx, restic.DataFile, func(id restic.ID, size int64) error {
repoPacks.Insert(id)
return nil
})
if err != nil {
errChan <- err
}
// orphaned: present in the repo but not in c.packs
@ -719,42 +721,58 @@ func (c *Checker) ReadData(ctx context.Context, p *restic.Progress, errChan chan
p.Start()
defer p.Done()
worker := func(wg *sync.WaitGroup, in <-chan restic.ID) {
defer wg.Done()
for {
var id restic.ID
var ok bool
g, ctx := errgroup.WithContext(ctx)
ch := make(chan restic.ID)
// start producer for channel ch
g.Go(func() error {
defer close(ch)
return c.repo.List(ctx, restic.DataFile, func(id restic.ID, size int64) error {
select {
case <-ctx.Done():
return
case id, ok = <-in:
if !ok {
return
case ch <- id:
}
return nil
})
})
// run workers
for i := 0; i < defaultParallelism; i++ {
g.Go(func() error {
for {
var id restic.ID
var ok bool
select {
case <-ctx.Done():
return nil
case id, ok = <-ch:
if !ok {
return nil
}
}
err := checkPack(ctx, c.repo, id)
p.Report(restic.Stat{Blobs: 1})
if err == nil {
continue
}
select {
case <-ctx.Done():
return nil
case errChan <- err:
}
}
})
}
err := checkPack(ctx, c.repo, id)
p.Report(restic.Stat{Blobs: 1})
if err == nil {
continue
}
select {
case <-ctx.Done():
return
case errChan <- err:
}
err := g.Wait()
if err != nil {
select {
case <-ctx.Done():
return
case errChan <- err:
}
}
ch := c.repo.List(ctx, restic.DataFile)
var wg sync.WaitGroup
for i := 0; i < defaultParallelism; i++ {
wg.Add(1)
go worker(&wg, ch)
}
wg.Wait()
}

View File

@ -1,11 +1,10 @@
package errors
import "github.com/pkg/errors"
import (
"net/url"
// Cause returns the cause of an error.
func Cause(err error) error {
return errors.Cause(err)
}
"github.com/pkg/errors"
)
// New creates a new error based on message. Wrapped so that this package does
// not appear in the stack trace.
@ -22,3 +21,29 @@ var Wrap = errors.Wrap
// Wrapf returns an error annotating err with the format specifier. If err is
// nil, Wrapf returns nil.
var Wrapf = errors.Wrapf
// Cause returns the cause of an error. It will also unwrap certain errors,
// e.g. *url.Error returned by the net/http client.
func Cause(err error) error {
type Causer interface {
Cause() error
}
for {
// unwrap *url.Error
if urlErr, ok := err.(*url.Error); ok {
err = urlErr.Err
continue
}
// if err is a Causer, return the cause for this error.
if c, ok := err.(Causer); ok {
err = c.Cause()
continue
}
break
}
return err
}

View File

@ -35,11 +35,17 @@ func testRead(t testing.TB, f *file, offset, length int, data []byte) {
}
func firstSnapshotID(t testing.TB, repo restic.Repository) (first restic.ID) {
for id := range repo.List(context.TODO(), restic.SnapshotFile) {
err := repo.List(context.TODO(), restic.SnapshotFile, func(id restic.ID, size int64) error {
if first.IsNull() {
first = id
}
return nil
})
if err != nil {
t.Fatal(err)
}
return first
}

View File

@ -227,18 +227,24 @@ func isElem(e string, list []string) bool {
const minSnapshotsReloadTime = 60 * time.Second
// update snapshots if repository has changed
func updateSnapshots(ctx context.Context, root *Root) {
func updateSnapshots(ctx context.Context, root *Root) error {
if time.Since(root.lastCheck) < minSnapshotsReloadTime {
return
return nil
}
snapshots, err := restic.FindFilteredSnapshots(ctx, root.repo, root.cfg.Host, root.cfg.Tags, root.cfg.Paths)
if err != nil {
return err
}
snapshots := restic.FindFilteredSnapshots(ctx, root.repo, root.cfg.Host, root.cfg.Tags, root.cfg.Paths)
if root.snCount != len(snapshots) {
root.snCount = len(snapshots)
root.repo.LoadIndex(ctx)
root.snapshots = snapshots
}
root.lastCheck = time.Now()
return nil
}
// read snapshot timestamps from the current repository-state.

View File

@ -115,13 +115,13 @@ func Load(ctx context.Context, repo restic.Repository, p *restic.Progress) (*Ind
index := newIndex()
for id := range repo.List(ctx, restic.IndexFile) {
err := repo.List(ctx, restic.IndexFile, func(id restic.ID, size int64) error {
p.Report(restic.Stat{Blobs: 1})
debug.Log("Load index %v", id.Str())
idx, err := loadIndexJSON(ctx, repo, id)
if err != nil {
return nil, err
return err
}
res := make(map[restic.ID]Pack)
@ -144,12 +144,18 @@ func Load(ctx context.Context, repo restic.Repository, p *restic.Progress) (*Ind
}
if err = index.AddPack(jpack.ID, 0, entries); err != nil {
return nil, err
return err
}
}
results[id] = res
index.IndexIDs.Insert(id)
return nil
})
if err != nil {
return nil, err
}
for superID, list := range supersedes {

View File

@ -28,7 +28,7 @@ func createFilledRepo(t testing.TB, snapshots int, dup float32) (restic.Reposito
}
func validateIndex(t testing.TB, repo restic.Repository, idx *Index) {
for id := range repo.List(context.TODO(), restic.DataFile) {
err := repo.List(context.TODO(), restic.DataFile, func(id restic.ID, size int64) error {
p, ok := idx.Packs[id]
if !ok {
t.Errorf("pack %v missing from index", id.Str())
@ -37,6 +37,11 @@ func validateIndex(t testing.TB, repo restic.Repository, idx *Index) {
if !p.ID.Equal(id) {
t.Errorf("pack %v has invalid ID: want %v, got %v", id.Str(), id, p.ID)
}
return nil
})
if err != nil {
t.Fatal(err)
}
}
@ -308,7 +313,14 @@ func TestIndexAddRemovePack(t *testing.T) {
t.Fatalf("Load() returned error %v", err)
}
packID := <-repo.List(context.TODO(), restic.DataFile)
var packID restic.ID
err = repo.List(context.TODO(), restic.DataFile, func(id restic.ID, size int64) error {
packID = id
return nil
})
if err != nil {
t.Fatal(err)
}
t.Logf("selected pack %v", packID.Str())

View File

@ -11,7 +11,7 @@ const listPackWorkers = 10
// Lister combines lists packs in a repo and blobs in a pack.
type Lister interface {
List(context.Context, restic.FileType) <-chan restic.ID
List(context.Context, restic.FileType, func(restic.ID, int64) error) error
ListPack(context.Context, restic.ID) ([]restic.Blob, int64, error)
}
@ -55,17 +55,19 @@ func AllPacks(ctx context.Context, repo Lister, ignorePacks restic.IDSet, ch cha
go func() {
defer close(jobCh)
for id := range repo.List(ctx, restic.DataFile) {
_ = repo.List(ctx, restic.DataFile, func(id restic.ID, size int64) error {
if ignorePacks.Has(id) {
continue
return nil
}
select {
case jobCh <- worker.Job{Data: id}:
case <-ctx.Done():
return
return ctx.Err()
}
}
return nil
})
}()
wp.Wait()

View File

@ -59,14 +59,14 @@ func (m *S3Layout) moveFiles(ctx context.Context, be *s3.Backend, l backend.Layo
fmt.Fprintf(os.Stderr, "renaming file returned error: %v\n", err)
}
for name := range be.List(ctx, t) {
h := restic.Handle{Type: t, Name: name}
return be.List(ctx, t, func(fi restic.FileInfo) error {
h := restic.Handle{Type: t, Name: fi.Name}
debug.Log("move %v", h)
retry(maxErrors, printErr, func() error {
return retry(maxErrors, printErr, func() error {
return be.Rename(h, l)
})
}
})
return nil
}

View File

@ -15,7 +15,7 @@ type Backend struct {
SaveFn func(ctx context.Context, h restic.Handle, rd io.Reader) error
LoadFn func(ctx context.Context, h restic.Handle, length int, offset int64) (io.ReadCloser, error)
StatFn func(ctx context.Context, h restic.Handle) (restic.FileInfo, error)
ListFn func(ctx context.Context, t restic.FileType) <-chan string
ListFn func(ctx context.Context, t restic.FileType, fn func(restic.FileInfo) error) error
RemoveFn func(ctx context.Context, h restic.Handle) error
TestFn func(ctx context.Context, h restic.Handle) (bool, error)
DeleteFn func(ctx context.Context) error
@ -77,14 +77,12 @@ func (m *Backend) Stat(ctx context.Context, h restic.Handle) (restic.FileInfo, e
}
// List items of type t.
func (m *Backend) List(ctx context.Context, t restic.FileType) <-chan string {
func (m *Backend) List(ctx context.Context, t restic.FileType, fn func(restic.FileInfo) error) error {
if m.ListFn == nil {
ch := make(chan string)
close(ch)
return ch
return nil
}
return m.ListFn(ctx, t)
return m.ListFn(ctx, t, fn)
}
// Remove data from the backend.

View File

@ -113,42 +113,48 @@ func OpenKey(ctx context.Context, s *Repository, name string, password string) (
// given password. If none could be found, ErrNoKeyFound is returned. When
// maxKeys is reached, ErrMaxKeysReached is returned. When setting maxKeys to
// zero, all keys in the repo are checked.
func SearchKey(ctx context.Context, s *Repository, password string, maxKeys int) (*Key, error) {
func SearchKey(ctx context.Context, s *Repository, password string, maxKeys int) (k *Key, err error) {
checked := 0
// try at most maxKeysForSearch keys in repo
for name := range s.Backend().List(ctx, restic.KeyFile) {
err = s.Backend().List(ctx, restic.KeyFile, func(fi restic.FileInfo) error {
if maxKeys > 0 && checked > maxKeys {
return nil, ErrMaxKeysReached
return ErrMaxKeysReached
}
_, err := restic.ParseID(name)
_, err := restic.ParseID(fi.Name)
if err != nil {
debug.Log("rejecting key with invalid name: %v", name)
continue
debug.Log("rejecting key with invalid name: %v", fi.Name)
return nil
}
debug.Log("trying key %q", name)
key, err := OpenKey(ctx, s, name, password)
debug.Log("trying key %q", fi.Name)
key, err := OpenKey(ctx, s, fi.Name, password)
if err != nil {
debug.Log("key %v returned error %v", name, err)
debug.Log("key %v returned error %v", fi.Name, err)
// ErrUnauthenticated means the password is wrong, try the next key
if errors.Cause(err) == crypto.ErrUnauthenticated {
continue
return nil
}
if err != nil {
debug.Log("unable to open key %v: %v\n", err)
continue
}
return err
}
debug.Log("successfully opened key %v", name)
return key, nil
debug.Log("successfully opened key %v", fi.Name)
k = key
return nil
})
if err != nil {
return nil, err
}
return nil, ErrNoKeyFound
if k == nil {
return nil, ErrNoKeyFound
}
return k, nil
}
// LoadKey loads a key from the backend.

View File

@ -2,10 +2,10 @@ package repository
import (
"context"
"sync"
"github.com/restic/restic/internal/debug"
"github.com/restic/restic/internal/restic"
"golang.org/x/sync/errgroup"
)
// ParallelWorkFunc gets one file ID to work on. If an error is returned,
@ -17,47 +17,36 @@ type ParallelWorkFunc func(ctx context.Context, id string) error
type ParallelIDWorkFunc func(ctx context.Context, id restic.ID) error
// FilesInParallel runs n workers of f in parallel, on the IDs that
// repo.List(t) yield. If f returns an error, the process is aborted and the
// repo.List(t) yields. If f returns an error, the process is aborted and the
// first error is returned.
func FilesInParallel(ctx context.Context, repo restic.Lister, t restic.FileType, n uint, f ParallelWorkFunc) error {
wg := &sync.WaitGroup{}
ch := repo.List(ctx, t)
errors := make(chan error, n)
func FilesInParallel(ctx context.Context, repo restic.Lister, t restic.FileType, n int, f ParallelWorkFunc) error {
g, ctx := errgroup.WithContext(ctx)
for i := 0; uint(i) < n; i++ {
wg.Add(1)
go func() {
defer wg.Done()
ch := make(chan string, n)
g.Go(func() error {
defer close(ch)
return repo.List(ctx, t, func(fi restic.FileInfo) error {
select {
case <-ctx.Done():
case ch <- fi.Name:
}
return nil
})
})
for {
select {
case id, ok := <-ch:
if !ok {
return
}
err := f(ctx, id)
if err != nil {
errors <- err
return
}
case <-ctx.Done():
return
for i := 0; i < n; i++ {
g.Go(func() error {
for name := range ch {
err := f(ctx, name)
if err != nil {
return err
}
}
}()
return nil
})
}
wg.Wait()
select {
case err := <-errors:
return err
default:
break
}
return nil
return g.Wait()
}
// ParallelWorkFuncParseID converts a function that takes a restic.ID to a

View File

@ -74,24 +74,25 @@ var lister = testIDs{
"34dd044c228727f2226a0c9c06a3e5ceb5e30e31cb7854f8fa1cde846b395a58",
}
func (tests testIDs) List(ctx context.Context, t restic.FileType) <-chan string {
ch := make(chan string)
func (tests testIDs) List(ctx context.Context, t restic.FileType, fn func(restic.FileInfo) error) error {
for i := 0; i < 500; i++ {
for _, id := range tests {
if ctx.Err() != nil {
return ctx.Err()
}
go func() {
defer close(ch)
fi := restic.FileInfo{
Name: id,
}
for i := 0; i < 500; i++ {
for _, id := range tests {
select {
case ch <- id:
case <-ctx.Done():
return
}
err := fn(fi)
if err != nil {
return err
}
}
}()
}
return ch
return nil
}
func TestFilesInParallel(t *testing.T) {
@ -100,7 +101,7 @@ func TestFilesInParallel(t *testing.T) {
return nil
}
for n := uint(1); n < 5; n++ {
for n := 1; n < 5; n++ {
err := repository.FilesInParallel(context.TODO(), lister, restic.DataFile, n*100, f)
rtest.OK(t, err)
}
@ -109,7 +110,6 @@ func TestFilesInParallel(t *testing.T) {
var errTest = errors.New("test error")
func TestFilesInParallelWithError(t *testing.T) {
f := func(ctx context.Context, id string) error {
time.Sleep(1 * time.Millisecond)
@ -120,8 +120,10 @@ func TestFilesInParallelWithError(t *testing.T) {
return nil
}
for n := uint(1); n < 5; n++ {
for n := 1; n < 5; n++ {
err := repository.FilesInParallel(context.TODO(), lister, restic.DataFile, n*100, f)
rtest.Equals(t, errTest, err)
if err != errTest {
t.Fatalf("wrong error returned, want %q, got %v", errTest, err)
}
}
}

View File

@ -16,7 +16,7 @@ func randomSize(min, max int) int {
}
func random(t testing.TB, length int) []byte {
rd := restic.NewRandReader(rand.New(rand.NewSource(int64(length))))
rd := restic.NewRandReader(rand.New(rand.NewSource(rand.Int63())))
buf := make([]byte, length)
_, err := io.ReadFull(rd, buf)
if err != nil {
@ -74,7 +74,7 @@ func selectBlobs(t *testing.T, repo restic.Repository, p float32) (list1, list2
blobs := restic.NewBlobSet()
for id := range repo.List(context.TODO(), restic.DataFile) {
err := repo.List(context.TODO(), restic.DataFile, func(id restic.ID, size int64) error {
entries, _, err := repo.ListPack(context.TODO(), id)
if err != nil {
t.Fatalf("error listing pack %v: %v", id, err)
@ -84,7 +84,7 @@ func selectBlobs(t *testing.T, repo restic.Repository, p float32) (list1, list2
h := restic.BlobHandle{ID: entry.ID, Type: entry.Type}
if blobs.Has(h) {
t.Errorf("ignoring duplicate blob %v", h)
continue
return nil
}
blobs.Insert(h)
@ -93,8 +93,11 @@ func selectBlobs(t *testing.T, repo restic.Repository, p float32) (list1, list2
} else {
list2.Insert(restic.BlobHandle{ID: entry.ID, Type: entry.Type})
}
}
return nil
})
if err != nil {
t.Fatal(err)
}
return list1, list2
@ -102,8 +105,13 @@ func selectBlobs(t *testing.T, repo restic.Repository, p float32) (list1, list2
func listPacks(t *testing.T, repo restic.Repository) restic.IDSet {
list := restic.NewIDSet()
for id := range repo.List(context.TODO(), restic.DataFile) {
err := repo.List(context.TODO(), restic.DataFile, func(id restic.ID, size int64) error {
list.Insert(id)
return nil
})
if err != nil {
t.Fatal(err)
}
return list
@ -153,15 +161,15 @@ func rebuildIndex(t *testing.T, repo restic.Repository) {
t.Fatal(err)
}
for id := range repo.List(context.TODO(), restic.IndexFile) {
err = repo.List(context.TODO(), restic.IndexFile, func(id restic.ID, size int64) error {
h := restic.Handle{
Type: restic.IndexFile,
Name: id.String(),
}
err = repo.Backend().Remove(context.TODO(), h)
if err != nil {
t.Fatal(err)
}
return repo.Backend().Remove(context.TODO(), h)
})
if err != nil {
t.Fatal(err)
}
_, err = idx.Save(context.TODO(), repo, nil)
@ -181,6 +189,10 @@ func TestRepack(t *testing.T) {
repo, cleanup := repository.TestRepository(t)
defer cleanup()
seed := rand.Int63()
rand.Seed(seed)
t.Logf("rand seed is %v", seed)
createRandomBlobs(t, repo, 100, 0.7)
packsBefore := listPacks(t, repo)

View File

@ -536,22 +536,15 @@ func (r *Repository) KeyName() string {
return r.keyName
}
// List returns a channel that yields all IDs of type t in the backend.
func (r *Repository) List(ctx context.Context, t restic.FileType) <-chan restic.ID {
out := make(chan restic.ID)
go func() {
defer close(out)
for strID := range r.be.List(ctx, t) {
if id, err := restic.ParseID(strID); err == nil {
select {
case out <- id:
case <-ctx.Done():
return
}
}
// List runs fn for all files of type t in the repo.
func (r *Repository) List(ctx context.Context, t restic.FileType, fn func(restic.ID, int64) error) error {
return r.be.List(ctx, t, func(fi restic.FileInfo) error {
id, err := restic.ParseID(fi.Name)
if err != nil {
debug.Log("unable to parse %v as an ID", fi.Name)
}
}()
return out
return fn(id, fi.Size)
})
}
// ListPack returns the list of blobs saved in the pack id and the length of

View File

@ -369,7 +369,7 @@ func TestRepositoryIncrementalIndex(t *testing.T) {
packEntries := make(map[restic.ID]map[restic.ID]struct{})
for id := range repo.List(context.TODO(), restic.IndexFile) {
err := repo.List(context.TODO(), restic.IndexFile, func(id restic.ID, size int64) error {
idx, err := repository.LoadIndex(context.TODO(), repo, id)
rtest.OK(t, err)
@ -380,6 +380,10 @@ func TestRepositoryIncrementalIndex(t *testing.T) {
packEntries[pb.PackID][id] = struct{}{}
}
return nil
})
if err != nil {
t.Fatal(err)
}
for packID, ids := range packEntries {

View File

@ -32,10 +32,12 @@ type Backend interface {
// Stat returns information about the File identified by h.
Stat(ctx context.Context, h Handle) (FileInfo, error)
// List returns a channel that yields all names of files of type t in an
// arbitrary order. A goroutine is started for this, which is stopped when
// ctx is cancelled.
List(ctx context.Context, t FileType) <-chan string
// List runs fn for each file in the backend which has the type t. When an
// error occurs (or fn returns an error), List stops and returns it.
//
// The function fn is called in the same Goroutine that List() is called
// from.
List(ctx context.Context, t FileType, fn func(FileInfo) error) error
// IsNotExist returns true if the error was caused by a non-existing file
// in the backend.
@ -45,6 +47,8 @@ type Backend interface {
Delete(ctx context.Context) error
}
// FileInfo is returned by Stat() and contains information about a file in the
// backend.
type FileInfo struct{ Size int64 }
// FileInfo is contains information about a file in the backend.
type FileInfo struct {
Size int64
Name string
}

View File

@ -20,15 +20,23 @@ var ErrMultipleIDMatches = errors.New("multiple IDs with prefix found")
func Find(be Lister, t FileType, prefix string) (string, error) {
match := ""
// TODO: optimize by sorting list etc.
for name := range be.List(context.TODO(), t) {
if prefix == name[:len(prefix)] {
ctx, cancel := context.WithCancel(context.TODO())
defer cancel()
err := be.List(ctx, t, func(fi FileInfo) error {
if prefix == fi.Name[:len(prefix)] {
if match == "" {
match = name
match = fi.Name
} else {
return "", ErrMultipleIDMatches
return ErrMultipleIDMatches
}
}
return nil
})
if err != nil {
return "", err
}
if match != "" {
@ -45,8 +53,17 @@ const minPrefixLength = 8
func PrefixLength(be Lister, t FileType) (int, error) {
// load all IDs of the given type
list := make([]string, 0, 100)
for name := range be.List(context.TODO(), t) {
list = append(list, name)
ctx, cancel := context.WithCancel(context.TODO())
defer cancel()
err := be.List(ctx, t, func(fi FileInfo) error {
list = append(list, fi.Name)
return nil
})
if err != nil {
return 0, err
}
// select prefixes of length l, test if the last one is the same as the current one

View File

@ -6,11 +6,11 @@ import (
)
type mockBackend struct {
list func(context.Context, FileType) <-chan string
list func(context.Context, FileType, func(FileInfo) error) error
}
func (m mockBackend) List(ctx context.Context, t FileType) <-chan string {
return m.list(ctx, t)
func (m mockBackend) List(ctx context.Context, t FileType, fn func(FileInfo) error) error {
return m.list(ctx, t, fn)
}
var samples = IDs{
@ -28,19 +28,14 @@ func TestPrefixLength(t *testing.T) {
list := samples
m := mockBackend{}
m.list = func(ctx context.Context, t FileType) <-chan string {
ch := make(chan string)
go func() {
defer close(ch)
for _, id := range list {
select {
case ch <- id.String():
case <-ctx.Done():
return
}
m.list = func(ctx context.Context, t FileType, fn func(FileInfo) error) error {
for _, id := range list {
err := fn(FileInfo{Name: id.String()})
if err != nil {
return err
}
}()
return ch
}
return nil
}
l, err := PrefixLength(m, SnapshotFile)

View File

@ -157,15 +157,14 @@ func (l *Lock) checkForOtherLocks(ctx context.Context) error {
}
func eachLock(ctx context.Context, repo Repository, f func(ID, *Lock, error) error) error {
for id := range repo.List(ctx, LockFile) {
return repo.List(ctx, LockFile, func(id ID, size int64) error {
lock, err := LoadLock(ctx, repo, id)
err = f(id, lock, err)
if err != nil {
return err
}
}
return nil
return f(id, lock, err)
})
}
// createLock acquires the lock by creating a file in the repository.

View File

@ -227,21 +227,29 @@ func TestLockRefresh(t *testing.T) {
rtest.OK(t, err)
var lockID *restic.ID
for id := range repo.List(context.TODO(), restic.LockFile) {
err = repo.List(context.TODO(), restic.LockFile, func(id restic.ID, size int64) error {
if lockID != nil {
t.Error("more than one lock found")
}
lockID = &id
return nil
})
if err != nil {
t.Fatal(err)
}
rtest.OK(t, lock.Refresh(context.TODO()))
var lockID2 *restic.ID
for id := range repo.List(context.TODO(), restic.LockFile) {
err = repo.List(context.TODO(), restic.LockFile, func(id restic.ID, size int64) error {
if lockID2 != nil {
t.Error("more than one lock found")
}
lockID2 = &id
return nil
})
if err != nil {
t.Fatal(err)
}
rtest.Assert(t, !lockID.Equal(*lockID2),

View File

@ -26,7 +26,12 @@ type Repository interface {
LookupBlobSize(ID, BlobType) (uint, error)
List(context.Context, FileType) <-chan ID
// List calls the function fn for each file of type t in the repository.
// When an error is returned by fn, processing stops and List() returns the
// error.
//
// The function fn is called in the same Goroutine List() was called from.
List(ctx context.Context, t FileType, fn func(ID, int64) error) error
ListPack(context.Context, ID) ([]Blob, int64, error)
Flush(context.Context) error
@ -46,7 +51,7 @@ type Repository interface {
// Lister allows listing files in a backend.
type Lister interface {
List(context.Context, FileType) <-chan string
List(context.Context, FileType, func(FileInfo) error) error
}
// Index keeps track of the blobs are stored within files.

View File

@ -64,15 +64,21 @@ func LoadSnapshot(ctx context.Context, repo Repository, id ID) (*Snapshot, error
// LoadAllSnapshots returns a list of all snapshots in the repo.
func LoadAllSnapshots(ctx context.Context, repo Repository) (snapshots []*Snapshot, err error) {
for id := range repo.List(ctx, SnapshotFile) {
err = repo.List(ctx, SnapshotFile, func(id ID, size int64) error {
sn, err := LoadSnapshot(ctx, repo, id)
if err != nil {
return nil, err
return err
}
snapshots = append(snapshots, sn)
return nil
})
if err != nil {
return nil, err
}
return
return snapshots, nil
}
func (sn Snapshot) String() string {

View File

@ -20,26 +20,31 @@ func FindLatestSnapshot(ctx context.Context, repo Repository, targets []string,
found bool
)
for snapshotID := range repo.List(ctx, SnapshotFile) {
err := repo.List(ctx, SnapshotFile, func(snapshotID ID, size int64) error {
snapshot, err := LoadSnapshot(ctx, repo, snapshotID)
if err != nil {
return ID{}, errors.Errorf("Error listing snapshot: %v", err)
return errors.Errorf("Error loading snapshot %v: %v", snapshotID.Str(), err)
}
if snapshot.Time.Before(latest) || (hostname != "" && hostname != snapshot.Hostname) {
continue
return nil
}
if !snapshot.HasTagList(tagLists) {
continue
return nil
}
if !snapshot.HasPaths(targets) {
continue
return nil
}
latest = snapshot.Time
latestID = snapshotID
found = true
return nil
})
if err != nil {
return ID{}, err
}
if !found {
@ -64,20 +69,27 @@ func FindSnapshot(repo Repository, s string) (ID, error) {
// FindFilteredSnapshots yields Snapshots filtered from the list of all
// snapshots.
func FindFilteredSnapshots(ctx context.Context, repo Repository, host string, tags []TagList, paths []string) Snapshots {
func FindFilteredSnapshots(ctx context.Context, repo Repository, host string, tags []TagList, paths []string) (Snapshots, error) {
results := make(Snapshots, 0, 20)
for id := range repo.List(ctx, SnapshotFile) {
err := repo.List(ctx, SnapshotFile, func(id ID, size int64) error {
sn, err := LoadSnapshot(ctx, repo, id)
if err != nil {
fmt.Fprintf(os.Stderr, "could not load snapshot %v: %v\n", id.Str(), err)
continue
return nil
}
if (host != "" && host != sn.Hostname) || !sn.HasTagList(tags) || !sn.HasPaths(paths) {
continue
return nil
}
results = append(results, sn)
return nil
})
if err != nil {
return nil, err
}
return results
return results, nil
}

3
vendor/golang.org/x/sync/AUTHORS generated vendored Normal file
View File

@ -0,0 +1,3 @@
# This source code refers to The Go Authors for copyright purposes.
# The master list of authors is in the main Go distribution,
# visible at http://tip.golang.org/AUTHORS.

31
vendor/golang.org/x/sync/CONTRIBUTING.md generated vendored Normal file
View File

@ -0,0 +1,31 @@
# Contributing to Go
Go is an open source project.
It is the work of hundreds of contributors. We appreciate your help!
## Filing issues
When [filing an issue](https://golang.org/issue/new), make sure to answer these five questions:
1. What version of Go are you using (`go version`)?
2. What operating system and processor architecture are you using?
3. What did you do?
4. What did you expect to see?
5. What did you see instead?
General questions should go to the [golang-nuts mailing list](https://groups.google.com/group/golang-nuts) instead of the issue tracker.
The gophers there will answer or ask you to file an issue if you've tripped over a bug.
## Contributing code
Please read the [Contribution Guidelines](https://golang.org/doc/contribute.html)
before sending patches.
**We do not accept GitHub pull requests**
(we use [Gerrit](https://code.google.com/p/gerrit/) instead for code review).
Unless otherwise noted, the Go source files are distributed under
the BSD-style license found in the LICENSE file.

3
vendor/golang.org/x/sync/CONTRIBUTORS generated vendored Normal file
View File

@ -0,0 +1,3 @@
# This source code was written by the Go contributors.
# The master list of contributors is in the main Go distribution,
# visible at http://tip.golang.org/CONTRIBUTORS.

27
vendor/golang.org/x/sync/LICENSE generated vendored Normal file
View File

@ -0,0 +1,27 @@
Copyright (c) 2009 The Go Authors. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
* Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above
copyright notice, this list of conditions and the following disclaimer
in the documentation and/or other materials provided with the
distribution.
* Neither the name of Google Inc. nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

22
vendor/golang.org/x/sync/PATENTS generated vendored Normal file
View File

@ -0,0 +1,22 @@
Additional IP Rights Grant (Patents)
"This implementation" means the copyrightable works distributed by
Google as part of the Go project.
Google hereby grants to You a perpetual, worldwide, non-exclusive,
no-charge, royalty-free, irrevocable (except as stated in this section)
patent license to make, have made, use, offer to sell, sell, import,
transfer and otherwise run, modify and propagate the contents of this
implementation of Go, where such license applies only to those patent
claims, both currently owned or controlled by Google and acquired in
the future, licensable by Google that are necessarily infringed by this
implementation of Go. This grant does not include claims that would be
infringed only as a consequence of further modification of this
implementation. If you or your agent or exclusive licensee institute or
order or agree to the institution of patent litigation against any
entity (including a cross-claim or counterclaim in a lawsuit) alleging
that this implementation of Go or any code incorporated within this
implementation of Go constitutes direct or contributory patent
infringement, or inducement of patent infringement, then any patent
rights granted to you under this License for this implementation of Go
shall terminate as of the date such litigation is filed.

18
vendor/golang.org/x/sync/README.md generated vendored Normal file
View File

@ -0,0 +1,18 @@
# Go Sync
This repository provides Go concurrency primitives in addition to the
ones provided by the language and "sync" and "sync/atomic" packages.
## Download/Install
The easiest way to install is to run `go get -u golang.org/x/sync`. You can
also manually git clone the repository to `$GOPATH/src/golang.org/x/sync`.
## Report Issues / Send Patches
This repository uses Gerrit for code changes. To learn how to submit changes to
this repository, see https://golang.org/doc/contribute.html.
The main issue tracker for the sync repository is located at
https://github.com/golang/go/issues. Prefix your issue with "x/sync:" in the
subject line, so it is easy to find.

1
vendor/golang.org/x/sync/codereview.cfg generated vendored Normal file
View File

@ -0,0 +1 @@
issuerepo: golang/go

67
vendor/golang.org/x/sync/errgroup/errgroup.go generated vendored Normal file
View File

@ -0,0 +1,67 @@
// Copyright 2016 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package errgroup provides synchronization, error propagation, and Context
// cancelation for groups of goroutines working on subtasks of a common task.
package errgroup
import (
"sync"
"golang.org/x/net/context"
)
// A Group is a collection of goroutines working on subtasks that are part of
// the same overall task.
//
// A zero Group is valid and does not cancel on error.
type Group struct {
cancel func()
wg sync.WaitGroup
errOnce sync.Once
err error
}
// WithContext returns a new Group and an associated Context derived from ctx.
//
// The derived Context is canceled the first time a function passed to Go
// returns a non-nil error or the first time Wait returns, whichever occurs
// first.
func WithContext(ctx context.Context) (*Group, context.Context) {
ctx, cancel := context.WithCancel(ctx)
return &Group{cancel: cancel}, ctx
}
// Wait blocks until all function calls from the Go method have returned, then
// returns the first non-nil error (if any) from them.
func (g *Group) Wait() error {
g.wg.Wait()
if g.cancel != nil {
g.cancel()
}
return g.err
}
// Go calls the given function in a new goroutine.
//
// The first call to return a non-nil error cancels the group; its error will be
// returned by Wait.
func (g *Group) Go(f func() error) {
g.wg.Add(1)
go func() {
defer g.wg.Done()
if err := f(); err != nil {
g.errOnce.Do(func() {
g.err = err
if g.cancel != nil {
g.cancel()
}
})
}
}()
}

View File

@ -0,0 +1,101 @@
// Copyright 2016 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package errgroup_test
import (
"crypto/md5"
"fmt"
"io/ioutil"
"log"
"os"
"path/filepath"
"golang.org/x/net/context"
"golang.org/x/sync/errgroup"
)
// Pipeline demonstrates the use of a Group to implement a multi-stage
// pipeline: a version of the MD5All function with bounded parallelism from
// https://blog.golang.org/pipelines.
func ExampleGroup_pipeline() {
m, err := MD5All(context.Background(), ".")
if err != nil {
log.Fatal(err)
}
for k, sum := range m {
fmt.Printf("%s:\t%x\n", k, sum)
}
}
type result struct {
path string
sum [md5.Size]byte
}
// MD5All reads all the files in the file tree rooted at root and returns a map
// from file path to the MD5 sum of the file's contents. If the directory walk
// fails or any read operation fails, MD5All returns an error.
func MD5All(ctx context.Context, root string) (map[string][md5.Size]byte, error) {
// ctx is canceled when g.Wait() returns. When this version of MD5All returns
// - even in case of error! - we know that all of the goroutines have finished
// and the memory they were using can be garbage-collected.
g, ctx := errgroup.WithContext(ctx)
paths := make(chan string)
g.Go(func() error {
defer close(paths)
return filepath.Walk(root, func(path string, info os.FileInfo, err error) error {
if err != nil {
return err
}
if !info.Mode().IsRegular() {
return nil
}
select {
case paths <- path:
case <-ctx.Done():
return ctx.Err()
}
return nil
})
})
// Start a fixed number of goroutines to read and digest files.
c := make(chan result)
const numDigesters = 20
for i := 0; i < numDigesters; i++ {
g.Go(func() error {
for path := range paths {
data, err := ioutil.ReadFile(path)
if err != nil {
return err
}
select {
case c <- result{path, md5.Sum(data)}:
case <-ctx.Done():
return ctx.Err()
}
}
return nil
})
}
go func() {
g.Wait()
close(c)
}()
m := make(map[string][md5.Size]byte)
for r := range c {
m[r.path] = r.sum
}
// Check whether any of the goroutines failed. Since g is accumulating the
// errors, we don't need to send them (or check for them) in the individual
// results sent on the channel.
if err := g.Wait(); err != nil {
return nil, err
}
return m, nil
}

176
vendor/golang.org/x/sync/errgroup/errgroup_test.go generated vendored Normal file
View File

@ -0,0 +1,176 @@
// Copyright 2016 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package errgroup_test
import (
"errors"
"fmt"
"net/http"
"os"
"testing"
"golang.org/x/net/context"
"golang.org/x/sync/errgroup"
)
var (
Web = fakeSearch("web")
Image = fakeSearch("image")
Video = fakeSearch("video")
)
type Result string
type Search func(ctx context.Context, query string) (Result, error)
func fakeSearch(kind string) Search {
return func(_ context.Context, query string) (Result, error) {
return Result(fmt.Sprintf("%s result for %q", kind, query)), nil
}
}
// JustErrors illustrates the use of a Group in place of a sync.WaitGroup to
// simplify goroutine counting and error handling. This example is derived from
// the sync.WaitGroup example at https://golang.org/pkg/sync/#example_WaitGroup.
func ExampleGroup_justErrors() {
var g errgroup.Group
var urls = []string{
"http://www.golang.org/",
"http://www.google.com/",
"http://www.somestupidname.com/",
}
for _, url := range urls {
// Launch a goroutine to fetch the URL.
url := url // https://golang.org/doc/faq#closures_and_goroutines
g.Go(func() error {
// Fetch the URL.
resp, err := http.Get(url)
if err == nil {
resp.Body.Close()
}
return err
})
}
// Wait for all HTTP fetches to complete.
if err := g.Wait(); err == nil {
fmt.Println("Successfully fetched all URLs.")
}
}
// Parallel illustrates the use of a Group for synchronizing a simple parallel
// task: the "Google Search 2.0" function from
// https://talks.golang.org/2012/concurrency.slide#46, augmented with a Context
// and error-handling.
func ExampleGroup_parallel() {
Google := func(ctx context.Context, query string) ([]Result, error) {
g, ctx := errgroup.WithContext(ctx)
searches := []Search{Web, Image, Video}
results := make([]Result, len(searches))
for i, search := range searches {
i, search := i, search // https://golang.org/doc/faq#closures_and_goroutines
g.Go(func() error {
result, err := search(ctx, query)
if err == nil {
results[i] = result
}
return err
})
}
if err := g.Wait(); err != nil {
return nil, err
}
return results, nil
}
results, err := Google(context.Background(), "golang")
if err != nil {
fmt.Fprintln(os.Stderr, err)
return
}
for _, result := range results {
fmt.Println(result)
}
// Output:
// web result for "golang"
// image result for "golang"
// video result for "golang"
}
func TestZeroGroup(t *testing.T) {
err1 := errors.New("errgroup_test: 1")
err2 := errors.New("errgroup_test: 2")
cases := []struct {
errs []error
}{
{errs: []error{}},
{errs: []error{nil}},
{errs: []error{err1}},
{errs: []error{err1, nil}},
{errs: []error{err1, nil, err2}},
}
for _, tc := range cases {
var g errgroup.Group
var firstErr error
for i, err := range tc.errs {
err := err
g.Go(func() error { return err })
if firstErr == nil && err != nil {
firstErr = err
}
if gErr := g.Wait(); gErr != firstErr {
t.Errorf("after %T.Go(func() error { return err }) for err in %v\n"+
"g.Wait() = %v; want %v",
g, tc.errs[:i+1], err, firstErr)
}
}
}
}
func TestWithContext(t *testing.T) {
errDoom := errors.New("group_test: doomed")
cases := []struct {
errs []error
want error
}{
{want: nil},
{errs: []error{nil}, want: nil},
{errs: []error{errDoom}, want: errDoom},
{errs: []error{errDoom, nil}, want: errDoom},
}
for _, tc := range cases {
g, ctx := errgroup.WithContext(context.Background())
for _, err := range tc.errs {
err := err
g.Go(func() error { return err })
}
if err := g.Wait(); err != tc.want {
t.Errorf("after %T.Go(func() error { return err }) for err in %v\n"+
"g.Wait() = %v; want %v",
g, tc.errs, err, tc.want)
}
canceled := false
select {
case <-ctx.Done():
canceled = true
default:
}
if !canceled {
t.Errorf("after %T.Go(func() error { return err }) for err in %v\n"+
"ctx.Done() was not closed",
g, tc.errs)
}
}
}

131
vendor/golang.org/x/sync/semaphore/semaphore.go generated vendored Normal file
View File

@ -0,0 +1,131 @@
// Copyright 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package semaphore provides a weighted semaphore implementation.
package semaphore // import "golang.org/x/sync/semaphore"
import (
"container/list"
"sync"
// Use the old context because packages that depend on this one
// (e.g. cloud.google.com/go/...) must run on Go 1.6.
// TODO(jba): update to "context" when possible.
"golang.org/x/net/context"
)
type waiter struct {
n int64
ready chan<- struct{} // Closed when semaphore acquired.
}
// NewWeighted creates a new weighted semaphore with the given
// maximum combined weight for concurrent access.
func NewWeighted(n int64) *Weighted {
w := &Weighted{size: n}
return w
}
// Weighted provides a way to bound concurrent access to a resource.
// The callers can request access with a given weight.
type Weighted struct {
size int64
cur int64
mu sync.Mutex
waiters list.List
}
// Acquire acquires the semaphore with a weight of n, blocking only until ctx
// is done. On success, returns nil. On failure, returns ctx.Err() and leaves
// the semaphore unchanged.
//
// If ctx is already done, Acquire may still succeed without blocking.
func (s *Weighted) Acquire(ctx context.Context, n int64) error {
s.mu.Lock()
if s.size-s.cur >= n && s.waiters.Len() == 0 {
s.cur += n
s.mu.Unlock()
return nil
}
if n > s.size {
// Don't make other Acquire calls block on one that's doomed to fail.
s.mu.Unlock()
<-ctx.Done()
return ctx.Err()
}
ready := make(chan struct{})
w := waiter{n: n, ready: ready}
elem := s.waiters.PushBack(w)
s.mu.Unlock()
select {
case <-ctx.Done():
err := ctx.Err()
s.mu.Lock()
select {
case <-ready:
// Acquired the semaphore after we were canceled. Rather than trying to
// fix up the queue, just pretend we didn't notice the cancelation.
err = nil
default:
s.waiters.Remove(elem)
}
s.mu.Unlock()
return err
case <-ready:
return nil
}
}
// TryAcquire acquires the semaphore with a weight of n without blocking.
// On success, returns true. On failure, returns false and leaves the semaphore unchanged.
func (s *Weighted) TryAcquire(n int64) bool {
s.mu.Lock()
success := s.size-s.cur >= n && s.waiters.Len() == 0
if success {
s.cur += n
}
s.mu.Unlock()
return success
}
// Release releases the semaphore with a weight of n.
func (s *Weighted) Release(n int64) {
s.mu.Lock()
s.cur -= n
if s.cur < 0 {
s.mu.Unlock()
panic("semaphore: bad release")
}
for {
next := s.waiters.Front()
if next == nil {
break // No more waiters blocked.
}
w := next.Value.(waiter)
if s.size-s.cur < w.n {
// Not enough tokens for the next waiter. We could keep going (to try to
// find a waiter with a smaller request), but under load that could cause
// starvation for large requests; instead, we leave all remaining waiters
// blocked.
//
// Consider a semaphore used as a read-write lock, with N tokens, N
// readers, and one writer. Each reader can Acquire(1) to obtain a read
// lock. The writer can Acquire(N) to obtain a write lock, excluding all
// of the readers. If we allow the readers to jump ahead in the queue,
// the writer will starve — there is always one token available for every
// reader.
break
}
s.cur += w.n
s.waiters.Remove(next)
close(w.ready)
}
s.mu.Unlock()
}

View File

@ -0,0 +1,131 @@
// Copyright 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build go1.7
package semaphore_test
import (
"fmt"
"testing"
"golang.org/x/net/context"
"golang.org/x/sync/semaphore"
)
// weighted is an interface matching a subset of *Weighted. It allows
// alternate implementations for testing and benchmarking.
type weighted interface {
Acquire(context.Context, int64) error
TryAcquire(int64) bool
Release(int64)
}
// semChan implements Weighted using a channel for
// comparing against the condition variable-based implementation.
type semChan chan struct{}
func newSemChan(n int64) semChan {
return semChan(make(chan struct{}, n))
}
func (s semChan) Acquire(_ context.Context, n int64) error {
for i := int64(0); i < n; i++ {
s <- struct{}{}
}
return nil
}
func (s semChan) TryAcquire(n int64) bool {
if int64(len(s))+n > int64(cap(s)) {
return false
}
for i := int64(0); i < n; i++ {
s <- struct{}{}
}
return true
}
func (s semChan) Release(n int64) {
for i := int64(0); i < n; i++ {
<-s
}
}
// acquireN calls Acquire(size) on sem N times and then calls Release(size) N times.
func acquireN(b *testing.B, sem weighted, size int64, N int) {
b.ResetTimer()
for i := 0; i < b.N; i++ {
for j := 0; j < N; j++ {
sem.Acquire(context.Background(), size)
}
for j := 0; j < N; j++ {
sem.Release(size)
}
}
}
// tryAcquireN calls TryAcquire(size) on sem N times and then calls Release(size) N times.
func tryAcquireN(b *testing.B, sem weighted, size int64, N int) {
b.ResetTimer()
for i := 0; i < b.N; i++ {
for j := 0; j < N; j++ {
if !sem.TryAcquire(size) {
b.Fatalf("TryAcquire(%v) = false, want true", size)
}
}
for j := 0; j < N; j++ {
sem.Release(size)
}
}
}
func BenchmarkNewSeq(b *testing.B) {
for _, cap := range []int64{1, 128} {
b.Run(fmt.Sprintf("Weighted-%d", cap), func(b *testing.B) {
for i := 0; i < b.N; i++ {
_ = semaphore.NewWeighted(cap)
}
})
b.Run(fmt.Sprintf("semChan-%d", cap), func(b *testing.B) {
for i := 0; i < b.N; i++ {
_ = newSemChan(cap)
}
})
}
}
func BenchmarkAcquireSeq(b *testing.B) {
for _, c := range []struct {
cap, size int64
N int
}{
{1, 1, 1},
{2, 1, 1},
{16, 1, 1},
{128, 1, 1},
{2, 2, 1},
{16, 2, 8},
{128, 2, 64},
{2, 1, 2},
{16, 8, 2},
{128, 64, 2},
} {
for _, w := range []struct {
name string
w weighted
}{
{"Weighted", semaphore.NewWeighted(c.cap)},
{"semChan", newSemChan(c.cap)},
} {
b.Run(fmt.Sprintf("%s-acquire-%d-%d-%d", w.name, c.cap, c.size, c.N), func(b *testing.B) {
acquireN(b, w.w, c.size, c.N)
})
b.Run(fmt.Sprintf("%s-tryAcquire-%d-%d-%d", w.name, c.cap, c.size, c.N), func(b *testing.B) {
tryAcquireN(b, w.w, c.size, c.N)
})
}
}
}

View File

@ -0,0 +1,84 @@
// Copyright 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package semaphore_test
import (
"context"
"fmt"
"log"
"runtime"
"golang.org/x/sync/semaphore"
)
// Example_workerPool demonstrates how to use a semaphore to limit the number of
// goroutines working on parallel tasks.
//
// This use of a semaphore mimics a typical “worker pool” pattern, but without
// the need to explicitly shut down idle workers when the work is done.
func Example_workerPool() {
ctx := context.TODO()
var (
maxWorkers = runtime.GOMAXPROCS(0)
sem = semaphore.NewWeighted(int64(maxWorkers))
out = make([]int, 32)
)
// Compute the output using up to maxWorkers goroutines at a time.
for i := range out {
// When maxWorkers goroutines are in flight, Acquire blocks until one of the
// workers finishes.
if err := sem.Acquire(ctx, 1); err != nil {
log.Printf("Failed to acquire semaphore: %v", err)
break
}
go func(i int) {
defer sem.Release(1)
out[i] = collatzSteps(i + 1)
}(i)
}
// Acquire all of the tokens to wait for any remaining workers to finish.
//
// If you are already waiting for the workers by some other means (such as an
// errgroup.Group), you can omit this final Acquire call.
if err := sem.Acquire(ctx, int64(maxWorkers)); err != nil {
log.Printf("Failed to acquire semaphore: %v", err)
}
fmt.Println(out)
// Output:
// [0 1 7 2 5 8 16 3 19 6 14 9 9 17 17 4 12 20 20 7 7 15 15 10 23 10 111 18 18 18 106 5]
}
// collatzSteps computes the number of steps to reach 1 under the Collatz
// conjecture. (See https://en.wikipedia.org/wiki/Collatz_conjecture.)
func collatzSteps(n int) (steps int) {
if n <= 0 {
panic("nonpositive input")
}
for ; n > 1; steps++ {
if steps < 0 {
panic("too many steps")
}
if n%2 == 0 {
n /= 2
continue
}
const maxInt = int(^uint(0) >> 1)
if n > (maxInt-1)/3 {
panic("overflow")
}
n = 3*n + 1
}
return steps
}

171
vendor/golang.org/x/sync/semaphore/semaphore_test.go generated vendored Normal file
View File

@ -0,0 +1,171 @@
// Copyright 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package semaphore_test
import (
"math/rand"
"runtime"
"sync"
"testing"
"time"
"golang.org/x/net/context"
"golang.org/x/sync/errgroup"
"golang.org/x/sync/semaphore"
)
const maxSleep = 1 * time.Millisecond
func HammerWeighted(sem *semaphore.Weighted, n int64, loops int) {
for i := 0; i < loops; i++ {
sem.Acquire(context.Background(), n)
time.Sleep(time.Duration(rand.Int63n(int64(maxSleep/time.Nanosecond))) * time.Nanosecond)
sem.Release(n)
}
}
func TestWeighted(t *testing.T) {
t.Parallel()
n := runtime.GOMAXPROCS(0)
loops := 10000 / n
sem := semaphore.NewWeighted(int64(n))
var wg sync.WaitGroup
wg.Add(n)
for i := 0; i < n; i++ {
i := i
go func() {
defer wg.Done()
HammerWeighted(sem, int64(i), loops)
}()
}
wg.Wait()
}
func TestWeightedPanic(t *testing.T) {
t.Parallel()
defer func() {
if recover() == nil {
t.Fatal("release of an unacquired weighted semaphore did not panic")
}
}()
w := semaphore.NewWeighted(1)
w.Release(1)
}
func TestWeightedTryAcquire(t *testing.T) {
t.Parallel()
ctx := context.Background()
sem := semaphore.NewWeighted(2)
tries := []bool{}
sem.Acquire(ctx, 1)
tries = append(tries, sem.TryAcquire(1))
tries = append(tries, sem.TryAcquire(1))
sem.Release(2)
tries = append(tries, sem.TryAcquire(1))
sem.Acquire(ctx, 1)
tries = append(tries, sem.TryAcquire(1))
want := []bool{true, false, true, false}
for i := range tries {
if tries[i] != want[i] {
t.Errorf("tries[%d]: got %t, want %t", i, tries[i], want[i])
}
}
}
func TestWeightedAcquire(t *testing.T) {
t.Parallel()
ctx := context.Background()
sem := semaphore.NewWeighted(2)
tryAcquire := func(n int64) bool {
ctx, cancel := context.WithTimeout(ctx, 10*time.Millisecond)
defer cancel()
return sem.Acquire(ctx, n) == nil
}
tries := []bool{}
sem.Acquire(ctx, 1)
tries = append(tries, tryAcquire(1))
tries = append(tries, tryAcquire(1))
sem.Release(2)
tries = append(tries, tryAcquire(1))
sem.Acquire(ctx, 1)
tries = append(tries, tryAcquire(1))
want := []bool{true, false, true, false}
for i := range tries {
if tries[i] != want[i] {
t.Errorf("tries[%d]: got %t, want %t", i, tries[i], want[i])
}
}
}
func TestWeightedDoesntBlockIfTooBig(t *testing.T) {
t.Parallel()
const n = 2
sem := semaphore.NewWeighted(n)
{
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go sem.Acquire(ctx, n+1)
}
g, ctx := errgroup.WithContext(context.Background())
for i := n * 3; i > 0; i-- {
g.Go(func() error {
err := sem.Acquire(ctx, 1)
if err == nil {
time.Sleep(1 * time.Millisecond)
sem.Release(1)
}
return err
})
}
if err := g.Wait(); err != nil {
t.Errorf("semaphore.NewWeighted(%v) failed to AcquireCtx(_, 1) with AcquireCtx(_, %v) pending", n, n+1)
}
}
// TestLargeAcquireDoesntStarve times out if a large call to Acquire starves.
// Merely returning from the test function indicates success.
func TestLargeAcquireDoesntStarve(t *testing.T) {
t.Parallel()
ctx := context.Background()
n := int64(runtime.GOMAXPROCS(0))
sem := semaphore.NewWeighted(n)
running := true
var wg sync.WaitGroup
wg.Add(int(n))
for i := n; i > 0; i-- {
sem.Acquire(ctx, 1)
go func() {
defer func() {
sem.Release(1)
wg.Done()
}()
for running {
time.Sleep(1 * time.Millisecond)
sem.Release(1)
sem.Acquire(ctx, 1)
}
}()
}
sem.Acquire(ctx, n)
running = false
sem.Release(n)
wg.Wait()
}

111
vendor/golang.org/x/sync/singleflight/singleflight.go generated vendored Normal file
View File

@ -0,0 +1,111 @@
// Copyright 2013 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package singleflight provides a duplicate function call suppression
// mechanism.
package singleflight // import "golang.org/x/sync/singleflight"
import "sync"
// call is an in-flight or completed singleflight.Do call
type call struct {
wg sync.WaitGroup
// These fields are written once before the WaitGroup is done
// and are only read after the WaitGroup is done.
val interface{}
err error
// These fields are read and written with the singleflight
// mutex held before the WaitGroup is done, and are read but
// not written after the WaitGroup is done.
dups int
chans []chan<- Result
}
// Group represents a class of work and forms a namespace in
// which units of work can be executed with duplicate suppression.
type Group struct {
mu sync.Mutex // protects m
m map[string]*call // lazily initialized
}
// Result holds the results of Do, so they can be passed
// on a channel.
type Result struct {
Val interface{}
Err error
Shared bool
}
// Do executes and returns the results of the given function, making
// sure that only one execution is in-flight for a given key at a
// time. If a duplicate comes in, the duplicate caller waits for the
// original to complete and receives the same results.
// The return value shared indicates whether v was given to multiple callers.
func (g *Group) Do(key string, fn func() (interface{}, error)) (v interface{}, err error, shared bool) {
g.mu.Lock()
if g.m == nil {
g.m = make(map[string]*call)
}
if c, ok := g.m[key]; ok {
c.dups++
g.mu.Unlock()
c.wg.Wait()
return c.val, c.err, true
}
c := new(call)
c.wg.Add(1)
g.m[key] = c
g.mu.Unlock()
g.doCall(c, key, fn)
return c.val, c.err, c.dups > 0
}
// DoChan is like Do but returns a channel that will receive the
// results when they are ready.
func (g *Group) DoChan(key string, fn func() (interface{}, error)) <-chan Result {
ch := make(chan Result, 1)
g.mu.Lock()
if g.m == nil {
g.m = make(map[string]*call)
}
if c, ok := g.m[key]; ok {
c.dups++
c.chans = append(c.chans, ch)
g.mu.Unlock()
return ch
}
c := &call{chans: []chan<- Result{ch}}
c.wg.Add(1)
g.m[key] = c
g.mu.Unlock()
go g.doCall(c, key, fn)
return ch
}
// doCall handles the single call for a key.
func (g *Group) doCall(c *call, key string, fn func() (interface{}, error)) {
c.val, c.err = fn()
c.wg.Done()
g.mu.Lock()
delete(g.m, key)
for _, ch := range c.chans {
ch <- Result{c.val, c.err, c.dups > 0}
}
g.mu.Unlock()
}
// Forget tells the singleflight to forget about a key. Future calls
// to Do for this key will call the function rather than waiting for
// an earlier call to complete.
func (g *Group) Forget(key string) {
g.mu.Lock()
delete(g.m, key)
g.mu.Unlock()
}

View File

@ -0,0 +1,87 @@
// Copyright 2013 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package singleflight
import (
"errors"
"fmt"
"sync"
"sync/atomic"
"testing"
"time"
)
func TestDo(t *testing.T) {
var g Group
v, err, _ := g.Do("key", func() (interface{}, error) {
return "bar", nil
})
if got, want := fmt.Sprintf("%v (%T)", v, v), "bar (string)"; got != want {
t.Errorf("Do = %v; want %v", got, want)
}
if err != nil {
t.Errorf("Do error = %v", err)
}
}
func TestDoErr(t *testing.T) {
var g Group
someErr := errors.New("Some error")
v, err, _ := g.Do("key", func() (interface{}, error) {
return nil, someErr
})
if err != someErr {
t.Errorf("Do error = %v; want someErr %v", err, someErr)
}
if v != nil {
t.Errorf("unexpected non-nil value %#v", v)
}
}
func TestDoDupSuppress(t *testing.T) {
var g Group
var wg1, wg2 sync.WaitGroup
c := make(chan string, 1)
var calls int32
fn := func() (interface{}, error) {
if atomic.AddInt32(&calls, 1) == 1 {
// First invocation.
wg1.Done()
}
v := <-c
c <- v // pump; make available for any future calls
time.Sleep(10 * time.Millisecond) // let more goroutines enter Do
return v, nil
}
const n = 10
wg1.Add(1)
for i := 0; i < n; i++ {
wg1.Add(1)
wg2.Add(1)
go func() {
defer wg2.Done()
wg1.Done()
v, err, _ := g.Do("key", fn)
if err != nil {
t.Errorf("Do error: %v", err)
return
}
if s, _ := v.(string); s != "bar" {
t.Errorf("Do = %T %v; want %q", v, v, "bar")
}
}()
}
wg1.Wait()
// At least one goroutine is in fn now and all of them have at
// least reached the line before the Do.
c <- "bar"
wg2.Wait()
if got := atomic.LoadInt32(&calls); got <= 0 || got >= n {
t.Errorf("number of calls = %d; want over 0 and less than %d", got, n)
}
}

372
vendor/golang.org/x/sync/syncmap/map.go generated vendored Normal file
View File

@ -0,0 +1,372 @@
// Copyright 2016 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package syncmap provides a concurrent map implementation.
// It is a prototype for a proposed addition to the sync package
// in the standard library.
// (https://golang.org/issue/18177)
package syncmap
import (
"sync"
"sync/atomic"
"unsafe"
)
// Map is a concurrent map with amortized-constant-time loads, stores, and deletes.
// It is safe for multiple goroutines to call a Map's methods concurrently.
//
// The zero Map is valid and empty.
//
// A Map must not be copied after first use.
type Map struct {
mu sync.Mutex
// read contains the portion of the map's contents that are safe for
// concurrent access (with or without mu held).
//
// The read field itself is always safe to load, but must only be stored with
// mu held.
//
// Entries stored in read may be updated concurrently without mu, but updating
// a previously-expunged entry requires that the entry be copied to the dirty
// map and unexpunged with mu held.
read atomic.Value // readOnly
// dirty contains the portion of the map's contents that require mu to be
// held. To ensure that the dirty map can be promoted to the read map quickly,
// it also includes all of the non-expunged entries in the read map.
//
// Expunged entries are not stored in the dirty map. An expunged entry in the
// clean map must be unexpunged and added to the dirty map before a new value
// can be stored to it.
//
// If the dirty map is nil, the next write to the map will initialize it by
// making a shallow copy of the clean map, omitting stale entries.
dirty map[interface{}]*entry
// misses counts the number of loads since the read map was last updated that
// needed to lock mu to determine whether the key was present.
//
// Once enough misses have occurred to cover the cost of copying the dirty
// map, the dirty map will be promoted to the read map (in the unamended
// state) and the next store to the map will make a new dirty copy.
misses int
}
// readOnly is an immutable struct stored atomically in the Map.read field.
type readOnly struct {
m map[interface{}]*entry
amended bool // true if the dirty map contains some key not in m.
}
// expunged is an arbitrary pointer that marks entries which have been deleted
// from the dirty map.
var expunged = unsafe.Pointer(new(interface{}))
// An entry is a slot in the map corresponding to a particular key.
type entry struct {
// p points to the interface{} value stored for the entry.
//
// If p == nil, the entry has been deleted and m.dirty == nil.
//
// If p == expunged, the entry has been deleted, m.dirty != nil, and the entry
// is missing from m.dirty.
//
// Otherwise, the entry is valid and recorded in m.read.m[key] and, if m.dirty
// != nil, in m.dirty[key].
//
// An entry can be deleted by atomic replacement with nil: when m.dirty is
// next created, it will atomically replace nil with expunged and leave
// m.dirty[key] unset.
//
// An entry's associated value can be updated by atomic replacement, provided
// p != expunged. If p == expunged, an entry's associated value can be updated
// only after first setting m.dirty[key] = e so that lookups using the dirty
// map find the entry.
p unsafe.Pointer // *interface{}
}
func newEntry(i interface{}) *entry {
return &entry{p: unsafe.Pointer(&i)}
}
// Load returns the value stored in the map for a key, or nil if no
// value is present.
// The ok result indicates whether value was found in the map.
func (m *Map) Load(key interface{}) (value interface{}, ok bool) {
read, _ := m.read.Load().(readOnly)
e, ok := read.m[key]
if !ok && read.amended {
m.mu.Lock()
// Avoid reporting a spurious miss if m.dirty got promoted while we were
// blocked on m.mu. (If further loads of the same key will not miss, it's
// not worth copying the dirty map for this key.)
read, _ = m.read.Load().(readOnly)
e, ok = read.m[key]
if !ok && read.amended {
e, ok = m.dirty[key]
// Regardless of whether the entry was present, record a miss: this key
// will take the slow path until the dirty map is promoted to the read
// map.
m.missLocked()
}
m.mu.Unlock()
}
if !ok {
return nil, false
}
return e.load()
}
func (e *entry) load() (value interface{}, ok bool) {
p := atomic.LoadPointer(&e.p)
if p == nil || p == expunged {
return nil, false
}
return *(*interface{})(p), true
}
// Store sets the value for a key.
func (m *Map) Store(key, value interface{}) {
read, _ := m.read.Load().(readOnly)
if e, ok := read.m[key]; ok && e.tryStore(&value) {
return
}
m.mu.Lock()
read, _ = m.read.Load().(readOnly)
if e, ok := read.m[key]; ok {
if e.unexpungeLocked() {
// The entry was previously expunged, which implies that there is a
// non-nil dirty map and this entry is not in it.
m.dirty[key] = e
}
e.storeLocked(&value)
} else if e, ok := m.dirty[key]; ok {
e.storeLocked(&value)
} else {
if !read.amended {
// We're adding the first new key to the dirty map.
// Make sure it is allocated and mark the read-only map as incomplete.
m.dirtyLocked()
m.read.Store(readOnly{m: read.m, amended: true})
}
m.dirty[key] = newEntry(value)
}
m.mu.Unlock()
}
// tryStore stores a value if the entry has not been expunged.
//
// If the entry is expunged, tryStore returns false and leaves the entry
// unchanged.
func (e *entry) tryStore(i *interface{}) bool {
p := atomic.LoadPointer(&e.p)
if p == expunged {
return false
}
for {
if atomic.CompareAndSwapPointer(&e.p, p, unsafe.Pointer(i)) {
return true
}
p = atomic.LoadPointer(&e.p)
if p == expunged {
return false
}
}
}
// unexpungeLocked ensures that the entry is not marked as expunged.
//
// If the entry was previously expunged, it must be added to the dirty map
// before m.mu is unlocked.
func (e *entry) unexpungeLocked() (wasExpunged bool) {
return atomic.CompareAndSwapPointer(&e.p, expunged, nil)
}
// storeLocked unconditionally stores a value to the entry.
//
// The entry must be known not to be expunged.
func (e *entry) storeLocked(i *interface{}) {
atomic.StorePointer(&e.p, unsafe.Pointer(i))
}
// LoadOrStore returns the existing value for the key if present.
// Otherwise, it stores and returns the given value.
// The loaded result is true if the value was loaded, false if stored.
func (m *Map) LoadOrStore(key, value interface{}) (actual interface{}, loaded bool) {
// Avoid locking if it's a clean hit.
read, _ := m.read.Load().(readOnly)
if e, ok := read.m[key]; ok {
actual, loaded, ok := e.tryLoadOrStore(value)
if ok {
return actual, loaded
}
}
m.mu.Lock()
read, _ = m.read.Load().(readOnly)
if e, ok := read.m[key]; ok {
if e.unexpungeLocked() {
m.dirty[key] = e
}
actual, loaded, _ = e.tryLoadOrStore(value)
} else if e, ok := m.dirty[key]; ok {
actual, loaded, _ = e.tryLoadOrStore(value)
m.missLocked()
} else {
if !read.amended {
// We're adding the first new key to the dirty map.
// Make sure it is allocated and mark the read-only map as incomplete.
m.dirtyLocked()
m.read.Store(readOnly{m: read.m, amended: true})
}
m.dirty[key] = newEntry(value)
actual, loaded = value, false
}
m.mu.Unlock()
return actual, loaded
}
// tryLoadOrStore atomically loads or stores a value if the entry is not
// expunged.
//
// If the entry is expunged, tryLoadOrStore leaves the entry unchanged and
// returns with ok==false.
func (e *entry) tryLoadOrStore(i interface{}) (actual interface{}, loaded, ok bool) {
p := atomic.LoadPointer(&e.p)
if p == expunged {
return nil, false, false
}
if p != nil {
return *(*interface{})(p), true, true
}
// Copy the interface after the first load to make this method more amenable
// to escape analysis: if we hit the "load" path or the entry is expunged, we
// shouldn't bother heap-allocating.
ic := i
for {
if atomic.CompareAndSwapPointer(&e.p, nil, unsafe.Pointer(&ic)) {
return i, false, true
}
p = atomic.LoadPointer(&e.p)
if p == expunged {
return nil, false, false
}
if p != nil {
return *(*interface{})(p), true, true
}
}
}
// Delete deletes the value for a key.
func (m *Map) Delete(key interface{}) {
read, _ := m.read.Load().(readOnly)
e, ok := read.m[key]
if !ok && read.amended {
m.mu.Lock()
read, _ = m.read.Load().(readOnly)
e, ok = read.m[key]
if !ok && read.amended {
delete(m.dirty, key)
}
m.mu.Unlock()
}
if ok {
e.delete()
}
}
func (e *entry) delete() (hadValue bool) {
for {
p := atomic.LoadPointer(&e.p)
if p == nil || p == expunged {
return false
}
if atomic.CompareAndSwapPointer(&e.p, p, nil) {
return true
}
}
}
// Range calls f sequentially for each key and value present in the map.
// If f returns false, range stops the iteration.
//
// Range does not necessarily correspond to any consistent snapshot of the Map's
// contents: no key will be visited more than once, but if the value for any key
// is stored or deleted concurrently, Range may reflect any mapping for that key
// from any point during the Range call.
//
// Range may be O(N) with the number of elements in the map even if f returns
// false after a constant number of calls.
func (m *Map) Range(f func(key, value interface{}) bool) {
// We need to be able to iterate over all of the keys that were already
// present at the start of the call to Range.
// If read.amended is false, then read.m satisfies that property without
// requiring us to hold m.mu for a long time.
read, _ := m.read.Load().(readOnly)
if read.amended {
// m.dirty contains keys not in read.m. Fortunately, Range is already O(N)
// (assuming the caller does not break out early), so a call to Range
// amortizes an entire copy of the map: we can promote the dirty copy
// immediately!
m.mu.Lock()
read, _ = m.read.Load().(readOnly)
if read.amended {
read = readOnly{m: m.dirty}
m.read.Store(read)
m.dirty = nil
m.misses = 0
}
m.mu.Unlock()
}
for k, e := range read.m {
v, ok := e.load()
if !ok {
continue
}
if !f(k, v) {
break
}
}
}
func (m *Map) missLocked() {
m.misses++
if m.misses < len(m.dirty) {
return
}
m.read.Store(readOnly{m: m.dirty})
m.dirty = nil
m.misses = 0
}
func (m *Map) dirtyLocked() {
if m.dirty != nil {
return
}
read, _ := m.read.Load().(readOnly)
m.dirty = make(map[interface{}]*entry, len(read.m))
for k, e := range read.m {
if !e.tryExpungeLocked() {
m.dirty[k] = e
}
}
}
func (e *entry) tryExpungeLocked() (isExpunged bool) {
p := atomic.LoadPointer(&e.p)
for p == nil {
if atomic.CompareAndSwapPointer(&e.p, nil, expunged) {
return true
}
p = atomic.LoadPointer(&e.p)
}
return p == expunged
}

216
vendor/golang.org/x/sync/syncmap/map_bench_test.go generated vendored Normal file
View File

@ -0,0 +1,216 @@
// Copyright 2016 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package syncmap_test
import (
"fmt"
"reflect"
"sync/atomic"
"testing"
"golang.org/x/sync/syncmap"
)
type bench struct {
setup func(*testing.B, mapInterface)
perG func(b *testing.B, pb *testing.PB, i int, m mapInterface)
}
func benchMap(b *testing.B, bench bench) {
for _, m := range [...]mapInterface{&DeepCopyMap{}, &RWMutexMap{}, &syncmap.Map{}} {
b.Run(fmt.Sprintf("%T", m), func(b *testing.B) {
m = reflect.New(reflect.TypeOf(m).Elem()).Interface().(mapInterface)
if bench.setup != nil {
bench.setup(b, m)
}
b.ResetTimer()
var i int64
b.RunParallel(func(pb *testing.PB) {
id := int(atomic.AddInt64(&i, 1) - 1)
bench.perG(b, pb, id*b.N, m)
})
})
}
}
func BenchmarkLoadMostlyHits(b *testing.B) {
const hits, misses = 1023, 1
benchMap(b, bench{
setup: func(_ *testing.B, m mapInterface) {
for i := 0; i < hits; i++ {
m.LoadOrStore(i, i)
}
// Prime the map to get it into a steady state.
for i := 0; i < hits*2; i++ {
m.Load(i % hits)
}
},
perG: func(b *testing.B, pb *testing.PB, i int, m mapInterface) {
for ; pb.Next(); i++ {
m.Load(i % (hits + misses))
}
},
})
}
func BenchmarkLoadMostlyMisses(b *testing.B) {
const hits, misses = 1, 1023
benchMap(b, bench{
setup: func(_ *testing.B, m mapInterface) {
for i := 0; i < hits; i++ {
m.LoadOrStore(i, i)
}
// Prime the map to get it into a steady state.
for i := 0; i < hits*2; i++ {
m.Load(i % hits)
}
},
perG: func(b *testing.B, pb *testing.PB, i int, m mapInterface) {
for ; pb.Next(); i++ {
m.Load(i % (hits + misses))
}
},
})
}
func BenchmarkLoadOrStoreBalanced(b *testing.B) {
const hits, misses = 128, 128
benchMap(b, bench{
setup: func(b *testing.B, m mapInterface) {
if _, ok := m.(*DeepCopyMap); ok {
b.Skip("DeepCopyMap has quadratic running time.")
}
for i := 0; i < hits; i++ {
m.LoadOrStore(i, i)
}
// Prime the map to get it into a steady state.
for i := 0; i < hits*2; i++ {
m.Load(i % hits)
}
},
perG: func(b *testing.B, pb *testing.PB, i int, m mapInterface) {
for ; pb.Next(); i++ {
j := i % (hits + misses)
if j < hits {
if _, ok := m.LoadOrStore(j, i); !ok {
b.Fatalf("unexpected miss for %v", j)
}
} else {
if v, loaded := m.LoadOrStore(i, i); loaded {
b.Fatalf("failed to store %v: existing value %v", i, v)
}
}
}
},
})
}
func BenchmarkLoadOrStoreUnique(b *testing.B) {
benchMap(b, bench{
setup: func(b *testing.B, m mapInterface) {
if _, ok := m.(*DeepCopyMap); ok {
b.Skip("DeepCopyMap has quadratic running time.")
}
},
perG: func(b *testing.B, pb *testing.PB, i int, m mapInterface) {
for ; pb.Next(); i++ {
m.LoadOrStore(i, i)
}
},
})
}
func BenchmarkLoadOrStoreCollision(b *testing.B) {
benchMap(b, bench{
setup: func(_ *testing.B, m mapInterface) {
m.LoadOrStore(0, 0)
},
perG: func(b *testing.B, pb *testing.PB, i int, m mapInterface) {
for ; pb.Next(); i++ {
m.LoadOrStore(0, 0)
}
},
})
}
func BenchmarkRange(b *testing.B) {
const mapSize = 1 << 10
benchMap(b, bench{
setup: func(_ *testing.B, m mapInterface) {
for i := 0; i < mapSize; i++ {
m.Store(i, i)
}
},
perG: func(b *testing.B, pb *testing.PB, i int, m mapInterface) {
for ; pb.Next(); i++ {
m.Range(func(_, _ interface{}) bool { return true })
}
},
})
}
// BenchmarkAdversarialAlloc tests performance when we store a new value
// immediately whenever the map is promoted to clean and otherwise load a
// unique, missing key.
//
// This forces the Load calls to always acquire the map's mutex.
func BenchmarkAdversarialAlloc(b *testing.B) {
benchMap(b, bench{
perG: func(b *testing.B, pb *testing.PB, i int, m mapInterface) {
var stores, loadsSinceStore int64
for ; pb.Next(); i++ {
m.Load(i)
if loadsSinceStore++; loadsSinceStore > stores {
m.LoadOrStore(i, stores)
loadsSinceStore = 0
stores++
}
}
},
})
}
// BenchmarkAdversarialDelete tests performance when we periodically delete
// one key and add a different one in a large map.
//
// This forces the Load calls to always acquire the map's mutex and periodically
// makes a full copy of the map despite changing only one entry.
func BenchmarkAdversarialDelete(b *testing.B) {
const mapSize = 1 << 10
benchMap(b, bench{
setup: func(_ *testing.B, m mapInterface) {
for i := 0; i < mapSize; i++ {
m.Store(i, i)
}
},
perG: func(b *testing.B, pb *testing.PB, i int, m mapInterface) {
for ; pb.Next(); i++ {
m.Load(i)
if i%mapSize == 0 {
m.Range(func(k, _ interface{}) bool {
m.Delete(k)
return false
})
m.Store(i, i)
}
}
},
})
}

151
vendor/golang.org/x/sync/syncmap/map_reference_test.go generated vendored Normal file
View File

@ -0,0 +1,151 @@
// Copyright 2016 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package syncmap_test
import (
"sync"
"sync/atomic"
)
// This file contains reference map implementations for unit-tests.
// mapInterface is the interface Map implements.
type mapInterface interface {
Load(interface{}) (interface{}, bool)
Store(key, value interface{})
LoadOrStore(key, value interface{}) (actual interface{}, loaded bool)
Delete(interface{})
Range(func(key, value interface{}) (shouldContinue bool))
}
// RWMutexMap is an implementation of mapInterface using a sync.RWMutex.
type RWMutexMap struct {
mu sync.RWMutex
dirty map[interface{}]interface{}
}
func (m *RWMutexMap) Load(key interface{}) (value interface{}, ok bool) {
m.mu.RLock()
value, ok = m.dirty[key]
m.mu.RUnlock()
return
}
func (m *RWMutexMap) Store(key, value interface{}) {
m.mu.Lock()
if m.dirty == nil {
m.dirty = make(map[interface{}]interface{})
}
m.dirty[key] = value
m.mu.Unlock()
}
func (m *RWMutexMap) LoadOrStore(key, value interface{}) (actual interface{}, loaded bool) {
m.mu.Lock()
actual, loaded = m.dirty[key]
if !loaded {
actual = value
if m.dirty == nil {
m.dirty = make(map[interface{}]interface{})
}
m.dirty[key] = value
}
m.mu.Unlock()
return actual, loaded
}
func (m *RWMutexMap) Delete(key interface{}) {
m.mu.Lock()
delete(m.dirty, key)
m.mu.Unlock()
}
func (m *RWMutexMap) Range(f func(key, value interface{}) (shouldContinue bool)) {
m.mu.RLock()
keys := make([]interface{}, 0, len(m.dirty))
for k := range m.dirty {
keys = append(keys, k)
}
m.mu.RUnlock()
for _, k := range keys {
v, ok := m.Load(k)
if !ok {
continue
}
if !f(k, v) {
break
}
}
}
// DeepCopyMap is an implementation of mapInterface using a Mutex and
// atomic.Value. It makes deep copies of the map on every write to avoid
// acquiring the Mutex in Load.
type DeepCopyMap struct {
mu sync.Mutex
clean atomic.Value
}
func (m *DeepCopyMap) Load(key interface{}) (value interface{}, ok bool) {
clean, _ := m.clean.Load().(map[interface{}]interface{})
value, ok = clean[key]
return value, ok
}
func (m *DeepCopyMap) Store(key, value interface{}) {
m.mu.Lock()
dirty := m.dirty()
dirty[key] = value
m.clean.Store(dirty)
m.mu.Unlock()
}
func (m *DeepCopyMap) LoadOrStore(key, value interface{}) (actual interface{}, loaded bool) {
clean, _ := m.clean.Load().(map[interface{}]interface{})
actual, loaded = clean[key]
if loaded {
return actual, loaded
}
m.mu.Lock()
// Reload clean in case it changed while we were waiting on m.mu.
clean, _ = m.clean.Load().(map[interface{}]interface{})
actual, loaded = clean[key]
if !loaded {
dirty := m.dirty()
dirty[key] = value
actual = value
m.clean.Store(dirty)
}
m.mu.Unlock()
return actual, loaded
}
func (m *DeepCopyMap) Delete(key interface{}) {
m.mu.Lock()
dirty := m.dirty()
delete(dirty, key)
m.clean.Store(dirty)
m.mu.Unlock()
}
func (m *DeepCopyMap) Range(f func(key, value interface{}) (shouldContinue bool)) {
clean, _ := m.clean.Load().(map[interface{}]interface{})
for k, v := range clean {
if !f(k, v) {
break
}
}
}
func (m *DeepCopyMap) dirty() map[interface{}]interface{} {
clean, _ := m.clean.Load().(map[interface{}]interface{})
dirty := make(map[interface{}]interface{}, len(clean)+1)
for k, v := range clean {
dirty[k] = v
}
return dirty
}

172
vendor/golang.org/x/sync/syncmap/map_test.go generated vendored Normal file
View File

@ -0,0 +1,172 @@
// Copyright 2016 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package syncmap_test
import (
"math/rand"
"reflect"
"runtime"
"sync"
"testing"
"testing/quick"
"golang.org/x/sync/syncmap"
)
type mapOp string
const (
opLoad = mapOp("Load")
opStore = mapOp("Store")
opLoadOrStore = mapOp("LoadOrStore")
opDelete = mapOp("Delete")
)
var mapOps = [...]mapOp{opLoad, opStore, opLoadOrStore, opDelete}
// mapCall is a quick.Generator for calls on mapInterface.
type mapCall struct {
op mapOp
k, v interface{}
}
func (c mapCall) apply(m mapInterface) (interface{}, bool) {
switch c.op {
case opLoad:
return m.Load(c.k)
case opStore:
m.Store(c.k, c.v)
return nil, false
case opLoadOrStore:
return m.LoadOrStore(c.k, c.v)
case opDelete:
m.Delete(c.k)
return nil, false
default:
panic("invalid mapOp")
}
}
type mapResult struct {
value interface{}
ok bool
}
func randValue(r *rand.Rand) interface{} {
b := make([]byte, r.Intn(4))
for i := range b {
b[i] = 'a' + byte(rand.Intn(26))
}
return string(b)
}
func (mapCall) Generate(r *rand.Rand, size int) reflect.Value {
c := mapCall{op: mapOps[rand.Intn(len(mapOps))], k: randValue(r)}
switch c.op {
case opStore, opLoadOrStore:
c.v = randValue(r)
}
return reflect.ValueOf(c)
}
func applyCalls(m mapInterface, calls []mapCall) (results []mapResult, final map[interface{}]interface{}) {
for _, c := range calls {
v, ok := c.apply(m)
results = append(results, mapResult{v, ok})
}
final = make(map[interface{}]interface{})
m.Range(func(k, v interface{}) bool {
final[k] = v
return true
})
return results, final
}
func applyMap(calls []mapCall) ([]mapResult, map[interface{}]interface{}) {
return applyCalls(new(syncmap.Map), calls)
}
func applyRWMutexMap(calls []mapCall) ([]mapResult, map[interface{}]interface{}) {
return applyCalls(new(RWMutexMap), calls)
}
func applyDeepCopyMap(calls []mapCall) ([]mapResult, map[interface{}]interface{}) {
return applyCalls(new(DeepCopyMap), calls)
}
func TestMapMatchesRWMutex(t *testing.T) {
if err := quick.CheckEqual(applyMap, applyRWMutexMap, nil); err != nil {
t.Error(err)
}
}
func TestMapMatchesDeepCopy(t *testing.T) {
if err := quick.CheckEqual(applyMap, applyDeepCopyMap, nil); err != nil {
t.Error(err)
}
}
func TestConcurrentRange(t *testing.T) {
const mapSize = 1 << 10
m := new(syncmap.Map)
for n := int64(1); n <= mapSize; n++ {
m.Store(n, int64(n))
}
done := make(chan struct{})
var wg sync.WaitGroup
defer func() {
close(done)
wg.Wait()
}()
for g := int64(runtime.GOMAXPROCS(0)); g > 0; g-- {
r := rand.New(rand.NewSource(g))
wg.Add(1)
go func(g int64) {
defer wg.Done()
for i := int64(0); ; i++ {
select {
case <-done:
return
default:
}
for n := int64(1); n < mapSize; n++ {
if r.Int63n(mapSize) == 0 {
m.Store(n, n*i*g)
} else {
m.Load(n)
}
}
}
}(g)
}
iters := 1 << 10
if testing.Short() {
iters = 16
}
for n := iters; n > 0; n-- {
seen := make(map[int64]bool, mapSize)
m.Range(func(ki, vi interface{}) bool {
k, v := ki.(int64), vi.(int64)
if v%k != 0 {
t.Fatalf("while Storing multiples of %v, Range saw value %v", k, v)
}
if seen[k] {
t.Fatalf("Range visited key %v twice", k)
}
seen[k] = true
return true
})
if len(seen) != mapSize {
t.Fatalf("Range visited %v elements of %v-element Map", len(seen), mapSize)
}
}
}