Skip to content

Commit 9321ed4

Browse files
committed
feat(procrastinate): worker-side task wrapper with sync/async support
1 parent 96f857e commit 9321ed4

2 files changed

Lines changed: 200 additions & 0 deletions

File tree

taskbadger/procrastinate.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,15 @@
1111

1212
from __future__ import annotations
1313

14+
import functools
15+
import inspect
1416
import logging
1517
from contextvars import ContextVar
1618

1719
from .internal.models import StatusEnum
20+
from .mug import Badger
21+
from .safe_sdk import update_task_safe
22+
from .sdk import DefaultMergeStrategy, get_task
1823

1924
log = logging.getLogger("taskbadger")
2025

@@ -37,3 +42,93 @@
3742
_OPTS_ATTR = "_taskbadger_opts"
3843

3944
_current_tb_task_id: ContextVar[str | None] = ContextVar("_current_tb_task_id", default=None)
45+
46+
47+
def _instrument_task(task, system=None, manual=False, opts=None):
48+
"""Wrap a Procrastinate Task's ``func`` so the worker side updates TaskBadger.
49+
50+
Idempotent: a second call on the same task is a no-op (but ``manual`` and
51+
``opts`` will be merged onto the existing attributes if provided).
52+
"""
53+
if opts is not None:
54+
existing_opts = getattr(task, _OPTS_ATTR, {}) or {}
55+
merged = {**existing_opts, **opts}
56+
setattr(task, _OPTS_ATTR, merged)
57+
elif not hasattr(task, _OPTS_ATTR):
58+
setattr(task, _OPTS_ATTR, {})
59+
60+
if manual:
61+
setattr(task, _MANUAL_ATTR, True)
62+
63+
if getattr(task, _INSTRUMENTED_ATTR, False):
64+
return
65+
66+
original_func = task.func
67+
is_async = inspect.iscoroutinefunction(original_func)
68+
69+
if is_async:
70+
71+
@functools.wraps(original_func)
72+
async def wrapped(*args, **kwargs):
73+
tb_id = kwargs.pop(TB_TASK_ID_KWARG, None)
74+
if tb_id is None:
75+
return await original_func(*args, **kwargs)
76+
token = _current_tb_task_id.set(tb_id)
77+
try:
78+
_update_status(tb_id, StatusEnum.PROCESSING)
79+
try:
80+
result = await original_func(*args, **kwargs)
81+
except Exception as exc:
82+
_update_status(tb_id, StatusEnum.ERROR, exception=exc)
83+
raise
84+
_update_status(tb_id, StatusEnum.SUCCESS)
85+
return result
86+
finally:
87+
_current_tb_task_id.reset(token)
88+
else:
89+
90+
@functools.wraps(original_func)
91+
def wrapped(*args, **kwargs):
92+
tb_id = kwargs.pop(TB_TASK_ID_KWARG, None)
93+
if tb_id is None:
94+
return original_func(*args, **kwargs)
95+
token = _current_tb_task_id.set(tb_id)
96+
try:
97+
_update_status(tb_id, StatusEnum.PROCESSING)
98+
try:
99+
result = original_func(*args, **kwargs)
100+
except Exception as exc:
101+
_update_status(tb_id, StatusEnum.ERROR, exception=exc)
102+
raise
103+
_update_status(tb_id, StatusEnum.SUCCESS)
104+
return result
105+
finally:
106+
_current_tb_task_id.reset(token)
107+
108+
task.func = wrapped
109+
setattr(task, _INSTRUMENTED_ATTR, True)
110+
setattr(task, "_taskbadger_system", system)
111+
112+
113+
def _update_status(tb_id, status, exception=None):
114+
"""Update the TaskBadger task to ``status``. Skips if already terminal."""
115+
if not Badger.is_configured():
116+
return
117+
118+
if exception is not None or status in TERMINAL_STATES:
119+
try:
120+
current = get_task(tb_id)
121+
except Exception as e:
122+
log.warning("Error fetching task '%s': %s", tb_id, e)
123+
current = None
124+
if current is not None and current.status in TERMINAL_STATES:
125+
return
126+
data = None
127+
if exception is not None and current is not None:
128+
data = DefaultMergeStrategy().merge(current.data, {"exception": str(exception)})
129+
if data:
130+
update_task_safe(tb_id, status=status, data=data)
131+
else:
132+
update_task_safe(tb_id, status=status)
133+
else:
134+
update_task_safe(tb_id, status=status)

tests/test_procrastinate.py

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
import asyncio
2+
import logging
3+
from unittest import mock
4+
5+
import procrastinate
6+
import pytest
7+
from procrastinate import testing
8+
9+
from taskbadger import StatusEnum
10+
from taskbadger.procrastinate import TB_TASK_ID_KWARG, _instrument_task
11+
from tests.utils import task_for_test
12+
13+
14+
@pytest.fixture(autouse=True)
15+
def _check_log_errors(caplog):
16+
yield
17+
errors = [r.getMessage() for r in caplog.get_records("call") if r.levelno == logging.ERROR]
18+
if errors:
19+
pytest.fail(f"log errors during tests: {errors}")
20+
21+
22+
@pytest.fixture
23+
def app():
24+
in_memory = testing.InMemoryConnector()
25+
app = procrastinate.App(connector=in_memory)
26+
with app.open():
27+
yield app
28+
29+
30+
@pytest.mark.usefixtures("_bind_settings")
31+
def test_worker_updates_task_sync(app):
32+
@app.task(name="add")
33+
def add(a, b):
34+
return a + b
35+
36+
_instrument_task(add, system=None, manual=True)
37+
38+
with (
39+
mock.patch("taskbadger.procrastinate.update_task_safe") as update,
40+
mock.patch("taskbadger.procrastinate.get_task") as get,
41+
):
42+
get.return_value = task_for_test(status=StatusEnum.PROCESSING)
43+
add.func(a=2, b=3, **{TB_TASK_ID_KWARG: "tb-123"})
44+
45+
statuses = [call.kwargs["status"] for call in update.call_args_list]
46+
assert statuses == [StatusEnum.PROCESSING, StatusEnum.SUCCESS]
47+
# The reserved key must not appear in the calls (it's stripped before user fn)
48+
assert all(TB_TASK_ID_KWARG not in c.kwargs for c in update.call_args_list)
49+
50+
51+
@pytest.mark.usefixtures("_bind_settings")
52+
def test_worker_updates_task_async(app):
53+
@app.task(name="add_async")
54+
async def add_async(a, b):
55+
return a + b
56+
57+
_instrument_task(add_async, system=None, manual=True)
58+
59+
with (
60+
mock.patch("taskbadger.procrastinate.update_task_safe") as update,
61+
mock.patch("taskbadger.procrastinate.get_task") as get,
62+
):
63+
get.return_value = task_for_test(status=StatusEnum.PROCESSING)
64+
result = asyncio.run(add_async.func(a=2, b=3, **{TB_TASK_ID_KWARG: "tb-456"}))
65+
66+
assert result == 5
67+
statuses = [call.kwargs["status"] for call in update.call_args_list]
68+
assert statuses == [StatusEnum.PROCESSING, StatusEnum.SUCCESS]
69+
70+
71+
@pytest.mark.usefixtures("_bind_settings")
72+
def test_worker_marks_error(app):
73+
@app.task(name="boom")
74+
def boom():
75+
raise ValueError("nope")
76+
77+
_instrument_task(boom, system=None, manual=True)
78+
79+
with (
80+
mock.patch("taskbadger.procrastinate.update_task_safe") as update,
81+
mock.patch("taskbadger.procrastinate.get_task") as get,
82+
):
83+
get.return_value = task_for_test(status=StatusEnum.PROCESSING, data={"x": 1})
84+
with pytest.raises(ValueError, match="nope"):
85+
boom.func(**{TB_TASK_ID_KWARG: "tb-789"})
86+
87+
statuses = [call.kwargs["status"] for call in update.call_args_list]
88+
assert statuses == [StatusEnum.PROCESSING, StatusEnum.ERROR]
89+
err_call = update.call_args_list[-1]
90+
assert err_call.kwargs["data"] == {"x": 1, "exception": "nope"}
91+
92+
93+
@pytest.mark.usefixtures("_bind_settings")
94+
def test_worker_no_id_runs_clean(app):
95+
@app.task(name="add2")
96+
def add2(a, b):
97+
return a + b
98+
99+
_instrument_task(add2, system=None, manual=True)
100+
101+
with mock.patch("taskbadger.procrastinate.update_task_safe") as update:
102+
result = add2.func(a=1, b=2)
103+
104+
assert result == 3
105+
update.assert_not_called()

0 commit comments

Comments
 (0)