Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@ def _convert_ollama_response_to_chatmessage(ollama_response: ChatResponse) -> Ch
for ollama_tc in ollama_tool_calls:
tool_calls.append(
ToolCall(
id=ollama_tc.get("id"),
tool_name=ollama_tc["function"]["name"],
arguments=ollama_tc["function"]["arguments"],
)
Expand Down Expand Up @@ -208,6 +209,7 @@ def _build_chunk(
tool_calls_list.append(
ToolCallDelta(
index=tool_call_index,
id=tool_call.get("id"),
tool_name=tool_call["function"]["name"],
arguments=json.dumps(tool_call["function"]["arguments"])
if tool_call["function"]["arguments"]
Expand Down Expand Up @@ -370,10 +372,11 @@ def _handle_streaming_response(
component_info = ComponentInfo.from_component(self)
chunks: list[StreamingChunk] = []

# Accumulators
arg_by_id: dict[str, str] = {}
name_by_id: dict[str, str] = {}
id_order: list[str] = []
# Accumulators keyed by tool_call.index (always unique per call, even for repeated tool names)
Comment thread
sjrl marked this conversation as resolved.
arg_by_index: dict[str, str] = {}
name_by_index: dict[str, str] = {}
id_by_index: dict[str, str | None] = {}
index_order: list[str] = []
tool_call_index: int = 0

# track reasoning and content blocks to correctly set start=True on the first chunk of each block
Expand All @@ -400,29 +403,14 @@ def _handle_streaming_response(

if chunk.tool_calls:
for tool_call in chunk.tool_calls:
# the Ollama server doesn't guarantee an id field in every tool_calls entry.
# OpenAI-compatible endpoint (/v1/chat/completions) - recent releases do add an auto-generated id
# when the model produces multiple tool calls, so that clients can map results back.
# Native Ollama endpoint (/api/chat) and older builds
# - the JSON often contains only function.name + arguments;
# many users have reported that id is missing even with several calls,
# making client-side resolution harder:
# https://github.com/ollama/ollama/issues/6708
# https://github.com/ollama/ollama/issues/7510
# - If id is provided → we can distinguish multiple calls to the same tool.

# - If id is missing → fallback to function.name works only when there's one call.
# - That's why the deduplication logic is cautious and assumes one logical
# call per name when id is absent.
tool_call_id = tool_call.id or tool_call.tool_name or ""
key = str(tool_call.index)
args = tool_call.arguments or ""

# Remember first-seen order and tool name
if tool_call_id not in id_order:
id_order.append(tool_call_id)
name_by_id[tool_call_id] = tool_call.tool_name or ""
# Update the argument accumulator for this tool_call_id.
arg_by_id[tool_call_id] = args
if key not in index_order:
index_order.append(key)
name_by_index[key] = tool_call.tool_name or ""
id_by_index[key] = tool_call.id
arg_by_index[key] = args

if callback:
callback(chunk)
Expand All @@ -435,9 +423,11 @@ def _handle_streaming_response(
reasoning += c.reasoning.reasoning_text if c.reasoning else ""

tool_calls = []
for tool_call_id in id_order:
arguments: str = arg_by_id.get(tool_call_id, "")
tool_calls.append(ToolCall(tool_name=name_by_id[tool_call_id], arguments=json.loads(arguments)))
for key in index_order:
arguments: str = arg_by_index.get(key, "")
tool_calls.append(
ToolCall(id=id_by_index[key], tool_name=name_by_index[key], arguments=json.loads(arguments))
)

# We can't use _convert_streaming_chunks_to_chat_message because
# we need to map tool_call name and args by order.
Expand All @@ -463,10 +453,11 @@ async def _handle_streaming_response_async(
component_info = ComponentInfo.from_component(self)
chunks: list[StreamingChunk] = []

# Accumulators
arg_by_id: dict[str, str] = {}
name_by_id: dict[str, str] = {}
id_order: list[str] = []
# Accumulators keyed by tool_call.index (always unique per call, even for repeated tool names)
arg_by_index: dict[str, str] = {}
name_by_index: dict[str, str] = {}
id_by_index: dict[str, str | None] = {}
index_order: list[str] = []
tool_call_index: int = 0

# track reasoning and content blocks to correctly set start=True on the first chunk of each block
Expand Down Expand Up @@ -494,15 +485,14 @@ async def _handle_streaming_response_async(

if chunk.tool_calls:
for tool_call in chunk.tool_calls:
tool_call_id = tool_call.id or tool_call.tool_name or ""
key = str(tool_call.index)
args = tool_call.arguments or ""

# Remember first-seen order and tool name
if tool_call_id not in id_order:
id_order.append(tool_call_id)
name_by_id[tool_call_id] = tool_call.tool_name or ""
# Update the argument accumulator for this tool_call_id
arg_by_id[tool_call_id] = args
if key not in index_order:
index_order.append(key)
name_by_index[key] = tool_call.tool_name or ""
id_by_index[key] = tool_call.id
arg_by_index[key] = args

if callback is not None:
await callback(chunk)
Expand All @@ -517,9 +507,11 @@ async def _handle_streaming_response_async(
reasoning += c.reasoning.reasoning_text if c.reasoning else ""

tool_calls = []
for tool_call_id in id_order:
arguments: str = arg_by_id.get(tool_call_id, "")
tool_calls.append(ToolCall(tool_name=name_by_id[tool_call_id], arguments=json.loads(arguments)))
for key in index_order:
arguments: str = arg_by_index.get(key, "")
tool_calls.append(
ToolCall(id=id_by_index[key], tool_name=name_by_index[key], arguments=json.loads(arguments))
)

# We can't use _convert_streaming_chunks_to_chat_message because
# we need to map tool_call name and args by order.
Expand Down
174 changes: 174 additions & 0 deletions integrations/ollama/tests/test_chat_generator.py
Comment thread
SyedShahmeerAli12 marked this conversation as resolved.
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,33 @@ def test_convert_ollama_response_to_chatmessage_with_tools(self):
arguments={"format": "celsius", "location": "Paris, FR"},
)

def test_convert_ollama_response_to_chatmessage_with_repeated_tool(self):
ollama_response = ChatResponse(
model="some_model",
created_at="2023-12-12T14:13:43.416799Z",
message={
"role": "assistant",
"content": "",
"tool_calls": [
{"function": {"name": "weather", "arguments": {"city": "Paris"}}},
{"function": {"name": "weather", "arguments": {"city": "London"}}},
],
},
done=True,
total_duration=5191566416,
load_duration=2154458,
prompt_eval_count=26,
prompt_eval_duration=383809000,
eval_count=298,
eval_duration=4799921000,
)

observed = _convert_ollama_response_to_chatmessage(ollama_response)

assert len(observed.tool_calls) == 2
assert observed.tool_calls[0] == ToolCall(tool_name="weather", arguments={"city": "Paris"})
assert observed.tool_calls[1] == ToolCall(tool_name="weather", arguments={"city": "London"})

def test_build_chunk(self):
generator = OllamaChatGenerator()

Expand Down Expand Up @@ -386,8 +413,10 @@ def test_callback(chunk: StreamingChunk):
assert result["replies"][0].text is None
assert result["replies"][0].tool_calls[0].tool_name == "calculator"
assert result["replies"][0].tool_calls[0].arguments == {"expression": "7 * (4 + 2)"}
assert result["replies"][0].tool_calls[0].id is None
assert result["replies"][0].tool_calls[1].tool_name == "factorial"
assert result["replies"][0].tool_calls[1].arguments == {"n": 5}
assert result["replies"][0].tool_calls[1].id is None
assert result["replies"][0].meta["finish_reason"] == "stop"
assert result["replies"][0].meta["model"] == "qwen3:0.6b"

Expand Down Expand Up @@ -422,6 +451,123 @@ def test_callback(chunk: StreamingChunk):
assert streaming_chunks[1].tool_calls[0].to_dict() == expected
assert len(streaming_chunks[2].tool_calls) == 0

def test_handle_streaming_response_repeated_tool_calls(self):
ollama_chunks = [
ChatResponse(
model="qwen3:0.6b",
created_at="2025-07-31T14:48:03.471292Z",
done=False,
message=Message(
role="assistant",
content="",
tool_calls=[
Message.ToolCall(
function=Message.ToolCall.Function(name="weather", arguments={"city": "Paris"})
)
],
),
),
ChatResponse(
model="qwen3:0.6b",
created_at="2025-07-31T14:48:03.660179Z",
done=False,
message=Message(
role="assistant",
content="",
tool_calls=[
Message.ToolCall(
function=Message.ToolCall.Function(name="weather", arguments={"city": "London"})
)
],
),
),
ChatResponse(
model="qwen3:0.6b",
created_at="2025-07-31T14:48:03.678729Z",
done=True,
done_reason="stop",
total_duration=774786292,
load_duration=43608375,
prompt_eval_count=217,
prompt_eval_duration=312974541,
eval_count=46,
eval_duration=417069750,
message=Message(role="assistant", content=""),
),
]

generator = OllamaChatGenerator()
result = generator._handle_streaming_response(ollama_chunks, None)

assert len(result["replies"][0].tool_calls) == 2
assert result["replies"][0].tool_calls[0].tool_name == "weather"
assert result["replies"][0].tool_calls[0].arguments == {"city": "Paris"}
assert result["replies"][0].tool_calls[0].id is None
assert result["replies"][0].tool_calls[1].tool_name == "weather"
assert result["replies"][0].tool_calls[1].arguments == {"city": "London"}
assert result["replies"][0].tool_calls[1].id is None

@pytest.mark.asyncio
async def test_handle_streaming_response_async_repeated_tool_calls(self):
ollama_chunks = [
ChatResponse(
model="qwen3:0.6b",
created_at="2025-07-31T14:48:03.471292Z",
done=False,
message=Message(
role="assistant",
content="",
tool_calls=[
Message.ToolCall(
function=Message.ToolCall.Function(name="weather", arguments={"city": "Paris"})
)
],
),
),
ChatResponse(
model="qwen3:0.6b",
created_at="2025-07-31T14:48:03.660179Z",
done=False,
message=Message(
role="assistant",
content="",
tool_calls=[
Message.ToolCall(
function=Message.ToolCall.Function(name="weather", arguments={"city": "London"})
)
],
),
),
ChatResponse(
model="qwen3:0.6b",
created_at="2025-07-31T14:48:03.678729Z",
done=True,
done_reason="stop",
total_duration=774786292,
load_duration=43608375,
prompt_eval_count=217,
prompt_eval_duration=312974541,
eval_count=46,
eval_duration=417069750,
message=Message(role="assistant", content=""),
),
]

async def async_chunks():
for chunk in ollama_chunks:
yield chunk

generator = OllamaChatGenerator()
result = await generator._handle_streaming_response_async(async_chunks(), None)

assert len(result["replies"][0].tool_calls) == 2
assert result["replies"][0].tool_calls[0].tool_name == "weather"
assert result["replies"][0].tool_calls[0].arguments == {"city": "Paris"}
assert result["replies"][0].tool_calls[0].id is None
assert result["replies"][0].tool_calls[1].tool_name == "weather"
assert result["replies"][0].tool_calls[1].arguments == {"city": "London"}
assert result["replies"][0].tool_calls[1].id is None

def test_handle_streaming_response_tool_calls_with_thinking(self):
ollama_chunks = [
ChatResponse(
Expand Down Expand Up @@ -536,6 +682,7 @@ def test_callback(chunk: StreamingChunk):
assert result["replies"][0].text is None
assert result["replies"][0].tool_calls[0].tool_name == "add_two_numbers"
assert result["replies"][0].tool_calls[0].arguments == {"a": 2, "b": 2}
assert result["replies"][0].tool_calls[0].id is None
assert result["replies"][0].reasoning.reasoning_text == "Okay, the user is asking 2 plus 2."
assert result["replies"][0].meta["finish_reason"] == "stop"
assert result["replies"][0].meta["model"] == "qwen3:0.6b"
Expand Down Expand Up @@ -1306,6 +1453,33 @@ def multiply(a: int, b: int) -> int:
assert new_response.tool_calls[0].tool_name == "multiply"
assert new_response.tool_calls[0].arguments == {"a": 5, "b": 10}

@pytest.mark.parametrize("streaming_callback", [None, print_streaming_chunk])
def test_live_run_with_repeated_tool_calls(self, tools, streaming_callback):
component = OllamaChatGenerator(model="qwen3:0.6b", tools=tools, streaming_callback=streaming_callback)
tool_invoker = ToolInvoker(tools=tools)

messages = [ChatMessage.from_user("What is the weather in Paris and London?")]
response = component.run(messages)

assert len(response["replies"]) == 1
assistant_msg = response["replies"][0]

assert assistant_msg.tool_calls
assert len(assistant_msg.tool_calls) == 2
for tc in assistant_msg.tool_calls:
assert isinstance(tc, ToolCall)
assert tc.tool_name == "weather"
assert "city" in tc.arguments

cities = {tc.arguments["city"].lower() for tc in assistant_msg.tool_calls}
assert any("paris" in c for c in cities)
assert any("london" in c for c in cities)

tool_messages = tool_invoker.run(messages=[assistant_msg])["tool_messages"]
final_response = component.run([*messages, assistant_msg, *tool_messages])
assert len(final_response["replies"]) == 1
assert final_response["replies"][0].text

def test_live_run_with_tools_and_format(self, tools):
response_format = {
"type": "object",
Expand Down
Loading