Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ dependencies = [
"sentry-sdk[fastapi]==2.42.0",
"sqlalchemy==2.0.44",
"tabulate==0.9.0",
"tenacity==8.5.0",
"tiktoken==0.11.0",
"uvicorn==0.35.0",
"loguru==0.7.3",
Expand Down
191 changes: 116 additions & 75 deletions src/mlpa/core/completions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@

import httpx
from fastapi import HTTPException
from tenacity import retry, stop_after_attempt, wait_exponential

from mlpa.core.classes import AuthorizedChatRequest, LitellmRoutingSnapshot
from mlpa.core.config import (
ERROR_CODE_BUDGET_LIMIT_EXCEEDED,
ERROR_CODE_MAX_USERS_REACHED,
ERROR_CODE_RATE_LIMIT_EXCEEDED,
ERROR_CODE_REQUEST_TOO_LARGE,
ERROR_CODE_UPSTREAM_ERROR,
LITELLM_COMPLETION_AUTH_HEADERS,
LITELLM_COMPLETIONS_URL,
env,
Expand All @@ -25,8 +27,12 @@
from mlpa.core.utils import (
get_or_create_user,
is_context_window_error,
is_litellm_upstream_rate_limit,
is_rate_limit_error,
litellm_request,
log_litellm_retry_attempt,
raise_and_log,
should_retry_on_litellm_error,
)

_RATE_LIMIT_REJECTION: dict[int, tuple[PrometheusRejectionReason, str]] = {
Expand All @@ -35,9 +41,42 @@
"86400",
),
ERROR_CODE_RATE_LIMIT_EXCEEDED: (PrometheusRejectionReason.RATE_LIMITED, "60"),
ERROR_CODE_UPSTREAM_ERROR: (PrometheusRejectionReason.UPSTREAM_ERROR, "60"),
}


@retry(
wait=wait_exponential(multiplier=1, min=1, max=4),
stop=stop_after_attempt(5),
retry=lambda state: (
should_retry_on_litellm_error(state.outcome.exception())
if state.outcome.failed
else False
),
before_sleep=log_litellm_retry_attempt,
reraise=True,
)
async def _call_litellm_with_retry(
client: httpx.AsyncClient,
method: str,
url: str,
headers: dict,
json: dict,
timeout: float,
stream: bool = False,
):
"""Helper to make LiteLLM calls with retry logic."""
return await litellm_request(
client,
method,
url,
headers,
json=json,
timeout=timeout,
stream=stream,
)


def _parse_rate_limit_error(error_text: str, user: str) -> int | None:
"""
Parse error response to detect budget or rate limit errors.
Expand All @@ -54,6 +93,8 @@ def _parse_rate_limit_error(error_text: str, user: str) -> int | None:
elif is_rate_limit_error(error_data, ["rate"]):
logger.warning(f"Rate limit exceeded for user {user}: {error_text}")
return ERROR_CODE_RATE_LIMIT_EXCEEDED
elif is_litellm_upstream_rate_limit(error_text):
return ERROR_CODE_UPSTREAM_ERROR
except (json.JSONDecodeError, AttributeError, UnicodeDecodeError):
pass

Expand Down Expand Up @@ -206,53 +247,16 @@ async def stream_completion(authorized_chat_request: AuthorizedChatRequest):
)
try:
client = get_http_client()
async with client.stream(
"POST",
LITELLM_COMPLETIONS_URL,
response = await _call_litellm_with_retry(
client=client,
method="POST",
url=LITELLM_COMPLETIONS_URL,
headers=LITELLM_COMPLETION_AUTH_HEADERS,
json=body,
timeout=env.STREAMING_TIMEOUT_SECONDS,
) as response:
try:
response.raise_for_status()
except httpx.HTTPStatusError as e:
# Read the error response content for streaming responses
error_text_str = ""
try:
error_bytes = await e.response.aread()
error_text_str = error_bytes.decode("utf-8") if error_bytes else ""
except Exception:
pass

if e.response.status_code in {400, 429}:
# Check for budget or rate limit errors
error_code = _parse_rate_limit_error(
error_text_str, authorized_chat_request.user
)
if error_code in _RATE_LIMIT_REJECTION:
reason, _ = _RATE_LIMIT_REJECTION[error_code]
_record_rejection(authorized_chat_request, reason)
yield f'data: {{"error": {error_code}}}\n\n'.encode()
return

# Context window exceeded: detect by error text or upstream 413
if e.response.status_code == 413 or is_context_window_error(
error_text_str
):
logger.warning(
f"Context window exceeded for user {authorized_chat_request.user}"
)
_record_rejection(
authorized_chat_request,
PrometheusRejectionReason.PAYLOAD_TOO_LARGE,
)
yield f'data: {{"error": {ERROR_CODE_REQUEST_TOO_LARGE}}}\n\n'.encode()
return

# For other errors or if we couldn't parse the error
yield raise_and_log(e, True)
return

stream=True,
)
try:
litellm_routing_snapshot = parse_litellm_routing_headers(response.headers)

async for chunk in response.aiter_bytes():
Expand Down Expand Up @@ -342,8 +346,44 @@ async def stream_completion(authorized_chat_request: AuthorizedChatRequest):
completion_tokens,
)
result = PrometheusResult.SUCCESS
finally:
await response.aclose()
except httpx.HTTPStatusError as e:
if not streaming_started:
# Read the error response content for streaming responses
error_text_str = ""
try:
error_bytes = await e.response.aread()
error_text_str = error_bytes.decode("utf-8") if error_bytes else ""
except Exception:
pass
finally:
await e.response.aclose()

if e.response.status_code in {400, 429}:
# Check for budget or rate limit errors
error_code = _parse_rate_limit_error(
error_text_str, authorized_chat_request.user
)
if error_code in _RATE_LIMIT_REJECTION:
reason, _ = _RATE_LIMIT_REJECTION[error_code]
_record_rejection(authorized_chat_request, reason)
yield f'data: {{"error": {error_code}}}\n\n'.encode()
return

# Context window exceeded: detect by error text or upstream 413
if e.response.status_code == 413 or is_context_window_error(error_text_str):
logger.warning(
f"Context window exceeded for user {authorized_chat_request.user}"
)
_record_rejection(
authorized_chat_request,
PrometheusRejectionReason.PAYLOAD_TOO_LARGE,
)
yield f'data: {{"error": {ERROR_CODE_REQUEST_TOO_LARGE}}}\n\n'.encode()
return

# For other errors or if we couldn't parse the error
yield raise_and_log(e, True)
else:
logger.error(f"Upstream service returned an error: {e}")
Expand Down Expand Up @@ -381,40 +421,14 @@ async def get_completion(authorized_chat_request: AuthorizedChatRequest):
)
try:
client = get_http_client()
response = await client.post(
LITELLM_COMPLETIONS_URL,
response = await _call_litellm_with_retry(
client=client,
method="POST",
url=LITELLM_COMPLETIONS_URL,
headers=LITELLM_COMPLETION_AUTH_HEADERS,
json=body,
timeout=env.STREAMING_TIMEOUT_SECONDS,
)
try:
response.raise_for_status()
except httpx.HTTPStatusError as e:
error_text = e.response.text
if e.response.status_code in {400, 429}:
error_code = _parse_rate_limit_error(
error_text, authorized_chat_request.user
)
if error_code in _RATE_LIMIT_REJECTION:
reason, retry_after = _RATE_LIMIT_REJECTION[error_code]
_record_rejection(authorized_chat_request, reason)
raise HTTPException(
status_code=429,
detail={"error": error_code},
headers={"Retry-After": retry_after},
)
# Context window exceeded: detect by error text or upstream 413
if e.response.status_code == 413 or is_context_window_error(error_text):
logger.warning(
f"Context window exceeded for user {authorized_chat_request.user}"
)
_record_rejection(
authorized_chat_request, PrometheusRejectionReason.PAYLOAD_TOO_LARGE
)
raise HTTPException(
status_code=413,
detail={"error": ERROR_CODE_REQUEST_TOO_LARGE},
)
raise_and_log(e)
litellm_routing_snapshot = parse_litellm_routing_headers(response.headers)
data = response.json()
usage = data.get("usage", {})
Expand Down Expand Up @@ -474,6 +488,33 @@ async def get_completion(authorized_chat_request: AuthorizedChatRequest):
)
result = PrometheusResult.SUCCESS
return data
except httpx.HTTPStatusError as e:
error_text = e.response.text
if e.response.status_code in {400, 429}:
error_code = _parse_rate_limit_error(
error_text, authorized_chat_request.user
)
if error_code in _RATE_LIMIT_REJECTION:
reason, retry_after = _RATE_LIMIT_REJECTION[error_code]
_record_rejection(authorized_chat_request, reason)
raise HTTPException(
status_code=429,
detail={"error": error_code},
headers={"Retry-After": retry_after},
)
# Context window exceeded: detect by error text or upstream 413
if e.response.status_code == 413 or is_context_window_error(error_text):
logger.warning(
f"Context window exceeded for user {authorized_chat_request.user}"
)
_record_rejection(
authorized_chat_request, PrometheusRejectionReason.PAYLOAD_TOO_LARGE
)
raise HTTPException(
status_code=413,
detail={"error": ERROR_CODE_REQUEST_TOO_LARGE},
)
raise_and_log(e)
except HTTPException:
raise
except Exception as e:
Expand Down
1 change: 1 addition & 0 deletions src/mlpa/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,7 @@ def __init__(self):
ERROR_CODE_RATE_LIMIT_EXCEEDED: int = 2
ERROR_CODE_REQUEST_TOO_LARGE: int = 3
ERROR_CODE_MAX_USERS_REACHED: int = 4
ERROR_CODE_UPSTREAM_ERROR: int = 5

RATE_LIMIT_ERROR_RESPONSE = {
429: {
Expand Down
1 change: 1 addition & 0 deletions src/mlpa/core/prometheus_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ class PrometheusRejectionReason(StrEnum):
RATE_LIMITED = "rate_limited"
PAYLOAD_TOO_LARGE = "payload_too_large"
SIGNUP_CAP_EXCEEDED = "signup_cap_exceeded"
UPSTREAM_ERROR = "upstream_error"


BUCKETS_FAST_AUTH = (0.001, 0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, float("inf"))
Expand Down
Loading
Loading