|
| 1 | +"""In-process test harness for workflow authoring. |
| 2 | +
|
| 3 | +:class:`WorkflowEnvironment` drives a workflow to completion in a single |
| 4 | +Python process, without a running server or worker. It reuses the same |
| 5 | +:func:`durable_workflow.workflow.replay` machinery the worker uses, but |
| 6 | +resolves yielded commands against user-registered activity mocks and |
| 7 | +auto-fires timers / side-effects / search-attribute upserts so tests do |
| 8 | +not need a real clock or Redis. |
| 9 | +
|
| 10 | +Typical use:: |
| 11 | +
|
| 12 | + def test_my_workflow(): |
| 13 | + env = WorkflowEnvironment() |
| 14 | + env.register_activity_result("charge_card", {"id": "ch_1"}) |
| 15 | + env.register_activity_result("send_receipt", None) |
| 16 | + result = env.execute_workflow(OrderWorkflow, "order-1", {"amount": 42}) |
| 17 | + assert result == {"status": "complete", "charge_id": "ch_1"} |
| 18 | +
|
| 19 | +For regression-testing workflow code against production histories, use |
| 20 | +:func:`replay_history` — it hands the real durable history straight to |
| 21 | +the worker's replayer and surfaces any non-determinism as a raised |
| 22 | +exception. |
| 23 | +""" |
| 24 | + |
| 25 | +from __future__ import annotations |
| 26 | + |
| 27 | +import json |
| 28 | +from collections.abc import Callable, Iterable |
| 29 | +from pathlib import Path |
| 30 | +from typing import Any |
| 31 | + |
| 32 | +from . import serializer |
| 33 | +from .errors import WorkflowCancelled, WorkflowFailed, WorkflowTerminated |
| 34 | +from .workflow import ( |
| 35 | + CompleteWorkflow, |
| 36 | + ContinueAsNew, |
| 37 | + FailWorkflow, |
| 38 | + RecordSideEffect, |
| 39 | + RecordVersionMarker, |
| 40 | + ReplayOutcome, |
| 41 | + ScheduleActivity, |
| 42 | + StartChildWorkflow, |
| 43 | + StartTimer, |
| 44 | + UpsertSearchAttributes, |
| 45 | + replay, |
| 46 | +) |
| 47 | + |
| 48 | + |
| 49 | +class WorkflowEnvironment: |
| 50 | + """Drives a workflow to completion against user-registered activity mocks.""" |
| 51 | + |
| 52 | + def __init__(self, *, iteration_limit: int = 1000) -> None: |
| 53 | + self._activity_results: dict[str, Any] = {} |
| 54 | + self._activity_fns: dict[str, Callable[..., Any]] = {} |
| 55 | + self._child_workflow_results: dict[str, Any] = {} |
| 56 | + self._pending_signals: list[tuple[str, list[Any]]] = [] |
| 57 | + self._iteration_limit = iteration_limit |
| 58 | + |
| 59 | + def register_activity_result(self, name: str, result: Any) -> None: |
| 60 | + """Canned response: every call to ``name`` returns ``result``.""" |
| 61 | + self._activity_results[name] = result |
| 62 | + |
| 63 | + def register_activity(self, name: str, fn: Callable[..., Any]) -> None: |
| 64 | + """Callable mock: ``fn(*arguments)`` is invoked for each scheduled call. |
| 65 | +
|
| 66 | + Use this when the test needs the mock to vary with arguments (e.g. |
| 67 | + look up by order id) or to capture invocations. |
| 68 | + """ |
| 69 | + self._activity_fns[name] = fn |
| 70 | + |
| 71 | + def register_child_workflow_result(self, workflow_type: str, result: Any) -> None: |
| 72 | + """Canned response for child workflow completions.""" |
| 73 | + self._child_workflow_results[workflow_type] = result |
| 74 | + |
| 75 | + def signal(self, name: str, args: list[Any] | None = None) -> None: |
| 76 | + """Queue a signal to be delivered before the next iteration. |
| 77 | +
|
| 78 | + Signals are drained in the order they were queued and injected into |
| 79 | + the workflow history as ``SignalReceived`` events; the replayer then |
| 80 | + dispatches each to its registered ``@workflow.signal`` handler. |
| 81 | + """ |
| 82 | + self._pending_signals.append((name, list(args) if args is not None else [])) |
| 83 | + |
| 84 | + def execute_workflow( |
| 85 | + self, |
| 86 | + workflow_cls: type, |
| 87 | + *args: Any, |
| 88 | + run_id: str = "test-run", |
| 89 | + ) -> Any: |
| 90 | + """Drive ``workflow_cls`` to a terminal state and return its result. |
| 91 | +
|
| 92 | + Raises :class:`~durable_workflow.errors.WorkflowFailed` if the workflow |
| 93 | + ended in the ``failed`` state. Activities that do not have a |
| 94 | + registered mock raise :class:`KeyError` so tests fail loudly on |
| 95 | + missing fixtures. |
| 96 | + """ |
| 97 | + history: list[dict[str, Any]] = [] |
| 98 | + |
| 99 | + for _ in range(self._iteration_limit): |
| 100 | + self._drain_pending_signals_into(history) |
| 101 | + outcome = replay(workflow_cls, history, list(args), run_id=run_id) |
| 102 | + terminal = self._apply_commands(outcome, history) |
| 103 | + if terminal is not _SENTINEL: |
| 104 | + return terminal |
| 105 | + |
| 106 | + raise RuntimeError( |
| 107 | + f"workflow did not terminate within {self._iteration_limit} iterations; " |
| 108 | + "check for missing activity mocks or signals that never satisfy a wait." |
| 109 | + ) |
| 110 | + |
| 111 | + def _drain_pending_signals_into(self, history: list[dict[str, Any]]) -> None: |
| 112 | + while self._pending_signals: |
| 113 | + name, sig_args = self._pending_signals.pop(0) |
| 114 | + history.append( |
| 115 | + { |
| 116 | + "event_type": "SignalReceived", |
| 117 | + "payload": { |
| 118 | + "signal_name": name, |
| 119 | + "value": serializer.envelope(sig_args), |
| 120 | + "payload_codec": serializer.AVRO_CODEC, |
| 121 | + }, |
| 122 | + } |
| 123 | + ) |
| 124 | + |
| 125 | + def _apply_commands( |
| 126 | + self, outcome: ReplayOutcome, history: list[dict[str, Any]] |
| 127 | + ) -> Any: |
| 128 | + for cmd in outcome.commands: |
| 129 | + if isinstance(cmd, CompleteWorkflow): |
| 130 | + return cmd.result |
| 131 | + if isinstance(cmd, FailWorkflow): |
| 132 | + raise WorkflowFailed(cmd.message, cmd.exception_type) |
| 133 | + if isinstance(cmd, ContinueAsNew): |
| 134 | + raise NotImplementedError( |
| 135 | + "continue_as_new is not yet supported by the test harness; " |
| 136 | + "drive each run explicitly with a separate execute_workflow call." |
| 137 | + ) |
| 138 | + if isinstance(cmd, ScheduleActivity): |
| 139 | + history.append(self._resolve_activity(cmd)) |
| 140 | + elif isinstance(cmd, StartTimer): |
| 141 | + history.append({"event_type": "TimerFired", "payload": {}}) |
| 142 | + elif isinstance(cmd, StartChildWorkflow): |
| 143 | + history.append(self._resolve_child_workflow(cmd)) |
| 144 | + elif isinstance(cmd, RecordSideEffect): |
| 145 | + history.append( |
| 146 | + { |
| 147 | + "event_type": "SideEffectRecorded", |
| 148 | + "payload": { |
| 149 | + "result": serializer.envelope(cmd.result), |
| 150 | + "payload_codec": serializer.AVRO_CODEC, |
| 151 | + }, |
| 152 | + } |
| 153 | + ) |
| 154 | + elif isinstance(cmd, UpsertSearchAttributes): |
| 155 | + history.append( |
| 156 | + {"event_type": "SearchAttributesUpserted", "payload": {}} |
| 157 | + ) |
| 158 | + elif isinstance(cmd, RecordVersionMarker): |
| 159 | + history.append( |
| 160 | + { |
| 161 | + "event_type": "VersionMarkerRecorded", |
| 162 | + "payload": {"version": cmd.version}, |
| 163 | + } |
| 164 | + ) |
| 165 | + else: |
| 166 | + raise TypeError(f"unsupported command in test harness: {cmd!r}") |
| 167 | + return _SENTINEL |
| 168 | + |
| 169 | + def _resolve_activity(self, cmd: ScheduleActivity) -> dict[str, Any]: |
| 170 | + if cmd.activity_type in self._activity_fns: |
| 171 | + result = self._activity_fns[cmd.activity_type](*cmd.arguments) |
| 172 | + elif cmd.activity_type in self._activity_results: |
| 173 | + result = self._activity_results[cmd.activity_type] |
| 174 | + else: |
| 175 | + raise KeyError( |
| 176 | + f"no mock registered for activity {cmd.activity_type!r}; " |
| 177 | + "call env.register_activity_result() or env.register_activity()." |
| 178 | + ) |
| 179 | + return { |
| 180 | + "event_type": "ActivityCompleted", |
| 181 | + "payload": { |
| 182 | + "result": serializer.envelope(result), |
| 183 | + "payload_codec": serializer.AVRO_CODEC, |
| 184 | + }, |
| 185 | + } |
| 186 | + |
| 187 | + def _resolve_child_workflow(self, cmd: StartChildWorkflow) -> dict[str, Any]: |
| 188 | + if cmd.workflow_type not in self._child_workflow_results: |
| 189 | + raise KeyError( |
| 190 | + f"no mock registered for child workflow {cmd.workflow_type!r}; " |
| 191 | + "call env.register_child_workflow_result()." |
| 192 | + ) |
| 193 | + return { |
| 194 | + "event_type": "ChildRunCompleted", |
| 195 | + "payload": { |
| 196 | + "result": serializer.envelope(self._child_workflow_results[cmd.workflow_type]), |
| 197 | + "payload_codec": serializer.AVRO_CODEC, |
| 198 | + }, |
| 199 | + } |
| 200 | + |
| 201 | + |
| 202 | +# Sentinel marking "no terminal command seen this iteration". |
| 203 | +_SENTINEL = object() |
| 204 | + |
| 205 | + |
| 206 | +def replay_history( |
| 207 | + workflow_cls: type, |
| 208 | + history_events: Iterable[dict[str, Any]], |
| 209 | + start_input: list[Any] | None = None, |
| 210 | + *, |
| 211 | + run_id: str = "", |
| 212 | + payload_codec: str | None = None, |
| 213 | +) -> ReplayOutcome: |
| 214 | + """Replay a production history against current workflow code. |
| 215 | +
|
| 216 | + Hands the durable history directly to the worker's replayer. Raises any |
| 217 | + exception the workflow would raise during replay — for example a |
| 218 | + non-determinism failure when ``run`` yields a different command sequence |
| 219 | + from the one recorded in history. |
| 220 | +
|
| 221 | + This is the supported way to regression-test a workflow change against |
| 222 | + real production traffic: dump the history from ``Client.get_history``, |
| 223 | + save the JSON, and replay it on every PR. |
| 224 | + """ |
| 225 | + return replay( |
| 226 | + workflow_cls, |
| 227 | + history_events, |
| 228 | + list(start_input or []), |
| 229 | + run_id=run_id, |
| 230 | + payload_codec=payload_codec, |
| 231 | + ) |
| 232 | + |
| 233 | + |
| 234 | +def replay_history_file( |
| 235 | + workflow_cls: type, |
| 236 | + path: str | Path, |
| 237 | + start_input: list[Any] | None = None, |
| 238 | + *, |
| 239 | + run_id: str = "", |
| 240 | + payload_codec: str | None = None, |
| 241 | +) -> ReplayOutcome: |
| 242 | + """Convenience wrapper: load a JSON history file and replay it. |
| 243 | +
|
| 244 | + Accepts either a list of events at the top level or a dict with an |
| 245 | + ``events`` key (matching the shape of ``Client.get_history``). |
| 246 | + """ |
| 247 | + data = json.loads(Path(path).read_text()) |
| 248 | + events = data["events"] if isinstance(data, dict) else data |
| 249 | + return replay_history( |
| 250 | + workflow_cls, |
| 251 | + events, |
| 252 | + start_input, |
| 253 | + run_id=run_id, |
| 254 | + payload_codec=payload_codec, |
| 255 | + ) |
| 256 | + |
| 257 | + |
| 258 | +__all__ = [ |
| 259 | + "WorkflowEnvironment", |
| 260 | + "replay_history", |
| 261 | + "replay_history_file", |
| 262 | +] |
| 263 | + |
| 264 | + |
| 265 | +# Re-export terminal exceptions the harness may raise so tests can catch |
| 266 | +# them without hunting for the right import path. |
| 267 | +_TERMINAL_EXCEPTIONS = (WorkflowFailed, WorkflowCancelled, WorkflowTerminated) |
0 commit comments