Skip to content

Commit 09697bb

Browse files
Bernd VerstCopilot
andcommitted
Add sync client gRPC channel recreation
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
1 parent 950892b commit 09697bb

3 files changed

Lines changed: 206 additions & 17 deletions

File tree

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,9 @@ FIXED
3131
apply on fresh connections, received work items reset failure tracking,
3232
SDK-owned channels are cleaned up on shutdown and full resets, and
3333
caller-owned channels are never recreated or closed during worker reconnects.
34+
- Fixed sync `TaskHubGrpcClient` transport resiliency so SDK-owned channels are
35+
recreated after repeated transport failures without counting long-poll
36+
timeout deadlines against the recreation threshold.
3437

3538
## v1.4.0
3639

durabletask/client.py

Lines changed: 82 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
# Licensed under the MIT License.
33

44
import logging
5+
import threading
6+
import time
57
import uuid
68
from dataclasses import dataclass
79
from datetime import datetime
@@ -25,6 +27,10 @@
2527
import durabletask.internal.shared as shared
2628
import durabletask.internal.tracing as tracing
2729
from durabletask import task
30+
from durabletask.internal.grpc_resiliency import (
31+
FailureTracker,
32+
is_client_transport_failure,
33+
)
2834
from durabletask.internal.client_helpers import (
2935
build_query_entities_req,
3036
build_query_instances_req,
@@ -201,10 +207,61 @@ def __init__(self, *,
201207
)
202208
self._channel = channel
203209
self._stub = stubs.TaskHubSidecarServiceStub(channel)
210+
self._client_failure_tracker = FailureTracker(
211+
self._resiliency_options.channel_recreate_failure_threshold
212+
)
213+
self._last_recreate_time = 0.0
214+
self._recreate_lock = threading.Lock()
204215
self._logger = shared.get_logger("client", log_handler, log_formatter)
205216
self.default_version = default_version
206217
self._payload_store = payload_store
207218

219+
def _invoke_unary(
220+
self,
221+
method_name: str,
222+
request: Any,
223+
*,
224+
timeout: Optional[int] = None):
225+
method = getattr(self._stub, method_name)
226+
try:
227+
if timeout is None:
228+
response = method(request)
229+
else:
230+
response = method(request, timeout=timeout)
231+
except grpc.RpcError as rpc_error:
232+
status_code = rpc_error.code()
233+
if is_client_transport_failure(method_name, status_code):
234+
should_recreate = self._client_failure_tracker.record_failure()
235+
if should_recreate:
236+
self._maybe_recreate_channel()
237+
elif status_code != grpc.StatusCode.DEADLINE_EXCEEDED:
238+
self._client_failure_tracker.record_success()
239+
raise
240+
else:
241+
self._client_failure_tracker.record_success()
242+
return response
243+
244+
def _maybe_recreate_channel(self) -> None:
245+
if not self._owns_channel:
246+
return
247+
with self._recreate_lock:
248+
now = time.monotonic()
249+
if now - self._last_recreate_time < self._resiliency_options.min_recreate_interval_seconds:
250+
return
251+
old_channel = self._channel
252+
self._channel = shared.get_grpc_channel(
253+
host_address=self._host_address,
254+
secure_channel=self._secure_channel,
255+
interceptors=self._interceptors,
256+
channel_options=self._channel_options,
257+
)
258+
self._stub = stubs.TaskHubSidecarServiceStub(self._channel)
259+
self._last_recreate_time = now
260+
self._client_failure_tracker.record_success()
261+
close_timer = threading.Timer(30.0, old_channel.close)
262+
close_timer.daemon = True
263+
close_timer.start()
264+
208265
def close(self) -> None:
209266
"""Closes the underlying gRPC channel.
210267
@@ -249,12 +306,12 @@ def schedule_new_orchestration(self, orchestrator: Union[task.Orchestrator[TInpu
249306
payload_helpers.externalize_payloads(
250307
req, self._payload_store, instance_id=req.instanceId,
251308
)
252-
res: pb.CreateInstanceResponse = self._stub.StartInstance(req)
309+
res: pb.CreateInstanceResponse = self._invoke_unary("StartInstance", req)
253310
return res.instanceId
254311

255312
def get_orchestration_state(self, instance_id: str, *, fetch_payloads: bool = True) -> Optional[OrchestrationState]:
256313
req = pb.GetInstanceRequest(instanceId=instance_id, getInputsAndOutputs=fetch_payloads)
257-
res: pb.GetInstanceResponse = self._stub.GetInstance(req)
314+
res: pb.GetInstanceResponse = self._invoke_unary("GetInstance", req)
258315
# De-externalize any large-payload tokens in the response
259316
if self._payload_store is not None and res.exists:
260317
payload_helpers.deexternalize_payloads(res, self._payload_store)
@@ -294,7 +351,7 @@ def list_instance_ids(self,
294351
f"page_size={page_size}, "
295352
f"continuation_token={continuation_token}"
296353
)
297-
resp: pb.ListInstanceIdsResponse = self._stub.ListInstanceIds(req)
354+
resp: pb.ListInstanceIdsResponse = self._invoke_unary("ListInstanceIds", req)
298355
next_token = resp.lastInstanceKey.value if resp.HasField("lastInstanceKey") else None
299356
return Page(items=list(resp.instanceIds), continuation_token=next_token)
300357

@@ -311,7 +368,7 @@ def get_all_orchestration_states(self,
311368

312369
while True:
313370
req = build_query_instances_req(orchestration_query, _continuation_token)
314-
resp: pb.QueryInstancesResponse = self._stub.QueryInstances(req)
371+
resp: pb.QueryInstancesResponse = self._invoke_unary("QueryInstances", req)
315372
if self._payload_store is not None:
316373
payload_helpers.deexternalize_payloads(resp, self._payload_store)
317374
states += [parse_orchestration_state(res) for res in resp.orchestrationState]
@@ -328,7 +385,11 @@ def wait_for_orchestration_start(self, instance_id: str, *,
328385
req = pb.GetInstanceRequest(instanceId=instance_id, getInputsAndOutputs=fetch_payloads)
329386
try:
330387
self._logger.info(f"Waiting up to {timeout}s for instance '{instance_id}' to start.")
331-
res: pb.GetInstanceResponse = self._stub.WaitForInstanceStart(req, timeout=timeout)
388+
res: pb.GetInstanceResponse = self._invoke_unary(
389+
"WaitForInstanceStart",
390+
req,
391+
timeout=timeout,
392+
)
332393
if self._payload_store is not None and res.exists:
333394
payload_helpers.deexternalize_payloads(res, self._payload_store)
334395
return new_orchestration_state(req.instanceId, res)
@@ -345,7 +406,11 @@ def wait_for_orchestration_completion(self, instance_id: str, *,
345406
req = pb.GetInstanceRequest(instanceId=instance_id, getInputsAndOutputs=fetch_payloads)
346407
try:
347408
self._logger.info(f"Waiting {timeout}s for instance '{instance_id}' to complete.")
348-
res: pb.GetInstanceResponse = self._stub.WaitForInstanceCompletion(req, timeout=timeout)
409+
res: pb.GetInstanceResponse = self._invoke_unary(
410+
"WaitForInstanceCompletion",
411+
req,
412+
timeout=timeout,
413+
)
349414
if self._payload_store is not None and res.exists:
350415
payload_helpers.deexternalize_payloads(res, self._payload_store)
351416
state = new_orchestration_state(req.instanceId, res)
@@ -366,7 +431,7 @@ def raise_orchestration_event(self, instance_id: str, event_name: str, *,
366431
payload_helpers.externalize_payloads(
367432
req, self._payload_store, instance_id=instance_id,
368433
)
369-
self._stub.RaiseEvent(req)
434+
self._invoke_unary("RaiseEvent", req)
370435

371436
def terminate_orchestration(self, instance_id: str, *,
372437
output: Optional[Any] = None,
@@ -378,17 +443,17 @@ def terminate_orchestration(self, instance_id: str, *,
378443
payload_helpers.externalize_payloads(
379444
req, self._payload_store, instance_id=instance_id,
380445
)
381-
self._stub.TerminateInstance(req)
446+
self._invoke_unary("TerminateInstance", req)
382447

383448
def suspend_orchestration(self, instance_id: str) -> None:
384449
req = pb.SuspendRequest(instanceId=instance_id)
385450
self._logger.info(f"Suspending instance '{instance_id}'.")
386-
self._stub.SuspendInstance(req)
451+
self._invoke_unary("SuspendInstance", req)
387452

388453
def resume_orchestration(self, instance_id: str) -> None:
389454
req = pb.ResumeRequest(instanceId=instance_id)
390455
self._logger.info(f"Resuming instance '{instance_id}'.")
391-
self._stub.ResumeInstance(req)
456+
self._invoke_unary("ResumeInstance", req)
392457

393458
def restart_orchestration(self, instance_id: str, *,
394459
restart_with_new_instance_id: bool = False) -> str:
@@ -407,13 +472,13 @@ def restart_orchestration(self, instance_id: str, *,
407472
restartWithNewInstanceId=restart_with_new_instance_id)
408473

409474
self._logger.info(f"Restarting instance '{instance_id}'.")
410-
res: pb.RestartInstanceResponse = self._stub.RestartInstance(req)
475+
res: pb.RestartInstanceResponse = self._invoke_unary("RestartInstance", req)
411476
return res.instanceId
412477

413478
def purge_orchestration(self, instance_id: str, recursive: bool = True) -> PurgeInstancesResult:
414479
req = pb.PurgeInstancesRequest(instanceId=instance_id, recursive=recursive)
415480
self._logger.info(f"Purging instance '{instance_id}'.")
416-
resp: pb.PurgeInstancesResponse = self._stub.PurgeInstances(req)
481+
resp: pb.PurgeInstancesResponse = self._invoke_unary("PurgeInstances", req)
417482
return PurgeInstancesResult(resp.deletedInstanceCount, resp.isComplete.value)
418483

419484
def purge_orchestrations_by(self,
@@ -427,7 +492,7 @@ def purge_orchestrations_by(self,
427492
f"runtime_status={[str(status) for status in runtime_status] if runtime_status else None}, "
428493
f"recursive={recursive}")
429494
req = build_purge_by_filter_req(created_time_from, created_time_to, runtime_status, recursive)
430-
resp: pb.PurgeInstancesResponse = self._stub.PurgeInstances(req)
495+
resp: pb.PurgeInstancesResponse = self._invoke_unary("PurgeInstances", req)
431496
return PurgeInstancesResult(resp.deletedInstanceCount, resp.isComplete.value)
432497

433498
def signal_entity(self,
@@ -440,15 +505,15 @@ def signal_entity(self,
440505
payload_helpers.externalize_payloads(
441506
req, self._payload_store, instance_id=str(entity_instance_id),
442507
)
443-
self._stub.SignalEntity(req, None) # TODO: Cancellation timeout?
508+
self._invoke_unary("SignalEntity", req) # TODO: Cancellation timeout?
444509

445510
def get_entity(self,
446511
entity_instance_id: EntityInstanceId,
447512
include_state: bool = True
448513
) -> Optional[EntityMetadata]:
449514
req = pb.GetEntityRequest(instanceId=str(entity_instance_id), includeState=include_state)
450515
self._logger.info(f"Getting entity '{entity_instance_id}'.")
451-
res: pb.GetEntityResponse = self._stub.GetEntity(req)
516+
res: pb.GetEntityResponse = self._invoke_unary("GetEntity", req)
452517
if not res.exists:
453518
return None
454519
if self._payload_store is not None:
@@ -467,7 +532,7 @@ def get_all_entities(self,
467532

468533
while True:
469534
query_request = build_query_entities_req(entity_query, _continuation_token)
470-
resp: pb.QueryEntitiesResponse = self._stub.QueryEntities(query_request)
535+
resp: pb.QueryEntitiesResponse = self._invoke_unary("QueryEntities", query_request)
471536
if self._payload_store is not None:
472537
payload_helpers.deexternalize_payloads(resp, self._payload_store)
473538
entities += [EntityMetadata.from_entity_metadata(entity, query_request.query.includeState) for entity in resp.entities]
@@ -493,7 +558,7 @@ def clean_entity_storage(self,
493558
releaseOrphanedLocks=release_orphaned_locks,
494559
continuationToken=_continuation_token
495560
)
496-
resp: pb.CleanEntityStorageResponse = self._stub.CleanEntityStorage(req)
561+
resp: pb.CleanEntityStorageResponse = self._invoke_unary("CleanEntityStorage", req)
497562
empty_entities_removed += resp.emptyEntitiesRemoved
498563
orphaned_locks_released += resp.orphanedLocksReleased
499564

tests/durabletask/test_client.py

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import json
2+
import grpc
23
import pytest
34
from datetime import datetime, timezone
45
from unittest.mock import ANY, AsyncMock, MagicMock, patch
@@ -32,6 +33,15 @@
3233
INTERCEPTORS = [DefaultClientInterceptorImpl(METADATA)]
3334

3435

36+
class FakeRpcError(grpc.RpcError):
37+
def __init__(self, status_code: grpc.StatusCode):
38+
super().__init__()
39+
self._status_code = status_code
40+
41+
def code(self):
42+
return self._status_code
43+
44+
3545
class FakePayloadStore(PayloadStore):
3646
TOKEN_PREFIX = 'fake://'
3747

@@ -317,6 +327,117 @@ def test_client_stores_resiliency_options_for_recreation():
317327
assert client._interceptors == interceptors
318328

319329

330+
def test_sync_client_recreates_sdk_owned_channel_after_repeated_unavailable():
331+
first_channel = MagicMock(name="first-channel")
332+
second_channel = MagicMock(name="second-channel")
333+
first_stub = MagicMock()
334+
second_stub = MagicMock()
335+
second_stub.GetInstance.return_value = MagicMock(exists=False)
336+
337+
rpc_error = FakeRpcError(grpc.StatusCode.UNAVAILABLE)
338+
first_stub.GetInstance.side_effect = rpc_error
339+
340+
timer = MagicMock()
341+
342+
with patch("durabletask.client.shared.get_grpc_channel", side_effect=[first_channel, second_channel]), patch(
343+
"durabletask.client.stubs.TaskHubSidecarServiceStub", side_effect=[first_stub, second_stub]
344+
), patch("threading.Timer", return_value=timer) as mock_timer:
345+
client = TaskHubGrpcClient(
346+
host_address="localhost:4001",
347+
resiliency_options=GrpcClientResiliencyOptions(
348+
channel_recreate_failure_threshold=1,
349+
min_recreate_interval_seconds=0.0,
350+
),
351+
)
352+
with pytest.raises(FakeRpcError):
353+
client.get_orchestration_state("abc")
354+
client.get_orchestration_state("abc")
355+
356+
assert client._channel is second_channel
357+
mock_timer.assert_called_once_with(30.0, first_channel.close)
358+
assert timer.daemon is True
359+
timer.start.assert_called_once_with()
360+
361+
362+
def test_sync_client_does_not_count_long_poll_deadline():
363+
stub = MagicMock()
364+
stub.GetInstance.side_effect = FakeRpcError(grpc.StatusCode.UNAVAILABLE)
365+
stub.WaitForInstanceStart.side_effect = FakeRpcError(grpc.StatusCode.DEADLINE_EXCEEDED)
366+
367+
with patch("durabletask.client.shared.get_grpc_channel", return_value=MagicMock()), patch(
368+
"durabletask.client.stubs.TaskHubSidecarServiceStub", return_value=stub
369+
):
370+
client = TaskHubGrpcClient(
371+
resiliency_options=GrpcClientResiliencyOptions(channel_recreate_failure_threshold=2)
372+
)
373+
with pytest.raises(FakeRpcError):
374+
client.get_orchestration_state("abc")
375+
with pytest.raises(TimeoutError):
376+
client.wait_for_orchestration_start("abc")
377+
assert client._client_failure_tracker.consecutive_failures == 1
378+
379+
380+
def test_sync_client_does_not_recreate_caller_owned_channel():
381+
provided_channel = MagicMock(name="provided-channel")
382+
stub = MagicMock()
383+
stub.GetInstance.side_effect = FakeRpcError(grpc.StatusCode.UNAVAILABLE)
384+
385+
with patch("durabletask.client.shared.get_grpc_channel") as mock_get_channel, patch(
386+
"durabletask.client.stubs.TaskHubSidecarServiceStub", return_value=stub
387+
) as mock_stub:
388+
client = TaskHubGrpcClient(
389+
channel=provided_channel,
390+
resiliency_options=GrpcClientResiliencyOptions(channel_recreate_failure_threshold=1),
391+
)
392+
with pytest.raises(FakeRpcError):
393+
client.get_orchestration_state("abc")
394+
with pytest.raises(FakeRpcError):
395+
client.get_orchestration_state("abc")
396+
397+
assert client._channel is provided_channel
398+
mock_get_channel.assert_not_called()
399+
mock_stub.assert_called_once_with(provided_channel)
400+
401+
402+
def test_sync_client_resets_failure_tracking_after_success():
403+
stub = MagicMock()
404+
stub.GetInstance.side_effect = [
405+
FakeRpcError(grpc.StatusCode.UNAVAILABLE),
406+
MagicMock(exists=False),
407+
]
408+
409+
with patch("durabletask.client.shared.get_grpc_channel", return_value=MagicMock()), patch(
410+
"durabletask.client.stubs.TaskHubSidecarServiceStub", return_value=stub
411+
):
412+
client = TaskHubGrpcClient(
413+
resiliency_options=GrpcClientResiliencyOptions(channel_recreate_failure_threshold=2)
414+
)
415+
with pytest.raises(FakeRpcError):
416+
client.get_orchestration_state("abc")
417+
assert client.get_orchestration_state("abc") is None
418+
assert client._client_failure_tracker.consecutive_failures == 0
419+
420+
421+
def test_sync_client_resets_failure_tracking_after_application_error():
422+
stub = MagicMock()
423+
stub.GetInstance.side_effect = [
424+
FakeRpcError(grpc.StatusCode.UNAVAILABLE),
425+
FakeRpcError(grpc.StatusCode.INVALID_ARGUMENT),
426+
]
427+
428+
with patch("durabletask.client.shared.get_grpc_channel", return_value=MagicMock()), patch(
429+
"durabletask.client.stubs.TaskHubSidecarServiceStub", return_value=stub
430+
):
431+
client = TaskHubGrpcClient(
432+
resiliency_options=GrpcClientResiliencyOptions(channel_recreate_failure_threshold=2)
433+
)
434+
with pytest.raises(FakeRpcError):
435+
client.get_orchestration_state("abc")
436+
with pytest.raises(FakeRpcError):
437+
client.get_orchestration_state("abc")
438+
assert client._client_failure_tracker.consecutive_failures == 0
439+
440+
320441
def test_async_client_stores_resolved_transport_inputs():
321442
resiliency = GrpcClientResiliencyOptions()
322443
channel_options = GrpcChannelOptions(max_send_message_length=4321)

0 commit comments

Comments
 (0)