mirror of
https://github.com/octoleo/restic.git
synced 2025-01-10 18:04:38 +00:00
Add worker pool
A worker pool is needed whenever something should be done concurrently. This small library makes it easy to create a worker pool by specifying channels, concurrency and a function that should be executed for each job and returns a result and an error.
This commit is contained in:
parent
1e0b7dbdd2
commit
e5ee4eba53
2
src/restic/worker/doc.go
Normal file
2
src/restic/worker/doc.go
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
// Package worker implements a worker pool.
|
||||||
|
package worker
|
95
src/restic/worker/pool.go
Normal file
95
src/restic/worker/pool.go
Normal file
@ -0,0 +1,95 @@
|
|||||||
|
package worker
|
||||||
|
|
||||||
|
import "sync"
|
||||||
|
|
||||||
|
// Job is one unit of work.
|
||||||
|
type Job interface{}
|
||||||
|
|
||||||
|
// Result is something the worker function returned, including the original job
|
||||||
|
// and an (eventual) error.
|
||||||
|
type Result struct {
|
||||||
|
Job Job
|
||||||
|
Result interface{}
|
||||||
|
Error error
|
||||||
|
}
|
||||||
|
|
||||||
|
// Func does the actual work within a Pool.
|
||||||
|
type Func func(job Job, done <-chan struct{}) (result interface{}, err error)
|
||||||
|
|
||||||
|
// Pool implements a worker pool.
|
||||||
|
type Pool struct {
|
||||||
|
f Func
|
||||||
|
done chan struct{}
|
||||||
|
wg *sync.WaitGroup
|
||||||
|
jobCh <-chan Job
|
||||||
|
resCh chan<- Result
|
||||||
|
}
|
||||||
|
|
||||||
|
// New returns a new worker pool with n goroutines, each running the function
|
||||||
|
// f. The workers are started immediately.
|
||||||
|
func New(n int, f Func, jobChan <-chan Job, resultChan chan<- Result) *Pool {
|
||||||
|
p := &Pool{
|
||||||
|
f: f,
|
||||||
|
done: make(chan struct{}),
|
||||||
|
wg: &sync.WaitGroup{},
|
||||||
|
jobCh: jobChan,
|
||||||
|
resCh: resultChan,
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := 0; i < n; i++ {
|
||||||
|
p.wg.Add(1)
|
||||||
|
go p.runWorker(i)
|
||||||
|
}
|
||||||
|
|
||||||
|
return p
|
||||||
|
}
|
||||||
|
|
||||||
|
// runWorker runs a worker function.
|
||||||
|
func (p *Pool) runWorker(numWorker int) {
|
||||||
|
defer p.wg.Done()
|
||||||
|
|
||||||
|
var (
|
||||||
|
// enable the input channel when starting up a new goroutine
|
||||||
|
inCh = p.jobCh
|
||||||
|
// but do not enable the output channel until we have a result
|
||||||
|
outCh chan<- Result
|
||||||
|
|
||||||
|
job Job
|
||||||
|
res Result
|
||||||
|
ok bool
|
||||||
|
)
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-p.done:
|
||||||
|
return
|
||||||
|
|
||||||
|
case job, ok = <-inCh:
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
r, err := p.f(job, p.done)
|
||||||
|
res = Result{Job: job, Result: r, Error: err}
|
||||||
|
|
||||||
|
inCh = nil
|
||||||
|
outCh = p.resCh
|
||||||
|
|
||||||
|
case outCh <- res:
|
||||||
|
outCh = nil
|
||||||
|
inCh = p.jobCh
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Cancel signals termination to all worker goroutines.
|
||||||
|
func (p *Pool) Cancel() {
|
||||||
|
close(p.done)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait waits for all worker goroutines to terminate, afterwards the output
|
||||||
|
// channel is closed.
|
||||||
|
func (p *Pool) Wait() {
|
||||||
|
p.wg.Wait()
|
||||||
|
close(p.resCh)
|
||||||
|
}
|
134
src/restic/worker/pool_test.go
Normal file
134
src/restic/worker/pool_test.go
Normal file
@ -0,0 +1,134 @@
|
|||||||
|
package worker_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"restic/worker"
|
||||||
|
)
|
||||||
|
|
||||||
|
const concurrency = 10
|
||||||
|
|
||||||
|
var errTooLarge = errors.New("too large")
|
||||||
|
|
||||||
|
func square(job worker.Job, done <-chan struct{}) (interface{}, error) {
|
||||||
|
n := job.(int)
|
||||||
|
if n > 2000 {
|
||||||
|
return nil, errTooLarge
|
||||||
|
}
|
||||||
|
return n * n, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func newBufferedPool(bufsize int, n int, f worker.Func) (chan worker.Job, chan worker.Result, *worker.Pool) {
|
||||||
|
inCh := make(chan worker.Job, bufsize)
|
||||||
|
outCh := make(chan worker.Result, bufsize)
|
||||||
|
|
||||||
|
return inCh, outCh, worker.New(n, f, inCh, outCh)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPool(t *testing.T) {
|
||||||
|
inCh, outCh, p := newBufferedPool(200, concurrency, square)
|
||||||
|
|
||||||
|
for i := 0; i < 150; i++ {
|
||||||
|
inCh <- i
|
||||||
|
}
|
||||||
|
|
||||||
|
close(inCh)
|
||||||
|
p.Wait()
|
||||||
|
|
||||||
|
for res := range outCh {
|
||||||
|
if res.Error != nil {
|
||||||
|
t.Errorf("unexpected error for job %v received: %v", res.Job, res.Error)
|
||||||
|
}
|
||||||
|
|
||||||
|
n := res.Job.(int)
|
||||||
|
m := res.Result.(int)
|
||||||
|
|
||||||
|
if m != n*n {
|
||||||
|
t.Errorf("wrong value for job %d returned: want %d, got %d", n, n*n, m)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPoolErrors(t *testing.T) {
|
||||||
|
inCh, outCh, p := newBufferedPool(200, concurrency, square)
|
||||||
|
|
||||||
|
for i := 0; i < 150; i++ {
|
||||||
|
inCh <- i + 1900
|
||||||
|
}
|
||||||
|
|
||||||
|
close(inCh)
|
||||||
|
p.Wait()
|
||||||
|
|
||||||
|
for res := range outCh {
|
||||||
|
n := res.Job.(int)
|
||||||
|
|
||||||
|
if n > 2000 {
|
||||||
|
if res.Error == nil {
|
||||||
|
t.Errorf("expected error not found, result is %v", res)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if res.Error != errTooLarge {
|
||||||
|
t.Errorf("unexpected error found, result is %v", res)
|
||||||
|
}
|
||||||
|
|
||||||
|
continue
|
||||||
|
} else {
|
||||||
|
if res.Error != nil {
|
||||||
|
t.Errorf("unexpected error for job %v received: %v", res.Job, res.Error)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
m := res.Result.(int)
|
||||||
|
if m != n*n {
|
||||||
|
t.Errorf("wrong value for job %d returned: want %d, got %d", n, n*n, m)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var errCancelled = errors.New("cancelled")
|
||||||
|
|
||||||
|
func wait(job worker.Job, done <-chan struct{}) (interface{}, error) {
|
||||||
|
d := job.(time.Duration)
|
||||||
|
select {
|
||||||
|
case <-time.After(d):
|
||||||
|
return time.Now(), nil
|
||||||
|
case <-done:
|
||||||
|
return nil, errCancelled
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPoolCancel(t *testing.T) {
|
||||||
|
jobCh, resCh, p := newBufferedPool(20, concurrency, wait)
|
||||||
|
|
||||||
|
for i := 0; i < 20; i++ {
|
||||||
|
jobCh <- 10 * time.Millisecond
|
||||||
|
}
|
||||||
|
|
||||||
|
time.Sleep(20 * time.Millisecond)
|
||||||
|
p.Cancel()
|
||||||
|
p.Wait()
|
||||||
|
|
||||||
|
foundResult := false
|
||||||
|
foundCancelError := false
|
||||||
|
for res := range resCh {
|
||||||
|
if res.Error == nil {
|
||||||
|
foundResult = true
|
||||||
|
}
|
||||||
|
|
||||||
|
if res.Error == errCancelled {
|
||||||
|
foundCancelError = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !foundResult {
|
||||||
|
t.Error("did not find one expected result")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !foundCancelError {
|
||||||
|
t.Error("did not find one expected cancel error")
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user