diff --git a/src/restic/worker/doc.go b/src/restic/worker/doc.go new file mode 100644 index 000000000..602bb5037 --- /dev/null +++ b/src/restic/worker/doc.go @@ -0,0 +1,2 @@ +// Package worker implements a worker pool. +package worker diff --git a/src/restic/worker/pool.go b/src/restic/worker/pool.go new file mode 100644 index 000000000..e94b9fe20 --- /dev/null +++ b/src/restic/worker/pool.go @@ -0,0 +1,95 @@ +package worker + +import "sync" + +// Job is one unit of work. +type Job interface{} + +// Result is something the worker function returned, including the original job +// and an (eventual) error. +type Result struct { + Job Job + Result interface{} + Error error +} + +// Func does the actual work within a Pool. +type Func func(job Job, done <-chan struct{}) (result interface{}, err error) + +// Pool implements a worker pool. +type Pool struct { + f Func + done chan struct{} + wg *sync.WaitGroup + jobCh <-chan Job + resCh chan<- Result +} + +// New returns a new worker pool with n goroutines, each running the function +// f. The workers are started immediately. +func New(n int, f Func, jobChan <-chan Job, resultChan chan<- Result) *Pool { + p := &Pool{ + f: f, + done: make(chan struct{}), + wg: &sync.WaitGroup{}, + jobCh: jobChan, + resCh: resultChan, + } + + for i := 0; i < n; i++ { + p.wg.Add(1) + go p.runWorker(i) + } + + return p +} + +// runWorker runs a worker function. +func (p *Pool) runWorker(numWorker int) { + defer p.wg.Done() + + var ( + // enable the input channel when starting up a new goroutine + inCh = p.jobCh + // but do not enable the output channel until we have a result + outCh chan<- Result + + job Job + res Result + ok bool + ) + + for { + select { + case <-p.done: + return + + case job, ok = <-inCh: + if !ok { + return + } + + r, err := p.f(job, p.done) + res = Result{Job: job, Result: r, Error: err} + + inCh = nil + outCh = p.resCh + + case outCh <- res: + outCh = nil + inCh = p.jobCh + } + } +} + +// Cancel signals termination to all worker goroutines. +func (p *Pool) Cancel() { + close(p.done) +} + +// Wait waits for all worker goroutines to terminate, afterwards the output +// channel is closed. +func (p *Pool) Wait() { + p.wg.Wait() + close(p.resCh) +} diff --git a/src/restic/worker/pool_test.go b/src/restic/worker/pool_test.go new file mode 100644 index 000000000..1b66c2ad7 --- /dev/null +++ b/src/restic/worker/pool_test.go @@ -0,0 +1,134 @@ +package worker_test + +import ( + "errors" + "testing" + "time" + + "restic/worker" +) + +const concurrency = 10 + +var errTooLarge = errors.New("too large") + +func square(job worker.Job, done <-chan struct{}) (interface{}, error) { + n := job.(int) + if n > 2000 { + return nil, errTooLarge + } + return n * n, nil +} + +func newBufferedPool(bufsize int, n int, f worker.Func) (chan worker.Job, chan worker.Result, *worker.Pool) { + inCh := make(chan worker.Job, bufsize) + outCh := make(chan worker.Result, bufsize) + + return inCh, outCh, worker.New(n, f, inCh, outCh) +} + +func TestPool(t *testing.T) { + inCh, outCh, p := newBufferedPool(200, concurrency, square) + + for i := 0; i < 150; i++ { + inCh <- i + } + + close(inCh) + p.Wait() + + for res := range outCh { + if res.Error != nil { + t.Errorf("unexpected error for job %v received: %v", res.Job, res.Error) + } + + n := res.Job.(int) + m := res.Result.(int) + + if m != n*n { + t.Errorf("wrong value for job %d returned: want %d, got %d", n, n*n, m) + } + } +} + +func TestPoolErrors(t *testing.T) { + inCh, outCh, p := newBufferedPool(200, concurrency, square) + + for i := 0; i < 150; i++ { + inCh <- i + 1900 + } + + close(inCh) + p.Wait() + + for res := range outCh { + n := res.Job.(int) + + if n > 2000 { + if res.Error == nil { + t.Errorf("expected error not found, result is %v", res) + continue + } + + if res.Error != errTooLarge { + t.Errorf("unexpected error found, result is %v", res) + } + + continue + } else { + if res.Error != nil { + t.Errorf("unexpected error for job %v received: %v", res.Job, res.Error) + continue + } + } + + m := res.Result.(int) + if m != n*n { + t.Errorf("wrong value for job %d returned: want %d, got %d", n, n*n, m) + } + } +} + +var errCancelled = errors.New("cancelled") + +func wait(job worker.Job, done <-chan struct{}) (interface{}, error) { + d := job.(time.Duration) + select { + case <-time.After(d): + return time.Now(), nil + case <-done: + return nil, errCancelled + } +} + +func TestPoolCancel(t *testing.T) { + jobCh, resCh, p := newBufferedPool(20, concurrency, wait) + + for i := 0; i < 20; i++ { + jobCh <- 10 * time.Millisecond + } + + time.Sleep(20 * time.Millisecond) + p.Cancel() + p.Wait() + + foundResult := false + foundCancelError := false + for res := range resCh { + if res.Error == nil { + foundResult = true + } + + if res.Error == errCancelled { + foundCancelError = true + } + } + + if !foundResult { + t.Error("did not find one expected result") + } + + if !foundCancelError { + t.Error("did not find one expected cancel error") + } +}