22# Licensed under the MIT License.
33
44import logging
5+ import threading
6+ import time
57import uuid
68from dataclasses import dataclass
79from datetime import datetime
2527import durabletask .internal .shared as shared
2628import durabletask .internal .tracing as tracing
2729from durabletask import task
30+ from durabletask .internal .grpc_resiliency import (
31+ FailureTracker ,
32+ is_client_transport_failure ,
33+ )
2834from durabletask .internal .client_helpers import (
2935 build_query_entities_req ,
3036 build_query_instances_req ,
@@ -201,10 +207,61 @@ def __init__(self, *,
201207 )
202208 self ._channel = channel
203209 self ._stub = stubs .TaskHubSidecarServiceStub (channel )
210+ self ._client_failure_tracker = FailureTracker (
211+ self ._resiliency_options .channel_recreate_failure_threshold
212+ )
213+ self ._last_recreate_time = 0.0
214+ self ._recreate_lock = threading .Lock ()
204215 self ._logger = shared .get_logger ("client" , log_handler , log_formatter )
205216 self .default_version = default_version
206217 self ._payload_store = payload_store
207218
219+ def _invoke_unary (
220+ self ,
221+ method_name : str ,
222+ request : Any ,
223+ * ,
224+ timeout : Optional [int ] = None ):
225+ method = getattr (self ._stub , method_name )
226+ try :
227+ if timeout is None :
228+ response = method (request )
229+ else :
230+ response = method (request , timeout = timeout )
231+ except grpc .RpcError as rpc_error :
232+ status_code = rpc_error .code ()
233+ if is_client_transport_failure (method_name , status_code ):
234+ should_recreate = self ._client_failure_tracker .record_failure ()
235+ if should_recreate :
236+ self ._maybe_recreate_channel ()
237+ elif status_code != grpc .StatusCode .DEADLINE_EXCEEDED :
238+ self ._client_failure_tracker .record_success ()
239+ raise
240+ else :
241+ self ._client_failure_tracker .record_success ()
242+ return response
243+
244+ def _maybe_recreate_channel (self ) -> None :
245+ if not self ._owns_channel :
246+ return
247+ with self ._recreate_lock :
248+ now = time .monotonic ()
249+ if now - self ._last_recreate_time < self ._resiliency_options .min_recreate_interval_seconds :
250+ return
251+ old_channel = self ._channel
252+ self ._channel = shared .get_grpc_channel (
253+ host_address = self ._host_address ,
254+ secure_channel = self ._secure_channel ,
255+ interceptors = self ._interceptors ,
256+ channel_options = self ._channel_options ,
257+ )
258+ self ._stub = stubs .TaskHubSidecarServiceStub (self ._channel )
259+ self ._last_recreate_time = now
260+ self ._client_failure_tracker .record_success ()
261+ close_timer = threading .Timer (30.0 , old_channel .close )
262+ close_timer .daemon = True
263+ close_timer .start ()
264+
208265 def close (self ) -> None :
209266 """Closes the underlying gRPC channel.
210267
@@ -249,12 +306,12 @@ def schedule_new_orchestration(self, orchestrator: Union[task.Orchestrator[TInpu
249306 payload_helpers .externalize_payloads (
250307 req , self ._payload_store , instance_id = req .instanceId ,
251308 )
252- res : pb .CreateInstanceResponse = self ._stub . StartInstance ( req )
309+ res : pb .CreateInstanceResponse = self ._invoke_unary ( "StartInstance" , req )
253310 return res .instanceId
254311
255312 def get_orchestration_state (self , instance_id : str , * , fetch_payloads : bool = True ) -> Optional [OrchestrationState ]:
256313 req = pb .GetInstanceRequest (instanceId = instance_id , getInputsAndOutputs = fetch_payloads )
257- res : pb .GetInstanceResponse = self ._stub . GetInstance ( req )
314+ res : pb .GetInstanceResponse = self ._invoke_unary ( "GetInstance" , req )
258315 # De-externalize any large-payload tokens in the response
259316 if self ._payload_store is not None and res .exists :
260317 payload_helpers .deexternalize_payloads (res , self ._payload_store )
@@ -294,7 +351,7 @@ def list_instance_ids(self,
294351 f"page_size={ page_size } , "
295352 f"continuation_token={ continuation_token } "
296353 )
297- resp : pb .ListInstanceIdsResponse = self ._stub . ListInstanceIds ( req )
354+ resp : pb .ListInstanceIdsResponse = self ._invoke_unary ( "ListInstanceIds" , req )
298355 next_token = resp .lastInstanceKey .value if resp .HasField ("lastInstanceKey" ) else None
299356 return Page (items = list (resp .instanceIds ), continuation_token = next_token )
300357
@@ -311,7 +368,7 @@ def get_all_orchestration_states(self,
311368
312369 while True :
313370 req = build_query_instances_req (orchestration_query , _continuation_token )
314- resp : pb .QueryInstancesResponse = self ._stub . QueryInstances ( req )
371+ resp : pb .QueryInstancesResponse = self ._invoke_unary ( "QueryInstances" , req )
315372 if self ._payload_store is not None :
316373 payload_helpers .deexternalize_payloads (resp , self ._payload_store )
317374 states += [parse_orchestration_state (res ) for res in resp .orchestrationState ]
@@ -328,7 +385,11 @@ def wait_for_orchestration_start(self, instance_id: str, *,
328385 req = pb .GetInstanceRequest (instanceId = instance_id , getInputsAndOutputs = fetch_payloads )
329386 try :
330387 self ._logger .info (f"Waiting up to { timeout } s for instance '{ instance_id } ' to start." )
331- res : pb .GetInstanceResponse = self ._stub .WaitForInstanceStart (req , timeout = timeout )
388+ res : pb .GetInstanceResponse = self ._invoke_unary (
389+ "WaitForInstanceStart" ,
390+ req ,
391+ timeout = timeout ,
392+ )
332393 if self ._payload_store is not None and res .exists :
333394 payload_helpers .deexternalize_payloads (res , self ._payload_store )
334395 return new_orchestration_state (req .instanceId , res )
@@ -345,7 +406,11 @@ def wait_for_orchestration_completion(self, instance_id: str, *,
345406 req = pb .GetInstanceRequest (instanceId = instance_id , getInputsAndOutputs = fetch_payloads )
346407 try :
347408 self ._logger .info (f"Waiting { timeout } s for instance '{ instance_id } ' to complete." )
348- res : pb .GetInstanceResponse = self ._stub .WaitForInstanceCompletion (req , timeout = timeout )
409+ res : pb .GetInstanceResponse = self ._invoke_unary (
410+ "WaitForInstanceCompletion" ,
411+ req ,
412+ timeout = timeout ,
413+ )
349414 if self ._payload_store is not None and res .exists :
350415 payload_helpers .deexternalize_payloads (res , self ._payload_store )
351416 state = new_orchestration_state (req .instanceId , res )
@@ -366,7 +431,7 @@ def raise_orchestration_event(self, instance_id: str, event_name: str, *,
366431 payload_helpers .externalize_payloads (
367432 req , self ._payload_store , instance_id = instance_id ,
368433 )
369- self ._stub . RaiseEvent ( req )
434+ self ._invoke_unary ( "RaiseEvent" , req )
370435
371436 def terminate_orchestration (self , instance_id : str , * ,
372437 output : Optional [Any ] = None ,
@@ -378,17 +443,17 @@ def terminate_orchestration(self, instance_id: str, *,
378443 payload_helpers .externalize_payloads (
379444 req , self ._payload_store , instance_id = instance_id ,
380445 )
381- self ._stub . TerminateInstance ( req )
446+ self ._invoke_unary ( "TerminateInstance" , req )
382447
383448 def suspend_orchestration (self , instance_id : str ) -> None :
384449 req = pb .SuspendRequest (instanceId = instance_id )
385450 self ._logger .info (f"Suspending instance '{ instance_id } '." )
386- self ._stub . SuspendInstance ( req )
451+ self ._invoke_unary ( "SuspendInstance" , req )
387452
388453 def resume_orchestration (self , instance_id : str ) -> None :
389454 req = pb .ResumeRequest (instanceId = instance_id )
390455 self ._logger .info (f"Resuming instance '{ instance_id } '." )
391- self ._stub . ResumeInstance ( req )
456+ self ._invoke_unary ( "ResumeInstance" , req )
392457
393458 def restart_orchestration (self , instance_id : str , * ,
394459 restart_with_new_instance_id : bool = False ) -> str :
@@ -407,13 +472,13 @@ def restart_orchestration(self, instance_id: str, *,
407472 restartWithNewInstanceId = restart_with_new_instance_id )
408473
409474 self ._logger .info (f"Restarting instance '{ instance_id } '." )
410- res : pb .RestartInstanceResponse = self ._stub . RestartInstance ( req )
475+ res : pb .RestartInstanceResponse = self ._invoke_unary ( "RestartInstance" , req )
411476 return res .instanceId
412477
413478 def purge_orchestration (self , instance_id : str , recursive : bool = True ) -> PurgeInstancesResult :
414479 req = pb .PurgeInstancesRequest (instanceId = instance_id , recursive = recursive )
415480 self ._logger .info (f"Purging instance '{ instance_id } '." )
416- resp : pb .PurgeInstancesResponse = self ._stub . PurgeInstances ( req )
481+ resp : pb .PurgeInstancesResponse = self ._invoke_unary ( "PurgeInstances" , req )
417482 return PurgeInstancesResult (resp .deletedInstanceCount , resp .isComplete .value )
418483
419484 def purge_orchestrations_by (self ,
@@ -427,7 +492,7 @@ def purge_orchestrations_by(self,
427492 f"runtime_status={ [str (status ) for status in runtime_status ] if runtime_status else None } , "
428493 f"recursive={ recursive } " )
429494 req = build_purge_by_filter_req (created_time_from , created_time_to , runtime_status , recursive )
430- resp : pb .PurgeInstancesResponse = self ._stub . PurgeInstances ( req )
495+ resp : pb .PurgeInstancesResponse = self ._invoke_unary ( "PurgeInstances" , req )
431496 return PurgeInstancesResult (resp .deletedInstanceCount , resp .isComplete .value )
432497
433498 def signal_entity (self ,
@@ -440,15 +505,15 @@ def signal_entity(self,
440505 payload_helpers .externalize_payloads (
441506 req , self ._payload_store , instance_id = str (entity_instance_id ),
442507 )
443- self ._stub . SignalEntity ( req , None ) # TODO: Cancellation timeout?
508+ self ._invoke_unary ( "SignalEntity" , req ) # TODO: Cancellation timeout?
444509
445510 def get_entity (self ,
446511 entity_instance_id : EntityInstanceId ,
447512 include_state : bool = True
448513 ) -> Optional [EntityMetadata ]:
449514 req = pb .GetEntityRequest (instanceId = str (entity_instance_id ), includeState = include_state )
450515 self ._logger .info (f"Getting entity '{ entity_instance_id } '." )
451- res : pb .GetEntityResponse = self ._stub . GetEntity ( req )
516+ res : pb .GetEntityResponse = self ._invoke_unary ( "GetEntity" , req )
452517 if not res .exists :
453518 return None
454519 if self ._payload_store is not None :
@@ -467,7 +532,7 @@ def get_all_entities(self,
467532
468533 while True :
469534 query_request = build_query_entities_req (entity_query , _continuation_token )
470- resp : pb .QueryEntitiesResponse = self ._stub . QueryEntities ( query_request )
535+ resp : pb .QueryEntitiesResponse = self ._invoke_unary ( "QueryEntities" , query_request )
471536 if self ._payload_store is not None :
472537 payload_helpers .deexternalize_payloads (resp , self ._payload_store )
473538 entities += [EntityMetadata .from_entity_metadata (entity , query_request .query .includeState ) for entity in resp .entities ]
@@ -493,7 +558,7 @@ def clean_entity_storage(self,
493558 releaseOrphanedLocks = release_orphaned_locks ,
494559 continuationToken = _continuation_token
495560 )
496- resp : pb .CleanEntityStorageResponse = self ._stub . CleanEntityStorage ( req )
561+ resp : pb .CleanEntityStorageResponse = self ._invoke_unary ( "CleanEntityStorage" , req )
497562 empty_entities_removed += resp .emptyEntitiesRemoved
498563 orphaned_locks_released += resp .orphanedLocksReleased
499564
0 commit comments