2
2
mirror of https://github.com/octoleo/restic.git synced 2025-01-25 16:18:34 +00:00

Make worker pools input/output chans symmetric

Input and output channel are now both of type `chan Job`, this makes it
possible to chain multiple worker pools together.
This commit is contained in:
Alexander Neumann 2016-02-05 22:22:24 +01:00
parent e5ee4eba53
commit ee422110c8
2 changed files with 26 additions and 27 deletions

View File

@ -1,14 +1,14 @@
package worker package worker
import "sync" import (
"fmt"
"sync"
)
// Job is one unit of work. // Job is one unit of work. It is given to a Func, and the returned result and
type Job interface{} // error are stored in Result and Error.
type Job struct {
// Result is something the worker function returned, including the original job Data interface{}
// and an (eventual) error.
type Result struct {
Job Job
Result interface{} Result interface{}
Error error Error error
} }
@ -22,12 +22,12 @@ type Pool struct {
done chan struct{} done chan struct{}
wg *sync.WaitGroup wg *sync.WaitGroup
jobCh <-chan Job jobCh <-chan Job
resCh chan<- Result resCh chan<- Job
} }
// New returns a new worker pool with n goroutines, each running the function // New returns a new worker pool with n goroutines, each running the function
// f. The workers are started immediately. // f. The workers are started immediately.
func New(n int, f Func, jobChan <-chan Job, resultChan chan<- Result) *Pool { func New(n int, f Func, jobChan <-chan Job, resultChan chan<- Job) *Pool {
p := &Pool{ p := &Pool{
f: f, f: f,
done: make(chan struct{}), done: make(chan struct{}),
@ -52,10 +52,9 @@ func (p *Pool) runWorker(numWorker int) {
// enable the input channel when starting up a new goroutine // enable the input channel when starting up a new goroutine
inCh = p.jobCh inCh = p.jobCh
// but do not enable the output channel until we have a result // but do not enable the output channel until we have a result
outCh chan<- Result outCh chan<- Job
job Job job Job
res Result
ok bool ok bool
) )
@ -66,16 +65,15 @@ func (p *Pool) runWorker(numWorker int) {
case job, ok = <-inCh: case job, ok = <-inCh:
if !ok { if !ok {
fmt.Printf("in channel closed, worker exiting\n")
return return
} }
r, err := p.f(job, p.done) job.Result, job.Error = p.f(job, p.done)
res = Result{Job: job, Result: r, Error: err}
inCh = nil inCh = nil
outCh = p.resCh outCh = p.resCh
case outCh <- res: case outCh <- job:
outCh = nil outCh = nil
inCh = p.jobCh inCh = p.jobCh
} }

View File

@ -13,16 +13,16 @@ const concurrency = 10
var errTooLarge = errors.New("too large") var errTooLarge = errors.New("too large")
func square(job worker.Job, done <-chan struct{}) (interface{}, error) { func square(job worker.Job, done <-chan struct{}) (interface{}, error) {
n := job.(int) n := job.Data.(int)
if n > 2000 { if n > 2000 {
return nil, errTooLarge return nil, errTooLarge
} }
return n * n, nil return n * n, nil
} }
func newBufferedPool(bufsize int, n int, f worker.Func) (chan worker.Job, chan worker.Result, *worker.Pool) { func newBufferedPool(bufsize int, n int, f worker.Func) (chan worker.Job, chan worker.Job, *worker.Pool) {
inCh := make(chan worker.Job, bufsize) inCh := make(chan worker.Job, bufsize)
outCh := make(chan worker.Result, bufsize) outCh := make(chan worker.Job, bufsize)
return inCh, outCh, worker.New(n, f, inCh, outCh) return inCh, outCh, worker.New(n, f, inCh, outCh)
} }
@ -31,7 +31,7 @@ func TestPool(t *testing.T) {
inCh, outCh, p := newBufferedPool(200, concurrency, square) inCh, outCh, p := newBufferedPool(200, concurrency, square)
for i := 0; i < 150; i++ { for i := 0; i < 150; i++ {
inCh <- i inCh <- worker.Job{Data: i}
} }
close(inCh) close(inCh)
@ -39,10 +39,11 @@ func TestPool(t *testing.T) {
for res := range outCh { for res := range outCh {
if res.Error != nil { if res.Error != nil {
t.Errorf("unexpected error for job %v received: %v", res.Job, res.Error) t.Errorf("unexpected error for job %v received: %v", res.Data, res.Error)
continue
} }
n := res.Job.(int) n := res.Data.(int)
m := res.Result.(int) m := res.Result.(int)
if m != n*n { if m != n*n {
@ -55,14 +56,14 @@ func TestPoolErrors(t *testing.T) {
inCh, outCh, p := newBufferedPool(200, concurrency, square) inCh, outCh, p := newBufferedPool(200, concurrency, square)
for i := 0; i < 150; i++ { for i := 0; i < 150; i++ {
inCh <- i + 1900 inCh <- worker.Job{Data: i + 1900}
} }
close(inCh) close(inCh)
p.Wait() p.Wait()
for res := range outCh { for res := range outCh {
n := res.Job.(int) n := res.Data.(int)
if n > 2000 { if n > 2000 {
if res.Error == nil { if res.Error == nil {
@ -77,7 +78,7 @@ func TestPoolErrors(t *testing.T) {
continue continue
} else { } else {
if res.Error != nil { if res.Error != nil {
t.Errorf("unexpected error for job %v received: %v", res.Job, res.Error) t.Errorf("unexpected error for job %v received: %v", res.Data, res.Error)
continue continue
} }
} }
@ -92,7 +93,7 @@ func TestPoolErrors(t *testing.T) {
var errCancelled = errors.New("cancelled") var errCancelled = errors.New("cancelled")
func wait(job worker.Job, done <-chan struct{}) (interface{}, error) { func wait(job worker.Job, done <-chan struct{}) (interface{}, error) {
d := job.(time.Duration) d := job.Data.(time.Duration)
select { select {
case <-time.After(d): case <-time.After(d):
return time.Now(), nil return time.Now(), nil
@ -105,7 +106,7 @@ func TestPoolCancel(t *testing.T) {
jobCh, resCh, p := newBufferedPool(20, concurrency, wait) jobCh, resCh, p := newBufferedPool(20, concurrency, wait)
for i := 0; i < 20; i++ { for i := 0; i < 20; i++ {
jobCh <- 10 * time.Millisecond jobCh <- worker.Job{Data: 10 * time.Millisecond}
} }
time.Sleep(20 * time.Millisecond) time.Sleep(20 * time.Millisecond)