Skip to content

Commit 551db68

Browse files
Apply payload warnings to workflow commands
Issue: zorporation/durable-workflow#444 Loop-ID: build-01
1 parent 9d1dbe8 commit 551db68

4 files changed

Lines changed: 310 additions & 19 deletions

File tree

src/durable_workflow/worker.py

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,10 @@ def _workflow_name(cls: type) -> str:
6868
return getattr(cls, "__workflow_name__", cls.__name__)
6969

7070

71+
def _string_or_none(value: Any) -> str | None:
72+
return value if isinstance(value, str) and value else None
73+
74+
7175
def _manifest_version(manifest: Any) -> str:
7276
if isinstance(manifest, dict):
7377
value = manifest.get("version")
@@ -185,6 +189,42 @@ def _record_task_metrics(self, task_kind: str, outcome: str, duration: float) ->
185189
self.metrics.increment(WORKER_TASKS, tags=tags)
186190
self.metrics.record(WORKER_TASK_DURATION_SECONDS, duration, tags=tags)
187191

192+
def _payload_size_warning_config(self) -> serializer.PayloadSizeWarningConfig | None:
193+
config = getattr(self.client, "payload_size_warning_config", serializer.DEFAULT_PAYLOAD_SIZE_WARNING)
194+
if config is None or isinstance(config, serializer.PayloadSizeWarningConfig):
195+
return config
196+
return serializer.DEFAULT_PAYLOAD_SIZE_WARNING
197+
198+
def _workflow_payload_warning_context(
199+
self,
200+
task: dict[str, Any],
201+
*,
202+
kind: str,
203+
update_name: str | None = None,
204+
) -> serializer.PayloadSizeWarningContext:
205+
namespace = getattr(self.client, "namespace", None)
206+
return serializer.PayloadSizeWarningContext(
207+
kind=kind,
208+
workflow_id=_string_or_none(task.get("workflow_id")),
209+
run_id=_string_or_none(task.get("run_id")),
210+
update_name=update_name,
211+
task_queue=self.task_queue,
212+
namespace=namespace if isinstance(namespace, str) else None,
213+
)
214+
215+
def _update_name_for_id(self, history: list[dict[str, Any]], update_id: str | None) -> str | None:
216+
if not update_id:
217+
return None
218+
for event in reversed(history):
219+
if event.get("event_type") not in {"UpdateAccepted", "UpdateApplied"}:
220+
continue
221+
payload = event.get("payload")
222+
if not isinstance(payload, dict) or payload.get("update_id") != update_id:
223+
continue
224+
update_name = payload.get("update_name")
225+
return update_name if isinstance(update_name, str) and update_name else None
226+
return None
227+
188228
async def _register(self) -> None:
189229
try:
190230
info = await self.client.get_cluster_info()
@@ -335,6 +375,12 @@ async def _run_workflow_task_core(self, task: dict[str, Any]) -> list[dict[str,
335375
command = update_command.to_server_command(
336376
self.task_queue,
337377
payload_codec=command_codec,
378+
size_warning=self._payload_size_warning_config(),
379+
warning_context=self._workflow_payload_warning_context(
380+
task,
381+
kind="workflow_command",
382+
update_name=self._update_name_for_id(history, update_id),
383+
),
338384
)
339385
log.info(
340386
"completing workflow update task %s for update %s with %s",
@@ -388,7 +434,15 @@ async def _run_workflow_task_core(self, task: dict[str, Any]) -> list[dict[str,
388434
return None
389435

390436
commands = [
391-
c.to_server_command(self.task_queue, payload_codec=command_codec)
437+
c.to_server_command(
438+
self.task_queue,
439+
payload_codec=command_codec,
440+
size_warning=self._payload_size_warning_config(),
441+
warning_context=self._workflow_payload_warning_context(
442+
task,
443+
kind="workflow_command",
444+
),
445+
)
392446
for c in outcome.commands
393447
]
394448
log.info(

src/durable_workflow/workflow.py

Lines changed: 172 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,36 @@ def _backoff_seconds(self) -> list[int]:
215215

216216

217217
ActivityRetryPolicyInput = ActivityRetryPolicy | Mapping[str, Any]
218+
PayloadWarningContext = serializer.PayloadSizeWarningContext | Mapping[str, Any] | None
219+
220+
221+
def _payload_warning_context(
222+
base: PayloadWarningContext,
223+
*,
224+
kind: str,
225+
task_queue: str | None = None,
226+
activity_name: str | None = None,
227+
update_name: str | None = None,
228+
) -> dict[str, str]:
229+
if isinstance(base, serializer.PayloadSizeWarningContext):
230+
context: dict[str, str] = base.to_log_context()
231+
elif base is None:
232+
context = {}
233+
else:
234+
context = {
235+
str(key): str(value)
236+
for key, value in base.items()
237+
if value is not None
238+
}
239+
240+
context["kind"] = kind
241+
if task_queue is not None:
242+
context["task_queue"] = task_queue
243+
if activity_name is not None:
244+
context["activity_name"] = activity_name
245+
if update_name is not None:
246+
context["update_name"] = update_name
247+
return context
218248

219249

220250
@dataclass
@@ -239,12 +269,27 @@ class ScheduleActivity:
239269
heartbeat_timeout: int | None = None
240270

241271
def to_server_command(
242-
self, task_queue: str, *, payload_codec: str = serializer.AVRO_CODEC
272+
self,
273+
task_queue: str,
274+
*,
275+
payload_codec: str = serializer.AVRO_CODEC,
276+
size_warning: serializer.PayloadSizeWarningConfig | None = serializer.DEFAULT_PAYLOAD_SIZE_WARNING,
277+
warning_context: PayloadWarningContext = None,
243278
) -> dict[str, Any]:
244279
command: dict[str, Any] = {
245280
"type": "schedule_activity",
246281
"activity_type": self.activity_type,
247-
"arguments": serializer.envelope(self.arguments, codec=payload_codec),
282+
"arguments": serializer.envelope(
283+
self.arguments,
284+
codec=payload_codec,
285+
size_warning=size_warning,
286+
warning_context=_payload_warning_context(
287+
warning_context,
288+
kind="activity_input",
289+
task_queue=self.queue or task_queue,
290+
activity_name=self.activity_type,
291+
),
292+
),
248293
"queue": self.queue or task_queue,
249294
}
250295
if self.retry_policy is not None:
@@ -271,7 +316,12 @@ class StartTimer:
271316
delay_seconds: int
272317

273318
def to_server_command(
274-
self, task_queue: str, *, payload_codec: str = serializer.AVRO_CODEC
319+
self,
320+
task_queue: str,
321+
*,
322+
payload_codec: str = serializer.AVRO_CODEC,
323+
size_warning: serializer.PayloadSizeWarningConfig | None = serializer.DEFAULT_PAYLOAD_SIZE_WARNING,
324+
warning_context: PayloadWarningContext = None,
275325
) -> dict[str, Any]:
276326
return {
277327
"type": "start_timer",
@@ -286,11 +336,25 @@ class CompleteWorkflow:
286336
result: Any
287337

288338
def to_server_command(
289-
self, task_queue: str, *, payload_codec: str = serializer.AVRO_CODEC
339+
self,
340+
task_queue: str,
341+
*,
342+
payload_codec: str = serializer.AVRO_CODEC,
343+
size_warning: serializer.PayloadSizeWarningConfig | None = serializer.DEFAULT_PAYLOAD_SIZE_WARNING,
344+
warning_context: PayloadWarningContext = None,
290345
) -> dict[str, Any]:
291346
return {
292347
"type": "complete_workflow",
293-
"result": serializer.envelope(self.result, codec=payload_codec),
348+
"result": serializer.envelope(
349+
self.result,
350+
codec=payload_codec,
351+
size_warning=size_warning,
352+
warning_context=_payload_warning_context(
353+
warning_context,
354+
kind="workflow_result",
355+
task_queue=task_queue,
356+
),
357+
),
294358
}
295359

296360

@@ -303,7 +367,12 @@ class FailWorkflow:
303367
non_retryable: bool = False
304368

305369
def to_server_command(
306-
self, task_queue: str, *, payload_codec: str = serializer.AVRO_CODEC
370+
self,
371+
task_queue: str,
372+
*,
373+
payload_codec: str = serializer.AVRO_CODEC,
374+
size_warning: serializer.PayloadSizeWarningConfig | None = serializer.DEFAULT_PAYLOAD_SIZE_WARNING,
375+
warning_context: PayloadWarningContext = None,
307376
) -> dict[str, Any]:
308377
cmd: dict[str, Any] = {
309378
"type": "fail_workflow",
@@ -324,12 +393,26 @@ class CompleteUpdate:
324393
result: Any
325394

326395
def to_server_command(
327-
self, task_queue: str, *, payload_codec: str = serializer.AVRO_CODEC
396+
self,
397+
task_queue: str,
398+
*,
399+
payload_codec: str = serializer.AVRO_CODEC,
400+
size_warning: serializer.PayloadSizeWarningConfig | None = serializer.DEFAULT_PAYLOAD_SIZE_WARNING,
401+
warning_context: PayloadWarningContext = None,
328402
) -> dict[str, Any]:
329403
return {
330404
"type": "complete_update",
331405
"update_id": self.update_id,
332-
"result": serializer.envelope(self.result, codec=payload_codec),
406+
"result": serializer.envelope(
407+
self.result,
408+
codec=payload_codec,
409+
size_warning=size_warning,
410+
warning_context=_payload_warning_context(
411+
warning_context,
412+
kind="update_result",
413+
task_queue=task_queue,
414+
),
415+
),
333416
}
334417

335418

@@ -344,7 +427,12 @@ class FailUpdate:
344427
non_retryable: bool = True
345428

346429
def to_server_command(
347-
self, task_queue: str, *, payload_codec: str = serializer.AVRO_CODEC
430+
self,
431+
task_queue: str,
432+
*,
433+
payload_codec: str = serializer.AVRO_CODEC,
434+
size_warning: serializer.PayloadSizeWarningConfig | None = serializer.DEFAULT_PAYLOAD_SIZE_WARNING,
435+
warning_context: PayloadWarningContext = None,
348436
) -> dict[str, Any]:
349437
cmd: dict[str, Any] = {
350438
"type": "fail_update",
@@ -369,12 +457,26 @@ class ContinueAsNew:
369457
task_queue: str | None = None
370458

371459
def to_server_command(
372-
self, task_queue: str, *, payload_codec: str = serializer.AVRO_CODEC
460+
self,
461+
task_queue: str,
462+
*,
463+
payload_codec: str = serializer.AVRO_CODEC,
464+
size_warning: serializer.PayloadSizeWarningConfig | None = serializer.DEFAULT_PAYLOAD_SIZE_WARNING,
465+
warning_context: PayloadWarningContext = None,
373466
) -> dict[str, Any]:
374467
cmd: dict[str, Any] = {"type": "continue_as_new"}
375468
if self.workflow_type is not None:
376469
cmd["workflow_type"] = self.workflow_type
377-
cmd["arguments"] = serializer.envelope(self.arguments, codec=payload_codec)
470+
cmd["arguments"] = serializer.envelope(
471+
self.arguments,
472+
codec=payload_codec,
473+
size_warning=size_warning,
474+
warning_context=_payload_warning_context(
475+
warning_context,
476+
kind="continue_as_new_input",
477+
task_queue=self.task_queue or task_queue,
478+
),
479+
)
378480
cmd["queue"] = self.task_queue or task_queue
379481
return cmd
380482

@@ -386,11 +488,25 @@ class RecordSideEffect:
386488
result: Any
387489

388490
def to_server_command(
389-
self, task_queue: str, *, payload_codec: str = serializer.AVRO_CODEC
491+
self,
492+
task_queue: str,
493+
*,
494+
payload_codec: str = serializer.AVRO_CODEC,
495+
size_warning: serializer.PayloadSizeWarningConfig | None = serializer.DEFAULT_PAYLOAD_SIZE_WARNING,
496+
warning_context: PayloadWarningContext = None,
390497
) -> dict[str, Any]:
391498
return {
392499
"type": "record_side_effect",
393-
"result": serializer.encode(self.result, codec=payload_codec),
500+
"result": serializer.encode(
501+
self.result,
502+
codec=payload_codec,
503+
size_warning=size_warning,
504+
warning_context=_payload_warning_context(
505+
warning_context,
506+
kind="side_effect_result",
507+
task_queue=task_queue,
508+
),
509+
),
394510
}
395511

396512

@@ -407,12 +523,26 @@ class StartChildWorkflow:
407523
run_timeout_seconds: int | None = None
408524

409525
def to_server_command(
410-
self, task_queue: str, *, payload_codec: str = serializer.AVRO_CODEC
526+
self,
527+
task_queue: str,
528+
*,
529+
payload_codec: str = serializer.AVRO_CODEC,
530+
size_warning: serializer.PayloadSizeWarningConfig | None = serializer.DEFAULT_PAYLOAD_SIZE_WARNING,
531+
warning_context: PayloadWarningContext = None,
411532
) -> dict[str, Any]:
412533
cmd: dict[str, Any] = {
413534
"type": "start_child_workflow",
414535
"workflow_type": self.workflow_type,
415-
"arguments": serializer.envelope(self.arguments, codec=payload_codec),
536+
"arguments": serializer.envelope(
537+
self.arguments,
538+
codec=payload_codec,
539+
size_warning=size_warning,
540+
warning_context=_payload_warning_context(
541+
warning_context,
542+
kind="child_workflow_input",
543+
task_queue=self.task_queue or task_queue,
544+
),
545+
),
416546
}
417547
if self.task_queue is not None:
418548
cmd["queue"] = self.task_queue
@@ -443,7 +573,12 @@ class RecordVersionMarker:
443573
max_supported: int
444574

445575
def to_server_command(
446-
self, task_queue: str, *, payload_codec: str = serializer.AVRO_CODEC
576+
self,
577+
task_queue: str,
578+
*,
579+
payload_codec: str = serializer.AVRO_CODEC,
580+
size_warning: serializer.PayloadSizeWarningConfig | None = serializer.DEFAULT_PAYLOAD_SIZE_WARNING,
581+
warning_context: PayloadWarningContext = None,
447582
) -> dict[str, Any]:
448583
return {
449584
"type": "record_version_marker",
@@ -461,8 +596,22 @@ class UpsertSearchAttributes:
461596
attributes: dict[str, Any]
462597

463598
def to_server_command(
464-
self, task_queue: str, *, payload_codec: str = serializer.AVRO_CODEC
599+
self,
600+
task_queue: str,
601+
*,
602+
payload_codec: str = serializer.AVRO_CODEC,
603+
size_warning: serializer.PayloadSizeWarningConfig | None = serializer.DEFAULT_PAYLOAD_SIZE_WARNING,
604+
warning_context: PayloadWarningContext = None,
465605
) -> dict[str, Any]:
606+
serializer.warn_if_json_payload_near_limit(
607+
self.attributes,
608+
size_warning=size_warning,
609+
warning_context=_payload_warning_context(
610+
warning_context,
611+
kind="search_attributes",
612+
task_queue=task_queue,
613+
),
614+
)
466615
return {
467616
"type": "upsert_search_attributes",
468617
"attributes": self.attributes,
@@ -486,7 +635,12 @@ class WaitCondition:
486635
timeout_seconds: int | None = None
487636

488637
def to_server_command(
489-
self, task_queue: str, *, payload_codec: str = serializer.AVRO_CODEC
638+
self,
639+
task_queue: str,
640+
*,
641+
payload_codec: str = serializer.AVRO_CODEC,
642+
size_warning: serializer.PayloadSizeWarningConfig | None = serializer.DEFAULT_PAYLOAD_SIZE_WARNING,
643+
warning_context: PayloadWarningContext = None,
490644
) -> dict[str, Any]:
491645
cmd: dict[str, Any] = {"type": "open_condition_wait"}
492646
if self.condition_key is not None:

0 commit comments

Comments
 (0)