Skip to content

Commit b91f6d6

Browse files
committed
instrumented chat stream for observability
1 parent 795424c commit b91f6d6

1 file changed

Lines changed: 173 additions & 0 deletions

File tree

Lines changed: 173 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,173 @@
1+
"""Async-generator wrapper that instruments a ChatCompletions stream with OTel metrics.
2+
3+
Agents using LiteLLM's ``acompletion(stream=True)`` paired with the
4+
openai-agents-sdk ``ChatCmplStreamHandler`` can wrap their stream with
5+
:func:`instrumented_chat_stream` to get TTFT, TTAT, TPS, cached-token,
6+
and reasoning-token metrics automatically — no per-agent boilerplate.
7+
8+
Usage::
9+
10+
from agentex.lib.core.observability.instrumented_chat_stream import instrumented_chat_stream
11+
12+
stream = await litellm.acompletion(**kwargs, stream=True)
13+
response = Response(...) # placeholder for ChatCmplStreamHandler
14+
async for event in instrumented_chat_stream(stream, response, model_name):
15+
yield event
16+
"""
17+
18+
from __future__ import annotations
19+
20+
import time
21+
import logging
22+
from typing import Any
23+
from collections.abc import AsyncIterator
24+
25+
from agents.items import TResponseStreamEvent
26+
from openai.types.responses import (
27+
Response,
28+
ResponseCompletedEvent,
29+
ResponseTextDeltaEvent,
30+
ResponseReasoningTextDeltaEvent,
31+
ResponseFunctionCallArgumentsDeltaEvent,
32+
)
33+
from agents.models.chatcmpl_stream_handler import ChatCmplStreamHandler
34+
35+
from agentex.lib.core.observability.llm_metrics import get_llm_metrics
36+
from agentex.lib.core.observability.llm_metrics_hooks import record_llm_failure
37+
38+
logger = logging.getLogger(__name__)
39+
40+
# Event types that produce tokens (for first_token_at / last_token_at).
41+
_TOKEN_EVENTS = (
42+
ResponseTextDeltaEvent,
43+
ResponseReasoningTextDeltaEvent,
44+
ResponseFunctionCallArgumentsDeltaEvent,
45+
)
46+
47+
# Event types that produce *answer* tokens — excludes reasoning (for first_answer_at).
48+
_ANSWER_EVENTS = (
49+
ResponseTextDeltaEvent,
50+
ResponseFunctionCallArgumentsDeltaEvent,
51+
)
52+
53+
54+
async def instrumented_chat_stream(
55+
raw_stream: AsyncIterator,
56+
response: Response,
57+
model_name: str,
58+
) -> AsyncIterator[TResponseStreamEvent]:
59+
"""Wrap a LiteLLM ChatCompletions stream with OTel metrics instrumentation.
60+
61+
Yields every ``TResponseStreamEvent`` unchanged while recording:
62+
63+
* ``agentex.llm.ttft`` — time to first content token (ms)
64+
* ``agentex.llm.ttat`` — time to first answering token, excludes reasoning (ms)
65+
* ``agentex.llm.tps`` — output tokens / second over the generation window
66+
* ``agentex.llm.cached_input_tokens`` — prompt-cache hits
67+
* ``agentex.llm.reasoning_tokens`` — reasoning output tokens
68+
69+
On exception the ``agentex.llm.requests`` failure counter is bumped via
70+
:func:`record_llm_failure`.
71+
72+
Parameters
73+
----------
74+
raw_stream:
75+
The async iterator returned by ``litellm.acompletion(stream=True)``.
76+
response:
77+
A placeholder ``Response`` object required by ``ChatCmplStreamHandler``.
78+
model_name:
79+
Model identifier used as the ``model`` metric attribute.
80+
"""
81+
# --- Usage capture wrapper ---------------------------------------------------
82+
# LiteLLM's CustomStreamWrapper strips prompt_tokens_details and
83+
# completion_tokens_details from outgoing chunks. After the stream ends,
84+
# stream_chunk_builder() reconstructs the full Usage and writes it back
85+
# into the *same* _hidden_params dict (shared by reference). We capture
86+
# both the raw per-chunk usage and the _hidden_params reference so we can
87+
# read the complete Usage after iteration.
88+
raw_usage: Any = None
89+
_last_hidden_params: dict[str, Any] | None = None
90+
91+
async def _usage_capturing_stream(): # type: ignore[return]
92+
nonlocal raw_usage, _last_hidden_params
93+
async for chunk in raw_stream:
94+
if hasattr(chunk, "usage") and chunk.usage is not None:
95+
raw_usage = chunk.usage
96+
hp = getattr(chunk, "_hidden_params", None)
97+
if isinstance(hp, dict):
98+
_last_hidden_params = hp
99+
yield chunk
100+
101+
# --- Timing bookmarks --------------------------------------------------------
102+
stream_start = time.perf_counter()
103+
first_token_at: float | None = None
104+
first_answer_at: float | None = None
105+
last_token_at: float | None = None
106+
output_tokens_count = 0
107+
108+
try:
109+
async for event in ChatCmplStreamHandler.handle_stream(response, _usage_capturing_stream()):
110+
if isinstance(event, _TOKEN_EVENTS):
111+
now = time.perf_counter()
112+
if first_token_at is None:
113+
first_token_at = now
114+
if first_answer_at is None and isinstance(event, _ANSWER_EVENTS):
115+
first_answer_at = now
116+
last_token_at = now
117+
elif isinstance(event, ResponseCompletedEvent):
118+
try:
119+
if event.response and event.response.usage:
120+
output_tokens_count = event.response.usage.output_tokens or 0
121+
except Exception:
122+
pass
123+
yield event
124+
except Exception as exc:
125+
record_llm_failure(model_name, exc)
126+
raise
127+
finally:
128+
try:
129+
m = get_llm_metrics()
130+
attrs = {"model": model_name}
131+
132+
# --- Timing metrics --------------------------------------------------
133+
if first_token_at is not None:
134+
m.ttft_ms.record((first_token_at - stream_start) * 1000, attrs)
135+
if first_answer_at is not None:
136+
m.ttat_ms.record((first_answer_at - stream_start) * 1000, attrs)
137+
if (
138+
first_token_at is not None
139+
and last_token_at is not None
140+
and last_token_at > first_token_at
141+
and output_tokens_count > 0
142+
):
143+
m.tps.record(output_tokens_count / (last_token_at - first_token_at), attrs)
144+
145+
# --- Token detail counters -------------------------------------------
146+
# Prefer _hidden_params["usage"] (reconstructed by stream_chunk_builder
147+
# with all detail fields) over raw per-chunk usage.
148+
if _last_hidden_params is not None:
149+
hp_usage = _last_hidden_params.get("usage")
150+
if hp_usage is not None:
151+
raw_usage = hp_usage
152+
153+
cached_tokens = 0
154+
reasoning_tokens = 0
155+
if raw_usage is not None:
156+
# prompt_tokens_details.cached_tokens (standard OpenAI field)
157+
ptd = getattr(raw_usage, "prompt_tokens_details", None)
158+
if ptd is not None:
159+
cached_tokens = getattr(ptd, "cached_tokens", 0) or 0
160+
# Fallback: LiteLLM PrivateAttr _cache_read_input_tokens
161+
if not cached_tokens:
162+
cached_tokens = getattr(raw_usage, "_cache_read_input_tokens", 0) or 0
163+
164+
ctd = getattr(raw_usage, "completion_tokens_details", None)
165+
if ctd is not None:
166+
reasoning_tokens = getattr(ctd, "reasoning_tokens", 0) or 0
167+
168+
if cached_tokens > 0:
169+
m.cached_input_tokens.add(cached_tokens, attrs)
170+
if reasoning_tokens > 0:
171+
m.reasoning_tokens.add(reasoning_tokens, attrs)
172+
except Exception:
173+
pass

0 commit comments

Comments
 (0)