Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 30 additions & 3 deletions openviking/utils/model_retry.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import asyncio
import logging
import random
import re
import threading
import time
from typing import Awaitable, Callable, TypeVar
Expand Down Expand Up @@ -68,6 +69,30 @@
"connection reset",
)

# Pre-compile regex for purely numeric patterns to avoid substring false positives
# (e.g. "413" matching inside a hex request ID like "d7c9130f344...").
_NUMERIC_PATTERN_RE: dict[str, re.Pattern] = {}


def _get_numeric_pattern_re(pattern: str) -> re.Pattern:
if pattern not in _NUMERIC_PATTERN_RE:
_NUMERIC_PATTERN_RE[pattern] = re.compile(r"\b" + re.escape(pattern) + r"\b")
return _NUMERIC_PATTERN_RE[pattern]


def _pattern_matches(text_lower: str, text_compact: str, pattern: str) -> bool:
"""Check if pattern matches in text, using word-boundary for numeric-only patterns.

Numeric-only patterns (e.g. ``"413"``) are matched with ``\\b`` word boundaries
to prevent false positives inside request IDs or hex strings. Non-numeric
patterns use plain substring matching as before.
"""
if pattern.isdigit():
return bool(_get_numeric_pattern_re(pattern).search(text_lower)) or bool(
_get_numeric_pattern_re(pattern).search(text_compact)
)
return pattern in text_lower or pattern in text_compact


def classify_api_error(error: Exception) -> str:
"""Classify an API error as permanent, quota_exceeded, transient, or unknown.
Expand All @@ -89,13 +114,14 @@ def classify_api_error(error: Exception) -> str:
text_lower = text.lower()
text_compact = text_lower.replace(" ", "")
for pattern in INPUT_TOO_LARGE_PATTERNS:
if pattern in text_lower or pattern in text_compact:
if _pattern_matches(text_lower, text_compact, pattern):
return ERROR_CLASS_INPUT_TOO_LARGE

for text in texts:
text_lower = text.lower()
text_compact = text_lower.replace(" ", "")
for pattern in PERMANENT_API_ERROR_PATTERNS:
if pattern in text_lower:
if _pattern_matches(text_lower, text_compact, pattern):
return ERROR_CLASS_PERMANENT

# Check quota_exceeded *before* transient so that "429 … AccountQuotaExceeded"
Expand All @@ -108,8 +134,9 @@ def classify_api_error(error: Exception) -> str:

for text in texts:
text_lower = text.lower()
text_compact = text_lower.replace(" ", "")
for pattern in TRANSIENT_API_ERROR_PATTERNS:
if pattern in text_lower:
if _pattern_matches(text_lower, text_compact, pattern):
return ERROR_CLASS_TRANSIENT

return ERROR_CLASS_UNKNOWN
Expand Down
23 changes: 23 additions & 0 deletions tests/unit/test_model_retry.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,3 +154,26 @@ def _call():
retry_sync(_call, max_retries=5)

assert attempts["count"] == 1


# --- numeric pattern word-boundary tests ---


def test_429_with_request_id_containing_413_is_transient():
"""A 429 error whose request ID happens to contain '413' must NOT be
misclassified as INPUT_TOO_LARGE (the original bug)."""
error = RuntimeError(
"Volcengine hybrid embedding failed: Error code: 429 - "
"{'error': {'code': 'ModelAccountRpmRateLimitExceeded', "
"'message': 'RPM limit exceeded', 'param': '', "
"'type': 'TooManyRequests'}, "
"'request_id': '0217801248873024288fe53d7c9130f34413480585e683685bc95'}"
)
assert classify_api_error(error) == ERROR_CLASS_TRANSIENT


def test_numeric_status_code_inside_longer_number_is_not_matched():
"""Status code patterns must not match inside longer numbers
(e.g. '400' must not match '1400')."""
assert classify_api_error(RuntimeError("status: 1400 OK")) == "unknown"
assert classify_api_error(RuntimeError("status: 5020 OK")) == "unknown"
Loading