Skip to content

Commit 4a8bf86

Browse files
Batch Python workflow payload encoding
1 parent 17069ff commit 4a8bf86

6 files changed

Lines changed: 337 additions & 16 deletions

File tree

src/durable_workflow/_avro.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
import base64
3131
import io
3232
import json
33+
from functools import lru_cache
3334
from typing import Any
3435

3536
from .errors import AvroNotInstalledError
@@ -44,6 +45,7 @@
4445
_PREFIX_TYPED_SCHEMA = b"\x01"
4546

4647

48+
@lru_cache(maxsize=1)
4749
def _load_avro_schema() -> Any:
4850
try:
4951
import avro.schema

src/durable_workflow/serializer.py

Lines changed: 70 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,9 @@
2020

2121
import json
2222
import logging
23-
from collections.abc import Mapping
23+
from collections.abc import Mapping, Sequence
2424
from dataclasses import dataclass
25-
from typing import Any
25+
from typing import Any, TypeGuard, cast
2626

2727
from . import _avro
2828

@@ -87,6 +87,8 @@ def to_log_context(self) -> dict[str, str]:
8787

8888

8989
DEFAULT_PAYLOAD_SIZE_WARNING = PayloadSizeWarningConfig()
90+
PayloadWarningContext = PayloadSizeWarningContext | Mapping[str, Any] | None
91+
PayloadWarningContexts = PayloadWarningContext | Sequence[PayloadWarningContext]
9092

9193

9294
def encode(
@@ -119,6 +121,32 @@ def encode(
119121
return blob
120122

121123

124+
def encode_many(
125+
values: Sequence[Any],
126+
codec: str = AVRO_CODEC,
127+
*,
128+
size_warning: PayloadSizeWarningConfig | None = DEFAULT_PAYLOAD_SIZE_WARNING,
129+
warning_context: PayloadWarningContexts = None,
130+
) -> list[str]:
131+
"""Encode several payload blobs through one codec hook.
132+
133+
The default implementation intentionally preserves the single-value
134+
encoder semantics and warning behavior. Codecs that can safely batch or
135+
parallelize work can specialize behind this boundary without changing
136+
call sites.
137+
"""
138+
contexts = _warning_contexts_for_values(values, warning_context)
139+
return [
140+
encode(
141+
value,
142+
codec=codec,
143+
size_warning=size_warning,
144+
warning_context=contexts[index],
145+
)
146+
for index, value in enumerate(values)
147+
]
148+
149+
122150
def envelope(
123151
value: Any,
124152
codec: str = AVRO_CODEC,
@@ -138,6 +166,25 @@ def envelope(
138166
}
139167

140168

169+
def envelope_many(
170+
values: Sequence[Any],
171+
codec: str = AVRO_CODEC,
172+
*,
173+
size_warning: PayloadSizeWarningConfig | None = DEFAULT_PAYLOAD_SIZE_WARNING,
174+
warning_context: PayloadWarningContexts = None,
175+
) -> list[dict[str, str]]:
176+
"""Wrap several values in ``{codec, blob}`` payload envelopes."""
177+
return [
178+
{"codec": codec, "blob": blob}
179+
for blob in encode_many(
180+
values,
181+
codec=codec,
182+
size_warning=size_warning,
183+
warning_context=warning_context,
184+
)
185+
]
186+
187+
141188
def warn_if_json_payload_near_limit(
142189
value: Any,
143190
*,
@@ -205,6 +252,27 @@ def _normalize_warning_context(
205252
return normalized
206253

207254

255+
def _warning_contexts_for_values(
256+
values: Sequence[Any],
257+
context: PayloadWarningContexts,
258+
) -> list[PayloadWarningContext]:
259+
if not _is_context_sequence(context):
260+
single_context = cast(PayloadWarningContext, context)
261+
return [single_context] * len(values)
262+
if len(context) != len(values):
263+
raise ValueError("payload warning context count must match value count")
264+
return list(context)
265+
266+
267+
def _is_context_sequence(
268+
context: PayloadWarningContexts,
269+
) -> TypeGuard[Sequence[PayloadWarningContext]]:
270+
return isinstance(context, Sequence) and not isinstance(
271+
context,
272+
(str, bytes, bytearray, PayloadSizeWarningContext, Mapping),
273+
)
274+
275+
208276
def decode_envelope(value: Any, codec: str | None = None) -> Any:
209277
"""Decode a value that may be a ``{codec, blob}`` envelope or a raw blob.
210278

src/durable_workflow/worker.py

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@
5252
WORKER_TASKS,
5353
MetricsRecorder,
5454
)
55-
from .workflow import apply_update, query_state, replay
55+
from .workflow import apply_update, commands_to_server_commands, query_state, replay
5656

5757
log = logging.getLogger("durable_workflow.worker")
5858

@@ -515,18 +515,16 @@ async def _run_workflow_task_core(self, task: dict[str, Any]) -> list[dict[str,
515515
log.warning("failed to report replay failure: %s", fe)
516516
return None
517517

518-
commands = [
519-
c.to_server_command(
520-
self.task_queue,
521-
payload_codec=command_codec,
522-
size_warning=self._payload_size_warning_config(),
523-
warning_context=self._workflow_payload_warning_context(
524-
task,
525-
kind="workflow_command",
526-
),
527-
)
528-
for c in outcome.commands
529-
]
518+
commands = commands_to_server_commands(
519+
outcome.commands,
520+
self.task_queue,
521+
payload_codec=command_codec,
522+
size_warning=self._payload_size_warning_config(),
523+
warning_context=self._workflow_payload_warning_context(
524+
task,
525+
kind="workflow_command",
526+
),
527+
)
530528
log.info(
531529
"completing workflow task %s with %d command(s): %s",
532530
task_id,

src/durable_workflow/workflow.py

Lines changed: 172 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
import math
2525
import random
2626
import uuid
27-
from collections.abc import Callable, Iterable, Mapping
27+
from collections.abc import Callable, Iterable, Mapping, Sequence
2828
from dataclasses import dataclass, field
2929
from datetime import datetime, timezone
3030
from typing import Any
@@ -685,6 +685,177 @@ def to_server_command(
685685
)
686686

687687

688+
def commands_to_server_commands(
689+
commands: Sequence[Command],
690+
task_queue: str,
691+
*,
692+
payload_codec: str = serializer.AVRO_CODEC,
693+
size_warning: serializer.PayloadSizeWarningConfig | None = serializer.DEFAULT_PAYLOAD_SIZE_WARNING,
694+
warning_context: PayloadWarningContext = None,
695+
) -> list[dict[str, Any]]:
696+
"""Convert workflow commands to the server wire shape with batched payload encoding."""
697+
server_commands: list[dict[str, Any]] = []
698+
envelope_jobs: list[tuple[int, str, Any, dict[str, str]]] = []
699+
encode_jobs: list[tuple[int, str, Any, dict[str, str]]] = []
700+
701+
for command in commands:
702+
if isinstance(command, ScheduleActivity):
703+
queue = command.queue or task_queue
704+
server_command: dict[str, Any] = {
705+
"type": "schedule_activity",
706+
"activity_type": command.activity_type,
707+
"queue": queue,
708+
}
709+
envelope_jobs.append((
710+
len(server_commands),
711+
"arguments",
712+
command.arguments,
713+
_payload_warning_context(
714+
warning_context,
715+
kind="activity_input",
716+
task_queue=queue,
717+
activity_name=command.activity_type,
718+
),
719+
))
720+
if command.retry_policy is not None:
721+
server_command["retry_policy"] = (
722+
command.retry_policy.to_dict()
723+
if isinstance(command.retry_policy, ActivityRetryPolicy)
724+
else dict(command.retry_policy)
725+
)
726+
if command.start_to_close_timeout is not None:
727+
server_command["start_to_close_timeout"] = command.start_to_close_timeout
728+
if command.schedule_to_start_timeout is not None:
729+
server_command["schedule_to_start_timeout"] = command.schedule_to_start_timeout
730+
if command.schedule_to_close_timeout is not None:
731+
server_command["schedule_to_close_timeout"] = command.schedule_to_close_timeout
732+
if command.heartbeat_timeout is not None:
733+
server_command["heartbeat_timeout"] = command.heartbeat_timeout
734+
server_commands.append(server_command)
735+
continue
736+
737+
if isinstance(command, CompleteWorkflow):
738+
server_commands.append({"type": "complete_workflow"})
739+
envelope_jobs.append((
740+
len(server_commands) - 1,
741+
"result",
742+
command.result,
743+
_payload_warning_context(
744+
warning_context,
745+
kind="workflow_result",
746+
task_queue=task_queue,
747+
),
748+
))
749+
continue
750+
751+
if isinstance(command, CompleteUpdate):
752+
server_commands.append({"type": "complete_update", "update_id": command.update_id})
753+
envelope_jobs.append((
754+
len(server_commands) - 1,
755+
"result",
756+
command.result,
757+
_payload_warning_context(
758+
warning_context,
759+
kind="update_result",
760+
task_queue=task_queue,
761+
),
762+
))
763+
continue
764+
765+
if isinstance(command, ContinueAsNew):
766+
queue = command.task_queue or task_queue
767+
server_command = {"type": "continue_as_new", "queue": queue}
768+
if command.workflow_type is not None:
769+
server_command["workflow_type"] = command.workflow_type
770+
server_commands.append(server_command)
771+
envelope_jobs.append((
772+
len(server_commands) - 1,
773+
"arguments",
774+
command.arguments,
775+
_payload_warning_context(
776+
warning_context,
777+
kind="continue_as_new_input",
778+
task_queue=queue,
779+
),
780+
))
781+
continue
782+
783+
if isinstance(command, RecordSideEffect):
784+
server_commands.append({"type": "record_side_effect"})
785+
encode_jobs.append((
786+
len(server_commands) - 1,
787+
"result",
788+
command.result,
789+
_payload_warning_context(
790+
warning_context,
791+
kind="side_effect_result",
792+
task_queue=task_queue,
793+
),
794+
))
795+
continue
796+
797+
if isinstance(command, StartChildWorkflow):
798+
queue = command.task_queue or task_queue
799+
server_command = {
800+
"type": "start_child_workflow",
801+
"workflow_type": command.workflow_type,
802+
"queue": queue,
803+
}
804+
envelope_jobs.append((
805+
len(server_commands),
806+
"arguments",
807+
command.arguments,
808+
_payload_warning_context(
809+
warning_context,
810+
kind="child_workflow_input",
811+
task_queue=queue,
812+
),
813+
))
814+
if command.parent_close_policy is not None:
815+
server_command["parent_close_policy"] = command.parent_close_policy
816+
if command.retry_policy is not None:
817+
server_command["retry_policy"] = (
818+
command.retry_policy.to_dict()
819+
if isinstance(command.retry_policy, ActivityRetryPolicy)
820+
else dict(command.retry_policy)
821+
)
822+
if command.execution_timeout_seconds is not None:
823+
server_command["execution_timeout_seconds"] = command.execution_timeout_seconds
824+
if command.run_timeout_seconds is not None:
825+
server_command["run_timeout_seconds"] = command.run_timeout_seconds
826+
server_commands.append(server_command)
827+
continue
828+
829+
server_commands.append(command.to_server_command(
830+
task_queue,
831+
payload_codec=payload_codec,
832+
size_warning=size_warning,
833+
warning_context=warning_context,
834+
))
835+
836+
if envelope_jobs:
837+
envelopes = serializer.envelope_many(
838+
[value for _, _, value, _ in envelope_jobs],
839+
codec=payload_codec,
840+
size_warning=size_warning,
841+
warning_context=[context for _, _, _, context in envelope_jobs],
842+
)
843+
for (index, key, _, _), envelope_value in zip(envelope_jobs, envelopes, strict=True):
844+
server_commands[index][key] = envelope_value
845+
846+
if encode_jobs:
847+
blobs = serializer.encode_many(
848+
[value for _, _, value, _ in encode_jobs],
849+
codec=payload_codec,
850+
size_warning=size_warning,
851+
warning_context=[context for _, _, _, context in encode_jobs],
852+
)
853+
for (index, key, _, _), blob in zip(encode_jobs, blobs, strict=True):
854+
server_commands[index][key] = blob
855+
856+
return server_commands
857+
858+
688859
# ── Context passed to the workflow's run() ───────────────────────────
689860

690861
_REPLAY_LOGGER = logging.getLogger("durable_workflow.workflow.replay")

tests/test_serializer.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,49 @@ def test_none_value(self) -> None:
123123
assert serializer.decode(env["blob"], codec="avro") is None
124124

125125

126+
class TestBatchEncoding:
127+
def test_encode_many_preserves_order(self) -> None:
128+
blobs = serializer.encode_many([["a"], ["b"]], codec="json")
129+
assert blobs == ['["a"]', '["b"]']
130+
131+
def test_envelope_many_wraps_each_value(self) -> None:
132+
envelopes = serializer.envelope_many([["a"], ["b"]], codec="json")
133+
assert envelopes == [
134+
{"codec": "json", "blob": '["a"]'},
135+
{"codec": "json", "blob": '["b"]'},
136+
]
137+
138+
def test_encode_many_accepts_per_payload_warning_context(
139+
self, caplog: pytest.LogCaptureFixture
140+
) -> None:
141+
config = serializer.PayloadSizeWarningConfig(limit_bytes=10, threshold_percent=50)
142+
contexts = [
143+
serializer.PayloadSizeWarningContext(kind="signal", signal_name="one"),
144+
serializer.PayloadSizeWarningContext(kind="signal", signal_name="two"),
145+
]
146+
147+
with caplog.at_level(logging.WARNING, logger="durable_workflow.serializer"):
148+
serializer.encode_many(
149+
["abcdef", "ghijkl"],
150+
codec="json",
151+
size_warning=config,
152+
warning_context=contexts,
153+
)
154+
155+
assert [record.durable_workflow_payload["signal_name"] for record in caplog.records] == [
156+
"one",
157+
"two",
158+
]
159+
160+
def test_encode_many_rejects_context_count_mismatch(self) -> None:
161+
with pytest.raises(ValueError, match="context count"):
162+
serializer.encode_many(
163+
["a", "b"],
164+
codec="json",
165+
warning_context=[serializer.PayloadSizeWarningContext(kind="payload")],
166+
)
167+
168+
126169
class TestPayloadSizeWarning:
127170
def test_encode_warns_with_structured_context(self, caplog: pytest.LogCaptureFixture) -> None:
128171
config = serializer.PayloadSizeWarningConfig(limit_bytes=10, threshold_percent=50)

0 commit comments

Comments
 (0)