2121import time
2222import traceback
2323import uuid
24- from collections .abc import Callable , Iterable
24+ from collections .abc import Awaitable , Callable , Iterable
2525from typing import Any
2626
2727from . import serializer
3535 WorkflowExecution ,
3636)
3737from .errors import ActivityCancelled , AvroNotInstalledError , NonRetryableError , QueryFailed
38+ from .interceptors import (
39+ ActivityInterceptorContext ,
40+ QueryTaskInterceptorContext ,
41+ WorkerInterceptor ,
42+ WorkflowTaskInterceptorContext ,
43+ )
3844from .metrics import (
3945 NOOP_METRICS ,
4046 WORKER_POLL_DURATION_SECONDS ,
@@ -151,6 +157,7 @@ def __init__(
151157 max_concurrent_activity_tasks : int = 10 ,
152158 shutdown_timeout : float = 30.0 ,
153159 metrics : MetricsRecorder | None = None ,
160+ interceptors : Iterable [WorkerInterceptor ] = (),
154161 ) -> None :
155162 self .client = client
156163 self .task_queue = task_queue
@@ -166,6 +173,7 @@ def __init__(
166173 self ._query_tasks_supported = False
167174 configured_metrics = metrics if metrics is not None else getattr (client , "metrics" , NOOP_METRICS )
168175 self .metrics : MetricsRecorder = configured_metrics or NOOP_METRICS
176+ self .interceptors = tuple (interceptors )
169177
170178 def _record_poll_metrics (self , task_kind : str , outcome : str , duration : float ) -> None :
171179 tags = {"task_kind" : task_kind , "task_queue" : self .task_queue , "outcome" : outcome }
@@ -201,6 +209,35 @@ async def _register(self) -> None:
201209 log .info ("worker %s registered on %s" , self .worker_id , self .task_queue )
202210
203211 async def _run_workflow_task (self , task : dict [str , Any ]) -> list [dict [str , Any ]] | None :
212+ context = WorkflowTaskInterceptorContext (
213+ worker_id = self .worker_id ,
214+ task_queue = self .task_queue ,
215+ task = task ,
216+ )
217+
218+ async def call_core (ctx : WorkflowTaskInterceptorContext ) -> list [dict [str , Any ]] | None :
219+ return await self ._run_workflow_task_core (ctx .task )
220+
221+ handler = call_core
222+ for interceptor in reversed (self .interceptors ):
223+ next_handler = handler
224+
225+ async def call_interceptor (
226+ ctx : WorkflowTaskInterceptorContext ,
227+ * ,
228+ interceptor : WorkerInterceptor = interceptor ,
229+ next_handler : Callable [
230+ [WorkflowTaskInterceptorContext ],
231+ Awaitable [list [dict [str , Any ]] | None ],
232+ ] = next_handler ,
233+ ) -> list [dict [str , Any ]] | None :
234+ return await interceptor .execute_workflow_task (ctx , next_handler )
235+
236+ handler = call_interceptor
237+
238+ return await handler (context )
239+
240+ async def _run_workflow_task_core (self , task : dict [str , Any ]) -> list [dict [str , Any ]] | None :
204241 import json as _json
205242
206243 log .debug ("workflow task payload: %s" , _json .dumps (task , default = str )[:2000 ])
@@ -446,7 +483,7 @@ async def _run_activity_task(self, task: dict[str, Any]) -> str:
446483 )
447484 _set_context (act_ctx )
448485 try :
449- result = fn ( * args ) if not asyncio . iscoroutinefunction ( fn ) else await fn ( * args )
486+ result = await self . _execute_activity_callable ( task , activity_type , tuple ( args ), fn )
450487 except ActivityCancelled :
451488 log .info ("activity %s cancelled via heartbeat" , task_id )
452489 try :
@@ -508,7 +545,70 @@ async def _run_activity_task(self, task: dict[str, Any]) -> str:
508545 return "complete_error"
509546 return "completed"
510547
548+ async def _execute_activity_callable (
549+ self ,
550+ task : dict [str , Any ],
551+ activity_type : str ,
552+ args : tuple [Any , ...],
553+ fn : Callable [..., Any ],
554+ ) -> Any :
555+ context = ActivityInterceptorContext (
556+ worker_id = self .worker_id ,
557+ task_queue = self .task_queue ,
558+ task = task ,
559+ activity_type = activity_type ,
560+ args = args ,
561+ )
562+
563+ async def call_activity (ctx : ActivityInterceptorContext ) -> Any :
564+ result = fn (* ctx .args )
565+ if asyncio .iscoroutine (result ):
566+ return await result
567+ return result
568+
569+ handler = call_activity
570+ for interceptor in reversed (self .interceptors ):
571+ next_handler = handler
572+
573+ async def call_interceptor (
574+ ctx : ActivityInterceptorContext ,
575+ * ,
576+ interceptor : WorkerInterceptor = interceptor ,
577+ next_handler : Callable [[ActivityInterceptorContext ], Awaitable [Any ]] = next_handler ,
578+ ) -> Any :
579+ return await interceptor .execute_activity (ctx , next_handler )
580+
581+ handler = call_interceptor
582+
583+ return await handler (context )
584+
511585 async def _run_query_task (self , task : dict [str , Any ]) -> str :
586+ context = QueryTaskInterceptorContext (
587+ worker_id = self .worker_id ,
588+ task_queue = self .task_queue ,
589+ task = task ,
590+ )
591+
592+ async def call_core (ctx : QueryTaskInterceptorContext ) -> str :
593+ return await self ._run_query_task_core (ctx .task )
594+
595+ handler = call_core
596+ for interceptor in reversed (self .interceptors ):
597+ next_handler = handler
598+
599+ async def call_interceptor (
600+ ctx : QueryTaskInterceptorContext ,
601+ * ,
602+ interceptor : WorkerInterceptor = interceptor ,
603+ next_handler : Callable [[QueryTaskInterceptorContext ], Awaitable [str ]] = next_handler ,
604+ ) -> str :
605+ return await interceptor .execute_query_task (ctx , next_handler )
606+
607+ handler = call_interceptor
608+
609+ return await handler (context )
610+
611+ async def _run_query_task_core (self , task : dict [str , Any ]) -> str :
512612 query_task_id : str = task ["query_task_id" ]
513613 attempt : int = task .get ("query_task_attempt" , 1 )
514614 wf_type : str = task .get ("workflow_type" , "" )
0 commit comments