Skip to content

Commit bdbce57

Browse files
Use OpenAI schema dataclasses for cloud stream responses (home-assistant#161663)
1 parent 8536472 commit bdbce57

7 files changed

Lines changed: 75 additions & 73 deletions

File tree

homeassistant/components/cloud/entity.py

Lines changed: 66 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,25 @@
1212
from hass_nabucasa.llm import (
1313
LLMAuthenticationError,
1414
LLMRateLimitError,
15+
LLMResponseCompletedEvent,
1516
LLMResponseError,
17+
LLMResponseErrorEvent,
18+
LLMResponseFailedEvent,
19+
LLMResponseFunctionCallArgumentsDeltaEvent,
20+
LLMResponseFunctionCallArgumentsDoneEvent,
21+
LLMResponseFunctionCallOutputItem,
22+
LLMResponseImageOutputItem,
23+
LLMResponseIncompleteEvent,
24+
LLMResponseMessageOutputItem,
25+
LLMResponseOutputItemAddedEvent,
26+
LLMResponseOutputItemDoneEvent,
27+
LLMResponseOutputTextDeltaEvent,
28+
LLMResponseReasoningOutputItem,
29+
LLMResponseReasoningSummaryTextDeltaEvent,
30+
LLMResponseWebSearchCallOutputItem,
31+
LLMResponseWebSearchCallSearchingEvent,
1632
LLMServiceError,
1733
)
18-
from litellm import (
19-
ResponseFunctionToolCall,
20-
ResponseInputParam,
21-
ResponsesAPIStreamEvents,
22-
)
2334
from openai.types.responses import (
2435
FunctionToolParam,
2536
ResponseInputItemParam,
@@ -60,9 +71,9 @@ class ResponseItemType(str, Enum):
6071

6172
def _convert_content_to_param(
6273
chat_content: Iterable[conversation.Content],
63-
) -> ResponseInputParam:
74+
) -> list[ResponseInputItemParam]:
6475
"""Convert any native chat message for this agent to the native format."""
65-
messages: ResponseInputParam = []
76+
messages: list[ResponseInputItemParam] = []
6677
reasoning_summary: list[str] = []
6778
web_search_calls: dict[str, dict[str, Any]] = {}
6879

@@ -238,7 +249,7 @@ async def _transform_stream( # noqa: C901 - This is complex, but better to have
238249
"""Transform stream result into HA format."""
239250
last_summary_index = None
240251
last_role: Literal["assistant", "tool_result"] | None = None
241-
current_tool_call: ResponseFunctionToolCall | None = None
252+
current_tool_call: LLMResponseFunctionCallOutputItem | None = None
242253

243254
# Non-reasoning models don't follow our request to remove citations, so we remove
244255
# them manually here. They always follow the same pattern: the citation is always
@@ -248,31 +259,22 @@ async def _transform_stream( # noqa: C901 - This is complex, but better to have
248259
citation_regexp = re.compile(r"\(\[([^\]]+)\]\((https?:\/\/[^\)]+)\)")
249260

250261
async for event in stream:
251-
event_type = getattr(event, "type", None)
252-
event_item = getattr(event, "item", None)
253-
event_item_type = getattr(event_item, "type", None) if event_item else None
254-
255-
_LOGGER.debug(
256-
"Event[%s] | item: %s",
257-
event_type,
258-
event_item_type,
259-
)
262+
_LOGGER.debug("Event[%s]", getattr(event, "type", None))
260263

261-
if event_type == ResponsesAPIStreamEvents.OUTPUT_ITEM_ADDED:
262-
# Detect function_call even when it's a BaseLiteLLMOpenAIResponseObject
263-
if event_item_type == ResponseItemType.FUNCTION_CALL:
264+
if isinstance(event, LLMResponseOutputItemAddedEvent):
265+
if isinstance(event.item, LLMResponseFunctionCallOutputItem):
264266
# OpenAI has tool calls as individual events
265267
# while HA puts tool calls inside the assistant message.
266268
# We turn them into individual assistant content for HA
267269
# to ensure that tools are called as soon as possible.
268270
yield {"role": "assistant"}
269271
last_role = "assistant"
270272
last_summary_index = None
271-
current_tool_call = cast(ResponseFunctionToolCall, event.item)
273+
current_tool_call = event.item
272274
elif (
273-
event_item_type == ResponseItemType.MESSAGE
275+
isinstance(event.item, LLMResponseMessageOutputItem)
274276
or (
275-
event_item_type == ResponseItemType.REASONING
277+
isinstance(event.item, LLMResponseReasoningOutputItem)
276278
and last_summary_index is not None
277279
) # Subsequent ResponseReasoningItem
278280
or last_role != "assistant"
@@ -281,29 +283,23 @@ async def _transform_stream( # noqa: C901 - This is complex, but better to have
281283
last_role = "assistant"
282284
last_summary_index = None
283285

284-
elif event_type == ResponsesAPIStreamEvents.OUTPUT_ITEM_DONE:
285-
if event_item_type == ResponseItemType.REASONING:
286-
encrypted_content = getattr(event.item, "encrypted_content", None)
287-
summary = getattr(event.item, "summary", []) or []
286+
elif isinstance(event, LLMResponseOutputItemDoneEvent):
287+
if isinstance(event.item, LLMResponseReasoningOutputItem):
288+
encrypted_content = event.item.encrypted_content
289+
summary = event.item.summary
288290

289291
yield {
290-
"native": ResponseReasoningItem(
291-
type="reasoning",
292+
"native": LLMResponseReasoningOutputItem(
293+
type=event.item.type,
292294
id=event.item.id,
293295
summary=[],
294296
encrypted_content=encrypted_content,
295297
)
296298
}
297299

298300
last_summary_index = len(summary) - 1 if summary else None
299-
elif event_item_type == ResponseItemType.WEB_SEARCH_CALL:
300-
action = getattr(event.item, "action", None)
301-
if isinstance(action, dict):
302-
action_dict = action
303-
elif action is not None:
304-
action_dict = action.to_dict()
305-
else:
306-
action_dict = {}
301+
elif isinstance(event.item, LLMResponseWebSearchCallOutputItem):
302+
action_dict = event.item.action
307303
yield {
308304
"tool_calls": [
309305
llm.ToolInput(
@@ -321,11 +317,11 @@ async def _transform_stream( # noqa: C901 - This is complex, but better to have
321317
"tool_result": {"status": event.item.status},
322318
}
323319
last_role = "tool_result"
324-
elif event_item_type == ResponseItemType.IMAGE:
325-
yield {"native": event.item}
320+
elif isinstance(event.item, LLMResponseImageOutputItem):
321+
yield {"native": event.item.raw}
326322
last_summary_index = -1 # Trigger new assistant message on next turn
327323

328-
elif event_type == ResponsesAPIStreamEvents.OUTPUT_TEXT_DELTA:
324+
elif isinstance(event, LLMResponseOutputTextDeltaEvent):
329325
data = event.delta
330326
if remove_parentheses:
331327
data = data.removeprefix(")")
@@ -344,7 +340,7 @@ async def _transform_stream( # noqa: C901 - This is complex, but better to have
344340
if data:
345341
yield {"content": data}
346342

347-
elif event_type == ResponsesAPIStreamEvents.REASONING_SUMMARY_TEXT_DELTA:
343+
elif isinstance(event, LLMResponseReasoningSummaryTextDeltaEvent):
348344
# OpenAI can output several reasoning summaries
349345
# in a single ResponseReasoningItem. We split them as separate
350346
# AssistantContent messages. Only last of them will have
@@ -358,14 +354,14 @@ async def _transform_stream( # noqa: C901 - This is complex, but better to have
358354
last_summary_index = event.summary_index
359355
yield {"thinking_content": event.delta}
360356

361-
elif event_type == ResponsesAPIStreamEvents.FUNCTION_CALL_ARGUMENTS_DELTA:
357+
elif isinstance(event, LLMResponseFunctionCallArgumentsDeltaEvent):
362358
if current_tool_call is not None:
363359
current_tool_call.arguments += event.delta
364360

365-
elif event_type == ResponsesAPIStreamEvents.WEB_SEARCH_CALL_SEARCHING:
361+
elif isinstance(event, LLMResponseWebSearchCallSearchingEvent):
366362
yield {"role": "assistant"}
367363

368-
elif event_type == ResponsesAPIStreamEvents.FUNCTION_CALL_ARGUMENTS_DONE:
364+
elif isinstance(event, LLMResponseFunctionCallArgumentsDoneEvent):
369365
if current_tool_call is not None:
370366
current_tool_call.status = "completed"
371367

@@ -385,35 +381,36 @@ async def _transform_stream( # noqa: C901 - This is complex, but better to have
385381
]
386382
}
387383

388-
elif event_type == ResponsesAPIStreamEvents.RESPONSE_COMPLETED:
389-
if event.response.usage is not None:
384+
elif isinstance(event, LLMResponseCompletedEvent):
385+
response = event.response
386+
if response and "usage" in response:
387+
usage = response["usage"]
390388
chat_log.async_trace(
391389
{
392390
"stats": {
393-
"input_tokens": event.response.usage.input_tokens,
394-
"output_tokens": event.response.usage.output_tokens,
391+
"input_tokens": usage.get("input_tokens"),
392+
"output_tokens": usage.get("output_tokens"),
395393
}
396394
}
397395
)
398396

399-
elif event_type == ResponsesAPIStreamEvents.RESPONSE_INCOMPLETE:
400-
if event.response.usage is not None:
397+
elif isinstance(event, LLMResponseIncompleteEvent):
398+
response = event.response
399+
if response and "usage" in response:
400+
usage = response["usage"]
401401
chat_log.async_trace(
402402
{
403403
"stats": {
404-
"input_tokens": event.response.usage.input_tokens,
405-
"output_tokens": event.response.usage.output_tokens,
404+
"input_tokens": usage.get("input_tokens"),
405+
"output_tokens": usage.get("output_tokens"),
406406
}
407407
}
408408
)
409409

410-
if (
411-
event.response.incomplete_details
412-
and event.response.incomplete_details.reason
413-
):
414-
reason: str = event.response.incomplete_details.reason
415-
else:
416-
reason = "unknown reason"
410+
incomplete_details = response.get("incomplete_details")
411+
reason = "unknown reason"
412+
if incomplete_details is not None and incomplete_details.get("reason"):
413+
reason = incomplete_details["reason"]
417414

418415
if reason == "max_output_tokens":
419416
reason = "max output tokens reached"
@@ -422,22 +419,24 @@ async def _transform_stream( # noqa: C901 - This is complex, but better to have
422419

423420
raise HomeAssistantError(f"OpenAI response incomplete: {reason}")
424421

425-
elif event_type == ResponsesAPIStreamEvents.RESPONSE_FAILED:
426-
if event.response.usage is not None:
422+
elif isinstance(event, LLMResponseFailedEvent):
423+
response = event.response
424+
if response and "usage" in response:
425+
usage = response["usage"]
427426
chat_log.async_trace(
428427
{
429428
"stats": {
430-
"input_tokens": event.response.usage.input_tokens,
431-
"output_tokens": event.response.usage.output_tokens,
429+
"input_tokens": usage.get("input_tokens"),
430+
"output_tokens": usage.get("output_tokens"),
432431
}
433432
}
434433
)
435434
reason = "unknown reason"
436-
if event.response.error is not None:
437-
reason = event.response.error.message
435+
if isinstance(error := response.get("error"), dict):
436+
reason = error.get("message") or reason
438437
raise HomeAssistantError(f"OpenAI response failed: {reason}")
439438

440-
elif event_type == ResponsesAPIStreamEvents.ERROR:
439+
elif isinstance(event, LLMResponseErrorEvent):
441440
raise HomeAssistantError(f"OpenAI response error: {event.message}")
442441

443442

@@ -452,7 +451,7 @@ def __init__(self, cloud: Cloud[CloudClient], config_entry: ConfigEntry) -> None
452451
async def _prepare_chat_for_generation(
453452
self,
454453
chat_log: conversation.ChatLog,
455-
messages: ResponseInputParam,
454+
messages: list[ResponseInputItemParam],
456455
response_format: dict[str, Any] | None = None,
457456
) -> dict[str, Any]:
458457
"""Prepare kwargs for Cloud LLM from the chat log."""

homeassistant/components/cloud/manifest.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,6 @@
1313
"integration_type": "system",
1414
"iot_class": "cloud_push",
1515
"loggers": ["acme", "hass_nabucasa", "snitun"],
16-
"requirements": ["hass-nabucasa==1.11.0"],
16+
"requirements": ["hass-nabucasa==1.12.0", "openai==2.15.0"],
1717
"single_config_entry": true
1818
}

homeassistant/package_constraints.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ fnv-hash-fast==1.6.0
3636
go2rtc-client==0.4.0
3737
ha-ffmpeg==3.2.2
3838
habluetooth==5.8.0
39-
hass-nabucasa==1.11.0
39+
hass-nabucasa==1.12.0
4040
hassil==3.5.0
4141
home-assistant-bluetooth==1.13.1
4242
home-assistant-frontend==20260128.1
@@ -46,6 +46,7 @@ ifaddr==0.2.0
4646
Jinja2==3.1.6
4747
lru-dict==1.3.0
4848
mutagen==1.47.0
49+
openai==2.15.0
4950
orjson==3.11.5
5051
packaging>=23.1
5152
paho-mqtt==2.1.0

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ dependencies = [
4848
"fnv-hash-fast==1.6.0",
4949
# hass-nabucasa is imported by helpers which don't depend on the cloud
5050
# integration
51-
"hass-nabucasa==1.11.0",
51+
"hass-nabucasa==1.12.0",
5252
# When bumping httpx, please check the version pins of
5353
# httpcore, anyio, and h11 in gen_requirements_all
5454
"httpx==0.28.1",

requirements.txt

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

requirements_all.txt

Lines changed: 2 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

requirements_test_all.txt

Lines changed: 2 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)