|
11 | 11 |
|
12 | 12 | from __future__ import annotations |
13 | 13 |
|
| 14 | +import functools |
| 15 | +import inspect |
14 | 16 | import logging |
15 | 17 | from contextvars import ContextVar |
16 | 18 |
|
17 | 19 | from .internal.models import StatusEnum |
| 20 | +from .mug import Badger |
| 21 | +from .safe_sdk import update_task_safe |
| 22 | +from .sdk import DefaultMergeStrategy, get_task |
18 | 23 |
|
19 | 24 | log = logging.getLogger("taskbadger") |
20 | 25 |
|
|
37 | 42 | _OPTS_ATTR = "_taskbadger_opts" |
38 | 43 |
|
39 | 44 | _current_tb_task_id: ContextVar[str | None] = ContextVar("_current_tb_task_id", default=None) |
| 45 | + |
| 46 | + |
| 47 | +def _instrument_task(task, system=None, manual=False, opts=None): |
| 48 | + """Wrap a Procrastinate Task's ``func`` so the worker side updates TaskBadger. |
| 49 | +
|
| 50 | + Idempotent: a second call on the same task is a no-op (but ``manual`` and |
| 51 | + ``opts`` will be merged onto the existing attributes if provided). |
| 52 | + """ |
| 53 | + if opts is not None: |
| 54 | + existing_opts = getattr(task, _OPTS_ATTR, {}) or {} |
| 55 | + merged = {**existing_opts, **opts} |
| 56 | + setattr(task, _OPTS_ATTR, merged) |
| 57 | + elif not hasattr(task, _OPTS_ATTR): |
| 58 | + setattr(task, _OPTS_ATTR, {}) |
| 59 | + |
| 60 | + if manual: |
| 61 | + setattr(task, _MANUAL_ATTR, True) |
| 62 | + |
| 63 | + if getattr(task, _INSTRUMENTED_ATTR, False): |
| 64 | + return |
| 65 | + |
| 66 | + original_func = task.func |
| 67 | + is_async = inspect.iscoroutinefunction(original_func) |
| 68 | + |
| 69 | + if is_async: |
| 70 | + |
| 71 | + @functools.wraps(original_func) |
| 72 | + async def wrapped(*args, **kwargs): |
| 73 | + tb_id = kwargs.pop(TB_TASK_ID_KWARG, None) |
| 74 | + if tb_id is None: |
| 75 | + return await original_func(*args, **kwargs) |
| 76 | + token = _current_tb_task_id.set(tb_id) |
| 77 | + try: |
| 78 | + _update_status(tb_id, StatusEnum.PROCESSING) |
| 79 | + try: |
| 80 | + result = await original_func(*args, **kwargs) |
| 81 | + except Exception as exc: |
| 82 | + _update_status(tb_id, StatusEnum.ERROR, exception=exc) |
| 83 | + raise |
| 84 | + _update_status(tb_id, StatusEnum.SUCCESS) |
| 85 | + return result |
| 86 | + finally: |
| 87 | + _current_tb_task_id.reset(token) |
| 88 | + else: |
| 89 | + |
| 90 | + @functools.wraps(original_func) |
| 91 | + def wrapped(*args, **kwargs): |
| 92 | + tb_id = kwargs.pop(TB_TASK_ID_KWARG, None) |
| 93 | + if tb_id is None: |
| 94 | + return original_func(*args, **kwargs) |
| 95 | + token = _current_tb_task_id.set(tb_id) |
| 96 | + try: |
| 97 | + _update_status(tb_id, StatusEnum.PROCESSING) |
| 98 | + try: |
| 99 | + result = original_func(*args, **kwargs) |
| 100 | + except Exception as exc: |
| 101 | + _update_status(tb_id, StatusEnum.ERROR, exception=exc) |
| 102 | + raise |
| 103 | + _update_status(tb_id, StatusEnum.SUCCESS) |
| 104 | + return result |
| 105 | + finally: |
| 106 | + _current_tb_task_id.reset(token) |
| 107 | + |
| 108 | + task.func = wrapped |
| 109 | + setattr(task, _INSTRUMENTED_ATTR, True) |
| 110 | + setattr(task, "_taskbadger_system", system) |
| 111 | + |
| 112 | + |
| 113 | +def _update_status(tb_id, status, exception=None): |
| 114 | + """Update the TaskBadger task to ``status``. Skips if already terminal.""" |
| 115 | + if not Badger.is_configured(): |
| 116 | + return |
| 117 | + |
| 118 | + if exception is not None or status in TERMINAL_STATES: |
| 119 | + try: |
| 120 | + current = get_task(tb_id) |
| 121 | + except Exception as e: |
| 122 | + log.warning("Error fetching task '%s': %s", tb_id, e) |
| 123 | + current = None |
| 124 | + if current is not None and current.status in TERMINAL_STATES: |
| 125 | + return |
| 126 | + data = None |
| 127 | + if exception is not None and current is not None: |
| 128 | + data = DefaultMergeStrategy().merge(current.data, {"exception": str(exception)}) |
| 129 | + if data: |
| 130 | + update_task_safe(tb_id, status=status, data=data) |
| 131 | + else: |
| 132 | + update_task_safe(tb_id, status=status) |
| 133 | + else: |
| 134 | + update_task_safe(tb_id, status=status) |
0 commit comments