77from procrastinate import testing
88
99from taskbadger import StatusEnum
10- from taskbadger .procrastinate import TB_TASK_ID_KWARG , _instrument_task , track
10+ from taskbadger .procrastinate import TB_TASK_ID_KWARG , _instrument_task , _task_cache , current_task , track
1111from tests .utils import task_for_test
1212
1313
@@ -19,6 +19,13 @@ def _check_log_errors(caplog):
1919 pytest .fail (f"log errors during tests: { errors } " )
2020
2121
22+ @pytest .fixture (autouse = True )
23+ def _clear_task_cache ():
24+ _task_cache .cache .clear ()
25+ yield
26+ _task_cache .cache .clear ()
27+
28+
2229@pytest .fixture
2330def app ():
2431 in_memory = testing .InMemoryConnector ()
@@ -248,3 +255,55 @@ def bad():
248255
249256 with pytest .raises (TypeError , match = "unexpected keyword" ):
250257 track (name = "x" , does_not_exist = True )(bad )
258+
259+
260+ @pytest .mark .usefixtures ("_bind_settings" )
261+ def test_current_task_inside_body (app ):
262+ captured = {}
263+
264+ @track
265+ @app .task (name = "capture" )
266+ def capture ():
267+ captured ["task" ] = current_task ()
268+
269+ tb = task_for_test ()
270+ with (
271+ mock .patch ("taskbadger.procrastinate.create_task_safe" , return_value = tb ),
272+ mock .patch ("taskbadger.procrastinate.update_task_safe" , return_value = tb ),
273+ mock .patch ("taskbadger.procrastinate.get_task" , return_value = tb ),
274+ ):
275+ capture .defer ()
276+ app .run_worker (wait = False , install_signal_handlers = False , listen_notify = False )
277+
278+ assert captured ["task" ] is not None
279+ assert captured ["task" ].id == tb .id
280+
281+
282+ def test_current_task_outside_returns_none ():
283+ assert current_task () is None
284+
285+
286+ @pytest .mark .usefixtures ("_bind_settings" )
287+ def test_user_set_terminal_state_not_overwritten (app ):
288+ @track
289+ @app .task (name = "self_complete" )
290+ def self_complete ():
291+ pass
292+
293+ tb_pending = task_for_test (status = StatusEnum .PENDING )
294+ tb_done = task_for_test (id = tb_pending .id , status = StatusEnum .SUCCESS )
295+
296+ with (
297+ mock .patch ("taskbadger.procrastinate.create_task_safe" , return_value = tb_pending ),
298+ mock .patch ("taskbadger.procrastinate.update_task_safe" ) as update ,
299+ mock .patch ("taskbadger.procrastinate.get_task" , return_value = tb_done ),
300+ ):
301+ self_complete .defer ()
302+ app .run_worker (wait = False , install_signal_handlers = False , listen_notify = False )
303+
304+ # The wrapper's post-call SUCCESS update is skipped because the cached
305+ # task is already SUCCESS. PROCESSING update is still allowed (early path).
306+ statuses = [c .kwargs ["status" ] for c in update .call_args_list ]
307+ assert StatusEnum .PROCESSING in statuses
308+ # Last attempted SUCCESS call should be suppressed
309+ assert statuses .count (StatusEnum .SUCCESS ) == 0
0 commit comments