Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 10 additions & 18 deletions evolution/core/hermes_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -721,24 +721,16 @@ def _maybe_resolve_nous_lm(
target_model: str,
role: Role,
) -> Optional[ResolvedLM]:
"""Build a NousLM-backed ResolvedLM when the auth.json pool entry
looks OAuth-managed; return None to let the caller fall through to
the generic OpenAI-wire handler when the entry is just an env-var-
style API key.

Nous uses a two-stage credential model: an OAuth access_token
(long-lived) is exchanged for a short-lived agent_key that's the
actual inference Bearer. NousLM handles both: refresh access_token
in-memory when expiring, mint a fresh agent_key from it, re-mint on
inference 401. See evolution/core/nous_lm.py.

The "looks OAuth-managed" signal: pool entry has a refresh_token. A
pool entry without refresh_token is either env-var-only (NOUS_API_KEY
set, no real OAuth state) or hand-edited; let the caller fall
through to direct pass-through so we don't break that setup.

The CodexLM-equivalent NousLM import is lazy to avoid a circular
dependency: nous_lm imports HermesProviderError from this module.
"""Build a NousLM-backed ResolvedLM when the pool entry has a
refresh_token (the OAuth-managed signal). Returns None for env-var
or hand-edited entries with an agent_key already present (caller
falls through to the generic OpenAI-wire handler), and raises for
partial OAuth setups (access_token without refresh_token or
agent_key) so the operator gets a `hermes model` recovery hint
instead of a silent inference 401.

See ``evolution/core/nous_lm.py`` for the two-stage credential
model and the in-memory refresh + mint flow.
"""
pool_entry = _pick_pool_entry(auth_store, "nous")
if pool_entry is None:
Expand Down
86 changes: 58 additions & 28 deletions evolution/core/nous_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,8 +197,8 @@ def __init__(
agent_key_expires_at=agent_key_expires_at,
)

# Initial mint if the constructor-supplied agent_key is missing or
# already expiring. Cheap on the happy path; one POST otherwise.
# Pay the mint cost at construction so the first forward() doesn't
# see a synchronous round-trip surprise.
self._ensure_credentials()

# ------------------------------------------------------------------
Expand All @@ -207,8 +207,13 @@ def __init__(

def _oauth_needs_refresh(self) -> bool:
if self._shared_state.oauth_expires_at is None:
# Unknown expiry → don't speculatively refresh; let the mint
# call surface a 401 if the access_token is actually dead.
# Unknown expiry → don't speculatively refresh; the mint
# call's own 401-triggers-refresh-retry path catches a
# genuinely-dead access_token. Note _agent_key_needs_mint
# makes the opposite choice (defaults True on unknown
# expiry) because there's no equivalent recovery for a
# missing agent_key — inference would just 401 with no
# built-in retry.
return False
return (
time.time() + OAUTH_REFRESH_SKEW_SECONDS
Expand All @@ -219,16 +224,15 @@ def _agent_key_needs_mint(self) -> bool:
if not self._shared_state.agent_key:
return True
if self._shared_state.agent_key_expires_at is None:
# Have a key but no expiry — treat as needing re-mint to be
# safe. Cheaper than letting it 401 mid-run.
# Have a key but no expiry re-mint defensively. See
# _oauth_needs_refresh for the asymmetric reasoning.
return True
return (
time.time() + AGENT_KEY_REFRESH_SKEW_SECONDS
>= self._shared_state.agent_key_expires_at
)

def _sync_from_shared_state(self) -> None:
"""Pull the latest agent_key out of shared state into self.kwargs."""
self.kwargs["api_key"] = self._shared_state.agent_key or ""

def _ensure_credentials(self) -> None:
Expand All @@ -254,8 +258,16 @@ def _force_remint(self) -> None:
"""Skip skew check and re-mint immediately. Called when an inference
call returned 401 — the cached agent_key is bad and we don't want
to wait for the skew window.

Pre-checks the OAuth expiry too. Without this, a stale OAuth +
revoked agent_key combo takes three round-trips (mint→401→
refresh→mint); with the pre-check it's two (refresh→mint). The
mint's 401-triggers-refresh path still backstops the case where
OAuth looks fresh by skew but the portal has revoked it.
"""
with self._shared_state.lock:
if self._oauth_needs_refresh():
self._refresh_oauth()
self._mint_agent_key(allow_oauth_retry=True)
self._sync_from_shared_state()

Expand Down Expand Up @@ -383,6 +395,15 @@ def _mint_agent_key(self, *, allow_oauth_retry: bool) -> None:
raise HermesProviderError(_format_mint_error(response))

def _absorb_mint_response(self, response: httpx.Response) -> None:
"""Parse a 200 mint response into shared state.

Tolerates both the current ``api_key`` field and the older
``agent_key`` shape, and prefers a server-supplied ``expires_at``
ISO 8601 timestamp over the relative ``expires_in``. When neither
expiry field is parseable, falls back to the requested floor TTL
with a warning so portal protocol drift doesn't silently cache a
key for longer than the server intended.
"""
try:
payload = response.json()
except ValueError as exc:
Expand All @@ -391,9 +412,6 @@ def _absorb_mint_response(self, response: httpx.Response) -> None:
"Run `hermes model` to re-authenticate."
) from exc

# Hermes uses both ``api_key`` (current portal field) and falls back
# to ``agent_key`` (older shape). Mirror both so a portal protocol
# rev doesn't break us.
agent_key = payload.get("api_key") or payload.get("agent_key")
if not isinstance(agent_key, str) or not agent_key.strip():
raise HermesProviderError(
Expand Down Expand Up @@ -489,7 +507,10 @@ def _format_oauth_error(response: httpx.Response) -> str:
"""
code, detail = _parse_error_body(response)

if code == "refresh_token_reused" or "reuse" in detail.lower():
# Match the explicit code field, not the free-form detail string —
# a substring search on detail would false-positive on unrelated
# portal messages like "this is not a reusable connection".
if "reused" in code.lower():
return (
"Nous Portal refresh token was already consumed by another "
"client (the portal enforces single-use refresh-token rotation). "
Expand Down Expand Up @@ -529,27 +550,36 @@ def _format_mint_error(response: httpx.Response) -> str:
def _parse_error_body(response: httpx.Response) -> tuple[str, str]:
"""Best-effort parse of OAuth-style error JSON. Returns (code, detail)
with sensible defaults when the body is missing or malformed.

On JSON parse failure (e.g., a CDN returning an HTML error page,
or a portal outage returning text), ``detail`` falls back to a
truncated snippet of the raw body so the operator can correlate
the failure with what the upstream actually sent.
"""
code = "unknown"
detail = f"status {response.status_code}"
try:
body = response.json()
if isinstance(body, dict):
err = body.get("error")
if isinstance(err, dict):
# OpenAI shape: {"error": {"code": ..., "message": ...}}
nested_code = err.get("code") or err.get("type")
if isinstance(nested_code, str) and nested_code.strip():
code = nested_code.strip()
nested_msg = err.get("message")
if isinstance(nested_msg, str) and nested_msg.strip():
detail = nested_msg.strip()
elif isinstance(err, str) and err.strip():
# OAuth-spec shape: {"error": "code", "error_description": "..."}
code = err.strip()
desc = body.get("error_description") or body.get("message")
if isinstance(desc, str) and desc.strip():
detail = desc.strip()
except ValueError:
pass
snippet = (response.text or "").strip()
if snippet:
detail = f"status {response.status_code}: {snippet[:512]}"
return code, detail

if isinstance(body, dict):
err = body.get("error")
if isinstance(err, dict):
# OpenAI shape: {"error": {"code": ..., "message": ...}}
nested_code = err.get("code") or err.get("type")
if isinstance(nested_code, str) and nested_code.strip():
code = nested_code.strip()
nested_msg = err.get("message")
if isinstance(nested_msg, str) and nested_msg.strip():
detail = nested_msg.strip()
elif isinstance(err, str) and err.strip():
# OAuth-spec shape: {"error": "code", "error_description": "..."}
code = err.strip()
desc = body.get("error_description") or body.get("message")
if isinstance(desc, str) and desc.strip():
detail = desc.strip()
return code, detail
110 changes: 110 additions & 0 deletions tests/core/test_nous_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -753,6 +753,116 @@ def test_aforward_propagates_second_401_as_hermes_provider_error(self):
)


class TestForceRemintPreChecksOAuth:
"""When inference 401s and we force a re-mint, an OAuth that's also
expiring should be refreshed FIRST, not re-discovered via mint→401→
refresh→mint (three round-trips). Saves one hop on the rare double-
stale path; the mint's own 401-retry still backstops the case where
OAuth looks fresh by skew but the portal has revoked it.
"""

def test_force_remint_refreshes_oauth_when_also_expiring(self):
# Build LM with both creds expiring — initial mint already fires.
with patch("evolution.core.nous_lm.httpx.Client") as mock_cls:
mock_cls.return_value = _mock_httpx_post(
[
# Initial _ensure_credentials path: OAuth is stale, refresh first.
_mock_response(json_body={"access_token": "init-refresh", "expires_in": 86400}),
# Then mint with the refreshed access_token.
_mock_response(json_body={"api_key": "init-mint", "expires_in": 1800}),
]
)
lm = NousLM(
model="openai/test-model",
access_token="stale",
refresh_token="r",
oauth_expires_at=time.time() + 30, # forces refresh
)

# Now manually expire the OAuth again and call _force_remint.
# Expect: refresh POST + mint POST (2 calls), NOT mint→401→refresh→mint (3+).
lm._shared_state.oauth_expires_at = time.time() + 30 # stale again
with patch("evolution.core.nous_lm.httpx.Client") as mock_cls:
mock_cls.return_value = _mock_httpx_post(
[
_mock_response(json_body={"access_token": "force-refresh", "expires_in": 86400}),
_mock_response(json_body={"api_key": "force-mint", "expires_in": 1800}),
]
)
lm._force_remint()
client = mock_cls.return_value
# Exactly 2 calls, in order: refresh THEN mint with the
# fresh access_token.
assert client.post.call_count == 2
paths = [c.args[0] for c in client.post.call_args_list]
assert paths[0].endswith("/api/oauth/token")
assert paths[1].endswith("/api/oauth/agent-key")
# Mint Bearer is the freshly-refreshed token, not the stale one.
assert (
client.post.call_args_list[1].kwargs["headers"]["Authorization"]
== "Bearer force-refresh"
)


class TestParseErrorBodyTextFallback:
"""When the OAuth/mint response body isn't JSON (CDN HTML page,
portal outage HTML, etc.), the error message should include a
snippet of what the upstream actually sent — not just a generic
'unknown: status N'.
"""

def test_html_body_appears_in_error_detail(self):
from evolution.core.nous_lm import _parse_error_body

mock = MagicMock(spec=httpx.Response)
mock.status_code = 502
mock.json = MagicMock(side_effect=ValueError("not json"))
mock.text = "<html><body>Cloudflare 1020 Access Denied</body></html>"

code, detail = _parse_error_body(mock)
assert code == "unknown"
assert "Cloudflare 1020" in detail
assert "status 502" in detail

def test_empty_body_falls_back_to_status_only(self):
from evolution.core.nous_lm import _parse_error_body

mock = MagicMock(spec=httpx.Response)
mock.status_code = 503
mock.json = MagicMock(side_effect=ValueError("not json"))
mock.text = ""

code, detail = _parse_error_body(mock)
assert code == "unknown"
assert detail == "status 503"


class TestReuseSubstringMatchesCodeNotDetail:
"""Regression guard: the 'reused' check is on the code field, not on
the free-form detail string. A portal returning a server-error body
like 'this is not a reusable connection' must NOT trigger the
refresh_token_reused user-facing message.
"""

def test_reusable_in_detail_does_not_trigger_reuse_message(self):
from evolution.core.nous_lm import _format_oauth_error

mock = MagicMock(spec=httpx.Response)
mock.status_code = 500
mock.json = MagicMock(
return_value={
"error": {
"code": "internal_error",
"message": "this is not a reusable connection",
}
}
)

msg = _format_oauth_error(mock)
assert "another client" not in msg
assert "single-use refresh-token rotation" not in msg


class TestSharedStateInvariants:
def test_post_init_rejects_partial_agent_key_state(self):
# _SharedNousState __post_init__ catches the construction-time
Expand Down
Loading