Skip to content

Commit 2dc4e62

Browse files
committed
Add DTS tests
1 parent 91057ab commit 2dc4e62

File tree

1 file changed

+297
-0
lines changed

1 file changed

+297
-0
lines changed
Lines changed: 297 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,297 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
4+
import json
5+
import os
6+
7+
import pytest
8+
9+
from durabletask import client, task
10+
from durabletask.azuremanaged.client import DurableTaskSchedulerClient
11+
from durabletask.azuremanaged.worker import DurableTaskSchedulerWorker
12+
13+
# NOTE: These tests assume a sidecar process is running. Example command:
14+
# docker run -i -p 8080:8080 -p 8082:8082 -d mcr.microsoft.com/dts/dts-emulator:latest
15+
pytestmark = pytest.mark.dts
16+
17+
# Read the environment variables
18+
taskhub_name = os.getenv("TASKHUB", "default")
19+
endpoint = os.getenv("ENDPOINT", "http://localhost:8080")
20+
21+
22+
def _get_credential():
23+
"""Returns DefaultAzureCredential if endpoint is https, otherwise None (for emulator)."""
24+
if endpoint.startswith("https://"):
25+
from azure.identity import DefaultAzureCredential
26+
return DefaultAzureCredential()
27+
return None
28+
29+
30+
# ---------------------------------------------------------------------------
31+
# Tests
32+
# ---------------------------------------------------------------------------
33+
34+
35+
def test_rewind_failed_activity():
36+
"""Rewind a failed orchestration whose single activity failed.
37+
38+
After rewind the activity succeeds and the orchestration completes.
39+
"""
40+
activity_call_count = 0
41+
should_fail = True
42+
43+
def failing_activity(_: task.ActivityContext, input: str) -> str:
44+
nonlocal activity_call_count
45+
activity_call_count += 1
46+
if should_fail:
47+
raise RuntimeError("Simulated failure")
48+
return f"Hello, {input}!"
49+
50+
def orchestrator(ctx: task.OrchestrationContext, input: str):
51+
result = yield ctx.call_activity(failing_activity, input=input)
52+
return result
53+
54+
with DurableTaskSchedulerWorker(host_address=endpoint, secure_channel=True,
55+
taskhub=taskhub_name, token_credential=None) as w:
56+
w.add_orchestrator(orchestrator)
57+
w.add_activity(failing_activity)
58+
w.start()
59+
60+
c = DurableTaskSchedulerClient(host_address=endpoint, secure_channel=True,
61+
taskhub=taskhub_name, token_credential=None)
62+
instance_id = c.schedule_new_orchestration(orchestrator, input="World")
63+
state = c.wait_for_orchestration_completion(instance_id, timeout=30)
64+
65+
# The orchestration should have failed.
66+
assert state is not None
67+
assert state.runtime_status == client.OrchestrationStatus.FAILED
68+
69+
# Fix the activity so it now succeeds, then rewind.
70+
should_fail = False
71+
c.rewind_orchestration(instance_id, reason="retry after fix")
72+
73+
state = c.wait_for_orchestration_completion(instance_id, timeout=30)
74+
75+
assert state is not None
76+
assert state.runtime_status == client.OrchestrationStatus.COMPLETED
77+
assert state.serialized_output == json.dumps("Hello, World!")
78+
assert state.failure_details is None
79+
# Activity was called twice (once failed, once succeeded after rewind).
80+
assert activity_call_count == 2
81+
82+
83+
def test_rewind_preserves_successful_results():
84+
"""When an orchestration has a mix of successful and failed activities,
85+
rewind should re-execute only the failed activity while the successful
86+
result is replayed from history."""
87+
call_tracker: dict[str, int] = {"first": 0, "second": 0}
88+
89+
def first_activity(_: task.ActivityContext, input: str) -> str:
90+
call_tracker["first"] += 1
91+
return f"first:{input}"
92+
93+
def second_activity(_: task.ActivityContext, input: str) -> str:
94+
call_tracker["second"] += 1
95+
if call_tracker["second"] == 1:
96+
raise RuntimeError("Temporary failure")
97+
return f"second:{input}"
98+
99+
def orchestrator(ctx: task.OrchestrationContext, input: str):
100+
r1 = yield ctx.call_activity(first_activity, input=input)
101+
r2 = yield ctx.call_activity(second_activity, input=input)
102+
return [r1, r2]
103+
104+
with DurableTaskSchedulerWorker(host_address=endpoint, secure_channel=True,
105+
taskhub=taskhub_name, token_credential=None) as w:
106+
w.add_orchestrator(orchestrator)
107+
w.add_activity(first_activity)
108+
w.add_activity(second_activity)
109+
w.start()
110+
111+
c = DurableTaskSchedulerClient(host_address=endpoint, secure_channel=True,
112+
taskhub=taskhub_name, token_credential=None)
113+
instance_id = c.schedule_new_orchestration(orchestrator, input="test")
114+
state = c.wait_for_orchestration_completion(instance_id, timeout=30)
115+
116+
# The orchestration should have failed (second_activity fails).
117+
assert state is not None
118+
assert state.runtime_status == client.OrchestrationStatus.FAILED
119+
120+
# Rewind – second_activity will now succeed on retry.
121+
c.rewind_orchestration(instance_id, reason="retry")
122+
state = c.wait_for_orchestration_completion(instance_id, timeout=30)
123+
124+
assert state is not None
125+
assert state.runtime_status == client.OrchestrationStatus.COMPLETED
126+
assert state.serialized_output == json.dumps(["first:test", "second:test"])
127+
assert state.failure_details is None
128+
# first_activity should NOT be re-executed – its result is replayed.
129+
assert call_tracker["first"] == 1
130+
# second_activity was called twice (once failed, once succeeded).
131+
assert call_tracker["second"] == 2
132+
133+
134+
def test_rewind_not_found():
135+
"""Rewinding a non-existent instance should raise an RPC error."""
136+
c = DurableTaskSchedulerClient(host_address=endpoint, secure_channel=True,
137+
taskhub=taskhub_name, token_credential=None)
138+
with pytest.raises(Exception):
139+
c.rewind_orchestration("nonexistent-id")
140+
141+
142+
def test_rewind_non_failed_instance():
143+
"""Rewinding a completed (non-failed) instance should raise an error."""
144+
def orchestrator(ctx: task.OrchestrationContext, _):
145+
return "done"
146+
147+
with DurableTaskSchedulerWorker(host_address=endpoint, secure_channel=True,
148+
taskhub=taskhub_name, token_credential=None) as w:
149+
w.add_orchestrator(orchestrator)
150+
w.start()
151+
152+
c = DurableTaskSchedulerClient(host_address=endpoint, secure_channel=True,
153+
taskhub=taskhub_name, token_credential=None)
154+
instance_id = c.schedule_new_orchestration(orchestrator)
155+
state = c.wait_for_orchestration_completion(instance_id, timeout=30)
156+
assert state is not None
157+
assert state.runtime_status == client.OrchestrationStatus.COMPLETED
158+
159+
with pytest.raises(Exception):
160+
c.rewind_orchestration(instance_id)
161+
162+
163+
def test_rewind_with_sub_orchestration():
164+
"""Rewind should recursively rewind failed sub-orchestrations."""
165+
sub_call_count = 0
166+
167+
def child_activity(_: task.ActivityContext, input: str) -> str:
168+
nonlocal sub_call_count
169+
sub_call_count += 1
170+
if sub_call_count == 1:
171+
raise RuntimeError("Child failure")
172+
return f"child:{input}"
173+
174+
def child_orchestrator(ctx: task.OrchestrationContext, input: str):
175+
result = yield ctx.call_activity(child_activity, input=input)
176+
return result
177+
178+
def parent_orchestrator(ctx: task.OrchestrationContext, input: str):
179+
result = yield ctx.call_sub_orchestrator(
180+
child_orchestrator, input=input)
181+
return f"parent:{result}"
182+
183+
with DurableTaskSchedulerWorker(host_address=endpoint, secure_channel=True,
184+
taskhub=taskhub_name, token_credential=None) as w:
185+
w.add_orchestrator(parent_orchestrator)
186+
w.add_orchestrator(child_orchestrator)
187+
w.add_activity(child_activity)
188+
w.start()
189+
190+
c = DurableTaskSchedulerClient(host_address=endpoint, secure_channel=True,
191+
taskhub=taskhub_name, token_credential=None)
192+
instance_id = c.schedule_new_orchestration(
193+
parent_orchestrator, input="data")
194+
state = c.wait_for_orchestration_completion(instance_id, timeout=30)
195+
196+
# Parent should fail because child failed.
197+
assert state is not None
198+
assert state.runtime_status == client.OrchestrationStatus.FAILED
199+
200+
# Rewind – child_activity will succeed on retry.
201+
c.rewind_orchestration(instance_id, reason="sub-orch fix")
202+
state = c.wait_for_orchestration_completion(instance_id, timeout=30)
203+
204+
assert state is not None
205+
assert state.runtime_status == client.OrchestrationStatus.COMPLETED
206+
assert state.serialized_output == json.dumps("parent:child:data")
207+
assert sub_call_count == 2
208+
209+
210+
def test_rewind_without_reason():
211+
"""Rewind should work when no reason is provided."""
212+
call_count = 0
213+
214+
def flaky_activity(_: task.ActivityContext, _1) -> str:
215+
nonlocal call_count
216+
call_count += 1
217+
if call_count == 1:
218+
raise RuntimeError("Boom")
219+
return "ok"
220+
221+
def orchestrator(ctx: task.OrchestrationContext, _):
222+
result = yield ctx.call_activity(flaky_activity)
223+
return result
224+
225+
with DurableTaskSchedulerWorker(host_address=endpoint, secure_channel=True,
226+
taskhub=taskhub_name, token_credential=None) as w:
227+
w.add_orchestrator(orchestrator)
228+
w.add_activity(flaky_activity)
229+
w.start()
230+
231+
c = DurableTaskSchedulerClient(host_address=endpoint, secure_channel=True,
232+
taskhub=taskhub_name, token_credential=None)
233+
instance_id = c.schedule_new_orchestration(orchestrator)
234+
state = c.wait_for_orchestration_completion(instance_id, timeout=30)
235+
assert state is not None
236+
assert state.runtime_status == client.OrchestrationStatus.FAILED
237+
238+
# Rewind without a reason
239+
c.rewind_orchestration(instance_id)
240+
state = c.wait_for_orchestration_completion(instance_id, timeout=30)
241+
242+
assert state is not None
243+
assert state.runtime_status == client.OrchestrationStatus.COMPLETED
244+
assert state.serialized_output == json.dumps("ok")
245+
246+
247+
def test_rewind_twice():
248+
"""Rewind the same orchestration twice after it fails a second time.
249+
250+
The first rewind cleans up the initial failure. The activity then
251+
fails again. A second rewind should clean up the new failure and
252+
the orchestration should eventually complete.
253+
"""
254+
call_count = 0
255+
256+
def flaky_activity(_: task.ActivityContext, input: str) -> str:
257+
nonlocal call_count
258+
call_count += 1
259+
# Fail on the 1st and 2nd calls; succeed on the 3rd.
260+
if call_count <= 2:
261+
raise RuntimeError(f"Failure #{call_count}")
262+
return f"Hello, {input}!"
263+
264+
def orchestrator(ctx: task.OrchestrationContext, input: str):
265+
result = yield ctx.call_activity(flaky_activity, input=input)
266+
return result
267+
268+
with DurableTaskSchedulerWorker(host_address=endpoint, secure_channel=True,
269+
taskhub=taskhub_name, token_credential=None) as w:
270+
w.add_orchestrator(orchestrator)
271+
w.add_activity(flaky_activity)
272+
w.start()
273+
274+
c = DurableTaskSchedulerClient(host_address=endpoint, secure_channel=True,
275+
taskhub=taskhub_name, token_credential=None)
276+
instance_id = c.schedule_new_orchestration(orchestrator, input="World")
277+
state = c.wait_for_orchestration_completion(instance_id, timeout=30)
278+
279+
# First failure.
280+
assert state is not None
281+
assert state.runtime_status == client.OrchestrationStatus.FAILED
282+
283+
# First rewind — activity will fail again (call_count == 2).
284+
c.rewind_orchestration(instance_id, reason="first rewind")
285+
state = c.wait_for_orchestration_completion(instance_id, timeout=30)
286+
287+
assert state is not None
288+
assert state.runtime_status == client.OrchestrationStatus.FAILED
289+
290+
# Second rewind — activity will succeed (call_count == 3).
291+
c.rewind_orchestration(instance_id, reason="second rewind")
292+
state = c.wait_for_orchestration_completion(instance_id, timeout=30)
293+
294+
assert state is not None
295+
assert state.runtime_status == client.OrchestrationStatus.COMPLETED
296+
assert state.serialized_output == json.dumps("Hello, World!")
297+
assert call_count == 3

0 commit comments

Comments
 (0)