Skip to content

Commit 72191f1

Browse files
Bernd VerstCopilot
andcommitted
Fix worker channel retirement for in-flight completions
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
1 parent bbaf2b0 commit 72191f1

3 files changed

Lines changed: 249 additions & 10 deletions

File tree

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@ FIXED
3333
so configured hello timeouts apply on fresh connections, received work resets
3434
failure tracking, SDK-owned channels are refreshed and cleaned up safely, and
3535
caller-owned channels are never recreated or closed during reconnects.
36+
- Fixed `TaskHubGrpcWorker` so in-flight work item completions can finish after
37+
a graceful gRPC stream reset before the worker retires an SDK-owned channel.
3638
- Improved sync and async gRPC clients so repeated transport failures recreate
3739
SDK-owned channels, while long-poll deadlines, successful replies, and
3840
application-level RPC errors do not trigger unnecessary channel replacement.

durabletask/worker.py

Lines changed: 124 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from concurrent.futures import ThreadPoolExecutor
1111
from dataclasses import dataclass, field
1212
from datetime import datetime, timedelta, timezone
13-
from threading import Event, Thread
13+
from threading import Event, Lock, Thread
1414
from types import GeneratorType
1515
from enum import Enum
1616
from typing import Any, Generator, Optional, Sequence, Tuple, TypeVar, Union, overload
@@ -130,6 +130,73 @@ class _WorkItemStreamOutcome(Enum):
130130
SILENT_DISCONNECT = "silent_disconnect"
131131

132132

133+
@dataclass
134+
class _TrackedChannelState:
135+
channel: Any
136+
ref_count: int = 0
137+
close_when_released: bool = False
138+
139+
140+
class _InFlightChannelTracker:
141+
def __init__(self):
142+
self._lock = Lock()
143+
self._states: dict[int, _TrackedChannelState] = {}
144+
145+
def acquire(self, channel: Any):
146+
channel_key = id(channel)
147+
with self._lock:
148+
state = self._states.get(channel_key)
149+
if state is None:
150+
state = _TrackedChannelState(channel=channel)
151+
self._states[channel_key] = state
152+
state.ref_count += 1
153+
154+
released = False
155+
156+
def release() -> None:
157+
nonlocal released
158+
if released:
159+
return
160+
released = True
161+
162+
channel_to_close = None
163+
with self._lock:
164+
state = self._states.get(channel_key)
165+
if state is None:
166+
return
167+
168+
state.ref_count -= 1
169+
if state.ref_count == 0:
170+
if state.close_when_released:
171+
channel_to_close = state.channel
172+
del self._states[channel_key]
173+
174+
if channel_to_close is not None:
175+
self._close_channel(channel_to_close)
176+
177+
return release
178+
179+
def retire(self, channel: Any) -> None:
180+
channel_key = id(channel)
181+
channel_to_close = None
182+
with self._lock:
183+
state = self._states.get(channel_key)
184+
if state is None:
185+
channel_to_close = channel
186+
else:
187+
state.close_when_released = True
188+
189+
if channel_to_close is not None:
190+
self._close_channel(channel_to_close)
191+
192+
@staticmethod
193+
def _close_channel(channel: Any) -> None:
194+
try:
195+
channel.close()
196+
except Exception:
197+
pass
198+
199+
133200
class VersioningOptions:
134201
"""Configuration options for orchestrator and activity versioning.
135202
@@ -642,6 +709,7 @@ async def _async_run_loop(self):
642709
failure_tracker = FailureTracker(
643710
threshold=self._resiliency_options.channel_recreate_failure_threshold,
644711
)
712+
in_flight_channel_tracker = _InFlightChannelTracker()
645713

646714
def get_reconnect_delay_seconds() -> float:
647715
return get_full_jitter_delay_seconds(
@@ -671,6 +739,45 @@ def create_fresh_connection():
671739
current_stub = None
672740
raise
673741

742+
def wrap_execution(handler, release):
743+
def wrapped(*args, **kwargs):
744+
result = handler(*args, **kwargs)
745+
release()
746+
return result
747+
748+
return wrapped
749+
750+
def wrap_cancellation(handler, release):
751+
def wrapped(*args, **kwargs):
752+
try:
753+
return handler(*args, **kwargs)
754+
finally:
755+
release()
756+
757+
return wrapped
758+
759+
def submit_work_item(
760+
submit_func,
761+
handler,
762+
cancellation_handler,
763+
request,
764+
stub,
765+
completion_token,
766+
channel,
767+
):
768+
release = in_flight_channel_tracker.acquire(channel)
769+
try:
770+
submit_func(
771+
wrap_execution(handler, release),
772+
wrap_cancellation(cancellation_handler, release),
773+
request,
774+
stub,
775+
completion_token,
776+
)
777+
except Exception:
778+
release()
779+
raise
780+
674781
def invalidate_connection(
675782
*,
676783
recreate_channel: bool = False,
@@ -700,10 +807,7 @@ def invalidate_connection(
700807
and self._can_recreate_channel()
701808
and (recreate_channel or close_channel)
702809
):
703-
try:
704-
current_channel.close()
705-
except Exception:
706-
pass
810+
in_flight_channel_tracker.retire(current_channel)
707811
current_channel = None
708812
current_stub = None
709813

@@ -742,7 +846,9 @@ def should_invalidate_connection(rpc_error):
742846
continue
743847
try:
744848
assert current_stub is not None
849+
assert current_channel is not None
745850
stub = current_stub
851+
channel = current_channel
746852
capabilities = []
747853
if self._payload_store is not None:
748854
capabilities.append(pb.WORKER_CAPABILITY_LARGE_PAYLOADS)
@@ -822,36 +928,44 @@ def stream_reader():
822928

823929
failure_tracker.record_success()
824930
if work_item.HasField("orchestratorRequest"):
825-
self._async_worker_manager.submit_orchestration(
931+
submit_work_item(
932+
self._async_worker_manager.submit_orchestration,
826933
self._execute_orchestrator,
827934
self._cancel_orchestrator,
828935
work_item.orchestratorRequest,
829936
stub,
830937
work_item.completionToken,
938+
channel,
831939
)
832940
elif work_item.HasField("activityRequest"):
833-
self._async_worker_manager.submit_activity(
941+
submit_work_item(
942+
self._async_worker_manager.submit_activity,
834943
self._execute_activity,
835944
self._cancel_activity,
836945
work_item.activityRequest,
837946
stub,
838947
work_item.completionToken,
948+
channel,
839949
)
840950
elif work_item.HasField("entityRequest"):
841-
self._async_worker_manager.submit_entity_batch(
951+
submit_work_item(
952+
self._async_worker_manager.submit_entity_batch,
842953
self._execute_entity_batch,
843954
self._cancel_entity_batch,
844955
work_item.entityRequest,
845956
stub,
846957
work_item.completionToken,
958+
channel,
847959
)
848960
elif work_item.HasField("entityRequestV2"):
849-
self._async_worker_manager.submit_entity_batch(
961+
submit_work_item(
962+
self._async_worker_manager.submit_entity_batch,
850963
self._execute_entity_batch,
851964
self._cancel_entity_batch,
852965
work_item.entityRequestV2,
853966
stub,
854-
work_item.completionToken
967+
work_item.completionToken,
968+
channel,
855969
)
856970
else:
857971
self._logger.warning(

tests/durabletask/test_worker_resiliency.py

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,16 @@ def shutdown(self):
7878
self._shutdown_event.set()
7979

8080

81+
def _complete_activity_request(req, stub, completion_token):
82+
stub.CompleteActivityTask(
83+
pb.ActivityResponse(
84+
instanceId=req.orchestrationInstance.instanceId,
85+
taskId=req.taskId,
86+
completionToken=completion_token,
87+
)
88+
)
89+
90+
8191
def _make_activity_work_item() -> pb.WorkItem:
8292
return pb.WorkItem(
8393
activityRequest=pb.ActivityRequest(
@@ -331,6 +341,119 @@ def create_stub(channel):
331341
created_channels[1].close.assert_called_once()
332342

333343

344+
@pytest.mark.asyncio
345+
async def test_worker_defers_sdk_owned_channel_close_until_inflight_completion_finishes(monkeypatch):
346+
worker = TaskHubGrpcWorker()
347+
worker_manager = DummyWorkerManager()
348+
worker._async_worker_manager = worker_manager
349+
worker._execute_activity = _complete_activity_request
350+
monkeypatch.setattr(worker._shutdown, "wait", lambda timeout: False)
351+
352+
created_channels = []
353+
354+
def get_grpc_channel(*args, **kwargs):
355+
channel = MagicMock(name=f"channel-{len(created_channels) + 1}")
356+
created_channels.append(channel)
357+
return channel
358+
359+
completed_responses = []
360+
first_stub = MagicMock()
361+
first_stub.GetWorkItems.return_value = FakeResponseStream(items=[_make_activity_work_item()])
362+
363+
def complete_activity(response):
364+
assert created_channels[0].close.call_count == 0
365+
completed_responses.append(response)
366+
367+
first_stub.CompleteActivityTask.side_effect = complete_activity
368+
369+
second_stub = MagicMock()
370+
second_stub.GetWorkItems.side_effect = FakeRpcError(
371+
grpc.StatusCode.CANCELLED,
372+
"stop",
373+
)
374+
375+
stubs = [first_stub, second_stub]
376+
stub_channels = []
377+
378+
def create_stub(channel):
379+
stub_channels.append(channel)
380+
return stubs.pop(0)
381+
382+
monkeypatch.setattr("durabletask.worker.shared.get_grpc_channel", get_grpc_channel)
383+
monkeypatch.setattr("durabletask.worker.stubs.TaskHubSidecarServiceStub", create_stub)
384+
385+
await worker._async_run_loop()
386+
387+
assert len(worker_manager.submissions) == 1
388+
assert len(created_channels) == 2
389+
assert stub_channels == created_channels
390+
created_channels[0].close.assert_not_called()
391+
created_channels[1].close.assert_called_once()
392+
393+
_, submission = worker_manager.submissions[0]
394+
func, _, req, stub, completion_token = submission
395+
func(req, stub, completion_token)
396+
397+
assert len(completed_responses) == 1
398+
assert completed_responses[0].completionToken == "token"
399+
created_channels[0].close.assert_called_once()
400+
401+
402+
@pytest.mark.asyncio
403+
async def test_worker_never_closes_caller_owned_channel_after_graceful_reset(monkeypatch):
404+
provided_channel = MagicMock(name="provided-channel")
405+
worker = TaskHubGrpcWorker(channel=provided_channel)
406+
worker_manager = DummyWorkerManager()
407+
worker._async_worker_manager = worker_manager
408+
worker._execute_activity = _complete_activity_request
409+
monkeypatch.setattr(worker._shutdown, "wait", lambda timeout: False)
410+
411+
completed_responses = []
412+
first_stub = MagicMock()
413+
first_stub.GetWorkItems.return_value = FakeResponseStream(items=[_make_activity_work_item()])
414+
415+
def complete_activity(response):
416+
assert provided_channel.close.call_count == 0
417+
completed_responses.append(response)
418+
419+
first_stub.CompleteActivityTask.side_effect = complete_activity
420+
421+
second_stub = MagicMock()
422+
second_stub.GetWorkItems.side_effect = FakeRpcError(
423+
grpc.StatusCode.CANCELLED,
424+
"stop",
425+
)
426+
427+
stubs = [first_stub, second_stub]
428+
stub_channels = []
429+
430+
def create_stub(channel):
431+
stub_channels.append(channel)
432+
return stubs.pop(0)
433+
434+
monkeypatch.setattr(
435+
"durabletask.worker.shared.get_grpc_channel",
436+
lambda *args, **kwargs: pytest.fail(
437+
"SDK channel factory should not run for caller-owned channels"
438+
),
439+
)
440+
monkeypatch.setattr("durabletask.worker.stubs.TaskHubSidecarServiceStub", create_stub)
441+
442+
await worker._async_run_loop()
443+
444+
assert len(worker_manager.submissions) == 1
445+
assert stub_channels == [provided_channel, provided_channel]
446+
provided_channel.close.assert_not_called()
447+
448+
_, submission = worker_manager.submissions[0]
449+
func, _, req, stub, completion_token = submission
450+
func(req, stub, completion_token)
451+
452+
assert len(completed_responses) == 1
453+
assert completed_responses[0].completionToken == "token"
454+
provided_channel.close.assert_not_called()
455+
456+
334457
@pytest.mark.asyncio
335458
async def test_worker_uses_reconnect_backoff_helper_after_connection_failure(monkeypatch):
336459
worker = TaskHubGrpcWorker(

0 commit comments

Comments
 (0)