11import asyncio
22import grpc
3- from threading import Event
3+ from threading import Event , Timer
44from unittest .mock import MagicMock
55
66import pytest
77
88from durabletask .grpc_options import GrpcWorkerResiliencyOptions
99from durabletask .internal import orchestrator_service_pb2 as pb
10- from durabletask .worker import TaskHubGrpcWorker , _WorkItemStreamOutcome
10+ from durabletask .worker import (
11+ _AsyncWorkerManager ,
12+ ConcurrencyOptions ,
13+ TaskHubGrpcWorker ,
14+ _WorkItemStreamOutcome ,
15+ )
1116
1217
1318class FakeRpcError (grpc .RpcError ):
@@ -62,6 +67,9 @@ def __init__(self):
6267 self ._shutdown_event = asyncio .Event ()
6368 self .submissions : list [tuple [str , tuple ]] = []
6469
70+ def prepare_for_run (self ):
71+ self ._shutdown_event = asyncio .Event ()
72+
6573 async def run (self ):
6674 await self ._shutdown_event .wait ()
6775
@@ -88,16 +96,58 @@ def _complete_activity_request(req, stub, completion_token):
8896 )
8997
9098
91- def _make_activity_work_item () -> pb .WorkItem :
99+ def _make_activity_work_item (
100+ task_id : int = 1 ,
101+ completion_token : str = "token" ,
102+ instance_id : str = "instance-id" ,
103+ ) -> pb .WorkItem :
92104 return pb .WorkItem (
93105 activityRequest = pb .ActivityRequest (
94106 name = "test_activity" ,
95- taskId = 1 ,
96- orchestrationInstance = pb .OrchestrationInstance (instanceId = "instance-id" ),
107+ taskId = task_id ,
108+ orchestrationInstance = pb .OrchestrationInstance (instanceId = instance_id ),
97109 ),
98- completionToken = "token" ,
110+ completionToken = completion_token ,
111+ )
112+
113+
114+ async def _wait_for_condition (predicate , * , timeout : float = 2.0 ):
115+ loop = asyncio .get_running_loop ()
116+ deadline = loop .time () + timeout
117+ while not predicate ():
118+ if loop .time () >= deadline :
119+ raise AssertionError ("condition was not met before timeout" )
120+ await asyncio .sleep (0.01 )
121+
122+
123+ @pytest .mark .asyncio
124+ async def test_async_worker_manager_honors_shutdown_requested_before_run ():
125+ manager = _AsyncWorkerManager (
126+ ConcurrencyOptions (maximum_thread_pool_workers = 1 ),
127+ MagicMock (),
99128 )
100129
130+ manager .shutdown ()
131+ await asyncio .wait_for (manager .run (), timeout = 1.0 )
132+
133+
134+ def test_worker_start_clears_prior_shutdown_request ():
135+ worker = TaskHubGrpcWorker ()
136+ worker ._shutdown .set ()
137+ run_started = Event ()
138+
139+ async def fake_run_loop ():
140+ run_started .set ()
141+
142+ worker ._async_run_loop = fake_run_loop
143+ worker .start ()
144+ worker ._runLoop .join (timeout = 1.0 )
145+
146+ assert run_started .is_set () is True
147+ assert worker ._shutdown .is_set () is False
148+
149+ worker .stop ()
150+
101151
102152def test_worker_classifies_graceful_close_before_first_message ():
103153 worker = TaskHubGrpcWorker (
@@ -399,6 +449,148 @@ def create_stub(channel):
399449 created_channels [0 ].close .assert_called_once ()
400450
401451
452+ @pytest .mark .asyncio
453+ async def test_worker_shutdown_drains_real_manager_work_before_closing_retired_sdk_channel (monkeypatch ):
454+ worker = TaskHubGrpcWorker (
455+ concurrency_options = ConcurrencyOptions (
456+ maximum_concurrent_activity_work_items = 1 ,
457+ maximum_thread_pool_workers = 1 ,
458+ )
459+ )
460+ worker ._execute_activity = _complete_activity_request
461+ monkeypatch .setattr (worker ._shutdown , "wait" , lambda timeout : False )
462+
463+ created_channels = []
464+
465+ def get_grpc_channel (* args , ** kwargs ):
466+ channel = MagicMock (name = f"channel-{ len (created_channels ) + 1 } " )
467+ created_channels .append (channel )
468+ return channel
469+
470+ allow_first_completion = Event ()
471+ first_completion_started = Event ()
472+ completed_task_ids = []
473+ first_stub = MagicMock ()
474+ first_stub .GetWorkItems .return_value = FakeResponseStream (items = [
475+ _make_activity_work_item (task_id = 1 , completion_token = "token-1" ),
476+ _make_activity_work_item (task_id = 2 , completion_token = "token-2" ),
477+ ])
478+
479+ def complete_activity (response ):
480+ completed_task_ids .append (response .taskId )
481+ if response .taskId == 1 :
482+ first_completion_started .set ()
483+ Timer (0.2 , allow_first_completion .set ).start ()
484+ assert allow_first_completion .wait (timeout = 5.0 )
485+ elif response .taskId == 2 :
486+ assert created_channels [0 ].close .call_count == 0
487+
488+ first_stub .CompleteActivityTask .side_effect = complete_activity
489+
490+ second_stub = MagicMock ()
491+ second_stub .GetWorkItems .side_effect = FakeRpcError (
492+ grpc .StatusCode .CANCELLED ,
493+ "stop" ,
494+ )
495+
496+ stubs = [first_stub , second_stub ]
497+ stub_channels = []
498+
499+ def create_stub (channel ):
500+ stub_channels .append (channel )
501+ return stubs .pop (0 )
502+
503+ monkeypatch .setattr ("durabletask.worker.shared.get_grpc_channel" , get_grpc_channel )
504+ monkeypatch .setattr ("durabletask.worker.stubs.TaskHubSidecarServiceStub" , create_stub )
505+
506+ run_task = asyncio .create_task (worker ._async_run_loop ())
507+ await asyncio .wait_for (run_task , timeout = 2.0 )
508+
509+ assert first_completion_started .is_set () is True
510+ assert len (created_channels ) == 2
511+ assert stub_channels == created_channels
512+ assert completed_task_ids == [1 , 2 ]
513+ created_channels [0 ].close .assert_called_once ()
514+ created_channels [1 ].close .assert_called_once ()
515+
516+
517+ @pytest .mark .asyncio
518+ async def test_worker_shutdown_runs_real_manager_cancellation_wrapper_before_closing_retired_sdk_channel (monkeypatch ):
519+ worker = TaskHubGrpcWorker (
520+ concurrency_options = ConcurrencyOptions (
521+ maximum_concurrent_activity_work_items = 1 ,
522+ maximum_thread_pool_workers = 1 ,
523+ )
524+ )
525+ monkeypatch .setattr (worker ._shutdown , "wait" , lambda timeout : False )
526+
527+ created_channels = []
528+
529+ def get_grpc_channel (* args , ** kwargs ):
530+ channel = MagicMock (name = f"channel-{ len (created_channels ) + 1 } " )
531+ created_channels .append (channel )
532+ return channel
533+
534+ allow_first_completion = Event ()
535+ first_completion_started = Event ()
536+ completed_task_ids = []
537+ cancelled_task_ids = []
538+
539+ def execute_activity (req , stub , completion_token ):
540+ if req .taskId == 1 :
541+ _complete_activity_request (req , stub , completion_token )
542+ else :
543+ raise RuntimeError ("boom" )
544+
545+ def cancel_activity (req , stub , completion_token ):
546+ cancelled_task_ids .append (req .taskId )
547+ assert created_channels [0 ].close .call_count == 0
548+
549+ worker ._execute_activity = execute_activity
550+ worker ._cancel_activity = cancel_activity
551+
552+ first_stub = MagicMock ()
553+ first_stub .GetWorkItems .return_value = FakeResponseStream (items = [
554+ _make_activity_work_item (task_id = 1 , completion_token = "token-1" ),
555+ _make_activity_work_item (task_id = 2 , completion_token = "token-2" ),
556+ ])
557+
558+ def complete_activity (response ):
559+ completed_task_ids .append (response .taskId )
560+ Timer (0.2 , allow_first_completion .set ).start ()
561+ first_completion_started .set ()
562+ assert allow_first_completion .wait (timeout = 5.0 )
563+
564+ first_stub .CompleteActivityTask .side_effect = complete_activity
565+
566+ second_stub = MagicMock ()
567+ second_stub .GetWorkItems .side_effect = FakeRpcError (
568+ grpc .StatusCode .CANCELLED ,
569+ "stop" ,
570+ )
571+
572+ stubs = [first_stub , second_stub ]
573+ stub_channels = []
574+
575+ def create_stub (channel ):
576+ stub_channels .append (channel )
577+ return stubs .pop (0 )
578+
579+ monkeypatch .setattr ("durabletask.worker.shared.get_grpc_channel" , get_grpc_channel )
580+ monkeypatch .setattr ("durabletask.worker.stubs.TaskHubSidecarServiceStub" , create_stub )
581+
582+ run_task = asyncio .create_task (worker ._async_run_loop ())
583+ await asyncio .wait_for (run_task , timeout = 2.0 )
584+
585+ assert first_completion_started .is_set () is True
586+ assert len (created_channels ) == 2
587+ assert stub_channels == created_channels
588+ assert completed_task_ids == [1 ]
589+ assert cancelled_task_ids == [2 ]
590+ created_channels [0 ].close .assert_called_once ()
591+ created_channels [1 ].close .assert_called_once ()
592+
593+
402594@pytest .mark .asyncio
403595async def test_worker_never_closes_caller_owned_channel_after_graceful_reset (monkeypatch ):
404596 provided_channel = MagicMock (name = "provided-channel" )
0 commit comments