Skip to content

Commit d9d7ff0

Browse files
Guard Python wait_condition replay fingerprints
Issue: zorporation/durable-workflow#397 Loop-ID: build-02
1 parent 5df638d commit d9d7ff0

3 files changed

Lines changed: 135 additions & 7 deletions

File tree

README.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,13 @@ server. Query routing and synchronous pre-accept update validator execution are
144144
still server-side follow-ups; use those paths only with deployments that
145145
advertise support for the target workflow type.
146146

147+
Use `yield ctx.wait_condition(lambda: self.approved, key="approved",
148+
timeout=30)` to wait for signal- or update-mutated workflow state without
149+
polling timers by hand. The SDK sends a stable predicate fingerprint with the
150+
durable wait command and rejects replay if history records a different wait
151+
key or predicate fingerprint, so condition changes fail visibly instead of
152+
silently resolving a different wait.
153+
147154
Workers fingerprint registered workflow class definitions and advertise those
148155
fingerprints during registration. Re-registering the same `worker_id` with a
149156
changed class body for an already advertised workflow type raises immediately;

src/durable_workflow/workflow.py

Lines changed: 59 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -678,6 +678,30 @@ def to_server_command(
678678
return cmd
679679

680680

681+
def _condition_predicate_fingerprint(predicate: Callable[[], bool]) -> str:
682+
h = hashlib.sha256()
683+
h.update(b"durable-workflow-python.wait-condition.v1\0")
684+
h.update(f"{getattr(predicate, '__module__', '')}\0".encode())
685+
h.update(f"{getattr(predicate, '__qualname__', '')}\0".encode())
686+
687+
code = getattr(predicate, "__code__", None)
688+
if code is None:
689+
h.update(repr(predicate).encode())
690+
else:
691+
h.update(repr((
692+
code.co_argcount,
693+
code.co_posonlyargcount,
694+
code.co_kwonlyargcount,
695+
code.co_code,
696+
code.co_consts,
697+
code.co_names,
698+
code.co_varnames,
699+
code.co_freevars,
700+
)).encode())
701+
702+
return f"sha256:{h.hexdigest()}"
703+
704+
681705
Command = (
682706
ScheduleActivity | StartTimer | CompleteWorkflow | FailWorkflow
683707
| CompleteUpdate | FailUpdate | ContinueAsNew | RecordSideEffect | StartChildWorkflow
@@ -960,6 +984,7 @@ def wait_condition(
960984
return WaitCondition(
961985
predicate=predicate,
962986
condition_key=key,
987+
condition_definition_fingerprint=_condition_predicate_fingerprint(predicate),
963988
timeout_seconds=timeout_seconds,
964989
)
965990

@@ -1551,10 +1576,10 @@ def _state(commands: list[Command]) -> _ReplayState:
15511576
# external receivers apply before the generator consumes the resolved_result
15521577
# at the stored index, preserving history interleaving with activities.
15531578
pending_receivers: list[tuple[int, str, str, list[Any]]] = []
1554-
# Ordered list of condition_wait_id strings from ConditionWaitOpened events,
1555-
# used by ``WaitCondition`` yields to match against their corresponding
1556-
# opened wait in history (Nth yield ↔ Nth opened).
1557-
wait_opened_ids: list[str] = []
1579+
# Ordered ``ConditionWaitOpened`` payloads, used by ``WaitCondition`` yields
1580+
# to match against their corresponding opened wait in history
1581+
# (Nth yield ↔ Nth opened).
1582+
wait_opened: list[dict[str, Any]] = []
15581583
# Map condition_wait_id → resolution: 'satisfied' (from ConditionWaitSatisfied
15591584
# in history, future server-recorded) or 'timed_out' (from a matching
15601585
# condition_timeout TimerFired event).
@@ -1577,7 +1602,7 @@ def _state(commands: list[Command]) -> _ReplayState:
15771602
elif etype == "ConditionWaitOpened":
15781603
wait_id = payload.get("condition_wait_id")
15791604
if isinstance(wait_id, str) and wait_id:
1580-
wait_opened_ids.append(wait_id)
1605+
wait_opened.append(dict(payload))
15811606
elif etype == "ConditionWaitSatisfied":
15821607
wait_id = payload.get("condition_wait_id")
15831608
if isinstance(wait_id, str) and wait_id:
@@ -1734,8 +1759,35 @@ def _apply_due_receivers() -> None:
17341759
continue
17351760
if isinstance(cmd, WaitCondition):
17361761
resolution: str | None = None
1737-
if wait_yield_count < len(wait_opened_ids):
1738-
resolution = wait_resolutions.get(wait_opened_ids[wait_yield_count])
1762+
opened: dict[str, Any] | None = None
1763+
if wait_yield_count < len(wait_opened):
1764+
opened = wait_opened[wait_yield_count]
1765+
opened_id = opened.get("condition_wait_id")
1766+
if isinstance(opened_id, str):
1767+
resolution = wait_resolutions.get(opened_id)
1768+
opened_key = opened.get("condition_key")
1769+
if isinstance(opened_key, str) and opened_key != (cmd.condition_key or ""):
1770+
return _state([FailWorkflow(
1771+
message=(
1772+
"wait_condition key changed during replay: "
1773+
f"history has {opened_key!r}, workflow yielded "
1774+
f"{cmd.condition_key!r}"
1775+
),
1776+
exception_type="NonDeterministicWaitCondition",
1777+
)])
1778+
opened_fingerprint = opened.get("condition_definition_fingerprint")
1779+
if (
1780+
isinstance(opened_fingerprint, str)
1781+
and cmd.condition_definition_fingerprint != opened_fingerprint
1782+
):
1783+
return _state([FailWorkflow(
1784+
message=(
1785+
"wait_condition predicate fingerprint changed during replay: "
1786+
f"history has {opened_fingerprint!r}, workflow yielded "
1787+
f"{cmd.condition_definition_fingerprint!r}"
1788+
),
1789+
exception_type="NonDeterministicWaitCondition",
1790+
)])
17391791
if resolution == "timed_out":
17401792
next_value = False
17411793
wait_yield_count += 1

tests/test_wait_condition.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from durable_workflow import serializer, workflow
44
from durable_workflow.workflow import (
55
CompleteWorkflow,
6+
FailWorkflow,
67
WaitCondition,
78
WorkflowContext,
89
replay,
@@ -67,6 +68,8 @@ def test_wait_condition_returns_dataclass_with_predicate_and_key(self) -> None:
6768
assert isinstance(cmd, WaitCondition)
6869
assert cmd.condition_key == "k"
6970
assert cmd.timeout_seconds == 5
71+
assert cmd.condition_definition_fingerprint is not None
72+
assert cmd.condition_definition_fingerprint.startswith("sha256:")
7073
assert callable(cmd.predicate)
7174
assert cmd.predicate() is False
7275

@@ -95,6 +98,7 @@ def test_emits_open_condition_wait_with_optional_fields(self) -> None:
9598
cmd = WaitCondition(
9699
predicate=lambda: True,
97100
condition_key="order",
101+
condition_definition_fingerprint="sha256:condition",
98102
timeout_seconds=60,
99103
)
100104

@@ -103,6 +107,7 @@ def test_emits_open_condition_wait_with_optional_fields(self) -> None:
103107
assert server_cmd == {
104108
"type": "open_condition_wait",
105109
"condition_key": "order",
110+
"condition_definition_fingerprint": "sha256:condition",
106111
"timeout_seconds": 60,
107112
}
108113

@@ -196,6 +201,70 @@ def test_open_with_no_resolution_and_predicate_false_re_emits_wait_condition(sel
196201
assert len(outcome.commands) == 1
197202
assert isinstance(outcome.commands[0], WaitCondition)
198203

204+
def test_replay_rejects_changed_condition_key(self) -> None:
205+
history = [
206+
{
207+
"event_type": "ConditionWaitOpened",
208+
"payload": {"condition_wait_id": "wait-1", "condition_key": "old-key"},
209+
},
210+
]
211+
212+
outcome = replay(WaitUntilApproved, history, [])
213+
214+
assert len(outcome.commands) == 1
215+
assert not isinstance(outcome.commands[0], CompleteWorkflow)
216+
cmd = outcome.commands[0]
217+
assert isinstance(cmd, FailWorkflow)
218+
assert cmd.exception_type == "NonDeterministicWaitCondition"
219+
assert "key changed" in cmd.message
220+
221+
def test_replay_rejects_changed_condition_fingerprint(self) -> None:
222+
initial = replay(WaitUntilApproved, [], [])
223+
assert isinstance(initial.commands[0], WaitCondition)
224+
fingerprint = initial.commands[0].condition_definition_fingerprint
225+
assert fingerprint is not None
226+
history = [
227+
{
228+
"event_type": "ConditionWaitOpened",
229+
"payload": {
230+
"condition_wait_id": "wait-1",
231+
"condition_key": "approved",
232+
"condition_definition_fingerprint": "sha256:different",
233+
},
234+
},
235+
]
236+
237+
outcome = replay(WaitUntilApproved, history, [])
238+
239+
assert len(outcome.commands) == 1
240+
cmd = outcome.commands[0]
241+
assert isinstance(cmd, FailWorkflow)
242+
assert cmd.exception_type == "NonDeterministicWaitCondition"
243+
assert "predicate fingerprint changed" in cmd.message
244+
assert fingerprint in cmd.message
245+
246+
def test_replay_accepts_matching_condition_fingerprint(self) -> None:
247+
initial = replay(WaitUntilApproved, [], [])
248+
assert isinstance(initial.commands[0], WaitCondition)
249+
fingerprint = initial.commands[0].condition_definition_fingerprint
250+
history = [
251+
{
252+
"event_type": "ConditionWaitOpened",
253+
"payload": {
254+
"condition_wait_id": "wait-1",
255+
"condition_key": "approved",
256+
"condition_definition_fingerprint": fingerprint,
257+
},
258+
},
259+
_signal_received_event("approve", []),
260+
]
261+
262+
outcome = replay(WaitUntilApproved, history, [])
263+
264+
assert len(outcome.commands) == 1
265+
assert isinstance(outcome.commands[0], CompleteWorkflow)
266+
assert outcome.commands[0].result == "approved"
267+
199268
def test_condition_timeout_does_not_pollute_start_timer_resolved_results(self) -> None:
200269
@workflow.defn(name="wait-then-sleep")
201270
class WaitThenSleep:

0 commit comments

Comments
 (0)