Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion backend/routers/pusher.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
trigger_external_integrations,
)
from utils.conversations.location import async_get_google_maps_location
from utils.byok import set_byok_keys
from utils.byok import set_byok_keys, set_byok_uid
from utils.conversations.process_conversation import process_conversation
from utils.executors import storage_executor
from utils.webhooks import (
Expand Down Expand Up @@ -79,6 +79,7 @@ async def _process_conversation_task(
"""
if byok_keys:
set_byok_keys(byok_keys)
set_byok_uid(uid)
try:
conversation_data = conversations_db.get_conversation(uid, conversation_id)
if not conversation_data:
Expand Down
4 changes: 3 additions & 1 deletion backend/routers/sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
)

from utils import encryption
from utils.byok import get_byok_keys, set_byok_keys
from utils.byok import get_byok_keys, set_byok_keys, set_byok_uid
from utils.log_sanitizer import sanitize
from utils.stt.pre_recorded import deepgram_prerecorded, get_deepgram_model_for_language, postprocess_words
from utils.stt.vad import vad_is_empty
Expand Down Expand Up @@ -1359,6 +1359,7 @@ def _run_full_pipeline_background(
Moved ALL heavy processing here so the v2 endpoint returns 202 immediately.
"""
set_byok_keys(byok_keys or {})
set_byok_uid(uid if byok_keys else None)
segmented_paths = set()
wav_paths = []
stage_timings = {}
Expand Down Expand Up @@ -1580,6 +1581,7 @@ def _process_one_segment(path):
pass
finally:
set_byok_keys({})
set_byok_uid(None)
_cleanup_files(list(segmented_paths))
_cleanup_files(wav_paths)
try:
Expand Down
92 changes: 92 additions & 0 deletions backend/tests/unit/test_byok_llm_logging.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
import os
import sys
import types
from unittest.mock import MagicMock, patch

os.environ.setdefault('OPENAI_API_KEY', 'sk-test-fake-for-unit-tests')
os.environ.setdefault('ANTHROPIC_API_KEY', 'ant-test-fake-for-unit-tests')
os.environ.setdefault('ENCRYPTION_SECRET', 'omi_ZwB2ZNqB2HHpMK6wStk7sTpavJiPTFg7gXUHnc4tFABPU6pZ2c2DKgehtfgi4RZv')

sys.modules.setdefault('database._client', MagicMock())
llm_usage_stub = types.ModuleType('database.llm_usage')
llm_usage_stub.record_llm_usage = MagicMock()
sys.modules.setdefault('database.llm_usage', llm_usage_stub)


class _HTTPError(Exception):
def __init__(self, message: str, status_code: int):
super().__init__(message)
self.status_code = status_code


def test_classify_byok_llm_error_authentication():
from utils.llm.byok_errors import classify_byok_llm_error

assert classify_byok_llm_error(_HTTPError("bad api key", 401)) == 'invalid'


def test_classify_byok_llm_error_permission():
from utils.llm.byok_errors import classify_byok_llm_error

assert classify_byok_llm_error(_HTTPError("project denied", 403)) == 'permission'


def test_classify_byok_llm_error_insufficient_quota():
from utils.llm.byok_errors import classify_byok_llm_error

assert classify_byok_llm_error(_HTTPError("insufficient_quota", 429)) == 'quota'


def test_classify_byok_llm_error_ignores_transient_rate_limit():
from utils.llm.byok_errors import classify_byok_llm_error

assert classify_byok_llm_error(_HTTPError("rate limit reached, retry later", 429)) is None


@patch('utils.llm.byok_errors.get_byok_uid', return_value='user-1')
@patch('utils.llm.byok_errors.get_byok_key', return_value='sk-user')
def test_handle_llm_error_logs_byok_source(mock_get_key, mock_get_uid):
from utils.llm.byok_errors import handle_llm_error

with patch('utils.llm.byok_errors.logger.error') as mock_log:
handle_llm_error(_HTTPError("insufficient_quota", 429), 'openai', feature='memories', model='gpt-test')

log_args = mock_log.call_args.args
assert 'LLM error source=%s' in log_args[0]
assert log_args[1] == 'byok'
assert log_args[2] == 'openai'
assert log_args[8] == 'quota'


@patch('utils.llm.byok_errors.get_byok_uid', return_value='user-1')
@patch('utils.llm.byok_errors.get_byok_key', return_value=None)
def test_handle_llm_error_logs_platform_source(mock_get_key, mock_get_uid):
from utils.llm.byok_errors import handle_llm_error

with patch('utils.llm.byok_errors.logger.error') as mock_log:
handle_llm_error(_HTTPError("server error", 500), 'openai', feature='memories', model='gpt-test')

assert mock_log.call_args.args[1] == 'platform'
assert mock_log.call_args.args[8] == 'unknown'


def test_validate_byok_request_records_current_uid():
from utils.byok import get_byok_uid, validate_byok_request

with patch('utils.byok._check_byok_validity', return_value=None):
validate_byok_request('user-1')

assert get_byok_uid() == 'user-1'


def test_llm_error_callback_uses_provider_context():
from utils.llm.clients import _LLMErrorCallback

callback = _LLMErrorCallback('openai', model='gpt-test', feature='memories')
error = _HTTPError('bad key', 401)

with patch('utils.llm.clients.handle_llm_error') as mock_handle:
callback.on_llm_error(error)

mock_handle.assert_called_once()
assert mock_handle.call_args.args[:2] == (error, 'openai')
16 changes: 16 additions & 0 deletions backend/utils/byok.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def invalidate_byok_state_cache(uid: str) -> None:
# Keys for the current request, if the client supplied them.
# Default is None (not {}) to avoid sharing a mutable object across contexts.
_byok_ctx: ContextVar[Optional[Dict[str, str]]] = ContextVar('byok_keys', default=None)
_byok_uid_ctx: ContextVar[Optional[str]] = ContextVar('byok_uid', default=None)


def get_byok_keys() -> Dict[str, str]:
Expand All @@ -87,6 +88,16 @@ def get_byok_key(provider: str) -> Optional[str]:
return keys.get(provider)


def get_byok_uid() -> Optional[str]:
"""Return the authenticated uid for the current request, when known."""
return _byok_uid_ctx.get()


def set_byok_uid(uid: Optional[str]) -> None:
"""Attach the authenticated uid to the current request context."""
_byok_uid_ctx.set(uid)


def has_byok_keys() -> bool:
"""True if the current request carries at least one BYOK header."""
keys = _byok_ctx.get()
Expand Down Expand Up @@ -127,10 +138,12 @@ async def dispatch(self, request: Request, call_next):
if value:
keys[provider] = value
token = _byok_ctx.set(keys)
uid_token = _byok_uid_ctx.set(None)
try:
return await call_next(request)
finally:
_byok_ctx.reset(token)
_byok_uid_ctx.reset(uid_token)


# ---------------------------------------------------------------------------
Expand Down Expand Up @@ -203,6 +216,7 @@ def validate_byok_request(uid: str) -> None:
if error:
logger.warning('BYOK validation failed uid=%s: %s', uid, error)
raise HTTPException(status_code=403, detail=error)
set_byok_uid(uid)


def validate_byok_websocket(uid: str) -> Optional[str]:
Expand All @@ -215,4 +229,6 @@ def validate_byok_websocket(uid: str) -> Optional[str]:
error = _check_byok_validity(uid)
if error:
logger.warning('BYOK WS validation failed uid=%s: %s', uid, error)
else:
set_byok_uid(uid)
return error
90 changes: 90 additions & 0 deletions backend/utils/llm/byok_errors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import asyncio
import logging
from typing import Optional

from utils.byok import get_byok_key, get_byok_uid
from utils.executors import storage_executor, submit_with_context
from utils.log_sanitizer import sanitize

logger = logging.getLogger(__name__)

_QUOTA_ERROR_NAMES = frozenset({'RateLimitError'})


def get_llm_error_source(provider: Optional[str]) -> str:
Comment on lines +11 to +14
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 The _QUOTA_ERROR_NAMES frozenset is named after quota errors but contains RateLimitError, which is a superset of quota failures. More critically, Anthropic's credit-exhaustion 429 response body typically reads "Your credit balance is too low…" — the word "quota" never appears. As a result, an Anthropic BYOK user who runs out of credits will log reason='unknown' instead of reason='quota', defeating the classification for the one provider wired to this path via handle_llm_error_async. Adding a 'credit' / 'balance' keyword check (and optionally renaming the frozenset) would fix the gap.

Suggested change
_QUOTA_ERROR_NAMES = frozenset({'RateLimitError'})
def get_llm_error_source(provider: Optional[str]) -> str:
_RATE_LIMIT_ERROR_NAMES = frozenset({'RateLimitError'})
_QUOTA_KEYWORDS = frozenset({'insufficient_quota', 'quota', 'credit', 'balance'})
def get_llm_error_source(provider: Optional[str]) -> str:

"""Return platform/byok for the current request and provider."""
if provider and get_byok_key(provider):
return 'byok'
return 'platform'


def classify_byok_llm_error(error: Exception) -> Optional[str]:
"""Classify user-actionable BYOK failures for structured logging."""
status_code = _get_status_code(error)
error_name = type(error).__name__
error_text = sanitize(str(error)).lower()

if status_code == 401 or error_name == 'AuthenticationError':
return 'invalid'
if status_code == 403 or error_name == 'PermissionDeniedError':
return 'permission'
if status_code == 429 or error_name in _QUOTA_ERROR_NAMES:
if 'insufficient_quota' in error_text or 'quota' in error_text:
return 'quota'
Comment on lines +31 to +33
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 If the rename above is applied, the 429-branch should use the updated name and keyword set to catch Anthropic credit errors.

Suggested change
if status_code == 429 or error_name in _QUOTA_ERROR_NAMES:
if 'insufficient_quota' in error_text or 'quota' in error_text:
return 'quota'
if status_code == 429 or error_name in _RATE_LIMIT_ERROR_NAMES:
if any(kw in error_text for kw in _QUOTA_KEYWORDS):
return 'quota'

return None


def handle_llm_error(
error: Exception,
provider: Optional[str],
feature: Optional[str] = None,
model: Optional[str] = None,
operation: str = 'chat',
) -> None:
"""Log LLM failures with source context."""
source = get_llm_error_source(provider)
reason = classify_byok_llm_error(error) if source == 'byok' else None
uid = get_byok_uid()
status_code = _get_status_code(error)

logger.error(
'LLM error source=%s provider=%s feature=%s model=%s operation=%s uid=%s status_code=%s reason=%s '
'error_type=%s error=%s',
source,
provider or 'unknown',
feature or 'unknown',
model or 'unknown',
operation,
uid or 'unknown',
status_code or 'unknown',
reason or 'unknown',
type(error).__name__,
sanitize(str(error)),
)


async def handle_llm_error_async(
error: Exception,
provider: Optional[str],
feature: Optional[str] = None,
model: Optional[str] = None,
operation: str = 'chat',
) -> None:
"""Run LLM error handling off the event loop while preserving BYOK context."""
future = submit_with_context(storage_executor, handle_llm_error, error, provider, feature, model, operation)
try:
await asyncio.wrap_future(future)
except Exception as e:
logger.error('Async LLM error handler failed provider=%s feature=%s: %s', provider, feature, e)
Comment on lines +66 to +78
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 storage_executor is the wrong executor for error logging

storage_executor is documented and dimensioned (16 threads) for audio file pre-caching and GCS operations. Submitting logging work there mixes latency classes and could steal threads from ongoing audio uploads during busy periods. postprocess_executor or the default loop.run_in_executor(None, ...) would be more appropriate.



def _get_status_code(error: Exception) -> Optional[int]:
status_code = getattr(error, 'status_code', None)
if isinstance(status_code, int):
return status_code

response = getattr(error, 'response', None)
response_status = getattr(response, 'status_code', None)
if isinstance(response_status, int):
return response_status
return None
Loading