Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion src/opengradient/client/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,12 @@

from ..types import TEE_LLM, ResponseFormat, StreamChoice, StreamChunk, StreamDelta, TextGenerationOutput, x402SettlementMode
from .opg_token import Permit2ApprovalResult, ensure_opg_approval
from .tee_connection import RegistryTEEConnection, StaticTEEConnection, TEEConnectionInterface
from .tee_connection import (
ActiveTEE,
RegistryTEEConnection,
StaticTEEConnection,
TEEConnectionInterface,
)
from .tee_registry import TEERegistry

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -140,6 +145,15 @@ async def close(self) -> None:
"""Cancel the background refresh loop and close the HTTP client."""
await self._tee.close()

def resolve_tee_connection(self, tee_id: Optional[str] = None) -> ActiveTEE:
"""Resolve the current TEE or a specific active registry TEE.

This is primarily for backend relays that need SDK-managed TEE routing,
TLS pinning, and x402 clients without using the chat/completion helpers
directly, for example when forwarding OHTTP ciphertext.
"""
return self._tee.resolve(tee_id)

# ── Request helpers ─────────────────────────────────────────────────

def _headers(self, settlement_mode: x402SettlementMode) -> Dict[str, str]:
Expand Down
71 changes: 69 additions & 2 deletions src/opengradient/client/tee_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,12 @@
from x402 import x402Client
from x402.http.clients import x402HttpxClient

from .tee_registry import TEE_TYPE_LLM_PROXY, TEERegistry, build_ssl_context_from_der
from .tee_registry import (
TEE_TYPE_LLM_PROXY,
TEEEndpoint,
TEERegistry,
build_ssl_context_from_der,
)

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -38,11 +43,23 @@ class TEEConnectionInterface(Protocol):
"""Interface for TEE connection implementations."""

def get(self) -> ActiveTEE: ...
def resolve(self, tee_id: Optional[str] = None) -> ActiveTEE: ...
def ensure_refresh_loop(self) -> None: ...
async def reconnect(self) -> None: ...
async def close(self) -> None: ...


def _normalize_tee_id(tee_id: Optional[str]) -> Optional[str]:
if not tee_id:
return None
normalized = tee_id.strip().lower()
if not normalized:
return None
if not normalized.startswith("0x"):
normalized = f"0x{normalized}"
return normalized


class StaticTEEConnection:
"""TEE connection with a hardcoded endpoint URL.

Expand All @@ -63,6 +80,14 @@ def get(self) -> ActiveTEE:
"""Return a snapshot of the current TEE connection."""
return self._active

def resolve(self, tee_id: Optional[str] = None) -> ActiveTEE:
"""Return the static connection.

Static/dev connections do not have a registry to validate selected
TEE ids against, so they always resolve to the configured endpoint.
"""
return self._active

def _connect(self) -> ActiveTEE:
return ActiveTEE(
endpoint=self._endpoint,
Expand Down Expand Up @@ -106,6 +131,7 @@ def __init__(self, x402_client: x402Client, registry: TEERegistry):

self._refresh_lock = asyncio.Lock()
self._refresh_task: Optional[asyncio.Task] = None
self._active_by_tee_id: dict[str, ActiveTEE] = {}

self._active: ActiveTEE = self._connect()

Expand All @@ -115,9 +141,47 @@ def get(self) -> ActiveTEE:
"""Return a snapshot of the current TEE connection."""
return self._active

def resolve(self, tee_id: Optional[str] = None) -> ActiveTEE:
"""Resolve a TEE connection, optionally pinned to an active TEE id.

Backend OHTTP relays can use this when the browser encrypted to a
specific on-chain TEE config, while the backend still owns x402 payment.
"""
normalized_tee_id = _normalize_tee_id(tee_id)
if normalized_tee_id is None:
return self._active

active_tee_id = _normalize_tee_id(self._active.tee_id)
if normalized_tee_id == active_tee_id:
return self._active

for tee in self._registry.get_active_tees_by_type(TEE_TYPE_LLM_PROXY):
if _normalize_tee_id(tee.tee_id) != normalized_tee_id:
continue

cached = self._active_by_tee_id.get(normalized_tee_id)
if (
cached is not None
and cached.endpoint.rstrip("/") == tee.endpoint.rstrip("/")
):
return cached

resolved = self._connect_to_tee(tee)
self._active_by_tee_id[normalized_tee_id] = resolved
logger.info(
"Resolved selected TEE endpoint from registry: %s (teeId=%s)",
resolved.endpoint,
normalized_tee_id,
)
return resolved

raise ValueError(
f"Selected TEE is not active in the registry: {normalized_tee_id}"
)

# ── Connection management ───────────────────────────────────────────

def _resolve_tee(self):
def _resolve_tee(self) -> TEEEndpoint:
"""Resolve TEE endpoint and metadata from the on-chain registry.

Returns:
Expand All @@ -141,7 +205,10 @@ def _resolve_tee(self):
def _connect(self) -> ActiveTEE:
"""Resolve TEE from registry and create a secure HTTP client."""
tee = self._resolve_tee()
return self._connect_to_tee(tee)

def _connect_to_tee(self, tee: TEEEndpoint) -> ActiveTEE:
"""Create a pinned x402 HTTP client for a resolved registry TEE."""
ssl_ctx = build_ssl_context_from_der(tee.tls_cert_der)
return ActiveTEE(
endpoint=tee.endpoint,
Expand Down
Loading