Skip to content

Commit 9c539e8

Browse files
Add explicit Avro payload adapters
Issue: zorporation/durable-workflow#538 Loop-ID: build-02
1 parent 4a8bf86 commit 9c539e8

4 files changed

Lines changed: 242 additions & 2 deletions

File tree

README.md

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,53 @@ scalars before passing them to the SDK. `IntEnum` and `StrEnum` encode because
198198
they are JSON scalar subclasses, but they decode as `int` and `str`.
199199
`OrderedDict` decodes as a plain `dict`.
200200

201+
Use `to_avro_payload_value(...)` when a rich value should enter durable
202+
history through the default Avro envelope:
203+
204+
```python
205+
from dataclasses import dataclass
206+
from datetime import datetime, timezone
207+
from decimal import Decimal
208+
from enum import Enum
209+
from uuid import UUID
210+
211+
from durable_workflow import Client, to_avro_payload_value
212+
213+
214+
class OrderStatus(Enum):
215+
PENDING = "pending"
216+
217+
218+
@dataclass
219+
class OrderInput:
220+
order_id: UUID
221+
placed_at: datetime
222+
amount: Decimal
223+
status: OrderStatus
224+
225+
226+
order = OrderInput(
227+
order_id=UUID("12345678-1234-5678-1234-567812345678"),
228+
placed_at=datetime.now(timezone.utc),
229+
amount=Decimal("10.25"),
230+
status=OrderStatus.PENDING,
231+
)
232+
233+
client = Client("http://server:8080", token="dev-token-123")
234+
await client.start_workflow(
235+
"order-workflow",
236+
task_queue="orders",
237+
workflow_id="order-123",
238+
input=[to_avro_payload_value(order)],
239+
)
240+
```
241+
242+
The helper also accepts pydantic-style models with `model_dump(mode="json")`
243+
and attrs-style classes. Rebuild domain objects explicitly inside workflows or
244+
activities, for example `OrderInput(order_id=UUID(data["order_id"]), ...)`.
245+
Adapter output is part of the durable history contract, so changing that shape
246+
is a workflow compatibility change.
247+
201248
## Authentication
202249

203250
For local servers that use one shared bearer token, pass `token=`:

src/durable_workflow/__init__.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,12 @@
5858
PrometheusMetrics,
5959
)
6060
from .retry_policy import RetryPolicy, TransportRetryPolicy
61-
from .serializer import PayloadSizeWarningConfig, PayloadSizeWarningContext
61+
from .serializer import (
62+
PayloadSizeWarningConfig,
63+
PayloadSizeWarningContext,
64+
to_avro_payload_value,
65+
to_avro_payload_values,
66+
)
6267
from .worker import Worker
6368
from .workflow import ActivityRetryPolicy, ChildWorkflowRetryPolicy, ContinueAsNew, StartChildWorkflow
6469

@@ -120,4 +125,6 @@
120125
"WorkflowFailed",
121126
"WorkflowNotFound",
122127
"WorkflowTerminated",
128+
"to_avro_payload_value",
129+
"to_avro_payload_values",
123130
]

src/durable_workflow/serializer.py

Lines changed: 93 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,12 @@
2121
import json
2222
import logging
2323
from collections.abc import Mapping, Sequence
24-
from dataclasses import dataclass
24+
from dataclasses import dataclass, fields, is_dataclass
25+
from datetime import date, datetime, time
26+
from decimal import Decimal
27+
from enum import Enum
2528
from typing import Any, TypeGuard, cast
29+
from uuid import UUID
2630

2731
from . import _avro
2832

@@ -91,6 +95,94 @@ def to_log_context(self) -> dict[str, str]:
9195
PayloadWarningContexts = PayloadWarningContext | Sequence[PayloadWarningContext]
9296

9397

98+
def to_avro_payload_value(value: Any) -> Any:
99+
"""Convert common rich Python values to JSON-native Avro wrapper values.
100+
101+
The SDK's default Avro codec is a language-neutral envelope around a JSON
102+
document. This helper is the explicit boundary for class-carrying values:
103+
callers opt in before encode and the returned value becomes part of durable
104+
history.
105+
"""
106+
if isinstance(value, Enum):
107+
return to_avro_payload_value(value.value)
108+
109+
if value is None or isinstance(value, str | int | float | bool):
110+
return value
111+
112+
if isinstance(value, datetime):
113+
return value.isoformat()
114+
115+
if isinstance(value, date | time):
116+
return value.isoformat()
117+
118+
if isinstance(value, UUID | Decimal):
119+
return str(value)
120+
121+
if isinstance(value, Mapping):
122+
converted: dict[str, Any] = {}
123+
for key, item in value.items():
124+
if not isinstance(key, str):
125+
raise TypeError("Avro JSON payload dictionaries must use string keys after adaptation")
126+
converted[key] = to_avro_payload_value(item)
127+
return converted
128+
129+
if is_dataclass(value) and not isinstance(value, type):
130+
return {
131+
field.name: to_avro_payload_value(getattr(value, field.name))
132+
for field in fields(value)
133+
}
134+
135+
pydantic_dump = _pydantic_model_dump(value)
136+
if pydantic_dump is not None:
137+
return to_avro_payload_value(pydantic_dump)
138+
139+
attrs_dump = _attrs_payload_dict(value)
140+
if attrs_dump is not None:
141+
return attrs_dump
142+
143+
if isinstance(value, Sequence) and not isinstance(value, str | bytes | bytearray):
144+
return [to_avro_payload_value(item) for item in value]
145+
146+
raise TypeError(
147+
f"Object of type {type(value).__name__} is not Avro JSON payload safe; "
148+
"adapt it to None, bool, int, float, str, list, or dict[str, value] first."
149+
)
150+
151+
152+
def to_avro_payload_values(values: Sequence[Any]) -> list[Any]:
153+
"""Convert several values with :func:`to_avro_payload_value`."""
154+
return [to_avro_payload_value(value) for value in values]
155+
156+
157+
def _pydantic_model_dump(value: Any) -> Any | None:
158+
if not (
159+
hasattr(value, "__pydantic_fields__")
160+
or hasattr(value, "__fields__")
161+
or value.__class__.__module__.startswith("pydantic.")
162+
):
163+
return None
164+
165+
model_dump = getattr(value, "model_dump", None)
166+
if callable(model_dump):
167+
return model_dump(mode="json")
168+
169+
dict_dump = getattr(value, "dict", None)
170+
if callable(dict_dump):
171+
return dict_dump()
172+
173+
return None
174+
175+
176+
def _attrs_payload_dict(value: Any) -> dict[str, Any] | None:
177+
attrs_fields = getattr(value.__class__, "__attrs_attrs__", None)
178+
if attrs_fields is None:
179+
return None
180+
return {
181+
field.name: to_avro_payload_value(getattr(value, field.name))
182+
for field in attrs_fields
183+
}
184+
185+
94186
def encode(
95187
value: Any,
96188
codec: str = AVRO_CODEC,

tests/test_serializer.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,14 @@ class SerializerDataclass:
3434
count: int
3535

3636

37+
@dataclass
38+
class SerializerOrder:
39+
order_id: UUID
40+
placed_at: datetime
41+
amount: Decimal
42+
status: "SerializerEnum"
43+
44+
3745
class SerializerEnum(Enum):
3846
PENDING = "pending"
3947

@@ -50,6 +58,32 @@ class SerializerStrEnum(StrEnum):
5058
SerializerStrEnum = None
5159

5260

61+
class _AttrsField:
62+
def __init__(self, name: str) -> None:
63+
self.name = name
64+
65+
66+
class SerializerAttrsStyle:
67+
__attrs_attrs__ = (_AttrsField("sku"), _AttrsField("quantity"))
68+
69+
def __init__(self, sku: str, quantity: int) -> None:
70+
self.sku = sku
71+
self.quantity = quantity
72+
73+
74+
class SerializerPydanticStyle:
75+
__pydantic_fields__ = {"order_id": object()}
76+
77+
def __init__(self, order_id: UUID, due_on: date) -> None:
78+
self.order_id = order_id
79+
self.due_on = due_on
80+
81+
def model_dump(self, *, mode: str = "python") -> dict[str, object]:
82+
if mode == "json":
83+
return {"order_id": str(self.order_id), "due_on": self.due_on.isoformat()}
84+
return {"order_id": self.order_id, "due_on": self.due_on}
85+
86+
5387
class TestEncode:
5488
def test_list(self) -> None:
5589
assert serializer.encode(["a", 1, True], codec="json") == '["a",1,true]'
@@ -166,6 +200,66 @@ def test_encode_many_rejects_context_count_mismatch(self) -> None:
166200
)
167201

168202

203+
class TestAvroPayloadAdapter:
204+
def test_adapts_dataclass_datetime_uuid_decimal_and_enum(self) -> None:
205+
order = SerializerOrder(
206+
order_id=UUID("12345678-1234-5678-1234-567812345678"),
207+
placed_at=datetime(2026, 4, 21, 10, 30, tzinfo=timezone.utc),
208+
amount=Decimal("10.25"),
209+
status=SerializerEnum.PENDING,
210+
)
211+
212+
assert serializer.to_avro_payload_value(order) == {
213+
"order_id": "12345678-1234-5678-1234-567812345678",
214+
"placed_at": "2026-04-21T10:30:00+00:00",
215+
"amount": "10.25",
216+
"status": "pending",
217+
}
218+
219+
def test_adapts_pydantic_style_models_through_json_mode_dump(self) -> None:
220+
model = SerializerPydanticStyle(
221+
UUID("12345678-1234-5678-1234-567812345678"),
222+
date(2026, 4, 21),
223+
)
224+
225+
assert serializer.to_avro_payload_value(model) == {
226+
"order_id": "12345678-1234-5678-1234-567812345678",
227+
"due_on": "2026-04-21",
228+
}
229+
230+
def test_adapts_attrs_style_objects_and_sequences(self) -> None:
231+
value = (SerializerAttrsStyle("ABC", 2), time(10, 30, tzinfo=timezone.utc))
232+
233+
assert serializer.to_avro_payload_value(value) == [
234+
{"sku": "ABC", "quantity": 2},
235+
"10:30:00+00:00",
236+
]
237+
238+
@requires_avro
239+
def test_adapter_output_round_trips_through_default_avro_codec(self) -> None:
240+
value = serializer.to_avro_payload_value(
241+
{
242+
"model": SerializerDataclass(name="Ada", count=2),
243+
"ids": [UUID("12345678-1234-5678-1234-567812345678")],
244+
}
245+
)
246+
247+
blob = serializer.encode(value, codec="avro")
248+
249+
assert serializer.decode(blob, codec="avro") == {
250+
"model": {"name": "Ada", "count": 2},
251+
"ids": ["12345678-1234-5678-1234-567812345678"],
252+
}
253+
254+
def test_adapter_rejects_non_string_mapping_keys(self) -> None:
255+
with pytest.raises(TypeError, match="string keys"):
256+
serializer.to_avro_payload_value({1: "one"})
257+
258+
def test_adapter_rejects_unadapted_objects(self) -> None:
259+
with pytest.raises(TypeError, match="not Avro JSON payload safe"):
260+
serializer.to_avro_payload_value(object())
261+
262+
169263
class TestPayloadSizeWarning:
170264
def test_encode_warns_with_structured_context(self, caplog: pytest.LogCaptureFixture) -> None:
171265
config = serializer.PayloadSizeWarningConfig(limit_bytes=10, threshold_percent=50)

0 commit comments

Comments
 (0)