Skip to content

Commit 024cc9c

Browse files
committed
refactor: extract shared helpers between celery and procrastinate
1 parent aed4969 commit 024cc9c

5 files changed

Lines changed: 143 additions & 144 deletions

File tree

taskbadger/_integrations.py

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
"""Shared internals for taskbadger's optional system integrations
2+
(Celery, Procrastinate). Not part of the public API.
3+
4+
Each integration creates its own module-level ``TaskCache`` instance and
5+
defines a thin ``safe_get_task`` wrapper that reads ``get_task`` from the
6+
integration module's own globals (so existing test mocks on
7+
``taskbadger.celery.get_task`` / ``taskbadger.procrastinate.get_task`` keep
8+
working). ``BaseSystemIntegration`` provides the common ctor/include-exclude
9+
shape; subclasses override ``track_task`` if they need to filter additional
10+
task names (e.g. Procrastinate built-ins).
11+
"""
12+
13+
from __future__ import annotations
14+
15+
import collections
16+
import logging
17+
import re
18+
from collections.abc import Callable
19+
20+
from .internal.models import StatusEnum
21+
from .systems import System
22+
23+
log = logging.getLogger("taskbadger")
24+
25+
TERMINAL_STATES = {
26+
StatusEnum.SUCCESS,
27+
StatusEnum.ERROR,
28+
StatusEnum.CANCELLED,
29+
StatusEnum.STALE,
30+
}
31+
32+
33+
class TaskCache:
34+
"""Bounded LRU-ish cache for TaskBadger Task objects.
35+
36+
Keys are arbitrary hashable values chosen by the caller (typically the
37+
task id). Auto-prunes on ``set`` when ``maxsize`` is exceeded.
38+
"""
39+
40+
def __init__(self, maxsize: int = 128):
41+
self.cache: collections.OrderedDict = collections.OrderedDict()
42+
self.maxsize = maxsize
43+
44+
def set(self, key, value) -> None:
45+
self.cache[key] = value
46+
if len(self.cache) > self.maxsize:
47+
self.cache.popitem(last=False)
48+
49+
def get(self, key):
50+
return self.cache.get(key)
51+
52+
def unset(self, key) -> None:
53+
self.cache.pop(key, None)
54+
55+
56+
def safe_get_task(cache: TaskCache, task_id: str, get_task_fn: Callable):
57+
"""Cache-aware ``get_task``: returns the cached entry if present, otherwise
58+
fetches via ``get_task_fn`` and caches the result. Errors are logged and
59+
swallowed (returns ``None``). ``None`` results are not cached.
60+
61+
``get_task_fn`` is passed in (rather than imported here) so callers can
62+
use their own module-level ``get_task`` reference — this keeps existing
63+
test patches on ``taskbadger.celery.get_task`` / ``taskbadger.procrastinate.get_task``
64+
intercepting the fetch.
65+
"""
66+
cached = cache.get(task_id)
67+
if cached is not None:
68+
return cached
69+
try:
70+
task = get_task_fn(task_id)
71+
except Exception as e:
72+
log.warning("Error fetching task '%s': %s", task_id, e)
73+
return None
74+
cache.set(task_id, task)
75+
return task
76+
77+
78+
def match_task_name(task_name: str, includes, excludes) -> bool:
79+
"""Return True if ``task_name`` should be tracked under the given rules.
80+
81+
Excludes win over includes. Both lists contain regex strings matched with
82+
``re.fullmatch``. ``None`` means "no rule".
83+
"""
84+
if excludes:
85+
for exclude in excludes:
86+
if re.fullmatch(exclude, task_name):
87+
return False
88+
89+
if includes:
90+
for include in includes:
91+
if re.fullmatch(include, task_name):
92+
return True
93+
return False
94+
95+
return True
96+
97+
98+
class BaseSystemIntegration(System):
99+
"""Common ctor + ``track_task`` body for system integrations.
100+
101+
Subclasses set ``identifier`` and may override ``track_task`` to add
102+
additional filtering (e.g. skipping built-in tasks).
103+
"""
104+
105+
def __init__(self, auto_track_tasks=True, includes=None, excludes=None, record_task_args=False):
106+
self.auto_track_tasks = auto_track_tasks
107+
self.includes = includes
108+
self.excludes = excludes
109+
self.record_task_args = record_task_args
110+
111+
def track_task(self, task_name: str) -> bool:
112+
if not self.auto_track_tasks:
113+
return False
114+
return match_task_name(task_name, self.includes, self.excludes)

taskbadger/celery.py

Lines changed: 7 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import collections
21
import functools
32
import json
43
import logging
@@ -13,6 +12,8 @@
1312
)
1413
from kombu import serialization
1514

15+
from ._integrations import TERMINAL_STATES, TaskCache
16+
from ._integrations import safe_get_task as _shared_safe_get_task
1617
from .internal.models import StatusEnum
1718
from .mug import Badger
1819
from .safe_sdk import create_task_safe, update_task_safe
@@ -23,54 +24,9 @@
2324
IGNORE_ARGS = {TB_KWARGS_ARG, f"{KWARG_PREFIX}task", f"{KWARG_PREFIX}task_id", f"{KWARG_PREFIX}record_task_args"}
2425
TB_TASK_ID = f"{KWARG_PREFIX}task_id"
2526

26-
TERMINAL_STATES = {
27-
StatusEnum.SUCCESS,
28-
StatusEnum.ERROR,
29-
StatusEnum.CANCELLED,
30-
StatusEnum.STALE,
31-
}
32-
3327
log = logging.getLogger("taskbadger")
3428

35-
36-
class Cache:
37-
def __init__(self, maxsize=128):
38-
self.cache = collections.OrderedDict()
39-
self.maxsize = maxsize
40-
41-
def set(self, key, value):
42-
self.cache[key] = value
43-
44-
def unset(self, key):
45-
self.cache.pop(key, None)
46-
47-
def get(self, key):
48-
return self.cache.get(key)
49-
50-
def prune(self):
51-
if len(self.cache) > self.maxsize:
52-
self.cache.popitem(last=False)
53-
54-
55-
def cached(cache_none=True, maxsize=128):
56-
cache = Cache(maxsize=maxsize)
57-
58-
def _wrapper(func):
59-
@functools.wraps(func)
60-
def _inner(*args, **kwargs):
61-
key = args + tuple(sorted(kwargs.items()))
62-
if key in cache.cache:
63-
return cache.get(key)
64-
65-
result = func(*args, **kwargs)
66-
if result is not None or cache_none:
67-
cache.set(key, result)
68-
return result
69-
70-
_inner.cache = cache
71-
return _inner
72-
73-
return _wrapper
29+
_task_cache = TaskCache()
7430

7531

7632
class Task(celery.Task):
@@ -292,7 +248,7 @@ def _maybe_create_task(signal_sender):
292248
if task:
293249
# Store the task ID in the request so _update_task can find it
294250
signal_sender.request.update({TB_TASK_ID: task.id})
295-
safe_get_task.cache.set((task.id,), task)
251+
_task_cache.set(task.id, task)
296252

297253

298254
@task_prerun.connect
@@ -344,7 +300,7 @@ def _update_task(signal_sender, status, einfo=None):
344300
data = DefaultMergeStrategy().merge(task.data, {"exception": str(einfo)})
345301
task = update_task_safe(task.id, status=status, data=data)
346302
if task:
347-
safe_get_task.cache.set((task_id,), task)
303+
_task_cache.set(task_id, task)
348304

349305

350306
def enter_session():
@@ -364,20 +320,15 @@ def exit_session(signal_sender):
364320
if not task_id or not Badger.is_configured():
365321
return
366322

367-
safe_get_task.cache.unset((task_id,))
368-
safe_get_task.cache.prune()
323+
_task_cache.unset(task_id)
369324

370325
session = Badger.current.session()
371326
if session.client:
372327
session.__exit__()
373328

374329

375-
@cached(cache_none=False)
376330
def safe_get_task(task_id: str):
377-
try:
378-
return get_task(task_id)
379-
except Exception as e:
380-
log.warning("Error fetching task '%s': %s", task_id, e)
331+
return _shared_safe_get_task(_task_cache, task_id, get_task)
381332

382333

383334
def _get_taskbadger_task_id(request):

taskbadger/procrastinate.py

Lines changed: 4 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,14 @@
1111

1212
from __future__ import annotations
1313

14-
import collections
1514
import functools
1615
import inspect
1716
import json
1817
import logging
1918
from contextvars import ContextVar
2019

20+
from ._integrations import TERMINAL_STATES, TaskCache
21+
from ._integrations import safe_get_task as _shared_safe_get_task
2122
from .internal.models import StatusEnum
2223
from .mug import Badger
2324
from .safe_sdk import create_task_safe, update_task_safe
@@ -30,13 +31,6 @@
3031
# user function is called.
3132
TB_TASK_ID_KWARG = "__taskbadger_task_id__"
3233

33-
TERMINAL_STATES = {
34-
StatusEnum.SUCCESS,
35-
StatusEnum.ERROR,
36-
StatusEnum.CANCELLED,
37-
StatusEnum.STALE,
38-
}
39-
4034
# Sentinel attribute names set on a Procrastinate Task object once it has been
4135
# instrumented. Used to make instrumentation idempotent.
4236
_INSTRUMENTED_ATTR = "_taskbadger_instrumented"
@@ -143,37 +137,11 @@ def _update_status(tb_id, status, exception=None):
143137
_task_cache.set(tb_id, updated)
144138

145139

146-
class _Cache:
147-
def __init__(self, maxsize=128):
148-
self.cache = collections.OrderedDict()
149-
self.maxsize = maxsize
150-
151-
def set(self, key, value):
152-
self.cache[key] = value
153-
if len(self.cache) > self.maxsize:
154-
self.cache.popitem(last=False)
155-
156-
def get(self, key):
157-
return self.cache.get(key)
158-
159-
def unset(self, key):
160-
self.cache.pop(key, None)
161-
162-
163-
_task_cache = _Cache()
140+
_task_cache = TaskCache()
164141

165142

166143
def _safe_get_task(task_id):
167-
cached = _task_cache.get(task_id)
168-
if cached is not None:
169-
return cached
170-
try:
171-
task = get_task(task_id)
172-
except Exception as e:
173-
log.warning("Error fetching task '%s': %s", task_id, e)
174-
return None
175-
_task_cache.set(task_id, task)
176-
return task
144+
return _shared_safe_get_task(_task_cache, task_id, get_task)
177145

178146

179147
def _wrap_defer(task):

taskbadger/systems/celery.py

Lines changed: 9 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,7 @@
1-
import re
1+
from taskbadger._integrations import BaseSystemIntegration
22

3-
from taskbadger.systems import System
43

5-
6-
class CelerySystemIntegration(System):
4+
class CelerySystemIntegration(BaseSystemIntegration):
75
identifier = "celery"
86

97
def __init__(self, auto_track_tasks=True, includes=None, excludes=None, record_task_args=False):
@@ -18,29 +16,13 @@ def __init__(self, auto_track_tasks=True, includes=None, excludes=None, record_t
1816
the full task name or a regular expression. Exclusions take precedence over inclusions.
1917
record_task_args: Record the arguments passed to each task.
2018
"""
21-
self.auto_track_tasks = auto_track_tasks
22-
self.includes = includes
23-
self.excludes = excludes
24-
self.record_task_args = record_task_args
19+
super().__init__(
20+
auto_track_tasks=auto_track_tasks,
21+
includes=includes,
22+
excludes=excludes,
23+
record_task_args=record_task_args,
24+
)
2525

2626
if auto_track_tasks:
2727
# Importing this here ensures that the Celery signal handlers are registered
28-
import taskbadger.celery # noqa
29-
30-
def track_task(self, task_name):
31-
if not self.auto_track_tasks:
32-
return False
33-
34-
if self.excludes:
35-
for exclude in self.excludes:
36-
if re.fullmatch(exclude, task_name):
37-
return False
38-
39-
if self.includes:
40-
for include in self.includes:
41-
if re.fullmatch(include, task_name):
42-
break
43-
else:
44-
return False
45-
46-
return True
28+
import taskbadger.celery # noqa: F401

0 commit comments

Comments
 (0)