Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 34 additions & 4 deletions tee_gateway/controllers/ohttp_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@

OHTTP_RESPONSE_MEDIA_TYPE = "message/ohttp-res"
OHTTP_CHUNKED_RESPONSE_MEDIA_TYPE = "message/ohttp-chunked-res"
OHTTP_BILLING_FRAME_MAGIC = b"\n--opengradient-ohttp-billing-v1--\n"
_SSE_CONTENT_TYPE = "text/event-stream"

# Cap on the encapsulated request size. The inner payload is a chat-completion
Expand Down Expand Up @@ -250,10 +251,9 @@ def _build_streaming_response(
``pending`` buffer below.

Cost can't be exposed as outer headers (those are already flushed
before the body); the relay bills via x402 settlement metadata
(X-Upto-Session header, set up-front). The client reads cost from the
``opengradient`` block on the final SSE event inside the decrypted
stream.
before the body), so the gateway emits a private plaintext billing
frame for the relay to strip and process before forwarding bytes to
the browser.
"""
# See _build_outer_response: keep as a list so duplicate HTTP header
# values (e.g. multiple WWW-Authenticate challenges) survive forwarding.
Expand All @@ -277,6 +277,16 @@ def _stream() -> Iterator[bytes]:
if pending is not None:
yield encrypter.encrypt_chunk(pending, is_final=False)
pending = chunk
response_json = _parse_final_sse_json(b"".join(plaintext_chunks))
_set_inner_cost_context(
cost_context,
response_json=response_json,
status_code=status,
)
billing_frame = _build_billing_frame(response_json)
if billing_frame:
yield billing_frame
Comment on lines +286 to +288
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

where do we consume this?


# Always emit exactly one final chunk so the AAD=b"final"
# marker is present — that's what protects clients from
# undetected truncation.
Expand Down Expand Up @@ -336,6 +346,26 @@ def _extract_cost_headers(body_bytes: bytes) -> dict[str, str]:
}


def _build_billing_frame(response_json: dict[str, Any] | None) -> bytes:
"""Build the private gateway-to-relay billing frame for streaming OHTTP.

This plaintext frame carries only the same billing fields projected as
outer headers for non-streaming OHTTP. The relay strips it before forwarding
bytes to the browser.
"""
if not isinstance(response_json, dict):
return b""
try:
cost = SessionCost.model_validate(response_json.get("opengradient"))
except Exception:
return b""
payload = json.dumps(
cost.model_dump(mode="json"),
separators=(",", ":"),
).encode("utf-8")
return OHTTP_BILLING_FRAME_MAGIC + len(payload).to_bytes(4, "big") + payload


def _wsgi_subrequest(
path: str,
body_bytes: bytes,
Expand Down
Loading