2
2
mirror of https://github.com/octoleo/restic.git synced 2025-01-25 16:18:34 +00:00

Add worker pool

A worker pool is needed whenever something should be done concurrently.
This small library makes it easy to create a worker pool by specifying
channels, concurrency and a function that should be executed for each
job and returns a result and an error.
This commit is contained in:
Alexander Neumann 2016-02-04 22:33:11 +01:00
parent 1e0b7dbdd2
commit e5ee4eba53
3 changed files with 231 additions and 0 deletions

2
src/restic/worker/doc.go Normal file
View File

@ -0,0 +1,2 @@
// Package worker implements a worker pool.
package worker

95
src/restic/worker/pool.go Normal file
View File

@ -0,0 +1,95 @@
package worker
import "sync"
// Job is one unit of work.
type Job interface{}
// Result is something the worker function returned, including the original job
// and an (eventual) error.
type Result struct {
Job Job
Result interface{}
Error error
}
// Func does the actual work within a Pool.
type Func func(job Job, done <-chan struct{}) (result interface{}, err error)
// Pool implements a worker pool.
type Pool struct {
f Func
done chan struct{}
wg *sync.WaitGroup
jobCh <-chan Job
resCh chan<- Result
}
// New returns a new worker pool with n goroutines, each running the function
// f. The workers are started immediately.
func New(n int, f Func, jobChan <-chan Job, resultChan chan<- Result) *Pool {
p := &Pool{
f: f,
done: make(chan struct{}),
wg: &sync.WaitGroup{},
jobCh: jobChan,
resCh: resultChan,
}
for i := 0; i < n; i++ {
p.wg.Add(1)
go p.runWorker(i)
}
return p
}
// runWorker runs a worker function.
func (p *Pool) runWorker(numWorker int) {
defer p.wg.Done()
var (
// enable the input channel when starting up a new goroutine
inCh = p.jobCh
// but do not enable the output channel until we have a result
outCh chan<- Result
job Job
res Result
ok bool
)
for {
select {
case <-p.done:
return
case job, ok = <-inCh:
if !ok {
return
}
r, err := p.f(job, p.done)
res = Result{Job: job, Result: r, Error: err}
inCh = nil
outCh = p.resCh
case outCh <- res:
outCh = nil
inCh = p.jobCh
}
}
}
// Cancel signals termination to all worker goroutines.
func (p *Pool) Cancel() {
close(p.done)
}
// Wait waits for all worker goroutines to terminate, afterwards the output
// channel is closed.
func (p *Pool) Wait() {
p.wg.Wait()
close(p.resCh)
}

View File

@ -0,0 +1,134 @@
package worker_test
import (
"errors"
"testing"
"time"
"restic/worker"
)
const concurrency = 10
var errTooLarge = errors.New("too large")
func square(job worker.Job, done <-chan struct{}) (interface{}, error) {
n := job.(int)
if n > 2000 {
return nil, errTooLarge
}
return n * n, nil
}
func newBufferedPool(bufsize int, n int, f worker.Func) (chan worker.Job, chan worker.Result, *worker.Pool) {
inCh := make(chan worker.Job, bufsize)
outCh := make(chan worker.Result, bufsize)
return inCh, outCh, worker.New(n, f, inCh, outCh)
}
func TestPool(t *testing.T) {
inCh, outCh, p := newBufferedPool(200, concurrency, square)
for i := 0; i < 150; i++ {
inCh <- i
}
close(inCh)
p.Wait()
for res := range outCh {
if res.Error != nil {
t.Errorf("unexpected error for job %v received: %v", res.Job, res.Error)
}
n := res.Job.(int)
m := res.Result.(int)
if m != n*n {
t.Errorf("wrong value for job %d returned: want %d, got %d", n, n*n, m)
}
}
}
func TestPoolErrors(t *testing.T) {
inCh, outCh, p := newBufferedPool(200, concurrency, square)
for i := 0; i < 150; i++ {
inCh <- i + 1900
}
close(inCh)
p.Wait()
for res := range outCh {
n := res.Job.(int)
if n > 2000 {
if res.Error == nil {
t.Errorf("expected error not found, result is %v", res)
continue
}
if res.Error != errTooLarge {
t.Errorf("unexpected error found, result is %v", res)
}
continue
} else {
if res.Error != nil {
t.Errorf("unexpected error for job %v received: %v", res.Job, res.Error)
continue
}
}
m := res.Result.(int)
if m != n*n {
t.Errorf("wrong value for job %d returned: want %d, got %d", n, n*n, m)
}
}
}
var errCancelled = errors.New("cancelled")
func wait(job worker.Job, done <-chan struct{}) (interface{}, error) {
d := job.(time.Duration)
select {
case <-time.After(d):
return time.Now(), nil
case <-done:
return nil, errCancelled
}
}
func TestPoolCancel(t *testing.T) {
jobCh, resCh, p := newBufferedPool(20, concurrency, wait)
for i := 0; i < 20; i++ {
jobCh <- 10 * time.Millisecond
}
time.Sleep(20 * time.Millisecond)
p.Cancel()
p.Wait()
foundResult := false
foundCancelError := false
for res := range resCh {
if res.Error == nil {
foundResult = true
}
if res.Error == errCancelled {
foundCancelError = true
}
}
if !foundResult {
t.Error("did not find one expected result")
}
if !foundCancelError {
t.Error("did not find one expected cancel error")
}
}