Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 95 additions & 2 deletions tee_gateway/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,66 @@ def _patched_read_body_bytes(environ):
x402_flask._read_body_bytes = _patched_read_body_bytes


def _patched_stream_session_response(
self,
environ,
start_response,
context,
session_id,
payment_payload,
payment_requirements,
):
"""Expose x402's per-request cost context to Flask route handlers.

OHTTP requests arrive as ciphertext and return ciphertext, so the x402
middleware cannot parse request_json/response_json from the outer HTTP
bodies. The OHTTP controller decrypts the inner request and plaintext
response inside the enclave; this patch gives it a request-local dict where
it can attach those inner JSON objects for dynamic settlement.
"""
self._start_reaper()

request_body_bytes = x402_flask._read_body_bytes(environ)
request_json = x402_flask._try_parse_json(request_body_bytes)
parsed_request_json = (
request_json if isinstance(request_json, (dict, list)) else None
)

x402_flask.g.payment_payload = payment_payload
x402_flask.g.payment_requirements = payment_requirements
x402_flask.g.x402_session_id = session_id

cost_context = {
"method": context.method,
"path": context.path,
"request_body_bytes": request_body_bytes,
"request_json": parsed_request_json,
"payment_payload": payment_payload,
"payment_requirements": payment_requirements,
}
environ["x402.cost_context"] = cost_context

status_capture = x402_flask.StatusCapture(start_response)
status_capture.add_header(x402_flask.UPTO_SESSION_HEADER, session_id)

upstream_iter = self._original_wsgi(environ, status_capture)

return x402_flask.StreamingSessionResponse(
upstream_iter,
middleware=self,
session_id=session_id,
cost_context=cost_context,
status_ref=status_capture,
)


setattr(
x402_flask.PaymentMiddleware,
"_stream_session_response",
_patched_stream_session_response,
)


def _session_cost_calculator(ctx: dict) -> int:
# The chat/completions controllers compute cost in-band and embed it on
# the response as a SessionCost model BEFORE returning. We parse it back
Expand All @@ -173,7 +233,10 @@ def _session_cost_calculator(ctx: dict) -> int:
# has already logged CRITICAL in that case.
from .pricing import SessionCost

response_json = ctx.get("response_json")
if ctx.get("path") == "/v1/ohttp":
response_json = ctx.get("inner_response_json")
else:
response_json = ctx.get("response_json")
if not isinstance(response_json, dict):
raise ValueError("response_json missing or not a dict")
cost_block = response_json.get("opengradient")
Expand Down Expand Up @@ -257,8 +320,35 @@ def _init_payment_middleware(facilitator_url: str) -> None:
mime_type="application/json",
description="Completion",
),
"POST /v1/ohttp": RouteConfig(
Comment thread
dixitaniket marked this conversation as resolved.
accepts=[
PaymentOption(
scheme="upto",
pay_to=EVM_PAYMENT_ADDRESS,
price=AssetAmount(
amount=CHAT_COMPLETIONS_OPG_SESSION_MAX_SPEND,
asset=BASE_MAINNET_OPG_ADDRESS,
extra={
"name": "OpenGradient",
"version": "1",
"assetTransferMethod": "permit2",
},
),
network=BASE_MAINNET_NETWORK,
),
],
extensions={
**declare_erc20_approval_gas_sponsoring_extension(),
},
mime_type="message/ohttp-req",
description="OHTTP-wrapped chat completion",
),
}

inner_wsgi_app = application.wsgi_app
flask_app = getattr(application, "app", application)
flask_app.config["OHTTP_INNER_WSGI_APP"] = inner_wsgi_app

# Return value intentionally discarded — PaymentMiddleware.__init__ self-wires
# by setting application.wsgi_app = self._wsgi_middleware internally.
payment_middleware(
Expand Down Expand Up @@ -499,14 +589,17 @@ def create_app():

@application.before_request
def _check_pricing_ready():
if request.path not in ("/v1/chat/completions", "/v1/completions"):
if request.path not in ("/v1/chat/completions", "/v1/completions", "/v1/ohttp"):
Comment thread
dixitaniket marked this conversation as resolved.
return
try:
_price_feed.get_price()
except ValueError as exc:
logger.warning("Rejecting inference request — price feed unavailable: %s", exc)
return jsonify({"error": f"Pricing unavailable: {exc}"}), 503

if request.path == "/v1/ohttp":
return

body = request.get_json(silent=True, cache=True) or {}
model = body.get("model")
if model:
Expand Down
2 changes: 1 addition & 1 deletion tee_gateway/controllers/chat_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ def _create_non_streaming_response(chat_request: CreateChatCompletionRequest):
openai_response["usage"] = usage
cost = compute_session_cost(chat_request.model, usage)
if cost is not None:
openai_response["opengradient"] = cost
openai_response["opengradient"] = cost.model_dump(mode="json")

# Validate schema (the extra tee_* fields are preserved by returning dict directly)
CreateChatCompletionResponse.from_dict(openai_response)
Expand Down
2 changes: 1 addition & 1 deletion tee_gateway/controllers/completions_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def create_completion(body):
if usage:
cost = compute_session_cost(body.model, usage)
if cost is not None:
completion_response["opengradient"] = cost
completion_response["opengradient"] = cost.model_dump(mode="json")
return completion_response

except Exception as e:
Expand Down
145 changes: 124 additions & 21 deletions tee_gateway/controllers/ohttp_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,17 +154,12 @@ def create_anonymous_chat_completion():
return _error(400, "inner payload must be a JSON object")

chat_body = _scrub(chat_body)
_set_inner_cost_context(flask_request, request_json=chat_body)
body_bytes = json.dumps(chat_body, separators=(",", ":")).encode("utf-8")

# The relay pays — x-payment is a standard outer-request header, not
# inside the encrypted envelope. Pass it through to the inner endpoint
# so x402 verifies and settles exactly as it does for a normal call.
payment_header = flask_request.headers.get("X-Payment")

sub_status, sub_headers, sub_iter = _wsgi_subrequest(
path="/v1/chat/completions",
body_bytes=body_bytes,
payment_header=payment_header,
)

inner_content_type = next(
Expand All @@ -177,11 +172,21 @@ def create_anonymous_chat_completion():
)

if is_streaming:
return _build_streaming_response(decap, sub_status, sub_headers, sub_iter)
cost_context = flask_request.environ.get("x402.cost_context")
return _build_streaming_response(
decap,
sub_status,
sub_headers,
sub_iter,
cost_context if isinstance(cost_context, dict) else None,
)

# Non-streaming: drain into bytes (this also triggers x402's
# post-response settlement via the WSGI iterator's close()).
# Non-streaming: drain into bytes, record the inner plaintext cost block
# for outer /v1/ohttp x402 settlement, then seal the response.
body_bytes_out = _drain(sub_iter)
_set_inner_response_cost_context(
flask_request, body_bytes_out, status_code=sub_status
)
return _build_outer_response(
decap, sub_status, sub_headers, body_bytes_out, inner_content_type
)
Expand Down Expand Up @@ -234,6 +239,7 @@ def _build_streaming_response(
status: int,
headers: list[tuple[str, str]],
sub_iter: Iterator[bytes],
cost_context: dict[str, Any] | None,
) -> Response:
"""Chunked OHTTP response (draft-ietf-ohai-chunked-ohttp-08).

Expand Down Expand Up @@ -262,10 +268,12 @@ def _stream() -> Iterator[bytes]:
yield encrypter.header()

pending: bytes | None = None
plaintext_chunks: list[bytes] = []
try:
for chunk in sub_iter:
if not chunk:
continue
plaintext_chunks.append(chunk)
if pending is not None:
yield encrypter.encrypt_chunk(pending, is_final=False)
pending = chunk
Expand All @@ -274,9 +282,11 @@ def _stream() -> Iterator[bytes]:
# undetected truncation.
yield encrypter.encrypt_chunk(pending or b"", is_final=True)
finally:
_set_inner_stream_cost_context(
cost_context, plaintext_chunks, status_code=status
)
close = getattr(sub_iter, "close", None)
if callable(close):
# Triggers x402's streaming-session settlement.
close()

return Response(
Expand Down Expand Up @@ -329,17 +339,15 @@ def _extract_cost_headers(body_bytes: bytes) -> dict[str, str]:
def _wsgi_subrequest(
path: str,
body_bytes: bytes,
payment_header: str | None,
) -> tuple[int, list[tuple[str, str]], Iterator[bytes]]:
"""Issue an in-process WSGI request through the app's full middleware stack.

Returns ``(status_code, headers, body_iterator)``. The caller is
responsible for draining and closing the iterator (close() triggers
x402's post-response settlement). We invoke ``current_app.wsgi_app``
directly so the x402 payment middleware (which wraps ``wsgi_app`` at
injection time) runs the same way it would for an external HTTP
request to the same path — including the pre-inference pricing gate,
payment verification, cost settlement and TEE response signing.
responsible for draining and closing the iterator. The outer /v1/ohttp
request is the x402-paid boundary, so this inner chat dispatch uses the
pre-x402 WSGI app saved at middleware installation time. That avoids
charging/verifying the same relay payment twice while still running
connexion routing, validation, TEE signing, and provider inference.
"""
outer_env = flask_request.environ
sub_env: dict[str, Any] = {
Expand All @@ -361,9 +369,6 @@ def _wsgi_subrequest(
"wsgi.input": io.BytesIO(body_bytes),
}
)
if payment_header:
sub_env["HTTP_X_PAYMENT"] = payment_header

# The OpenAPI spec declares a global ApiKeyAuth requirement and connexion
# enforces it before our handler runs (returns 401 "No authorization
# token provided"). The security function (security_controller.py) is an
Expand All @@ -382,7 +387,8 @@ def _start_response(status: str, headers: list, exc_info: Any = None):
captured["headers"] = headers
return lambda _chunk: None

iterator = current_app.wsgi_app(sub_env, _start_response)
inner_wsgi = current_app.config.get("OHTTP_INNER_WSGI_APP") or current_app.wsgi_app
iterator = inner_wsgi(sub_env, _start_response)
status_code = int(captured["status"].split(" ", 1)[0])
# Don't wrap in iter() — that would strip the iterable's close() method,
# which the caller relies on to trigger x402's post-response settlement.
Expand Down Expand Up @@ -411,3 +417,100 @@ def _error(status: int, message: str) -> tuple[dict, int]:
also returned plaintext so the relay can surface them to the client —
they never contain user prompts."""
return {"error": message}, status


def _set_inner_cost_context(
req_or_context,
*,
request_json: dict[str, Any] | None = None,
response_json: dict[str, Any] | None = None,
status_code: int | None = None,
) -> None:
if isinstance(req_or_context, dict):
cost_context = req_or_context
elif req_or_context is None:
return
else:
cost_context = req_or_context.environ.get("x402.cost_context")
if not isinstance(cost_context, dict):
return
if request_json is not None:
cost_context["inner_request_json"] = request_json
if response_json is not None:
cost_context["inner_response_json"] = response_json
if status_code is not None:
cost_context["inner_status_code"] = status_code


def _set_inner_response_cost_context(
req,
body_bytes: bytes,
*,
status_code: int,
) -> None:
try:
response_json = json.loads(body_bytes.decode("utf-8"))
except (UnicodeDecodeError, json.JSONDecodeError):
response_json = None
_set_inner_cost_context(
req,
response_json=response_json if isinstance(response_json, dict) else None,
status_code=status_code,
)


def _set_inner_stream_cost_context(
req,
chunks: list[bytes],
*,
status_code: int,
) -> None:
body = b"".join(chunks)
response_json = _parse_final_sse_json(body)
_set_inner_cost_context(
req,
response_json=response_json,
status_code=status_code,
)


def _parse_final_sse_json(body: bytes) -> dict[str, Any] | None:
last_json: dict[str, Any] | None = None
for line in body.decode("utf-8", errors="replace").splitlines():
line = line.strip()
if not line.startswith("data:"):
continue
payload = line[len("data:") :].strip()
if not payload or payload == "[DONE]":
continue
try:
parsed = json.loads(payload)
except json.JSONDecodeError:
continue
if isinstance(parsed, dict):
last_json = parsed
return last_json


def _sealed_error(req, decap: ohttp.DecapsulatedRequest, status: int, message: str):
body = {"error": message}
_set_inner_cost_context(req, response_json=body, status_code=status)
return _sealed_json_response(decap, status, body)


def _sealed_json_response(
decap: ohttp.DecapsulatedRequest,
status: int,
body_obj: Any,
) -> Response:
inner_json = json.dumps(
{"status": status, "body": body_obj},
separators=(",", ":"),
).encode("utf-8")

sealed = ohttp.encapsulate_response(decap.response_key, decap.enc, inner_json)
return Response(
sealed,
status=200,
mimetype=OHTTP_RESPONSE_MEDIA_TYPE,
)
7 changes: 4 additions & 3 deletions tee_gateway/test/test_ohttp_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,9 +235,10 @@ def inner():
assert resp.headers.get("X-Payment-Required") == "true"
assert resp.headers.get("X-Tee-Signature") == "sig"
assert "Set-Cookie" not in resp.headers
# The relay's X-Payment was forwarded into the inner env so the
# x402 middleware can verify it.
assert captured["env"]["HTTP_X_PAYMENT"] == "client-payment-blob"
# The outer /v1/ohttp request is the paid x402 boundary. The decrypted
# in-process chat subrequest bypasses x402, so the payment blob must not
# be forwarded into the inner env.
assert "HTTP_X_PAYMENT" not in captured["env"]
# WSGI iterator was drained AND closed (drains x402 settlement).
assert captured["iter"].closed is True

Expand Down
Loading
Loading