Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion workerpool.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ var (
type WorkerPool struct {
workers chan struct{}
tasks chan *task
done <-chan struct{}
cancel context.CancelFunc
results []Task
wg sync.WaitGroup
Expand Down Expand Up @@ -65,6 +66,7 @@ func NewWithContext(ctx context.Context, n int) *WorkerPool {
}
ctx, cancel := context.WithCancel(ctx)
wp.cancel = cancel
wp.done = ctx.Done()
go wp.run(ctx)
return wp
}
Expand All @@ -82,14 +84,16 @@ func (wp *WorkerPool) Len() int {
// Submit submits f for processing by a worker. The given id is useful for
// identifying the task once it is completed. The task f must return when the
// context ctx is cancelled. The context passed to task f is cancelled when
// Close is called.
// Close is called or when the parent context passed to NewWithContext is done.
//
// Submit blocks until a routine start processing the task.
//
// If a drain operation is in progress, ErrDraining is returned and the task
// is not submitted for processing.
// If the worker pool is closed, ErrClosed is returned and the task is not
// submitted for processing.
// If the parent context is done, context.Canceled is returned and the task is
// not submitted for processing.
func (wp *WorkerPool) Submit(id string, f func(ctx context.Context) error) error {
wp.mu.Lock()
if wp.closed {
Expand All @@ -100,6 +104,12 @@ func (wp *WorkerPool) Submit(id string, f func(ctx context.Context) error) error
wp.mu.Unlock()
return ErrDraining
}
select {
case <-wp.done:
wp.mu.Unlock()
return context.Canceled
default:
}
wp.wg.Add(1)
wp.mu.Unlock()
wp.tasks <- &task{
Expand Down
60 changes: 45 additions & 15 deletions workerpool_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -459,26 +459,56 @@ func TestWorkerPoolNewWithContext(t *testing.T) {
wg.Wait()

// Submitting a task once the parent context has been cancelled should
// succeed and give a cancelled context to the task. This is not ideal and
// might change in the future.
wg.Add(1)
id := "last"
err := wp.Submit(id, func(ctx context.Context) error {
defer wg.Done()
select {
case <-ctx.Done():
default:
t.Errorf("last task expected context to be cancelled")
// return context.Canceled and not submit the task. Call Submit twice to
// ensure the mutex is released on this path (a missing unlock would
// deadlock the second call).
for range 2 {
err := wp.Submit("last", nil)
if !errors.Is(err, context.Canceled) {
t.Errorf("submit after parent cancel: got %v, want %v", err, context.Canceled)
}
return nil
})
if err != nil {
t.Errorf("failed to submit task '%s': %v", id, err)
}

wg.Wait()
// Drain should return only the n tasks that completed, not the rejected ones.
results, err := wp.Drain()
if err != nil {
t.Errorf("drain: got '%v', want no error", err)
}
if len(results) != n {
t.Errorf("drain: got %d results, want %d", len(results), n)
}

if err := wp.Close(); err != nil {
t.Errorf("close: got '%v', want no error", err)
}
}

func TestWorkerPoolNewWithCancelledContext(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
cancel() // cancel before creating the pool

wp := workerpool.NewWithContext(ctx, runtime.NumCPU())
defer func() {
if err := wp.Close(); err != nil {
t.Errorf("close: got '%v', want no error", err)
}
}()

// Submit should return context.Canceled immediately. Call twice to verify
// the mutex is released on this path.
for range 2 {
err := wp.Submit("task", nil)
if !errors.Is(err, context.Canceled) {
t.Errorf("submit with cancelled context: got %v, want %v", err, context.Canceled)
}
}

// No tasks should have been queued.
results, err := wp.Drain()
if err != nil {
t.Errorf("drain: got '%v', want no error", err)
}
if len(results) != 0 {
t.Errorf("drain: got %d results, want 0", len(results))
}
}
Loading