Skip to content

Commit 4f4b3a3

Browse files
Bernd VerstCopilot
andcommitted
Add async client gRPC channel recreation
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
1 parent 2834177 commit 4f4b3a3

3 files changed

Lines changed: 272 additions & 17 deletions

File tree

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,10 @@ FIXED
3838
recreated after repeated transport failures while long-poll timeout
3939
deadlines, successful replies, and application-level RPC errors reset the
4040
failure tracker.
41+
- Fixed async `AsyncTaskHubGrpcClient` transport resiliency so SDK-owned
42+
channels are recreated after repeated transport failures while long-poll
43+
timeout deadlines, successful replies, and application-level RPC errors
44+
reset the failure tracker.
4145

4246
## v1.4.0
4347

durabletask/client.py

Lines changed: 104 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# Copyright (c) Microsoft Corporation.
22
# Licensed under the MIT License.
33

4+
import asyncio
45
import logging
56
import threading
67
import time
@@ -614,6 +615,14 @@ def __init__(self, *,
614615
)
615616
self._channel = channel
616617
self._stub = stubs.TaskHubSidecarServiceStub(channel)
618+
self._client_failure_tracker = FailureTracker(
619+
self._resiliency_options.channel_recreate_failure_threshold
620+
)
621+
self._closing = False
622+
self._recreate_lock = asyncio.Lock()
623+
self._last_recreate_time = 0.0
624+
self._retired_channels: list[grpc.aio.Channel] = []
625+
self._retired_channel_close_tasks: set[asyncio.Task[None]] = set()
617626
self._logger = shared.get_logger("async_client", log_handler, log_formatter)
618627
self.default_version = default_version
619628
self._payload_store = payload_store
@@ -627,6 +636,18 @@ async def close(self) -> None:
627636
it.
628637
"""
629638
if self._owns_channel:
639+
self._closing = True
640+
async with self._recreate_lock:
641+
retired_channels = list(self._retired_channels)
642+
self._retired_channels.clear()
643+
close_tasks = list(self._retired_channel_close_tasks)
644+
self._retired_channel_close_tasks.clear()
645+
for close_task in close_tasks:
646+
close_task.cancel()
647+
if close_tasks:
648+
await asyncio.gather(*close_tasks, return_exceptions=True)
649+
for retired_channel in retired_channels:
650+
await retired_channel.close()
630651
await self._channel.close()
631652

632653
async def __aenter__(self):
@@ -635,6 +656,64 @@ async def __aenter__(self):
635656
async def __aexit__(self, exc_type, exc_val, exc_tb):
636657
await self.close()
637658

659+
async def _invoke_unary(
660+
self,
661+
method_name: str,
662+
request: Any,
663+
*,
664+
timeout: Optional[int] = None):
665+
method = getattr(self._stub, method_name)
666+
try:
667+
if timeout is None:
668+
response = await method(request)
669+
else:
670+
response = await method(request, timeout=timeout)
671+
except grpc.aio.AioRpcError as rpc_error:
672+
if is_client_transport_failure(method_name, rpc_error.code()):
673+
should_recreate = self._client_failure_tracker.record_failure()
674+
if should_recreate:
675+
await self._maybe_recreate_channel()
676+
else:
677+
self._client_failure_tracker.record_success()
678+
raise
679+
else:
680+
self._client_failure_tracker.record_success()
681+
return response
682+
683+
async def _maybe_recreate_channel(self) -> None:
684+
if not self._owns_channel or self._closing:
685+
return
686+
async with self._recreate_lock:
687+
if self._closing:
688+
return
689+
now = time.monotonic()
690+
if now - self._last_recreate_time < self._resiliency_options.min_recreate_interval_seconds:
691+
return
692+
old_channel = self._channel
693+
self._channel = shared.get_async_grpc_channel(
694+
host_address=self._host_address,
695+
secure_channel=self._secure_channel,
696+
interceptors=self._interceptors,
697+
channel_options=self._channel_options,
698+
)
699+
self._stub = stubs.TaskHubSidecarServiceStub(self._channel)
700+
self._last_recreate_time = now
701+
self._client_failure_tracker.record_success()
702+
self._retired_channels.append(old_channel)
703+
close_task = asyncio.create_task(self._close_retired_channel(old_channel))
704+
self._retired_channel_close_tasks.add(close_task)
705+
close_task.add_done_callback(self._retired_channel_close_tasks.discard)
706+
707+
async def _close_retired_channel(self, channel: grpc.aio.Channel) -> None:
708+
try:
709+
await asyncio.sleep(30.0)
710+
await channel.close()
711+
finally:
712+
try:
713+
self._retired_channels.remove(channel)
714+
except ValueError:
715+
pass
716+
638717
async def schedule_new_orchestration(self, orchestrator: Union[task.Orchestrator[TInput, TOutput], str], *,
639718
input: Optional[TInput] = None,
640719
instance_id: Optional[str] = None,
@@ -665,13 +744,13 @@ async def schedule_new_orchestration(self, orchestrator: Union[task.Orchestrator
665744
await payload_helpers.externalize_payloads_async(
666745
req, self._payload_store, instance_id=req.instanceId,
667746
)
668-
res: pb.CreateInstanceResponse = await self._stub.StartInstance(req)
747+
res: pb.CreateInstanceResponse = await self._invoke_unary("StartInstance", req)
669748
return res.instanceId
670749

671750
async def get_orchestration_state(self, instance_id: str, *,
672751
fetch_payloads: bool = True) -> Optional[OrchestrationState]:
673752
req = pb.GetInstanceRequest(instanceId=instance_id, getInputsAndOutputs=fetch_payloads)
674-
res: pb.GetInstanceResponse = await self._stub.GetInstance(req)
753+
res: pb.GetInstanceResponse = await self._invoke_unary("GetInstance", req)
675754
if self._payload_store is not None and res.exists:
676755
await payload_helpers.deexternalize_payloads_async(res, self._payload_store)
677756
return new_orchestration_state(req.instanceId, res)
@@ -710,7 +789,7 @@ async def list_instance_ids(self,
710789
f"page_size={page_size}, "
711790
f"continuation_token={continuation_token}"
712791
)
713-
resp: pb.ListInstanceIdsResponse = await self._stub.ListInstanceIds(req)
792+
resp: pb.ListInstanceIdsResponse = await self._invoke_unary("ListInstanceIds", req)
714793
next_token = resp.lastInstanceKey.value if resp.HasField("lastInstanceKey") else None
715794
return Page(items=list(resp.instanceIds), continuation_token=next_token)
716795

@@ -727,7 +806,7 @@ async def get_all_orchestration_states(self,
727806

728807
while True:
729808
req = build_query_instances_req(orchestration_query, _continuation_token)
730-
resp: pb.QueryInstancesResponse = await self._stub.QueryInstances(req)
809+
resp: pb.QueryInstancesResponse = await self._invoke_unary("QueryInstances", req)
731810
if self._payload_store is not None:
732811
await payload_helpers.deexternalize_payloads_async(resp, self._payload_store)
733812
states += [parse_orchestration_state(res) for res in resp.orchestrationState]
@@ -744,7 +823,11 @@ async def wait_for_orchestration_start(self, instance_id: str, *,
744823
req = pb.GetInstanceRequest(instanceId=instance_id, getInputsAndOutputs=fetch_payloads)
745824
try:
746825
self._logger.info(f"Waiting up to {timeout}s for instance '{instance_id}' to start.")
747-
res: pb.GetInstanceResponse = await self._stub.WaitForInstanceStart(req, timeout=timeout)
826+
res: pb.GetInstanceResponse = await self._invoke_unary(
827+
"WaitForInstanceStart",
828+
req,
829+
timeout=timeout,
830+
)
748831
if self._payload_store is not None and res.exists:
749832
await payload_helpers.deexternalize_payloads_async(res, self._payload_store)
750833
return new_orchestration_state(req.instanceId, res)
@@ -760,7 +843,11 @@ async def wait_for_orchestration_completion(self, instance_id: str, *,
760843
req = pb.GetInstanceRequest(instanceId=instance_id, getInputsAndOutputs=fetch_payloads)
761844
try:
762845
self._logger.info(f"Waiting {timeout}s for instance '{instance_id}' to complete.")
763-
res: pb.GetInstanceResponse = await self._stub.WaitForInstanceCompletion(req, timeout=timeout)
846+
res: pb.GetInstanceResponse = await self._invoke_unary(
847+
"WaitForInstanceCompletion",
848+
req,
849+
timeout=timeout,
850+
)
764851
if self._payload_store is not None and res.exists:
765852
await payload_helpers.deexternalize_payloads_async(res, self._payload_store)
766853
state = new_orchestration_state(req.instanceId, res)
@@ -781,7 +868,7 @@ async def raise_orchestration_event(self, instance_id: str, event_name: str, *,
781868
await payload_helpers.externalize_payloads_async(
782869
req, self._payload_store, instance_id=instance_id,
783870
)
784-
await self._stub.RaiseEvent(req)
871+
await self._invoke_unary("RaiseEvent", req)
785872

786873
async def terminate_orchestration(self, instance_id: str, *,
787874
output: Optional[Any] = None,
@@ -793,17 +880,17 @@ async def terminate_orchestration(self, instance_id: str, *,
793880
await payload_helpers.externalize_payloads_async(
794881
req, self._payload_store, instance_id=instance_id,
795882
)
796-
await self._stub.TerminateInstance(req)
883+
await self._invoke_unary("TerminateInstance", req)
797884

798885
async def suspend_orchestration(self, instance_id: str) -> None:
799886
req = pb.SuspendRequest(instanceId=instance_id)
800887
self._logger.info(f"Suspending instance '{instance_id}'.")
801-
await self._stub.SuspendInstance(req)
888+
await self._invoke_unary("SuspendInstance", req)
802889

803890
async def resume_orchestration(self, instance_id: str) -> None:
804891
req = pb.ResumeRequest(instanceId=instance_id)
805892
self._logger.info(f"Resuming instance '{instance_id}'.")
806-
await self._stub.ResumeInstance(req)
893+
await self._invoke_unary("ResumeInstance", req)
807894

808895
async def restart_orchestration(self, instance_id: str, *,
809896
restart_with_new_instance_id: bool = False) -> str:
@@ -822,13 +909,13 @@ async def restart_orchestration(self, instance_id: str, *,
822909
restartWithNewInstanceId=restart_with_new_instance_id)
823910

824911
self._logger.info(f"Restarting instance '{instance_id}'.")
825-
res: pb.RestartInstanceResponse = await self._stub.RestartInstance(req)
912+
res: pb.RestartInstanceResponse = await self._invoke_unary("RestartInstance", req)
826913
return res.instanceId
827914

828915
async def purge_orchestration(self, instance_id: str, recursive: bool = True) -> PurgeInstancesResult:
829916
req = pb.PurgeInstancesRequest(instanceId=instance_id, recursive=recursive)
830917
self._logger.info(f"Purging instance '{instance_id}'.")
831-
resp: pb.PurgeInstancesResponse = await self._stub.PurgeInstances(req)
918+
resp: pb.PurgeInstancesResponse = await self._invoke_unary("PurgeInstances", req)
832919
return PurgeInstancesResult(resp.deletedInstanceCount, resp.isComplete.value)
833920

834921
async def purge_orchestrations_by(self,
@@ -842,7 +929,7 @@ async def purge_orchestrations_by(self,
842929
f"runtime_status={[str(status) for status in runtime_status] if runtime_status else None}, "
843930
f"recursive={recursive}")
844931
req = build_purge_by_filter_req(created_time_from, created_time_to, runtime_status, recursive)
845-
resp: pb.PurgeInstancesResponse = await self._stub.PurgeInstances(req)
932+
resp: pb.PurgeInstancesResponse = await self._invoke_unary("PurgeInstances", req)
846933
return PurgeInstancesResult(resp.deletedInstanceCount, resp.isComplete.value)
847934

848935
async def signal_entity(self,
@@ -855,15 +942,15 @@ async def signal_entity(self,
855942
await payload_helpers.externalize_payloads_async(
856943
req, self._payload_store, instance_id=str(entity_instance_id),
857944
)
858-
await self._stub.SignalEntity(req, None)
945+
await self._invoke_unary("SignalEntity", req)
859946

860947
async def get_entity(self,
861948
entity_instance_id: EntityInstanceId,
862949
include_state: bool = True
863950
) -> Optional[EntityMetadata]:
864951
req = pb.GetEntityRequest(instanceId=str(entity_instance_id), includeState=include_state)
865952
self._logger.info(f"Getting entity '{entity_instance_id}'.")
866-
res: pb.GetEntityResponse = await self._stub.GetEntity(req)
953+
res: pb.GetEntityResponse = await self._invoke_unary("GetEntity", req)
867954
if not res.exists:
868955
return None
869956
if self._payload_store is not None:
@@ -882,7 +969,7 @@ async def get_all_entities(self,
882969

883970
while True:
884971
query_request = build_query_entities_req(entity_query, _continuation_token)
885-
resp: pb.QueryEntitiesResponse = await self._stub.QueryEntities(query_request)
972+
resp: pb.QueryEntitiesResponse = await self._invoke_unary("QueryEntities", query_request)
886973
if self._payload_store is not None:
887974
await payload_helpers.deexternalize_payloads_async(resp, self._payload_store)
888975
entities += [EntityMetadata.from_entity_metadata(entity, query_request.query.includeState) for entity in resp.entities]
@@ -908,7 +995,7 @@ async def clean_entity_storage(self,
908995
releaseOrphanedLocks=release_orphaned_locks,
909996
continuationToken=_continuation_token
910997
)
911-
resp: pb.CleanEntityStorageResponse = await self._stub.CleanEntityStorage(req)
998+
resp: pb.CleanEntityStorageResponse = await self._invoke_unary("CleanEntityStorage", req)
912999
empty_entities_removed += resp.emptyEntitiesRemoved
9131000
orphaned_locks_released += resp.orphanedLocksReleased
9141001

0 commit comments

Comments
 (0)