Skip to content

Commit 193bf32

Browse files
committed
docs: add docstring to mark_as_cancelled and improve test coverage
- Add comprehensive docstring to mark_as_cancelled() explaining it's reserved for future cancellation feature implementation - Rename test_trigger_ttl.py to test_trigger_cron.py to match module name - Expand parametrized tests: test all 3 mark functions with 3 retention periods (9 combinations) for consistent coverage - Add 9 new tests covering all trigger_cron.py functions: - get_due_triggers (with/without triggers) - call_trigger_graph - create_next_triggers (success, DuplicateKeyError, other exceptions) - handle_trigger (success and failure paths) - trigger_cron orchestration - Total: 18 comprehensive tests with timezone-aware datetime assertions Signed-off-by: Sparsh <sparsh.raj30@gmail.com>
1 parent 16434a6 commit 193bf32

3 files changed

Lines changed: 303 additions & 85 deletions

File tree

state-manager/app/tasks/trigger_cron.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,14 @@ async def mark_as_failed(trigger: DatabaseTriggers, retention_days: int):
4646
)
4747

4848
async def mark_as_cancelled(trigger: DatabaseTriggers, retention_days: int):
49+
"""
50+
Mark a trigger as CANCELLED and set TTL for cleanup.
51+
52+
Note: This function is reserved for future cancellation feature implementation.
53+
Currently, there is no production code path that cancels triggers, but this
54+
function ensures complete terminal state handling (TRIGGERED, FAILED, CANCELLED)
55+
with consistent TTL behavior.
56+
"""
4957
expires_at = datetime.now(timezone.utc) + timedelta(days=retention_days)
5058

5159
await DatabaseTriggers.get_pymongo_collection().update_one(
Lines changed: 295 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,295 @@
1+
"""
2+
Tests for trigger TTL (Time To Live) expiration logic.
3+
Verifies that completed/failed triggers are properly marked for cleanup.
4+
"""
5+
import pytest
6+
from unittest.mock import MagicMock, AsyncMock, patch
7+
from datetime import datetime, timedelta, timezone
8+
from pymongo.errors import DuplicateKeyError
9+
10+
from app.tasks.trigger_cron import (
11+
mark_as_triggered,
12+
mark_as_failed,
13+
mark_as_cancelled,
14+
get_due_triggers,
15+
call_trigger_graph,
16+
create_next_triggers,
17+
handle_trigger,
18+
trigger_cron
19+
)
20+
from app.models.db.trigger import DatabaseTriggers
21+
from app.models.trigger_models import TriggerStatusEnum
22+
23+
24+
@pytest.mark.asyncio
25+
@pytest.mark.parametrize("mark_function,expected_status", [
26+
(mark_as_triggered, TriggerStatusEnum.TRIGGERED),
27+
(mark_as_failed, TriggerStatusEnum.FAILED),
28+
(mark_as_cancelled, TriggerStatusEnum.CANCELLED),
29+
])
30+
async def test_mark_trigger_sets_expires_at(mark_function, expected_status):
31+
"""Test that marking a trigger sets the expires_at field correctly"""
32+
# Create a mock trigger
33+
trigger = MagicMock(spec=DatabaseTriggers)
34+
trigger.id = "test_trigger_id"
35+
36+
# Mock the database update
37+
with patch.object(DatabaseTriggers, 'get_pymongo_collection') as mock_collection:
38+
mock_collection.return_value.update_one = AsyncMock()
39+
40+
# Call the function with retention_days parameter
41+
await mark_function(trigger, retention_days=30)
42+
43+
# Verify update_one was called
44+
assert mock_collection.return_value.update_one.called
45+
call_args = mock_collection.return_value.update_one.call_args
46+
47+
# Verify the filter (first argument)
48+
assert call_args[0][0] == {"_id": trigger.id}
49+
50+
# Verify the update includes both status and expires_at
51+
update_dict = call_args[0][1]["$set"]
52+
assert update_dict["trigger_status"] == expected_status
53+
assert "expires_at" in update_dict
54+
55+
# Verify expires_at is approximately 30 days from now (UTC)
56+
expires_at = update_dict["expires_at"]
57+
expected_expiry = datetime.now(timezone.utc) + timedelta(days=30)
58+
time_diff = abs((expires_at - expected_expiry).total_seconds())
59+
assert time_diff < 2 # Within 2 seconds tolerance
60+
61+
# Verify expires_at is timezone-aware UTC
62+
assert expires_at.tzinfo is not None
63+
assert expires_at.tzinfo == timezone.utc
64+
65+
66+
@pytest.mark.asyncio
67+
@pytest.mark.parametrize("mark_function,retention_days", [
68+
(mark_as_triggered, 7),
69+
(mark_as_triggered, 14),
70+
(mark_as_triggered, 21),
71+
(mark_as_failed, 7),
72+
(mark_as_failed, 14),
73+
(mark_as_failed, 21),
74+
(mark_as_cancelled, 7),
75+
(mark_as_cancelled, 14),
76+
(mark_as_cancelled, 21),
77+
])
78+
async def test_mark_trigger_uses_custom_retention_period(mark_function, retention_days):
79+
"""Test that custom retention period is respected across all mark functions"""
80+
# Create a mock trigger
81+
trigger = MagicMock(spec=DatabaseTriggers)
82+
trigger.id = "test_trigger_id"
83+
84+
# Mock the database update
85+
with patch.object(DatabaseTriggers, 'get_pymongo_collection') as mock_collection:
86+
mock_collection.return_value.update_one = AsyncMock()
87+
88+
# Call the function with custom retention period
89+
await mark_function(trigger, retention_days=retention_days)
90+
91+
# Verify expires_at is approximately retention_days from now (UTC)
92+
call_args = mock_collection.return_value.update_one.call_args
93+
update_dict = call_args[0][1]["$set"]
94+
expires_at = update_dict["expires_at"]
95+
expected_expiry = datetime.now(timezone.utc) + timedelta(days=retention_days)
96+
time_diff = abs((expires_at - expected_expiry).total_seconds())
97+
assert time_diff < 2 # Within 2 seconds tolerance
98+
99+
# Verify expires_at is timezone-aware UTC
100+
assert expires_at.tzinfo is not None
101+
assert expires_at.tzinfo == timezone.utc
102+
103+
@pytest.mark.asyncio
104+
async def test_get_due_triggers_returns_trigger():
105+
"""Test get_due_triggers returns a PENDING trigger"""
106+
cron_time = datetime.now(timezone.utc)
107+
108+
with patch.object(DatabaseTriggers, 'get_pymongo_collection') as mock_collection:
109+
with patch.object(DatabaseTriggers, '__init__', return_value=None):
110+
mock_collection.return_value.find_one_and_update = AsyncMock(return_value={"_id": "trigger_id"})
111+
112+
result = await get_due_triggers(cron_time)
113+
114+
# Verify the query
115+
call_args = mock_collection.return_value.find_one_and_update.call_args
116+
assert call_args[0][0] == {
117+
"trigger_time": {"$lte": cron_time},
118+
"trigger_status": TriggerStatusEnum.PENDING
119+
}
120+
assert call_args[0][1] == {"$set": {"trigger_status": TriggerStatusEnum.TRIGGERING}}
121+
122+
# Verify result is not None (data was returned)
123+
assert result is not None
124+
125+
126+
@pytest.mark.asyncio
127+
async def test_get_due_triggers_returns_none_when_no_triggers():
128+
"""Test get_due_triggers returns None when no triggers are due"""
129+
cron_time = datetime.now(timezone.utc)
130+
131+
with patch.object(DatabaseTriggers, 'get_pymongo_collection') as mock_collection:
132+
mock_collection.return_value.find_one_and_update = AsyncMock(return_value=None)
133+
134+
result = await get_due_triggers(cron_time)
135+
136+
assert result is None
137+
138+
139+
@pytest.mark.asyncio
140+
async def test_call_trigger_graph():
141+
"""Test call_trigger_graph calls trigger_graph controller"""
142+
trigger = MagicMock(spec=DatabaseTriggers)
143+
trigger.namespace = "test_ns"
144+
trigger.graph_name = "test_graph"
145+
146+
with patch('app.tasks.trigger_cron.trigger_graph') as mock_trigger_graph:
147+
mock_trigger_graph.return_value = AsyncMock()
148+
149+
await call_trigger_graph(trigger)
150+
151+
# Verify trigger_graph was called with correct parameters
152+
mock_trigger_graph.assert_called_once()
153+
call_kwargs = mock_trigger_graph.call_args.kwargs
154+
assert call_kwargs['namespace_name'] == "test_ns"
155+
assert call_kwargs['graph_name'] == "test_graph"
156+
assert 'body' in call_kwargs
157+
assert 'x_exosphere_request_id' in call_kwargs
158+
159+
160+
@pytest.mark.asyncio
161+
async def test_create_next_triggers_creates_future_trigger():
162+
"""Test create_next_triggers creates next trigger in the future"""
163+
cron_time = datetime.now(timezone.utc)
164+
trigger = MagicMock(spec=DatabaseTriggers)
165+
trigger.expression = "0 9 * * *"
166+
trigger.trigger_time = cron_time - timedelta(days=1)
167+
trigger.graph_name = "test_graph"
168+
trigger.namespace = "test_ns"
169+
170+
with patch('app.tasks.trigger_cron.DatabaseTriggers') as MockDatabaseTriggers:
171+
mock_instance = MagicMock()
172+
mock_instance.insert = AsyncMock()
173+
MockDatabaseTriggers.return_value = mock_instance
174+
175+
await create_next_triggers(trigger, cron_time)
176+
177+
# Verify at least one trigger was created
178+
assert MockDatabaseTriggers.called
179+
assert mock_instance.insert.called
180+
181+
182+
@pytest.mark.asyncio
183+
async def test_create_next_triggers_handles_duplicate_key_error():
184+
"""Test create_next_triggers handles DuplicateKeyError gracefully"""
185+
cron_time = datetime.now(timezone.utc)
186+
trigger = MagicMock(spec=DatabaseTriggers)
187+
trigger.expression = "0 9 * * *"
188+
trigger.trigger_time = cron_time - timedelta(days=1)
189+
trigger.graph_name = "test_graph"
190+
trigger.namespace = "test_ns"
191+
192+
with patch('app.tasks.trigger_cron.DatabaseTriggers') as MockDatabaseTriggers:
193+
mock_instance = MagicMock()
194+
mock_instance.insert = AsyncMock(side_effect=DuplicateKeyError("duplicate"))
195+
MockDatabaseTriggers.return_value = mock_instance
196+
197+
# Should not raise exception
198+
await create_next_triggers(trigger, cron_time)
199+
200+
201+
@pytest.mark.asyncio
202+
async def test_create_next_triggers_raises_on_other_exceptions():
203+
"""Test create_next_triggers raises on non-DuplicateKeyError exceptions"""
204+
cron_time = datetime.now(timezone.utc)
205+
trigger = MagicMock(spec=DatabaseTriggers)
206+
trigger.expression = "0 9 * * *"
207+
trigger.trigger_time = cron_time - timedelta(days=1)
208+
trigger.graph_name = "test_graph"
209+
trigger.namespace = "test_ns"
210+
211+
with patch('app.tasks.trigger_cron.DatabaseTriggers') as MockDatabaseTriggers:
212+
mock_instance = MagicMock()
213+
mock_instance.insert = AsyncMock(side_effect=ValueError("test error"))
214+
MockDatabaseTriggers.return_value = mock_instance
215+
216+
with pytest.raises(ValueError, match="test error"):
217+
await create_next_triggers(trigger, cron_time)
218+
219+
220+
@pytest.mark.asyncio
221+
async def test_handle_trigger_success_path():
222+
"""Test handle_trigger processes trigger successfully"""
223+
cron_time = datetime.now(timezone.utc)
224+
trigger = MagicMock(spec=DatabaseTriggers)
225+
trigger.id = "trigger_id"
226+
trigger.expression = "0 9 * * *"
227+
trigger.trigger_time = cron_time - timedelta(days=1)
228+
trigger.graph_name = "test_graph"
229+
trigger.namespace = "test_ns"
230+
231+
with patch('app.tasks.trigger_cron.get_due_triggers') as mock_get_due:
232+
with patch('app.tasks.trigger_cron.call_trigger_graph') as mock_call:
233+
with patch('app.tasks.trigger_cron.mark_as_triggered') as mock_mark_triggered:
234+
with patch('app.tasks.trigger_cron.create_next_triggers') as mock_create_next:
235+
# Return trigger once, then None to stop loop
236+
mock_get_due.side_effect = [trigger, None]
237+
mock_call.return_value = AsyncMock()
238+
mock_mark_triggered.return_value = AsyncMock()
239+
mock_create_next.return_value = AsyncMock()
240+
241+
await handle_trigger(cron_time, retention_days=30)
242+
243+
# Verify all functions were called
244+
assert mock_call.called
245+
assert mock_mark_triggered.called
246+
assert mock_create_next.called
247+
248+
249+
@pytest.mark.asyncio
250+
async def test_handle_trigger_failure_path():
251+
"""Test handle_trigger marks trigger as failed on exception"""
252+
cron_time = datetime.now(timezone.utc)
253+
trigger = MagicMock(spec=DatabaseTriggers)
254+
trigger.id = "trigger_id"
255+
trigger.expression = "0 9 * * *"
256+
trigger.trigger_time = cron_time - timedelta(days=1)
257+
trigger.graph_name = "test_graph"
258+
trigger.namespace = "test_ns"
259+
260+
with patch('app.tasks.trigger_cron.get_due_triggers') as mock_get_due:
261+
with patch('app.tasks.trigger_cron.call_trigger_graph') as mock_call:
262+
with patch('app.tasks.trigger_cron.mark_as_failed') as mock_mark_failed:
263+
with patch('app.tasks.trigger_cron.create_next_triggers') as mock_create_next:
264+
# Return trigger once, then None
265+
mock_get_due.side_effect = [trigger, None]
266+
mock_call.side_effect = Exception("Trigger failed")
267+
mock_mark_failed.return_value = AsyncMock()
268+
mock_create_next.return_value = AsyncMock()
269+
270+
await handle_trigger(cron_time, retention_days=30)
271+
272+
# Verify mark_as_failed was called
273+
mock_mark_failed.assert_called_once_with(trigger, 30)
274+
# Verify create_next_triggers was still called (finally block)
275+
assert mock_create_next.called
276+
277+
278+
@pytest.mark.asyncio
279+
async def test_trigger_cron():
280+
"""Test trigger_cron orchestrates handle_trigger with settings"""
281+
with patch('app.tasks.trigger_cron.get_settings') as mock_get_settings:
282+
with patch('app.tasks.trigger_cron.handle_trigger') as mock_handle:
283+
mock_settings = MagicMock()
284+
mock_settings.trigger_retention_days = 30
285+
mock_settings.trigger_workers = 2
286+
mock_get_settings.return_value = mock_settings
287+
mock_handle.return_value = AsyncMock()
288+
289+
await trigger_cron()
290+
291+
# Verify handle_trigger was called correct number of times
292+
assert mock_handle.call_count == 2
293+
# Verify retention_days parameter was passed
294+
for call in mock_handle.call_args_list:
295+
assert call[0][1] == 30 # retention_days

state-manager/tests/unit/tasks/test_trigger_ttl.py

Lines changed: 0 additions & 85 deletions
This file was deleted.

0 commit comments

Comments
 (0)