Created
October 15, 2020 16:53
-
-
Save FZambia/8fc2b3f79d463e28c3d3c32e462aebff to your computer and use it in GitHub Desktop.
Goroutine (worker) pool for Go language
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package gpool | |
import "context" | |
// Job represents function to be executed in worker. | |
type Job func() | |
type worker struct { | |
jobs chan Job | |
stop chan struct{} | |
done chan struct{} | |
} | |
func newWorker(jobs chan Job) *worker { | |
return &worker{ | |
jobs: jobs, | |
stop: make(chan struct{}, 1), | |
done: make(chan struct{}, 1), | |
} | |
} | |
func (w *worker) start() { | |
go func() { | |
for { | |
select { | |
case job := <-w.jobs: | |
job() | |
case <-w.stop: | |
w.done <- struct{}{} | |
return | |
} | |
} | |
}() | |
} | |
// Pool of worker goroutines. | |
type Pool struct { | |
workers []*worker | |
Jobs chan Job | |
} | |
// NewPool will make a pool of worker goroutines. | |
// Returned object contains Jobs to send a job for execution. | |
func NewPool(numWorkers int) *Pool { | |
jobs := make(chan Job, 0) | |
workers := make([]*worker, 0, numWorkers) | |
for i := 0; i < numWorkers; i++ { | |
worker := newWorker(jobs) | |
worker.start() | |
workers = append(workers, worker) | |
} | |
return &Pool{ | |
Jobs: jobs, | |
workers: workers, | |
} | |
} | |
// Close will release resources used by a pool. | |
func (p *Pool) Close(ctx context.Context) error { | |
for i := 0; i < len(p.workers); i++ { | |
worker := p.workers[i] | |
select { | |
case <-ctx.Done(): | |
return ctx.Err() | |
case worker.stop <- struct{}{}: | |
} | |
} | |
for i := 0; i < len(p.workers); i++ { | |
worker := p.workers[i] | |
select { | |
case <-ctx.Done(): | |
return ctx.Err() | |
case <-worker.done: | |
} | |
} | |
return nil | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package gpool | |
import ( | |
"context" | |
"sync" | |
"sync/atomic" | |
"testing" | |
"time" | |
"github.com/stretchr/testify/require" | |
) | |
func TestWorker_New(t *testing.T) { | |
jobQueue := make(chan Job) | |
worker := newWorker(jobQueue) | |
worker.start() | |
require.NotNil(t, worker) | |
called := false | |
done := make(chan bool) | |
job := func() { | |
called = true | |
done <- true | |
} | |
worker.jobs <- job | |
<-done | |
require.Equal(t, true, called) | |
} | |
func TestPool_New(t *testing.T) { | |
pool := NewPool(1000) | |
defer func() { _ = pool.Close(context.Background()) }() | |
numJobs := 10000 | |
var wg sync.WaitGroup | |
wg.Add(numJobs) | |
var counter uint64 | |
for i := 0; i < numJobs; i++ { | |
arg := uint64(1) | |
job := func() { | |
defer wg.Done() | |
atomic.AddUint64(&counter, arg) | |
require.Equal(t, uint64(1), arg) | |
} | |
pool.Jobs <- job | |
} | |
wg.Wait() | |
require.Equal(t, uint64(numJobs), atomic.LoadUint64(&counter)) | |
} | |
func TestPool_Close(t *testing.T) { | |
pool := NewPool(100) | |
numJobs := 1000 | |
for i := 0; i < numJobs; i++ { | |
job := func() {} | |
pool.Jobs <- job | |
} | |
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) | |
defer cancel() | |
_ = pool.Close(ctx) | |
} | |
func TestPool_CloseContext(t *testing.T) { | |
pool := NewPool(1) | |
pool.Jobs <- func() { | |
time.Sleep(5 * time.Second) | |
} | |
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond) | |
defer cancel() | |
err := pool.Close(ctx) | |
require.Equal(t, context.DeadlineExceeded, err) | |
} | |
func BenchmarkPool_RawPerformance(b *testing.B) { | |
pool := NewPool(1) | |
defer func() { _ = pool.Close(context.Background()) }() | |
ch := make(chan struct{}, 1) | |
b.ResetTimer() | |
for n := 0; n < b.N; n++ { | |
pool.Jobs <- func() { | |
ch <- struct{}{} | |
} | |
<-ch | |
} | |
} | |
func BenchmarkPool_Sequential(b *testing.B) { | |
pool := NewPool(16) | |
defer func() { _ = pool.Close(context.Background()) }() | |
for n := 0; n < b.N; n++ { | |
var wg sync.WaitGroup | |
wg.Add(1) | |
pool.Jobs <- func() { | |
time.Sleep(100 * time.Millisecond) | |
wg.Done() | |
} | |
wg.Wait() | |
} | |
} | |
func BenchmarkPool_Parallel(b *testing.B) { | |
pool := NewPool(4096) | |
defer func() { _ = pool.Close(context.Background()) }() | |
b.SetParallelism(4096) | |
b.RunParallel(func(pb *testing.PB) { | |
for pb.Next() { | |
var wg sync.WaitGroup | |
wg.Add(1) | |
pool.Jobs <- func() { | |
time.Sleep(100 * time.Millisecond) | |
wg.Done() | |
} | |
wg.Wait() | |
} | |
}) | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment