Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 11 additions & 5 deletions homeassistant/components/anthropic/entity.py
Original file line number Diff line number Diff line change
Expand Up @@ -600,6 +600,16 @@ async def _async_handle_chat_log(
system = chat_log.content[0]
if not isinstance(system, conversation.SystemContent):
raise TypeError("First message must be a system message")

# System prompt with caching enabled
system_prompt: list[TextBlockParam] = [
TextBlockParam(
type="text",
text=system.content,
cache_control={"type": "ephemeral"},
)
]

messages = _convert_content(chat_log.content[1:])

model = options.get(CONF_CHAT_MODEL, DEFAULT[CONF_CHAT_MODEL])
Expand All @@ -608,7 +618,7 @@ async def _async_handle_chat_log(
model=model,
messages=messages,
max_tokens=options.get(CONF_MAX_TOKENS, DEFAULT[CONF_MAX_TOKENS]),
system=system.content,
system=system_prompt,
stream=True,
)

Expand Down Expand Up @@ -695,10 +705,6 @@ async def _async_handle_chat_log(
type="auto",
)

if isinstance(model_args["system"], str):
model_args["system"] = [
TextBlockParam(type="text", text=model_args["system"])
]
model_args["system"].append( # type: ignore[union-attr]
TextBlockParam(
type="text",
Expand Down
39 changes: 20 additions & 19 deletions tests/components/anthropic/test_ai_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,22 +190,23 @@ async def test_generate_data_with_attachments(
assert user_message_with_attachments is not None
assert isinstance(user_message_with_attachments["content"], list)
assert len(user_message_with_attachments["content"]) == 3 # Text + attachments
assert user_message_with_attachments["content"] == [
{"type": "text", "text": "Test prompt"},
{
"type": "image",
"source": {
"data": "ZmFrZV9pbWFnZV9kYXRh",
"media_type": "image/jpeg",
"type": "base64",
},
},
{
"type": "document",
"source": {
"data": "ZmFrZV9pbWFnZV9kYXRh",
"media_type": "application/pdf",
"type": "base64",
},
},
]

text_block, image_block, document_block = user_message_with_attachments["content"]

# Text block
assert text_block["type"] == "text"
assert text_block["text"] == "Test prompt"

# Image attachment
assert image_block["type"] == "image"
assert image_block["source"] == {
"data": "ZmFrZV9pbWFnZV9kYXRh",
"media_type": "image/jpeg",
"type": "base64",
}

# Document attachment (ignore extra metadata like cache_control)
assert document_block["type"] == "document"
assert document_block["source"]["data"] == "ZmFrZV9pbWFnZV9kYXRh"
assert document_block["source"]["media_type"] == "application/pdf"
assert document_block["source"]["type"] == "base64"
51 changes: 43 additions & 8 deletions tests/components/anthropic/test_conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,10 +153,13 @@ async def test_template_variables(
result.response.speech["plain"]["speech"]
== "Okay, let me take care of that for you."
)
assert (
"The user name is Test User." in mock_create_stream.call_args.kwargs["system"]
)
assert "The user id is 12345." in mock_create_stream.call_args.kwargs["system"]

system = mock_create_stream.call_args.kwargs["system"]
assert isinstance(system, list)
system_text = " ".join(block["text"] for block in system if "text" in block)

assert "The user name is Test User." in system_text
assert "The user id is 12345." in system_text


async def test_conversation_agent(
Expand All @@ -169,6 +172,38 @@ async def test_conversation_agent(
assert agent.supported_languages == "*"


async def test_system_prompt_uses_text_block_with_cache_control(
hass: HomeAssistant,
mock_config_entry: MockConfigEntry,
mock_create_stream: AsyncMock,
) -> None:
"""Ensure system prompt is sent as TextBlockParam with cache_control."""
context = Context()

mock_create_stream.return_value = [
create_content_block(0, ["ok"]),
]

with patch("anthropic.resources.models.AsyncModels.list", new_callable=AsyncMock):
await hass.config_entries.async_setup(mock_config_entry.entry_id)
await hass.async_block_till_done()
await conversation.async_converse(
hass,
"hello",
None,
context,
agent_id="conversation.claude_conversation",
)

system = mock_create_stream.call_args.kwargs["system"]
assert isinstance(system, list)
assert len(system) == 1
block = system[0]
assert block["type"] == "text"
assert "Home Assistant" in block["text"]
assert block["cache_control"] == {"type": "ephemeral"}


@patch("homeassistant.components.anthropic.entity.llm.AssistAPI._async_get_tools")
@pytest.mark.parametrize(
("tool_call_json_parts", "expected_call_tool_args"),
Expand Down Expand Up @@ -229,10 +264,10 @@ async def test_function_call(
agent_id=agent_id,
)

assert (
"You are a voice assistant for Home Assistant."
in mock_create_stream.mock_calls[1][2]["system"]
)
system = mock_create_stream.mock_calls[1][2]["system"]
assert isinstance(system, list)
system_text = " ".join(block["text"] for block in system if "text" in block)
assert "You are a voice assistant for Home Assistant." in system_text

assert result.response.response_type == intent.IntentResponseType.ACTION_DONE
assert (
Expand Down