Skip to content

Commit ad65f65

Browse files
committed
feat(procrastinate): defer-time task creation with id injection
1 parent 9321ed4 commit ad65f65

2 files changed

Lines changed: 171 additions & 8 deletions

File tree

taskbadger/procrastinate.py

Lines changed: 94 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,15 @@
1111

1212
from __future__ import annotations
1313

14+
import collections
1415
import functools
1516
import inspect
1617
import logging
1718
from contextvars import ContextVar
1819

1920
from .internal.models import StatusEnum
2021
from .mug import Badger
21-
from .safe_sdk import update_task_safe
22+
from .safe_sdk import create_task_safe, update_task_safe
2223
from .sdk import DefaultMergeStrategy, get_task
2324

2425
log = logging.getLogger("taskbadger")
@@ -105,6 +106,7 @@ def wrapped(*args, **kwargs):
105106
finally:
106107
_current_tb_task_id.reset(token)
107108

109+
_wrap_defer(task)
108110
task.func = wrapped
109111
setattr(task, _INSTRUMENTED_ATTR, True)
110112
setattr(task, "_taskbadger_system", system)
@@ -116,19 +118,103 @@ def _update_status(tb_id, status, exception=None):
116118
return
117119

118120
if exception is not None or status in TERMINAL_STATES:
119-
try:
120-
current = get_task(tb_id)
121-
except Exception as e:
122-
log.warning("Error fetching task '%s': %s", tb_id, e)
123-
current = None
121+
current = _safe_get_task(tb_id)
124122
if current is not None and current.status in TERMINAL_STATES:
125123
return
126124
data = None
127125
if exception is not None and current is not None:
128-
data = DefaultMergeStrategy().merge(current.data, {"exception": str(exception)})
129-
if data:
126+
base = dict(current.data) if current.data else None
127+
data = DefaultMergeStrategy().merge(base, {"exception": str(exception)})
128+
if data is not None:
130129
update_task_safe(tb_id, status=status, data=data)
131130
else:
132131
update_task_safe(tb_id, status=status)
133132
else:
134133
update_task_safe(tb_id, status=status)
134+
135+
136+
class _Cache:
137+
def __init__(self, maxsize=128):
138+
self.cache = collections.OrderedDict()
139+
self.maxsize = maxsize
140+
141+
def set(self, key, value):
142+
self.cache[key] = value
143+
if len(self.cache) > self.maxsize:
144+
self.cache.popitem(last=False)
145+
146+
def get(self, key):
147+
return self.cache.get(key)
148+
149+
def unset(self, key):
150+
self.cache.pop(key, None)
151+
152+
153+
_task_cache = _Cache()
154+
155+
156+
def _safe_get_task(task_id):
157+
cached = _task_cache.get(task_id)
158+
if cached is not None:
159+
return cached
160+
try:
161+
task = get_task(task_id)
162+
except Exception as e:
163+
log.warning("Error fetching task '%s': %s", task_id, e)
164+
return None
165+
_task_cache.set(task_id, task)
166+
return task
167+
168+
169+
def _wrap_defer(task):
170+
"""Wrap ``task.defer`` and ``task.defer_async`` so they create a TaskBadger
171+
task in PENDING state and inject its id into the job's task_kwargs.
172+
173+
The original defer methods are stashed on the task to keep the wrap
174+
idempotent (a second call replaces nothing because the marker is set)."""
175+
original_defer = task.defer
176+
original_defer_async = task.defer_async
177+
178+
@functools.wraps(original_defer)
179+
def defer(**kwargs):
180+
kwargs = _maybe_create_pending(task, kwargs)
181+
return original_defer(**kwargs)
182+
183+
@functools.wraps(original_defer_async)
184+
async def defer_async(**kwargs):
185+
kwargs = _maybe_create_pending(task, kwargs)
186+
return await original_defer_async(**kwargs)
187+
188+
task.defer = defer
189+
task.defer_async = defer_async
190+
191+
192+
def _maybe_create_pending(task, kwargs):
193+
"""Decide whether to track this defer, and if so create the TaskBadger
194+
task and inject its id into ``kwargs``. Always returns the kwargs dict."""
195+
if not Badger.is_configured():
196+
return kwargs
197+
198+
system = getattr(task, "_taskbadger_system", None)
199+
manual = getattr(task, _MANUAL_ATTR, False)
200+
auto = bool(system) and system.track_task(task.name)
201+
if not manual and not auto:
202+
return kwargs
203+
204+
opts = dict(getattr(task, _OPTS_ATTR, {}) or {})
205+
name = opts.pop("name", None) or task.name
206+
create_kwargs = {"status": StatusEnum.PENDING}
207+
for key in ("value_max", "tags"):
208+
if key in opts and opts[key] is not None:
209+
create_kwargs[key] = opts[key]
210+
user_data = opts.get("data")
211+
if user_data:
212+
create_kwargs["data"] = dict(user_data)
213+
214+
tb_task = create_task_safe(name, **create_kwargs)
215+
if tb_task is None:
216+
return kwargs
217+
218+
new_kwargs = dict(kwargs)
219+
new_kwargs[TB_TASK_ID_KWARG] = tb_task.id
220+
return new_kwargs

tests/test_procrastinate.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,3 +103,80 @@ def add2(a, b):
103103

104104
assert result == 3
105105
update.assert_not_called()
106+
107+
108+
@pytest.mark.usefixtures("_bind_settings")
109+
def test_defer_creates_pending_task_and_injects_id(app):
110+
@app.task(name="add3")
111+
def add3(a, b):
112+
return a + b
113+
114+
_instrument_task(add3, system=None, manual=True)
115+
116+
tb = task_for_test()
117+
with mock.patch("taskbadger.procrastinate.create_task_safe", return_value=tb) as create:
118+
add3.defer(a=1, b=2)
119+
120+
create.assert_called_once()
121+
assert create.call_args.args == ("add3",)
122+
assert create.call_args.kwargs == {"status": StatusEnum.PENDING}
123+
124+
# The injected id should appear in the Procrastinate job's task kwargs.
125+
jobs = list(app.connector.jobs.values())
126+
assert len(jobs) == 1
127+
assert jobs[0]["args"][TB_TASK_ID_KWARG] == tb.id
128+
129+
130+
def test_defer_no_taskbadger_when_unconfigured(app):
131+
@app.task(name="add4")
132+
def add4(a, b):
133+
return a + b
134+
135+
_instrument_task(add4, system=None, manual=True)
136+
137+
# Badger is not configured (no _bind_settings fixture).
138+
with mock.patch("taskbadger.procrastinate.create_task_safe") as create:
139+
add4.defer(a=1, b=2)
140+
141+
create.assert_not_called()
142+
jobs = list(app.connector.jobs.values())
143+
assert TB_TASK_ID_KWARG not in jobs[0]["args"]
144+
145+
146+
@pytest.mark.usefixtures("_bind_settings")
147+
def test_defer_async_injects_id(app):
148+
@app.task(name="add5")
149+
async def add5(a, b):
150+
return a + b
151+
152+
_instrument_task(add5, system=None, manual=True)
153+
154+
tb = task_for_test()
155+
with mock.patch("taskbadger.procrastinate.create_task_safe", return_value=tb):
156+
asyncio.run(add5.defer_async(a=1, b=2))
157+
158+
jobs = list(app.connector.jobs.values())
159+
assert jobs[0]["args"][TB_TASK_ID_KWARG] == tb.id
160+
161+
162+
@pytest.mark.usefixtures("_bind_settings")
163+
def test_end_to_end_via_worker(app):
164+
@app.task(name="add6")
165+
def add6(a, b):
166+
return a + b
167+
168+
_instrument_task(add6, system=None, manual=True)
169+
170+
tb = task_for_test()
171+
with (
172+
mock.patch("taskbadger.procrastinate.create_task_safe", return_value=tb) as create,
173+
mock.patch("taskbadger.procrastinate.update_task_safe") as update,
174+
mock.patch("taskbadger.procrastinate.get_task") as get,
175+
):
176+
get.return_value = task_for_test(id=tb.id, status=StatusEnum.PROCESSING)
177+
add6.defer(a=2, b=3)
178+
app.run_worker(wait=False, install_signal_handlers=False, listen_notify=False)
179+
180+
create.assert_called_once()
181+
statuses = [c.kwargs["status"] for c in update.call_args_list]
182+
assert statuses == [StatusEnum.PROCESSING, StatusEnum.SUCCESS]

0 commit comments

Comments
 (0)