Skip to content

Commit a49f5ce

Browse files
Bernd VerstCopilot
andcommitted
Fix worker shutdown channel draining
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
1 parent 72191f1 commit a49f5ce

5 files changed

Lines changed: 231 additions & 20 deletions

File tree

CHANGELOG.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,9 @@ FIXED
3333
so configured hello timeouts apply on fresh connections, received work resets
3434
failure tracking, SDK-owned channels are refreshed and cleaned up safely, and
3535
caller-owned channels are never recreated or closed during reconnects.
36-
- Fixed `TaskHubGrpcWorker` so in-flight work item completions can finish after
37-
a graceful gRPC stream reset before the worker retires an SDK-owned channel.
36+
- Fixed `TaskHubGrpcWorker` so in-flight and queued work item completions keep
37+
draining across graceful gRPC stream resets and worker shutdown before the
38+
worker retires an SDK-owned channel.
3839
- Improved sync and async gRPC clients so repeated transport failures recreate
3940
SDK-owned channels, while long-poll deadlines, successful replies, and
4041
application-level RPC errors do not trigger unnecessary channel replacement.

durabletask/worker.py

Lines changed: 27 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -690,6 +690,8 @@ def start(self):
690690
if self._auto_generate_work_item_filters:
691691
self._work_item_filters = WorkItemFilters._from_registry(self._registry)
692692

693+
self._shutdown.clear()
694+
693695
def run_loop():
694696
loop = asyncio.new_event_loop()
695697
asyncio.set_event_loop(loop)
@@ -701,6 +703,7 @@ def run_loop():
701703
self._is_running = True
702704

703705
async def _async_run_loop(self):
706+
self._async_worker_manager.prepare_for_run()
704707
worker_task = asyncio.create_task(self._async_worker_manager.run())
705708
current_channel = self._channel
706709
current_stub = None
@@ -1060,6 +1063,11 @@ def stop(self):
10601063
self._response_stream.cancel()
10611064
if self._runLoop is not None:
10621065
self._runLoop.join(timeout=30)
1066+
if self._runLoop.is_alive():
1067+
self._logger.info(
1068+
"Waiting for pending work items to finish before completing shutdown..."
1069+
)
1070+
self._runLoop.join()
10631071
self._async_worker_manager.shutdown()
10641072
self._logger.info("Worker shutdown completed")
10651073
self._is_running = False
@@ -2883,11 +2891,22 @@ def __init__(self, concurrency_options: ConcurrencyOptions, logger: logging.Logg
28832891
self._pending_activity_work: list = []
28842892
self._pending_orchestration_work: list = []
28852893
self._pending_entity_batch_work: list = []
2886-
self.thread_pool = ThreadPoolExecutor(
2887-
max_workers=concurrency_options.maximum_thread_pool_workers,
2894+
self.thread_pool = self._create_thread_pool()
2895+
self._shutdown = False
2896+
2897+
def _create_thread_pool(self) -> ThreadPoolExecutor:
2898+
return ThreadPoolExecutor(
2899+
max_workers=self.concurrency_options.maximum_thread_pool_workers,
28882900
thread_name_prefix="DurableTask",
28892901
)
2902+
2903+
def _ensure_thread_pool(self) -> None:
2904+
if getattr(self.thread_pool, "_shutdown", False):
2905+
self.thread_pool = self._create_thread_pool()
2906+
2907+
def prepare_for_run(self) -> None:
28902908
self._shutdown = False
2909+
self._ensure_thread_pool()
28912910

28922911
def _ensure_queues_for_current_loop(self):
28932912
"""Ensure queues are bound to the current event loop."""
@@ -2962,8 +2981,7 @@ def _ensure_queues_for_current_loop(self):
29622981
self._pending_entity_batch_work.clear()
29632982

29642983
async def run(self):
2965-
# Reset shutdown flag in case this manager is being reused
2966-
self._shutdown = False
2984+
self._ensure_thread_pool()
29672985

29682986
# Ensure queues are properly bound to the current event loop
29692987
self._ensure_queues_for_current_loop()
@@ -3025,6 +3043,9 @@ async def run(self):
30253043
except Exception as cancellation_exception:
30263044
self._logger.error(f"Uncaught error while cancelling entity batch work item: {cancellation_exception}")
30273045
self.shutdown()
3046+
finally:
3047+
if not getattr(self.thread_pool, "_shutdown", False):
3048+
self.thread_pool.shutdown(wait=True)
30283049

30293050
async def _consume_queue(self, queue: asyncio.Queue, semaphore: asyncio.Semaphore):
30303051
# List to track running tasks
@@ -3068,12 +3089,7 @@ async def _run_func(self, func, *args, **kwargs):
30683089
return await func(*args, **kwargs)
30693090
else:
30703091
loop = asyncio.get_running_loop()
3071-
# Avoid submitting to executor after shutdown
3072-
if (
3073-
getattr(self, "_shutdown", False) and getattr(self, "thread_pool", None) and getattr(
3074-
self.thread_pool, "_shutdown", False)
3075-
):
3076-
return None
3092+
self._ensure_thread_pool()
30773093
return await loop.run_in_executor(
30783094
self.thread_pool, lambda: func(*args, **kwargs)
30793095
)
@@ -3113,11 +3129,10 @@ def submit_entity_batch(self, func, cancellation_func, *args, **kwargs):
31133129

31143130
def shutdown(self):
31153131
self._shutdown = True
3116-
self.thread_pool.shutdown(wait=True)
31173132

31183133
async def reset_for_new_run(self):
31193134
"""Reset the manager state for a new run."""
3120-
self._shutdown = False
3135+
self.prepare_for_run()
31213136
# Clear any existing queues - they'll be recreated when needed
31223137
if self.activity_queue is not None:
31233138
# Clear existing queue by creating a new one

tests/durabletask/test_worker_concurrency_loop.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ def cancel_dummy_activity(req, stub, completionToken):
7373

7474
async def run_test():
7575
# Start the worker manager's run loop in the background
76+
worker._async_worker_manager.prepare_for_run()
7677
worker_task = asyncio.create_task(worker._async_worker_manager.run())
7778
for req in orchestrator_requests:
7879
worker._async_worker_manager.submit_orchestration(dummy_orchestrator, cancel_dummy_orchestrator, req, stub, DummyCompletionToken())
@@ -133,6 +134,7 @@ def fn(*args, **kwargs):
133134

134135
# Run the manager loop in a thread (sync context)
135136
def run_manager():
137+
manager.prepare_for_run()
136138
asyncio.run(manager.run())
137139

138140
t = threading.Thread(target=run_manager)

tests/durabletask/test_worker_concurrency_loop_async.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ async def cancel_dummy_activity(req, stub, completionToken):
7272
async def run_test():
7373
# Clear stub state before each run
7474
stub.completed.clear()
75+
grpc_worker._async_worker_manager.prepare_for_run()
7576
worker_task = asyncio.create_task(grpc_worker._async_worker_manager.run())
7677
# Need to yield to that thread in order to let it start up on the second run
7778
startup_attempts = 0

tests/durabletask/test_worker_resiliency.py

Lines changed: 198 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,18 @@
11
import asyncio
22
import grpc
3-
from threading import Event
3+
from threading import Event, Timer
44
from unittest.mock import MagicMock
55

66
import pytest
77

88
from durabletask.grpc_options import GrpcWorkerResiliencyOptions
99
from durabletask.internal import orchestrator_service_pb2 as pb
10-
from durabletask.worker import TaskHubGrpcWorker, _WorkItemStreamOutcome
10+
from durabletask.worker import (
11+
_AsyncWorkerManager,
12+
ConcurrencyOptions,
13+
TaskHubGrpcWorker,
14+
_WorkItemStreamOutcome,
15+
)
1116

1217

1318
class FakeRpcError(grpc.RpcError):
@@ -62,6 +67,9 @@ def __init__(self):
6267
self._shutdown_event = asyncio.Event()
6368
self.submissions: list[tuple[str, tuple]] = []
6469

70+
def prepare_for_run(self):
71+
self._shutdown_event = asyncio.Event()
72+
6573
async def run(self):
6674
await self._shutdown_event.wait()
6775

@@ -88,16 +96,58 @@ def _complete_activity_request(req, stub, completion_token):
8896
)
8997

9098

91-
def _make_activity_work_item() -> pb.WorkItem:
99+
def _make_activity_work_item(
100+
task_id: int = 1,
101+
completion_token: str = "token",
102+
instance_id: str = "instance-id",
103+
) -> pb.WorkItem:
92104
return pb.WorkItem(
93105
activityRequest=pb.ActivityRequest(
94106
name="test_activity",
95-
taskId=1,
96-
orchestrationInstance=pb.OrchestrationInstance(instanceId="instance-id"),
107+
taskId=task_id,
108+
orchestrationInstance=pb.OrchestrationInstance(instanceId=instance_id),
97109
),
98-
completionToken="token",
110+
completionToken=completion_token,
111+
)
112+
113+
114+
async def _wait_for_condition(predicate, *, timeout: float = 2.0):
115+
loop = asyncio.get_running_loop()
116+
deadline = loop.time() + timeout
117+
while not predicate():
118+
if loop.time() >= deadline:
119+
raise AssertionError("condition was not met before timeout")
120+
await asyncio.sleep(0.01)
121+
122+
123+
@pytest.mark.asyncio
124+
async def test_async_worker_manager_honors_shutdown_requested_before_run():
125+
manager = _AsyncWorkerManager(
126+
ConcurrencyOptions(maximum_thread_pool_workers=1),
127+
MagicMock(),
99128
)
100129

130+
manager.shutdown()
131+
await asyncio.wait_for(manager.run(), timeout=1.0)
132+
133+
134+
def test_worker_start_clears_prior_shutdown_request():
135+
worker = TaskHubGrpcWorker()
136+
worker._shutdown.set()
137+
run_started = Event()
138+
139+
async def fake_run_loop():
140+
run_started.set()
141+
142+
worker._async_run_loop = fake_run_loop
143+
worker.start()
144+
worker._runLoop.join(timeout=1.0)
145+
146+
assert run_started.is_set() is True
147+
assert worker._shutdown.is_set() is False
148+
149+
worker.stop()
150+
101151

102152
def test_worker_classifies_graceful_close_before_first_message():
103153
worker = TaskHubGrpcWorker(
@@ -399,6 +449,148 @@ def create_stub(channel):
399449
created_channels[0].close.assert_called_once()
400450

401451

452+
@pytest.mark.asyncio
453+
async def test_worker_shutdown_drains_real_manager_work_before_closing_retired_sdk_channel(monkeypatch):
454+
worker = TaskHubGrpcWorker(
455+
concurrency_options=ConcurrencyOptions(
456+
maximum_concurrent_activity_work_items=1,
457+
maximum_thread_pool_workers=1,
458+
)
459+
)
460+
worker._execute_activity = _complete_activity_request
461+
monkeypatch.setattr(worker._shutdown, "wait", lambda timeout: False)
462+
463+
created_channels = []
464+
465+
def get_grpc_channel(*args, **kwargs):
466+
channel = MagicMock(name=f"channel-{len(created_channels) + 1}")
467+
created_channels.append(channel)
468+
return channel
469+
470+
allow_first_completion = Event()
471+
first_completion_started = Event()
472+
completed_task_ids = []
473+
first_stub = MagicMock()
474+
first_stub.GetWorkItems.return_value = FakeResponseStream(items=[
475+
_make_activity_work_item(task_id=1, completion_token="token-1"),
476+
_make_activity_work_item(task_id=2, completion_token="token-2"),
477+
])
478+
479+
def complete_activity(response):
480+
completed_task_ids.append(response.taskId)
481+
if response.taskId == 1:
482+
first_completion_started.set()
483+
Timer(0.2, allow_first_completion.set).start()
484+
assert allow_first_completion.wait(timeout=5.0)
485+
elif response.taskId == 2:
486+
assert created_channels[0].close.call_count == 0
487+
488+
first_stub.CompleteActivityTask.side_effect = complete_activity
489+
490+
second_stub = MagicMock()
491+
second_stub.GetWorkItems.side_effect = FakeRpcError(
492+
grpc.StatusCode.CANCELLED,
493+
"stop",
494+
)
495+
496+
stubs = [first_stub, second_stub]
497+
stub_channels = []
498+
499+
def create_stub(channel):
500+
stub_channels.append(channel)
501+
return stubs.pop(0)
502+
503+
monkeypatch.setattr("durabletask.worker.shared.get_grpc_channel", get_grpc_channel)
504+
monkeypatch.setattr("durabletask.worker.stubs.TaskHubSidecarServiceStub", create_stub)
505+
506+
run_task = asyncio.create_task(worker._async_run_loop())
507+
await asyncio.wait_for(run_task, timeout=2.0)
508+
509+
assert first_completion_started.is_set() is True
510+
assert len(created_channels) == 2
511+
assert stub_channels == created_channels
512+
assert completed_task_ids == [1, 2]
513+
created_channels[0].close.assert_called_once()
514+
created_channels[1].close.assert_called_once()
515+
516+
517+
@pytest.mark.asyncio
518+
async def test_worker_shutdown_runs_real_manager_cancellation_wrapper_before_closing_retired_sdk_channel(monkeypatch):
519+
worker = TaskHubGrpcWorker(
520+
concurrency_options=ConcurrencyOptions(
521+
maximum_concurrent_activity_work_items=1,
522+
maximum_thread_pool_workers=1,
523+
)
524+
)
525+
monkeypatch.setattr(worker._shutdown, "wait", lambda timeout: False)
526+
527+
created_channels = []
528+
529+
def get_grpc_channel(*args, **kwargs):
530+
channel = MagicMock(name=f"channel-{len(created_channels) + 1}")
531+
created_channels.append(channel)
532+
return channel
533+
534+
allow_first_completion = Event()
535+
first_completion_started = Event()
536+
completed_task_ids = []
537+
cancelled_task_ids = []
538+
539+
def execute_activity(req, stub, completion_token):
540+
if req.taskId == 1:
541+
_complete_activity_request(req, stub, completion_token)
542+
else:
543+
raise RuntimeError("boom")
544+
545+
def cancel_activity(req, stub, completion_token):
546+
cancelled_task_ids.append(req.taskId)
547+
assert created_channels[0].close.call_count == 0
548+
549+
worker._execute_activity = execute_activity
550+
worker._cancel_activity = cancel_activity
551+
552+
first_stub = MagicMock()
553+
first_stub.GetWorkItems.return_value = FakeResponseStream(items=[
554+
_make_activity_work_item(task_id=1, completion_token="token-1"),
555+
_make_activity_work_item(task_id=2, completion_token="token-2"),
556+
])
557+
558+
def complete_activity(response):
559+
completed_task_ids.append(response.taskId)
560+
Timer(0.2, allow_first_completion.set).start()
561+
first_completion_started.set()
562+
assert allow_first_completion.wait(timeout=5.0)
563+
564+
first_stub.CompleteActivityTask.side_effect = complete_activity
565+
566+
second_stub = MagicMock()
567+
second_stub.GetWorkItems.side_effect = FakeRpcError(
568+
grpc.StatusCode.CANCELLED,
569+
"stop",
570+
)
571+
572+
stubs = [first_stub, second_stub]
573+
stub_channels = []
574+
575+
def create_stub(channel):
576+
stub_channels.append(channel)
577+
return stubs.pop(0)
578+
579+
monkeypatch.setattr("durabletask.worker.shared.get_grpc_channel", get_grpc_channel)
580+
monkeypatch.setattr("durabletask.worker.stubs.TaskHubSidecarServiceStub", create_stub)
581+
582+
run_task = asyncio.create_task(worker._async_run_loop())
583+
await asyncio.wait_for(run_task, timeout=2.0)
584+
585+
assert first_completion_started.is_set() is True
586+
assert len(created_channels) == 2
587+
assert stub_channels == created_channels
588+
assert completed_task_ids == [1]
589+
assert cancelled_task_ids == [2]
590+
created_channels[0].close.assert_called_once()
591+
created_channels[1].close.assert_called_once()
592+
593+
402594
@pytest.mark.asyncio
403595
async def test_worker_never_closes_caller_owned_channel_after_graceful_reset(monkeypatch):
404596
provided_channel = MagicMock(name="provided-channel")

0 commit comments

Comments
 (0)