Skip to content

Commit 5a9893a

Browse files
committed
docs: add docstring to mark_as_cancelled and improve test coverage
- Add comprehensive docstring to mark_as_cancelled() explaining it's reserved for future cancellation feature implementation - Rename test_trigger_ttl.py to test_trigger_cron.py to match module name - Expand parametrized tests: test all 3 mark functions with 3 retention periods (9 combinations) for consistent coverage - Add 9 new tests covering all trigger_cron.py functions: - get_due_triggers (with/without triggers) - call_trigger_graph - create_next_triggers (success, DuplicateKeyError, other exceptions) - handle_trigger (success and failure paths) - trigger_cron orchestration - Total: 18 comprehensive tests with timezone-aware datetime assertions Signed-off-by: Sparsh <sparsh.raj30@gmail.com>
1 parent 16434a6 commit 5a9893a

6 files changed

Lines changed: 311 additions & 108 deletions

File tree

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
version = "0.0.3b2"
1+
version = "0.0.3b1"

state-manager/app/config/settings.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ class Settings(BaseModel):
1313
state_manager_secret: str = Field(..., description="Secret key for API authentication")
1414
secrets_encryption_key: str = Field(..., description="Key for encrypting secrets")
1515
trigger_workers: int = Field(default=1, description="Number of workers to run the trigger cron")
16-
trigger_retention_days: int = Field(default=30, description="Number of days to retain completed/failed triggers before cleanup")
16+
trigger_retention_hours: int = Field(default=24, description="Number of hours to retain completed/failed triggers before cleanup")
1717

1818
@classmethod
1919
def from_env(cls) -> "Settings":
@@ -23,7 +23,7 @@ def from_env(cls) -> "Settings":
2323
state_manager_secret=os.getenv("STATE_MANAGER_SECRET"), # type: ignore
2424
secrets_encryption_key=os.getenv("SECRETS_ENCRYPTION_KEY"), # type: ignore
2525
trigger_workers=int(os.getenv("TRIGGER_WORKERS", 1)), # type: ignore
26-
trigger_retention_days=int(os.getenv("TRIGGER_RETENTION_DAYS", 30)) # type: ignore
26+
trigger_retention_hours=int(os.getenv("TRIGGER_RETENTION_HOURS", 24)) # type: ignore
2727
)
2828

2929

state-manager/app/models/db/trigger.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,15 @@ class Settings:
3939
("expires_at", 1),
4040
],
4141
name="ttl_expires_at",
42-
expireAfterSeconds=0 # Delete immediately when expires_at is reached
42+
expireAfterSeconds=0, # Delete immediately when expires_at is reached
43+
partialFilterExpression={
44+
"trigger_status": {
45+
"$in": [
46+
TriggerStatusEnum.TRIGGERED,
47+
TriggerStatusEnum.FAILED,
48+
TriggerStatusEnum.CANCELLED
49+
]
50+
}
51+
}
4352
)
4453
]

state-manager/app/tasks/trigger_cron.py

Lines changed: 8 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,8 @@ async def call_trigger_graph(trigger: DatabaseTriggers):
3434
x_exosphere_request_id=str(uuid4())
3535
)
3636

37-
async def mark_as_failed(trigger: DatabaseTriggers, retention_days: int):
38-
expires_at = datetime.now(timezone.utc) + timedelta(days=retention_days)
37+
async def mark_as_failed(trigger: DatabaseTriggers, retention_hours: int):
38+
expires_at = datetime.now(timezone.utc) + timedelta(hours=retention_hours)
3939

4040
await DatabaseTriggers.get_pymongo_collection().update_one(
4141
{"_id": trigger.id},
@@ -45,17 +45,6 @@ async def mark_as_failed(trigger: DatabaseTriggers, retention_days: int):
4545
}}
4646
)
4747

48-
async def mark_as_cancelled(trigger: DatabaseTriggers, retention_days: int):
49-
expires_at = datetime.now(timezone.utc) + timedelta(days=retention_days)
50-
51-
await DatabaseTriggers.get_pymongo_collection().update_one(
52-
{"_id": trigger.id},
53-
{"$set": {
54-
"trigger_status": TriggerStatusEnum.CANCELLED,
55-
"expires_at": expires_at
56-
}}
57-
)
58-
5948
async def create_next_triggers(trigger: DatabaseTriggers, cron_time: datetime):
6049
assert trigger.expression is not None
6150
iter = croniter.croniter(trigger.expression, trigger.trigger_time)
@@ -81,8 +70,8 @@ async def create_next_triggers(trigger: DatabaseTriggers, cron_time: datetime):
8170
if next_trigger_time > cron_time:
8271
break
8372

84-
async def mark_as_triggered(trigger: DatabaseTriggers, retention_days: int):
85-
expires_at = datetime.now(timezone.utc) + timedelta(days=retention_days)
73+
async def mark_as_triggered(trigger: DatabaseTriggers, retention_hours: int):
74+
expires_at = datetime.now(timezone.utc) + timedelta(hours=retention_hours)
8675

8776
await DatabaseTriggers.get_pymongo_collection().update_one(
8877
{"_id": trigger.id},
@@ -92,13 +81,13 @@ async def mark_as_triggered(trigger: DatabaseTriggers, retention_days: int):
9281
}}
9382
)
9483

95-
async def handle_trigger(cron_time: datetime, retention_days: int):
84+
async def handle_trigger(cron_time: datetime, retention_hours: int):
9685
while(trigger:= await get_due_triggers(cron_time)):
9786
try:
9887
await call_trigger_graph(trigger)
99-
await mark_as_triggered(trigger, retention_days)
88+
await mark_as_triggered(trigger, retention_hours)
10089
except Exception as e:
101-
await mark_as_failed(trigger, retention_days)
90+
await mark_as_failed(trigger, retention_hours)
10291
logger.error(f"Error calling trigger graph: {e}")
10392
finally:
10493
await create_next_triggers(trigger, cron_time)
@@ -107,4 +96,4 @@ async def trigger_cron():
10796
cron_time = datetime.now()
10897
settings = get_settings()
10998
logger.info(f"starting trigger_cron: {cron_time}")
110-
await asyncio.gather(*[handle_trigger(cron_time, settings.trigger_retention_days) for _ in range(settings.trigger_workers)])
99+
await asyncio.gather(*[handle_trigger(cron_time, settings.trigger_retention_hours) for _ in range(settings.trigger_workers)])
Lines changed: 290 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,290 @@
1+
"""
2+
Tests for trigger TTL (Time To Live) expiration logic.
3+
Verifies that completed/failed triggers are properly marked for cleanup.
4+
"""
5+
import pytest
6+
from unittest.mock import MagicMock, AsyncMock, patch
7+
from datetime import datetime, timedelta, timezone
8+
from pymongo.errors import DuplicateKeyError
9+
10+
from app.tasks.trigger_cron import (
11+
mark_as_triggered,
12+
mark_as_failed,
13+
get_due_triggers,
14+
call_trigger_graph,
15+
create_next_triggers,
16+
handle_trigger,
17+
trigger_cron
18+
)
19+
from app.models.db.trigger import DatabaseTriggers
20+
from app.models.trigger_models import TriggerStatusEnum
21+
22+
23+
@pytest.mark.asyncio
24+
@pytest.mark.parametrize("mark_function,expected_status", [
25+
(mark_as_triggered, TriggerStatusEnum.TRIGGERED),
26+
(mark_as_failed, TriggerStatusEnum.FAILED),
27+
])
28+
async def test_mark_trigger_sets_expires_at(mark_function, expected_status):
29+
"""Test that marking a trigger sets the expires_at field correctly"""
30+
# Create a mock trigger
31+
trigger = MagicMock(spec=DatabaseTriggers)
32+
trigger.id = "test_trigger_id"
33+
34+
# Mock the database update
35+
with patch.object(DatabaseTriggers, 'get_pymongo_collection') as mock_collection:
36+
mock_collection.return_value.update_one = AsyncMock()
37+
38+
# Call the function with retention_hours parameter
39+
await mark_function(trigger, retention_hours=24)
40+
41+
# Verify update_one was called
42+
assert mock_collection.return_value.update_one.called
43+
call_args = mock_collection.return_value.update_one.call_args
44+
45+
# Verify the filter (first argument)
46+
assert call_args[0][0] == {"_id": trigger.id}
47+
48+
# Verify the update includes both status and expires_at
49+
update_dict = call_args[0][1]["$set"]
50+
assert update_dict["trigger_status"] == expected_status
51+
assert "expires_at" in update_dict
52+
53+
# Verify expires_at is approximately 24 hours from now (UTC)
54+
expires_at = update_dict["expires_at"]
55+
expected_expiry = datetime.now(timezone.utc) + timedelta(hours=24)
56+
time_diff = abs((expires_at - expected_expiry).total_seconds())
57+
assert time_diff < 2 # Within 2 seconds tolerance
58+
59+
# Verify expires_at is timezone-aware UTC
60+
assert expires_at.tzinfo is not None
61+
assert expires_at.tzinfo == timezone.utc
62+
63+
64+
@pytest.mark.asyncio
65+
@pytest.mark.parametrize("mark_function,retention_hours", [
66+
(mark_as_triggered, 12),
67+
(mark_as_triggered, 24),
68+
(mark_as_triggered, 48),
69+
(mark_as_failed, 12),
70+
(mark_as_failed, 24),
71+
(mark_as_failed, 48),
72+
])
73+
async def test_mark_trigger_uses_custom_retention_period(mark_function, retention_hours):
74+
"""Test that custom retention period is respected across all mark functions"""
75+
# Create a mock trigger
76+
trigger = MagicMock(spec=DatabaseTriggers)
77+
trigger.id = "test_trigger_id"
78+
79+
# Mock the database update
80+
with patch.object(DatabaseTriggers, 'get_pymongo_collection') as mock_collection:
81+
mock_collection.return_value.update_one = AsyncMock()
82+
83+
# Call the function with custom retention period
84+
await mark_function(trigger, retention_hours=retention_hours)
85+
86+
# Verify expires_at is approximately retention_hours from now (UTC)
87+
call_args = mock_collection.return_value.update_one.call_args
88+
update_dict = call_args[0][1]["$set"]
89+
expires_at = update_dict["expires_at"]
90+
expected_expiry = datetime.now(timezone.utc) + timedelta(hours=retention_hours)
91+
time_diff = abs((expires_at - expected_expiry).total_seconds())
92+
assert time_diff < 2 # Within 2 seconds tolerance
93+
94+
# Verify expires_at is timezone-aware UTC
95+
assert expires_at.tzinfo is not None
96+
assert expires_at.tzinfo == timezone.utc
97+
98+
@pytest.mark.asyncio
99+
async def test_get_due_triggers_returns_trigger():
100+
"""Test get_due_triggers returns a PENDING trigger"""
101+
cron_time = datetime.now(timezone.utc)
102+
103+
with patch.object(DatabaseTriggers, 'get_pymongo_collection') as mock_collection:
104+
with patch.object(DatabaseTriggers, '__init__', return_value=None):
105+
mock_collection.return_value.find_one_and_update = AsyncMock(return_value={"_id": "trigger_id"})
106+
107+
result = await get_due_triggers(cron_time)
108+
109+
# Verify the query
110+
call_args = mock_collection.return_value.find_one_and_update.call_args
111+
assert call_args[0][0] == {
112+
"trigger_time": {"$lte": cron_time},
113+
"trigger_status": TriggerStatusEnum.PENDING
114+
}
115+
assert call_args[0][1] == {"$set": {"trigger_status": TriggerStatusEnum.TRIGGERING}}
116+
117+
# Verify result is not None (data was returned)
118+
assert result is not None
119+
120+
121+
@pytest.mark.asyncio
122+
async def test_get_due_triggers_returns_none_when_no_triggers():
123+
"""Test get_due_triggers returns None when no triggers are due"""
124+
cron_time = datetime.now(timezone.utc)
125+
126+
with patch.object(DatabaseTriggers, 'get_pymongo_collection') as mock_collection:
127+
mock_collection.return_value.find_one_and_update = AsyncMock(return_value=None)
128+
129+
result = await get_due_triggers(cron_time)
130+
131+
assert result is None
132+
133+
134+
@pytest.mark.asyncio
135+
async def test_call_trigger_graph():
136+
"""Test call_trigger_graph calls trigger_graph controller"""
137+
trigger = MagicMock(spec=DatabaseTriggers)
138+
trigger.namespace = "test_ns"
139+
trigger.graph_name = "test_graph"
140+
141+
with patch('app.tasks.trigger_cron.trigger_graph') as mock_trigger_graph:
142+
mock_trigger_graph.return_value = AsyncMock()
143+
144+
await call_trigger_graph(trigger)
145+
146+
# Verify trigger_graph was called with correct parameters
147+
mock_trigger_graph.assert_called_once()
148+
call_kwargs = mock_trigger_graph.call_args.kwargs
149+
assert call_kwargs['namespace_name'] == "test_ns"
150+
assert call_kwargs['graph_name'] == "test_graph"
151+
assert 'body' in call_kwargs
152+
assert 'x_exosphere_request_id' in call_kwargs
153+
154+
155+
@pytest.mark.asyncio
156+
async def test_create_next_triggers_creates_future_trigger():
157+
"""Test create_next_triggers creates next trigger in the future"""
158+
cron_time = datetime.now(timezone.utc)
159+
trigger = MagicMock(spec=DatabaseTriggers)
160+
trigger.expression = "0 9 * * *"
161+
trigger.trigger_time = cron_time - timedelta(days=1)
162+
trigger.graph_name = "test_graph"
163+
trigger.namespace = "test_ns"
164+
165+
with patch('app.tasks.trigger_cron.DatabaseTriggers') as MockDatabaseTriggers:
166+
mock_instance = MagicMock()
167+
mock_instance.insert = AsyncMock()
168+
MockDatabaseTriggers.return_value = mock_instance
169+
170+
await create_next_triggers(trigger, cron_time)
171+
172+
# Verify at least one trigger was created
173+
assert MockDatabaseTriggers.called
174+
assert mock_instance.insert.called
175+
176+
177+
@pytest.mark.asyncio
178+
async def test_create_next_triggers_handles_duplicate_key_error():
179+
"""Test create_next_triggers handles DuplicateKeyError gracefully"""
180+
cron_time = datetime.now(timezone.utc)
181+
trigger = MagicMock(spec=DatabaseTriggers)
182+
trigger.expression = "0 9 * * *"
183+
trigger.trigger_time = cron_time - timedelta(days=1)
184+
trigger.graph_name = "test_graph"
185+
trigger.namespace = "test_ns"
186+
187+
with patch('app.tasks.trigger_cron.DatabaseTriggers') as MockDatabaseTriggers:
188+
mock_instance = MagicMock()
189+
mock_instance.insert = AsyncMock(side_effect=DuplicateKeyError("duplicate"))
190+
MockDatabaseTriggers.return_value = mock_instance
191+
192+
# Should not raise exception
193+
await create_next_triggers(trigger, cron_time)
194+
195+
196+
@pytest.mark.asyncio
197+
async def test_create_next_triggers_raises_on_other_exceptions():
198+
"""Test create_next_triggers raises on non-DuplicateKeyError exceptions"""
199+
cron_time = datetime.now(timezone.utc)
200+
trigger = MagicMock(spec=DatabaseTriggers)
201+
trigger.expression = "0 9 * * *"
202+
trigger.trigger_time = cron_time - timedelta(days=1)
203+
trigger.graph_name = "test_graph"
204+
trigger.namespace = "test_ns"
205+
206+
with patch('app.tasks.trigger_cron.DatabaseTriggers') as MockDatabaseTriggers:
207+
mock_instance = MagicMock()
208+
mock_instance.insert = AsyncMock(side_effect=ValueError("test error"))
209+
MockDatabaseTriggers.return_value = mock_instance
210+
211+
with pytest.raises(ValueError, match="test error"):
212+
await create_next_triggers(trigger, cron_time)
213+
214+
215+
@pytest.mark.asyncio
216+
async def test_handle_trigger_success_path():
217+
"""Test handle_trigger processes trigger successfully"""
218+
cron_time = datetime.now(timezone.utc)
219+
trigger = MagicMock(spec=DatabaseTriggers)
220+
trigger.id = "trigger_id"
221+
trigger.expression = "0 9 * * *"
222+
trigger.trigger_time = cron_time - timedelta(days=1)
223+
trigger.graph_name = "test_graph"
224+
trigger.namespace = "test_ns"
225+
226+
with patch('app.tasks.trigger_cron.get_due_triggers') as mock_get_due:
227+
with patch('app.tasks.trigger_cron.call_trigger_graph') as mock_call:
228+
with patch('app.tasks.trigger_cron.mark_as_triggered') as mock_mark_triggered:
229+
with patch('app.tasks.trigger_cron.create_next_triggers') as mock_create_next:
230+
# Return trigger once, then None to stop loop
231+
mock_get_due.side_effect = [trigger, None]
232+
mock_call.return_value = AsyncMock()
233+
mock_mark_triggered.return_value = AsyncMock()
234+
mock_create_next.return_value = AsyncMock()
235+
236+
await handle_trigger(cron_time, retention_hours=24)
237+
238+
# Verify all functions were called
239+
assert mock_call.called
240+
assert mock_mark_triggered.called
241+
assert mock_create_next.called
242+
243+
244+
@pytest.mark.asyncio
245+
async def test_handle_trigger_failure_path():
246+
"""Test handle_trigger marks trigger as failed on exception"""
247+
cron_time = datetime.now(timezone.utc)
248+
trigger = MagicMock(spec=DatabaseTriggers)
249+
trigger.id = "trigger_id"
250+
trigger.expression = "0 9 * * *"
251+
trigger.trigger_time = cron_time - timedelta(days=1)
252+
trigger.graph_name = "test_graph"
253+
trigger.namespace = "test_ns"
254+
255+
with patch('app.tasks.trigger_cron.get_due_triggers') as mock_get_due:
256+
with patch('app.tasks.trigger_cron.call_trigger_graph') as mock_call:
257+
with patch('app.tasks.trigger_cron.mark_as_failed') as mock_mark_failed:
258+
with patch('app.tasks.trigger_cron.create_next_triggers') as mock_create_next:
259+
# Return trigger once, then None
260+
mock_get_due.side_effect = [trigger, None]
261+
mock_call.side_effect = Exception("Trigger failed")
262+
mock_mark_failed.return_value = AsyncMock()
263+
mock_create_next.return_value = AsyncMock()
264+
265+
await handle_trigger(cron_time, retention_hours=24)
266+
267+
# Verify mark_as_failed was called
268+
mock_mark_failed.assert_called_once_with(trigger, 24)
269+
# Verify create_next_triggers was still called (finally block)
270+
assert mock_create_next.called
271+
272+
273+
@pytest.mark.asyncio
274+
async def test_trigger_cron():
275+
"""Test trigger_cron orchestrates handle_trigger with settings"""
276+
with patch('app.tasks.trigger_cron.get_settings') as mock_get_settings:
277+
with patch('app.tasks.trigger_cron.handle_trigger') as mock_handle:
278+
mock_settings = MagicMock()
279+
mock_settings.trigger_retention_hours = 24
280+
mock_settings.trigger_workers = 2
281+
mock_get_settings.return_value = mock_settings
282+
mock_handle.return_value = AsyncMock()
283+
284+
await trigger_cron()
285+
286+
# Verify handle_trigger was called correct number of times
287+
assert mock_handle.call_count == 2
288+
# Verify retention_hours parameter was passed
289+
for call in mock_handle.call_args_list:
290+
assert call[0][1] == 24 # retention_hours

0 commit comments

Comments
 (0)