Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 60 additions & 1 deletion test/utils/test_jinja2_chat_extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from jinja2 import TemplateSyntaxError
from jinja2.sandbox import SandboxedEnvironment

from haystack.dataclasses.chat_message import ImageContent, ReasoningContent, ToolCall, ToolCallResult
from haystack.dataclasses.chat_message import ImageContent, ReasoningContent, TextContent, ToolCall, ToolCallResult
from haystack.utils.jinja2_chat_extension import ChatMessageExtension, templatize_part


Expand Down Expand Up @@ -178,6 +178,55 @@ def test_tool_message(self, jinja_env):
}
assert output == expected

def test_tool_message_tool_call_result_list(self, jinja_env, base64_image_string):
template = """
{% message role="tool" %}
{{ tool_result | templatize_part }}
{% endmessage %}
"""
tool_call = ToolCall(tool_name="find_image", arguments={"query": "a beautiful image"}, id="find_image_1")
tool_result = ToolCallResult(
result=[
TextContent(text="Here is a beautiful image"),
ImageContent(base64_image=base64_image_string, mime_type="image/png"),
],
origin=tool_call,
error=False,
)
rendered = jinja_env.from_string(template).render(tool_result=tool_result)
output = json.loads(rendered.strip())
expected = {
"role": "tool",
"content": [
{
"tool_call_result": {
"result": [
{"text": "Here is a beautiful image"},
{
"image": {
"base64_image": base64_image_string,
"mime_type": "image/png",
"detail": None,
"meta": {},
"validation": True,
}
},
],
"error": False,
"origin": {
"tool_name": "find_image",
"arguments": {"query": "a beautiful image"},
"id": "find_image_1",
"extra": None,
},
}
}
],
"name": None,
"meta": {},
}
assert output == expected

def test_user_message_with_image(self, jinja_env, base64_image_string):
template = """
{% message role="user" %}
Expand Down Expand Up @@ -417,6 +466,16 @@ def test_invalid_json_in_content_part_raises_error(self, jinja_env):
with pytest.raises(json.JSONDecodeError):
jinja_env.from_string(template).render()

def test_user_message_with_invalid_parts_raises_error(self, jinja_env):
template = """
{% message role="user" %}
{{ tool_call | templatize_part }}
{% endmessage %}
"""
tool_call = ToolCall(tool_name="search", arguments={"query": "test"}, id="search_1")
with pytest.raises(ValueError, match="User message must contain only TextContent"):
jinja_env.from_string(template).render(tool_call=tool_call)

def test_invalid_system_message_raises_error(self, jinja_env, base64_image_string):
template = """
{% message role="system" %}
Expand Down
Loading