Skip to content

Commit 9d1dbe8

Browse files
Add Python worker interceptors
Issue: zorporation/durable-workflow#451 Loop-ID: build-03
1 parent 10f4995 commit 9d1dbe8

5 files changed

Lines changed: 411 additions & 2 deletions

File tree

README.md

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,7 @@ advertise support for the target workflow type.
141141
- **HTTP/JSON protocol**: No gRPC, no protobuf dependencies
142142
- **Codec envelopes**: Avro payloads by default, with JSON decode compatibility for existing history
143143
- **Payload-size warnings**: Structured warnings before oversized workflow, activity, signal, update, query, or search-attribute payloads reach the server
144+
- **Worker interceptors**: Typed hooks around workflow tasks, activity calls, and query tasks for tracing, logging, and custom metrics
144145
- **Metrics hooks**: Pluggable counters and histograms, with an optional Prometheus adapter
145146

146147
## Payload-size warnings
@@ -207,6 +208,44 @@ client = Client("http://server:8080", token="dev-token-123", metrics=metrics)
207208

208209
Custom recorders implement `increment(name, value=1.0, tags=None)` and `record(name, value, tags=None)`.
209210

211+
## Worker interceptors
212+
213+
Use `Worker(interceptors=[...])` when instrumentation needs the task payload,
214+
result, or exception around worker execution instead of only aggregate counters.
215+
Interceptors run in list order; the first interceptor is the outermost wrapper.
216+
217+
```python
218+
from durable_workflow import (
219+
ActivityInterceptorContext,
220+
ActivityHandler,
221+
PassthroughWorkerInterceptor,
222+
Worker,
223+
)
224+
225+
class LoggingInterceptor(PassthroughWorkerInterceptor):
226+
async def execute_activity(
227+
self,
228+
context: ActivityInterceptorContext,
229+
next: ActivityHandler,
230+
) -> object:
231+
print("activity started", context.activity_type)
232+
try:
233+
result = await next(context)
234+
except Exception:
235+
print("activity failed", context.activity_type)
236+
raise
237+
print("activity completed", context.activity_type)
238+
return result
239+
240+
worker = Worker(
241+
client,
242+
task_queue="python-workers",
243+
workflows=[GreeterWorkflow],
244+
activities=[greet],
245+
interceptors=[LoggingInterceptor()],
246+
)
247+
```
248+
210249
## Documentation
211250

212251
Full documentation is available at

src/durable_workflow/__init__.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,16 @@
4040
WorkflowNotFound,
4141
WorkflowTerminated,
4242
)
43+
from .interceptors import (
44+
ActivityHandler,
45+
ActivityInterceptorContext,
46+
PassthroughWorkerInterceptor,
47+
QueryTaskHandler,
48+
QueryTaskInterceptorContext,
49+
WorkerInterceptor,
50+
WorkflowTaskHandler,
51+
WorkflowTaskInterceptorContext,
52+
)
4353
from .metrics import (
4454
InMemoryMetrics,
4555
MetricsRecorder,
@@ -55,7 +65,9 @@
5565
"__version__",
5666
"ActivityCancelled",
5767
"ActivityContext",
68+
"ActivityHandler",
5869
"ActivityInfo",
70+
"ActivityInterceptorContext",
5971
"ActivityRetryPolicy",
6072
"ChildWorkflowRetryPolicy",
6173
"ChildWorkflowFailed",
@@ -73,7 +85,10 @@
7385
"ScheduleTriggerResult",
7486
"StartChildWorkflow",
7587
"Worker",
88+
"WorkerInterceptor",
7689
"WorkflowExecution",
90+
"WorkflowTaskHandler",
91+
"WorkflowTaskInterceptorContext",
7792
"WorkflowHandle",
7893
"WorkflowList",
7994
"activity",
@@ -90,6 +105,9 @@
90105
"PrometheusMetrics",
91106
"PayloadSizeWarningConfig",
92107
"PayloadSizeWarningContext",
108+
"PassthroughWorkerInterceptor",
109+
"QueryTaskHandler",
110+
"QueryTaskInterceptorContext",
93111
"RetryPolicy",
94112
"ServerError",
95113
"TransportRetryPolicy",
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
"""Worker interceptor protocols for SDK instrumentation hooks."""
2+
3+
from __future__ import annotations
4+
5+
from collections.abc import Awaitable, Callable
6+
from dataclasses import dataclass
7+
from typing import Any, Protocol
8+
9+
10+
@dataclass(frozen=True)
11+
class WorkflowTaskInterceptorContext:
12+
"""Context passed to workflow task interceptor hooks."""
13+
14+
worker_id: str
15+
task_queue: str
16+
task: dict[str, Any]
17+
18+
19+
@dataclass(frozen=True)
20+
class ActivityInterceptorContext:
21+
"""Context passed to activity execution interceptor hooks."""
22+
23+
worker_id: str
24+
task_queue: str
25+
task: dict[str, Any]
26+
activity_type: str
27+
args: tuple[Any, ...]
28+
29+
30+
@dataclass(frozen=True)
31+
class QueryTaskInterceptorContext:
32+
"""Context passed to query task interceptor hooks."""
33+
34+
worker_id: str
35+
task_queue: str
36+
task: dict[str, Any]
37+
38+
39+
WorkflowTaskHandler = Callable[[WorkflowTaskInterceptorContext], Awaitable[list[dict[str, Any]] | None]]
40+
ActivityHandler = Callable[[ActivityInterceptorContext], Awaitable[Any]]
41+
QueryTaskHandler = Callable[[QueryTaskInterceptorContext], Awaitable[str]]
42+
43+
44+
class WorkerInterceptor(Protocol):
45+
"""Protocol for wrapping worker task execution.
46+
47+
Interceptors run in the order passed to ``Worker(..., interceptors=[...])``.
48+
The first interceptor is the outermost wrapper and should call ``next`` to
49+
continue the chain.
50+
"""
51+
52+
async def execute_workflow_task(
53+
self,
54+
context: WorkflowTaskInterceptorContext,
55+
next: WorkflowTaskHandler,
56+
) -> list[dict[str, Any]] | None:
57+
"""Wrap workflow task execution."""
58+
...
59+
60+
async def execute_activity(
61+
self,
62+
context: ActivityInterceptorContext,
63+
next: ActivityHandler,
64+
) -> Any:
65+
"""Wrap the registered activity callable."""
66+
...
67+
68+
async def execute_query_task(
69+
self,
70+
context: QueryTaskInterceptorContext,
71+
next: QueryTaskHandler,
72+
) -> str:
73+
"""Wrap query task execution."""
74+
...
75+
76+
77+
class PassthroughWorkerInterceptor:
78+
"""No-op interceptor useful as a base for partial custom interceptors."""
79+
80+
async def execute_workflow_task(
81+
self,
82+
context: WorkflowTaskInterceptorContext,
83+
next: WorkflowTaskHandler,
84+
) -> list[dict[str, Any]] | None:
85+
return await next(context)
86+
87+
async def execute_activity(
88+
self,
89+
context: ActivityInterceptorContext,
90+
next: ActivityHandler,
91+
) -> Any:
92+
return await next(context)
93+
94+
async def execute_query_task(
95+
self,
96+
context: QueryTaskInterceptorContext,
97+
next: QueryTaskHandler,
98+
) -> str:
99+
return await next(context)

src/durable_workflow/worker.py

Lines changed: 102 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
import time
2222
import traceback
2323
import uuid
24-
from collections.abc import Callable, Iterable
24+
from collections.abc import Awaitable, Callable, Iterable
2525
from typing import Any
2626

2727
from . import serializer
@@ -35,6 +35,12 @@
3535
WorkflowExecution,
3636
)
3737
from .errors import ActivityCancelled, AvroNotInstalledError, NonRetryableError, QueryFailed
38+
from .interceptors import (
39+
ActivityInterceptorContext,
40+
QueryTaskInterceptorContext,
41+
WorkerInterceptor,
42+
WorkflowTaskInterceptorContext,
43+
)
3844
from .metrics import (
3945
NOOP_METRICS,
4046
WORKER_POLL_DURATION_SECONDS,
@@ -151,6 +157,7 @@ def __init__(
151157
max_concurrent_activity_tasks: int = 10,
152158
shutdown_timeout: float = 30.0,
153159
metrics: MetricsRecorder | None = None,
160+
interceptors: Iterable[WorkerInterceptor] = (),
154161
) -> None:
155162
self.client = client
156163
self.task_queue = task_queue
@@ -166,6 +173,7 @@ def __init__(
166173
self._query_tasks_supported = False
167174
configured_metrics = metrics if metrics is not None else getattr(client, "metrics", NOOP_METRICS)
168175
self.metrics: MetricsRecorder = configured_metrics or NOOP_METRICS
176+
self.interceptors = tuple(interceptors)
169177

170178
def _record_poll_metrics(self, task_kind: str, outcome: str, duration: float) -> None:
171179
tags = {"task_kind": task_kind, "task_queue": self.task_queue, "outcome": outcome}
@@ -201,6 +209,35 @@ async def _register(self) -> None:
201209
log.info("worker %s registered on %s", self.worker_id, self.task_queue)
202210

203211
async def _run_workflow_task(self, task: dict[str, Any]) -> list[dict[str, Any]] | None:
212+
context = WorkflowTaskInterceptorContext(
213+
worker_id=self.worker_id,
214+
task_queue=self.task_queue,
215+
task=task,
216+
)
217+
218+
async def call_core(ctx: WorkflowTaskInterceptorContext) -> list[dict[str, Any]] | None:
219+
return await self._run_workflow_task_core(ctx.task)
220+
221+
handler = call_core
222+
for interceptor in reversed(self.interceptors):
223+
next_handler = handler
224+
225+
async def call_interceptor(
226+
ctx: WorkflowTaskInterceptorContext,
227+
*,
228+
interceptor: WorkerInterceptor = interceptor,
229+
next_handler: Callable[
230+
[WorkflowTaskInterceptorContext],
231+
Awaitable[list[dict[str, Any]] | None],
232+
] = next_handler,
233+
) -> list[dict[str, Any]] | None:
234+
return await interceptor.execute_workflow_task(ctx, next_handler)
235+
236+
handler = call_interceptor
237+
238+
return await handler(context)
239+
240+
async def _run_workflow_task_core(self, task: dict[str, Any]) -> list[dict[str, Any]] | None:
204241
import json as _json
205242

206243
log.debug("workflow task payload: %s", _json.dumps(task, default=str)[:2000])
@@ -446,7 +483,7 @@ async def _run_activity_task(self, task: dict[str, Any]) -> str:
446483
)
447484
_set_context(act_ctx)
448485
try:
449-
result = fn(*args) if not asyncio.iscoroutinefunction(fn) else await fn(*args)
486+
result = await self._execute_activity_callable(task, activity_type, tuple(args), fn)
450487
except ActivityCancelled:
451488
log.info("activity %s cancelled via heartbeat", task_id)
452489
try:
@@ -508,7 +545,70 @@ async def _run_activity_task(self, task: dict[str, Any]) -> str:
508545
return "complete_error"
509546
return "completed"
510547

548+
async def _execute_activity_callable(
549+
self,
550+
task: dict[str, Any],
551+
activity_type: str,
552+
args: tuple[Any, ...],
553+
fn: Callable[..., Any],
554+
) -> Any:
555+
context = ActivityInterceptorContext(
556+
worker_id=self.worker_id,
557+
task_queue=self.task_queue,
558+
task=task,
559+
activity_type=activity_type,
560+
args=args,
561+
)
562+
563+
async def call_activity(ctx: ActivityInterceptorContext) -> Any:
564+
result = fn(*ctx.args)
565+
if asyncio.iscoroutine(result):
566+
return await result
567+
return result
568+
569+
handler = call_activity
570+
for interceptor in reversed(self.interceptors):
571+
next_handler = handler
572+
573+
async def call_interceptor(
574+
ctx: ActivityInterceptorContext,
575+
*,
576+
interceptor: WorkerInterceptor = interceptor,
577+
next_handler: Callable[[ActivityInterceptorContext], Awaitable[Any]] = next_handler,
578+
) -> Any:
579+
return await interceptor.execute_activity(ctx, next_handler)
580+
581+
handler = call_interceptor
582+
583+
return await handler(context)
584+
511585
async def _run_query_task(self, task: dict[str, Any]) -> str:
586+
context = QueryTaskInterceptorContext(
587+
worker_id=self.worker_id,
588+
task_queue=self.task_queue,
589+
task=task,
590+
)
591+
592+
async def call_core(ctx: QueryTaskInterceptorContext) -> str:
593+
return await self._run_query_task_core(ctx.task)
594+
595+
handler = call_core
596+
for interceptor in reversed(self.interceptors):
597+
next_handler = handler
598+
599+
async def call_interceptor(
600+
ctx: QueryTaskInterceptorContext,
601+
*,
602+
interceptor: WorkerInterceptor = interceptor,
603+
next_handler: Callable[[QueryTaskInterceptorContext], Awaitable[str]] = next_handler,
604+
) -> str:
605+
return await interceptor.execute_query_task(ctx, next_handler)
606+
607+
handler = call_interceptor
608+
609+
return await handler(context)
610+
611+
async def _run_query_task_core(self, task: dict[str, Any]) -> str:
512612
query_task_id: str = task["query_task_id"]
513613
attempt: int = task.get("query_task_attempt", 1)
514614
wf_type: str = task.get("workflow_type", "")

0 commit comments

Comments
 (0)