Skip to content

Commit bc60bc1

Browse files
snopokeclaude
andcommitted
feat(procrastinate): current_task() accessor via ContextVar
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 8a1b778 commit bc60bc1

2 files changed

Lines changed: 76 additions & 1 deletion

File tree

taskbadger/procrastinate.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,10 @@ def _update_status(tb_id, status, exception=None):
118118
return
119119

120120
if exception is not None or status in TERMINAL_STATES:
121+
# Bypass the cache for the terminal-state check: the user may have
122+
# updated the task to a terminal state via the regular SDK during
123+
# the body, which wouldn't be reflected in our local cache.
124+
_task_cache.unset(tb_id)
121125
current = _safe_get_task(tb_id)
122126
if current is not None and current.status in TERMINAL_STATES:
123127
return
@@ -260,3 +264,15 @@ def wrap(task):
260264
if original_task is None:
261265
return wrap
262266
return wrap(original_task)
267+
268+
269+
def current_task():
270+
"""Return the TaskBadger Task for the currently-running Procrastinate job.
271+
272+
Returns ``None`` outside of a tracked task or if the task can't be fetched.
273+
Result is cached for the lifetime of the worker process via an LRU.
274+
"""
275+
tb_id = _current_tb_task_id.get()
276+
if tb_id is None:
277+
return None
278+
return _safe_get_task(tb_id)

tests/test_procrastinate.py

Lines changed: 60 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from procrastinate import testing
88

99
from taskbadger import StatusEnum
10-
from taskbadger.procrastinate import TB_TASK_ID_KWARG, _instrument_task, track
10+
from taskbadger.procrastinate import TB_TASK_ID_KWARG, _instrument_task, _task_cache, current_task, track
1111
from tests.utils import task_for_test
1212

1313

@@ -19,6 +19,13 @@ def _check_log_errors(caplog):
1919
pytest.fail(f"log errors during tests: {errors}")
2020

2121

22+
@pytest.fixture(autouse=True)
23+
def _clear_task_cache():
24+
_task_cache.cache.clear()
25+
yield
26+
_task_cache.cache.clear()
27+
28+
2229
@pytest.fixture
2330
def app():
2431
in_memory = testing.InMemoryConnector()
@@ -248,3 +255,55 @@ def bad():
248255

249256
with pytest.raises(TypeError, match="unexpected keyword"):
250257
track(name="x", does_not_exist=True)(bad)
258+
259+
260+
@pytest.mark.usefixtures("_bind_settings")
261+
def test_current_task_inside_body(app):
262+
captured = {}
263+
264+
@track
265+
@app.task(name="capture")
266+
def capture():
267+
captured["task"] = current_task()
268+
269+
tb = task_for_test()
270+
with (
271+
mock.patch("taskbadger.procrastinate.create_task_safe", return_value=tb),
272+
mock.patch("taskbadger.procrastinate.update_task_safe", return_value=tb),
273+
mock.patch("taskbadger.procrastinate.get_task", return_value=tb),
274+
):
275+
capture.defer()
276+
app.run_worker(wait=False, install_signal_handlers=False, listen_notify=False)
277+
278+
assert captured["task"] is not None
279+
assert captured["task"].id == tb.id
280+
281+
282+
def test_current_task_outside_returns_none():
283+
assert current_task() is None
284+
285+
286+
@pytest.mark.usefixtures("_bind_settings")
287+
def test_user_set_terminal_state_not_overwritten(app):
288+
@track
289+
@app.task(name="self_complete")
290+
def self_complete():
291+
pass
292+
293+
tb_pending = task_for_test(status=StatusEnum.PENDING)
294+
tb_done = task_for_test(id=tb_pending.id, status=StatusEnum.SUCCESS)
295+
296+
with (
297+
mock.patch("taskbadger.procrastinate.create_task_safe", return_value=tb_pending),
298+
mock.patch("taskbadger.procrastinate.update_task_safe") as update,
299+
mock.patch("taskbadger.procrastinate.get_task", return_value=tb_done),
300+
):
301+
self_complete.defer()
302+
app.run_worker(wait=False, install_signal_handlers=False, listen_notify=False)
303+
304+
# The wrapper's post-call SUCCESS update is skipped because the cached
305+
# task is already SUCCESS. PROCESSING update is still allowed (early path).
306+
statuses = [c.kwargs["status"] for c in update.call_args_list]
307+
assert StatusEnum.PROCESSING in statuses
308+
# Last attempted SUCCESS call should be suppressed
309+
assert statuses.count(StatusEnum.SUCCESS) == 0

0 commit comments

Comments
 (0)