Skip to content
Merged
187 changes: 181 additions & 6 deletions workers/shared/infrastructure/logging/logger.py
Comment thread
muhammad-ali-e marked this conversation as resolved.
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,21 @@
Provides structured logging, performance monitoring, and metrics collection for workers.
"""

import dataclasses
import functools
import inspect
import logging
import os
import sys
import time
import warnings
from collections.abc import Callable
from collections.abc import Callable, Mapping, Sequence
from contextlib import contextmanager
from dataclasses import asdict, dataclass
from datetime import UTC, datetime
from threading import local
from typing import Any
from uuid import UUID

# Thread-local storage for context
_context = local()
Expand All @@ -35,14 +38,18 @@ class LogContext:
class RequestIDFilter(logging.Filter):
"""Filter to inject request_id into log records.

Adopts the proven pattern from unstract/core/flask/logging.py for consistency.
Normalizes missing or falsy values (None, empty string, etc.) to "-" for
consistent log formatting.
Resolution order:
1. ``record.request_id`` if set explicitly via ``extra={...}``.
2. ``LogContext.request_id`` from the thread-local context.
3. ``"-"`` placeholder.
"""

def filter(self, record):
if not getattr(record, "request_id", None):
record.request_id = "-"
value = getattr(record, "request_id", None)
if not value:
ctx = getattr(_context, "log_context", None)
value = getattr(ctx, "request_id", None) if ctx else None
record.request_id = value or "-"
return True


Expand Down Expand Up @@ -145,6 +152,8 @@ def configure(
category=UserWarning,
)

_install_celery_request_id_signals()
Comment thread
muhammad-ali-e marked this conversation as resolved.

cls._configured = True

@classmethod
Expand Down Expand Up @@ -575,6 +584,172 @@ def log_execution(func: Callable) -> Callable:
return logged_execution()(func)


_REQUEST_ID_KEYS: tuple[str, ...] = (
"request_id",
"file_execution_id",
"execution_id",
"run_id",
)


def _coerce_id(value: Any) -> str | None:
"""Return a safe string id, or None if value is not a usable id type.

Rejects arbitrary objects so a misnamed payload field (e.g. a dataclass
or a list) cannot leak its ``__str__`` into log lines or OTel attributes.
Booleans are excluded explicitly because ``bool`` is a subclass of
``int`` in Python; without the guard, ``True``/``False`` would
serialize to the literal strings ``"True"``/``"False"``.
"""
if isinstance(value, bool):
return None
if isinstance(value, str):
return value or None
if isinstance(value, (int, UUID)):
return str(value)
Comment thread
coderabbitai[bot] marked this conversation as resolved.
return None
Comment thread
muhammad-ali-e marked this conversation as resolved.


def _bind_positional_args_to_names(
args: Sequence[Any], kwargs: Mapping[str, Any], task: Any
) -> Mapping[str, Any] | None:
"""Map positional args to parameter names via the task's signature.

Catches the case where ids are passed as positional strings (e.g.
``send_task("async_execute_bin", args=[schema, wf_id, exec_id, ...])``).
Returns ``None`` if the task is not introspectable or has no resolved
arguments.
"""
if task is None:
return None
runner = getattr(task, "run", task)
try:
bound = inspect.signature(runner).bind_partial(*(args or ()), **(kwargs or {}))
except (TypeError, ValueError):
return None
return dict(bound.arguments) if bound.arguments else None


def _arg_as_mapping(arg: Any) -> Mapping[str, Any] | None:
"""Return a mapping view of a positional arg, or None if not coercible.

Accepts ``Mapping`` instances directly and dataclass instances via
``dataclasses.asdict``.
"""
if isinstance(arg, Mapping):
return arg
if dataclasses.is_dataclass(arg) and not isinstance(arg, type):
return dataclasses.asdict(arg)
return None


def _gather_containers(
args: Sequence[Any], kwargs: Mapping[str, Any], task: Any
) -> list[Mapping[str, Any]]:
"""Build the ordered list of mappings to scan for an id.

Order (used only for stable iteration; final priority is by KEY in
``_extract_request_id``):

1. Positional args bound to parameter names via the task signature.
2. Top-level kwargs.
3. ``Mapping`` values nested inside kwargs.
4. ``Mapping`` / dataclass positional args.
"""
containers: list[Mapping[str, Any]] = []

bound = _bind_positional_args_to_names(args, kwargs, task)
if bound:
containers.append(bound)

if kwargs:
containers.append(kwargs)
containers.extend(v for v in kwargs.values() if isinstance(v, Mapping))

for arg in args or ():
mapping = _arg_as_mapping(arg)
if mapping is not None:
containers.append(mapping)

return containers


def _extract_request_id(
args: Sequence[Any], kwargs: Mapping[str, Any], task: Any = None
) -> str | None:
"""Pull a usable request_id from a Celery task payload.

Workers bind ``file_execution_id`` to the ``request_id`` log field
(legacy structure-tool convention), giving per-file granularity in
logs across the multi-tool execution chain. Cross-service correlation
across the API/HTTP boundary is handled by OpenTelemetry ``trace_id``,
not by ``request_id``.

The key priority below is also the migration path: callers may start
passing a real HTTP ``request_id`` in the payload at any time and it
will take precedence over ``file_execution_id`` automatically -- no
worker change required.

Priority is by KEY (not by container), so a payload with both
``request_id`` in one container and ``file_execution_id`` in another
deterministically picks ``request_id`` regardless of insertion order.
"""
containers = _gather_containers(args, kwargs, task)
for key in _REQUEST_ID_KEYS:
for container in containers:
value = _coerce_id(container.get(key))
if value:
return value
return None


def _bind_task_context(task_id, task, args, kwargs, **_):
"""Celery ``task_prerun`` handler: bind request_id onto the log context.

Catches any extraction failure so a malformed payload can never leave
the previous task's id bound on the thread.
"""
try:
request_id = _extract_request_id(args or (), kwargs or {}, task) or task_id
except Exception:
logging.getLogger(__name__).debug(
"request_id extraction failed for task %s; falling back to task_id",
task_id,
exc_info=True,
)
request_id = task_id
WorkerLogger.update_context(request_id=request_id, task_id=task_id)
Comment thread
muhammad-ali-e marked this conversation as resolved.


def _clear_task_context(**_):
"""Celery ``task_postrun`` handler: reset task-scoped fields only.

Preserves baseline context (``worker_name``, etc.) set at
``WorkerLogger.configure()``; only nulls out the per-task fields bound
in ``_bind_task_context``.
"""
WorkerLogger.update_context(request_id=None, task_id=None)


@functools.lru_cache(maxsize=1)
def _install_celery_request_id_signals() -> None:
"""Wire Celery signals once per process.

Idempotent and thread-safe via ``functools.lru_cache``. No-ops with a
debug log if Celery is not importable (e.g. unit tests).
"""
try:
from celery.signals import task_postrun, task_prerun
except ImportError as exc:
logging.getLogger(__name__).debug(
"celery.signals not importable; request_id signal install skipped: %s",
exc,
)
return
task_prerun.connect(_bind_task_context, weak=False)
task_postrun.connect(_clear_task_context, weak=False)
Comment thread
coderabbitai[bot] marked this conversation as resolved.


# Dataclass/Dictionary Access Utilities


Expand Down