1111
1212from __future__ import annotations
1313
14+ import collections
1415import functools
1516import inspect
1617import logging
1718from contextvars import ContextVar
1819
1920from .internal .models import StatusEnum
2021from .mug import Badger
21- from .safe_sdk import update_task_safe
22+ from .safe_sdk import create_task_safe , update_task_safe
2223from .sdk import DefaultMergeStrategy , get_task
2324
2425log = logging .getLogger ("taskbadger" )
@@ -105,6 +106,7 @@ def wrapped(*args, **kwargs):
105106 finally :
106107 _current_tb_task_id .reset (token )
107108
109+ _wrap_defer (task )
108110 task .func = wrapped
109111 setattr (task , _INSTRUMENTED_ATTR , True )
110112 setattr (task , "_taskbadger_system" , system )
@@ -116,19 +118,103 @@ def _update_status(tb_id, status, exception=None):
116118 return
117119
118120 if exception is not None or status in TERMINAL_STATES :
119- try :
120- current = get_task (tb_id )
121- except Exception as e :
122- log .warning ("Error fetching task '%s': %s" , tb_id , e )
123- current = None
121+ current = _safe_get_task (tb_id )
124122 if current is not None and current .status in TERMINAL_STATES :
125123 return
126124 data = None
127125 if exception is not None and current is not None :
128- data = DefaultMergeStrategy ().merge (current .data , {"exception" : str (exception )})
129- if data :
126+ base = dict (current .data ) if current .data else None
127+ data = DefaultMergeStrategy ().merge (base , {"exception" : str (exception )})
128+ if data is not None :
130129 update_task_safe (tb_id , status = status , data = data )
131130 else :
132131 update_task_safe (tb_id , status = status )
133132 else :
134133 update_task_safe (tb_id , status = status )
134+
135+
136+ class _Cache :
137+ def __init__ (self , maxsize = 128 ):
138+ self .cache = collections .OrderedDict ()
139+ self .maxsize = maxsize
140+
141+ def set (self , key , value ):
142+ self .cache [key ] = value
143+ if len (self .cache ) > self .maxsize :
144+ self .cache .popitem (last = False )
145+
146+ def get (self , key ):
147+ return self .cache .get (key )
148+
149+ def unset (self , key ):
150+ self .cache .pop (key , None )
151+
152+
153+ _task_cache = _Cache ()
154+
155+
156+ def _safe_get_task (task_id ):
157+ cached = _task_cache .get (task_id )
158+ if cached is not None :
159+ return cached
160+ try :
161+ task = get_task (task_id )
162+ except Exception as e :
163+ log .warning ("Error fetching task '%s': %s" , task_id , e )
164+ return None
165+ _task_cache .set (task_id , task )
166+ return task
167+
168+
169+ def _wrap_defer (task ):
170+ """Wrap ``task.defer`` and ``task.defer_async`` so they create a TaskBadger
171+ task in PENDING state and inject its id into the job's task_kwargs.
172+
173+ The original defer methods are stashed on the task to keep the wrap
174+ idempotent (a second call replaces nothing because the marker is set)."""
175+ original_defer = task .defer
176+ original_defer_async = task .defer_async
177+
178+ @functools .wraps (original_defer )
179+ def defer (** kwargs ):
180+ kwargs = _maybe_create_pending (task , kwargs )
181+ return original_defer (** kwargs )
182+
183+ @functools .wraps (original_defer_async )
184+ async def defer_async (** kwargs ):
185+ kwargs = _maybe_create_pending (task , kwargs )
186+ return await original_defer_async (** kwargs )
187+
188+ task .defer = defer
189+ task .defer_async = defer_async
190+
191+
192+ def _maybe_create_pending (task , kwargs ):
193+ """Decide whether to track this defer, and if so create the TaskBadger
194+ task and inject its id into ``kwargs``. Always returns the kwargs dict."""
195+ if not Badger .is_configured ():
196+ return kwargs
197+
198+ system = getattr (task , "_taskbadger_system" , None )
199+ manual = getattr (task , _MANUAL_ATTR , False )
200+ auto = bool (system ) and system .track_task (task .name )
201+ if not manual and not auto :
202+ return kwargs
203+
204+ opts = dict (getattr (task , _OPTS_ATTR , {}) or {})
205+ name = opts .pop ("name" , None ) or task .name
206+ create_kwargs = {"status" : StatusEnum .PENDING }
207+ for key in ("value_max" , "tags" ):
208+ if key in opts and opts [key ] is not None :
209+ create_kwargs [key ] = opts [key ]
210+ user_data = opts .get ("data" )
211+ if user_data :
212+ create_kwargs ["data" ] = dict (user_data )
213+
214+ tb_task = create_task_safe (name , ** create_kwargs )
215+ if tb_task is None :
216+ return kwargs
217+
218+ new_kwargs = dict (kwargs )
219+ new_kwargs [TB_TASK_ID_KWARG ] = tb_task .id
220+ return new_kwargs
0 commit comments