11# Copyright (c) Microsoft Corporation.
22# Licensed under the MIT License.
33
4+ import asyncio
45import logging
56import threading
67import time
@@ -614,6 +615,14 @@ def __init__(self, *,
614615 )
615616 self ._channel = channel
616617 self ._stub = stubs .TaskHubSidecarServiceStub (channel )
618+ self ._client_failure_tracker = FailureTracker (
619+ self ._resiliency_options .channel_recreate_failure_threshold
620+ )
621+ self ._closing = False
622+ self ._recreate_lock = asyncio .Lock ()
623+ self ._last_recreate_time = 0.0
624+ self ._retired_channels : list [grpc .aio .Channel ] = []
625+ self ._retired_channel_close_tasks : set [asyncio .Task [None ]] = set ()
617626 self ._logger = shared .get_logger ("async_client" , log_handler , log_formatter )
618627 self .default_version = default_version
619628 self ._payload_store = payload_store
@@ -627,6 +636,18 @@ async def close(self) -> None:
627636 it.
628637 """
629638 if self ._owns_channel :
639+ self ._closing = True
640+ async with self ._recreate_lock :
641+ retired_channels = list (self ._retired_channels )
642+ self ._retired_channels .clear ()
643+ close_tasks = list (self ._retired_channel_close_tasks )
644+ self ._retired_channel_close_tasks .clear ()
645+ for close_task in close_tasks :
646+ close_task .cancel ()
647+ if close_tasks :
648+ await asyncio .gather (* close_tasks , return_exceptions = True )
649+ for retired_channel in retired_channels :
650+ await retired_channel .close ()
630651 await self ._channel .close ()
631652
632653 async def __aenter__ (self ):
@@ -635,6 +656,64 @@ async def __aenter__(self):
635656 async def __aexit__ (self , exc_type , exc_val , exc_tb ):
636657 await self .close ()
637658
659+ async def _invoke_unary (
660+ self ,
661+ method_name : str ,
662+ request : Any ,
663+ * ,
664+ timeout : Optional [int ] = None ):
665+ method = getattr (self ._stub , method_name )
666+ try :
667+ if timeout is None :
668+ response = await method (request )
669+ else :
670+ response = await method (request , timeout = timeout )
671+ except grpc .aio .AioRpcError as rpc_error :
672+ if is_client_transport_failure (method_name , rpc_error .code ()):
673+ should_recreate = self ._client_failure_tracker .record_failure ()
674+ if should_recreate :
675+ await self ._maybe_recreate_channel ()
676+ else :
677+ self ._client_failure_tracker .record_success ()
678+ raise
679+ else :
680+ self ._client_failure_tracker .record_success ()
681+ return response
682+
683+ async def _maybe_recreate_channel (self ) -> None :
684+ if not self ._owns_channel or self ._closing :
685+ return
686+ async with self ._recreate_lock :
687+ if self ._closing :
688+ return
689+ now = time .monotonic ()
690+ if now - self ._last_recreate_time < self ._resiliency_options .min_recreate_interval_seconds :
691+ return
692+ old_channel = self ._channel
693+ self ._channel = shared .get_async_grpc_channel (
694+ host_address = self ._host_address ,
695+ secure_channel = self ._secure_channel ,
696+ interceptors = self ._interceptors ,
697+ channel_options = self ._channel_options ,
698+ )
699+ self ._stub = stubs .TaskHubSidecarServiceStub (self ._channel )
700+ self ._last_recreate_time = now
701+ self ._client_failure_tracker .record_success ()
702+ self ._retired_channels .append (old_channel )
703+ close_task = asyncio .create_task (self ._close_retired_channel (old_channel ))
704+ self ._retired_channel_close_tasks .add (close_task )
705+ close_task .add_done_callback (self ._retired_channel_close_tasks .discard )
706+
707+ async def _close_retired_channel (self , channel : grpc .aio .Channel ) -> None :
708+ try :
709+ await asyncio .sleep (30.0 )
710+ await channel .close ()
711+ finally :
712+ try :
713+ self ._retired_channels .remove (channel )
714+ except ValueError :
715+ pass
716+
638717 async def schedule_new_orchestration (self , orchestrator : Union [task .Orchestrator [TInput , TOutput ], str ], * ,
639718 input : Optional [TInput ] = None ,
640719 instance_id : Optional [str ] = None ,
@@ -665,13 +744,13 @@ async def schedule_new_orchestration(self, orchestrator: Union[task.Orchestrator
665744 await payload_helpers .externalize_payloads_async (
666745 req , self ._payload_store , instance_id = req .instanceId ,
667746 )
668- res : pb .CreateInstanceResponse = await self ._stub . StartInstance ( req )
747+ res : pb .CreateInstanceResponse = await self ._invoke_unary ( "StartInstance" , req )
669748 return res .instanceId
670749
671750 async def get_orchestration_state (self , instance_id : str , * ,
672751 fetch_payloads : bool = True ) -> Optional [OrchestrationState ]:
673752 req = pb .GetInstanceRequest (instanceId = instance_id , getInputsAndOutputs = fetch_payloads )
674- res : pb .GetInstanceResponse = await self ._stub . GetInstance ( req )
753+ res : pb .GetInstanceResponse = await self ._invoke_unary ( "GetInstance" , req )
675754 if self ._payload_store is not None and res .exists :
676755 await payload_helpers .deexternalize_payloads_async (res , self ._payload_store )
677756 return new_orchestration_state (req .instanceId , res )
@@ -710,7 +789,7 @@ async def list_instance_ids(self,
710789 f"page_size={ page_size } , "
711790 f"continuation_token={ continuation_token } "
712791 )
713- resp : pb .ListInstanceIdsResponse = await self ._stub . ListInstanceIds ( req )
792+ resp : pb .ListInstanceIdsResponse = await self ._invoke_unary ( "ListInstanceIds" , req )
714793 next_token = resp .lastInstanceKey .value if resp .HasField ("lastInstanceKey" ) else None
715794 return Page (items = list (resp .instanceIds ), continuation_token = next_token )
716795
@@ -727,7 +806,7 @@ async def get_all_orchestration_states(self,
727806
728807 while True :
729808 req = build_query_instances_req (orchestration_query , _continuation_token )
730- resp : pb .QueryInstancesResponse = await self ._stub . QueryInstances ( req )
809+ resp : pb .QueryInstancesResponse = await self ._invoke_unary ( "QueryInstances" , req )
731810 if self ._payload_store is not None :
732811 await payload_helpers .deexternalize_payloads_async (resp , self ._payload_store )
733812 states += [parse_orchestration_state (res ) for res in resp .orchestrationState ]
@@ -744,7 +823,11 @@ async def wait_for_orchestration_start(self, instance_id: str, *,
744823 req = pb .GetInstanceRequest (instanceId = instance_id , getInputsAndOutputs = fetch_payloads )
745824 try :
746825 self ._logger .info (f"Waiting up to { timeout } s for instance '{ instance_id } ' to start." )
747- res : pb .GetInstanceResponse = await self ._stub .WaitForInstanceStart (req , timeout = timeout )
826+ res : pb .GetInstanceResponse = await self ._invoke_unary (
827+ "WaitForInstanceStart" ,
828+ req ,
829+ timeout = timeout ,
830+ )
748831 if self ._payload_store is not None and res .exists :
749832 await payload_helpers .deexternalize_payloads_async (res , self ._payload_store )
750833 return new_orchestration_state (req .instanceId , res )
@@ -760,7 +843,11 @@ async def wait_for_orchestration_completion(self, instance_id: str, *,
760843 req = pb .GetInstanceRequest (instanceId = instance_id , getInputsAndOutputs = fetch_payloads )
761844 try :
762845 self ._logger .info (f"Waiting { timeout } s for instance '{ instance_id } ' to complete." )
763- res : pb .GetInstanceResponse = await self ._stub .WaitForInstanceCompletion (req , timeout = timeout )
846+ res : pb .GetInstanceResponse = await self ._invoke_unary (
847+ "WaitForInstanceCompletion" ,
848+ req ,
849+ timeout = timeout ,
850+ )
764851 if self ._payload_store is not None and res .exists :
765852 await payload_helpers .deexternalize_payloads_async (res , self ._payload_store )
766853 state = new_orchestration_state (req .instanceId , res )
@@ -781,7 +868,7 @@ async def raise_orchestration_event(self, instance_id: str, event_name: str, *,
781868 await payload_helpers .externalize_payloads_async (
782869 req , self ._payload_store , instance_id = instance_id ,
783870 )
784- await self ._stub . RaiseEvent ( req )
871+ await self ._invoke_unary ( "RaiseEvent" , req )
785872
786873 async def terminate_orchestration (self , instance_id : str , * ,
787874 output : Optional [Any ] = None ,
@@ -793,17 +880,17 @@ async def terminate_orchestration(self, instance_id: str, *,
793880 await payload_helpers .externalize_payloads_async (
794881 req , self ._payload_store , instance_id = instance_id ,
795882 )
796- await self ._stub . TerminateInstance ( req )
883+ await self ._invoke_unary ( "TerminateInstance" , req )
797884
798885 async def suspend_orchestration (self , instance_id : str ) -> None :
799886 req = pb .SuspendRequest (instanceId = instance_id )
800887 self ._logger .info (f"Suspending instance '{ instance_id } '." )
801- await self ._stub . SuspendInstance ( req )
888+ await self ._invoke_unary ( "SuspendInstance" , req )
802889
803890 async def resume_orchestration (self , instance_id : str ) -> None :
804891 req = pb .ResumeRequest (instanceId = instance_id )
805892 self ._logger .info (f"Resuming instance '{ instance_id } '." )
806- await self ._stub . ResumeInstance ( req )
893+ await self ._invoke_unary ( "ResumeInstance" , req )
807894
808895 async def restart_orchestration (self , instance_id : str , * ,
809896 restart_with_new_instance_id : bool = False ) -> str :
@@ -822,13 +909,13 @@ async def restart_orchestration(self, instance_id: str, *,
822909 restartWithNewInstanceId = restart_with_new_instance_id )
823910
824911 self ._logger .info (f"Restarting instance '{ instance_id } '." )
825- res : pb .RestartInstanceResponse = await self ._stub . RestartInstance ( req )
912+ res : pb .RestartInstanceResponse = await self ._invoke_unary ( "RestartInstance" , req )
826913 return res .instanceId
827914
828915 async def purge_orchestration (self , instance_id : str , recursive : bool = True ) -> PurgeInstancesResult :
829916 req = pb .PurgeInstancesRequest (instanceId = instance_id , recursive = recursive )
830917 self ._logger .info (f"Purging instance '{ instance_id } '." )
831- resp : pb .PurgeInstancesResponse = await self ._stub . PurgeInstances ( req )
918+ resp : pb .PurgeInstancesResponse = await self ._invoke_unary ( "PurgeInstances" , req )
832919 return PurgeInstancesResult (resp .deletedInstanceCount , resp .isComplete .value )
833920
834921 async def purge_orchestrations_by (self ,
@@ -842,7 +929,7 @@ async def purge_orchestrations_by(self,
842929 f"runtime_status={ [str (status ) for status in runtime_status ] if runtime_status else None } , "
843930 f"recursive={ recursive } " )
844931 req = build_purge_by_filter_req (created_time_from , created_time_to , runtime_status , recursive )
845- resp : pb .PurgeInstancesResponse = await self ._stub . PurgeInstances ( req )
932+ resp : pb .PurgeInstancesResponse = await self ._invoke_unary ( "PurgeInstances" , req )
846933 return PurgeInstancesResult (resp .deletedInstanceCount , resp .isComplete .value )
847934
848935 async def signal_entity (self ,
@@ -855,15 +942,15 @@ async def signal_entity(self,
855942 await payload_helpers .externalize_payloads_async (
856943 req , self ._payload_store , instance_id = str (entity_instance_id ),
857944 )
858- await self ._stub . SignalEntity ( req , None )
945+ await self ._invoke_unary ( "SignalEntity" , req )
859946
860947 async def get_entity (self ,
861948 entity_instance_id : EntityInstanceId ,
862949 include_state : bool = True
863950 ) -> Optional [EntityMetadata ]:
864951 req = pb .GetEntityRequest (instanceId = str (entity_instance_id ), includeState = include_state )
865952 self ._logger .info (f"Getting entity '{ entity_instance_id } '." )
866- res : pb .GetEntityResponse = await self ._stub . GetEntity ( req )
953+ res : pb .GetEntityResponse = await self ._invoke_unary ( "GetEntity" , req )
867954 if not res .exists :
868955 return None
869956 if self ._payload_store is not None :
@@ -882,7 +969,7 @@ async def get_all_entities(self,
882969
883970 while True :
884971 query_request = build_query_entities_req (entity_query , _continuation_token )
885- resp : pb .QueryEntitiesResponse = await self ._stub . QueryEntities ( query_request )
972+ resp : pb .QueryEntitiesResponse = await self ._invoke_unary ( "QueryEntities" , query_request )
886973 if self ._payload_store is not None :
887974 await payload_helpers .deexternalize_payloads_async (resp , self ._payload_store )
888975 entities += [EntityMetadata .from_entity_metadata (entity , query_request .query .includeState ) for entity in resp .entities ]
@@ -908,7 +995,7 @@ async def clean_entity_storage(self,
908995 releaseOrphanedLocks = release_orphaned_locks ,
909996 continuationToken = _continuation_token
910997 )
911- resp : pb .CleanEntityStorageResponse = await self ._stub . CleanEntityStorage ( req )
998+ resp : pb .CleanEntityStorageResponse = await self ._invoke_unary ( "CleanEntityStorage" , req )
912999 empty_entities_removed += resp .emptyEntitiesRemoved
9131000 orphaned_locks_released += resp .orphanedLocksReleased
9141001
0 commit comments