Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 11 additions & 11 deletions python/packages/kagent-adk/src/kagent/adk/models/_ollama.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,13 @@ def _convert_tools_to_ollama(tools: list[types.Tool]) -> list[ollama_sdk.Tool]:
return ollama_tools


def _convert_tool_call_to_part(tc: OllamaMessage.ToolCall) -> types.Part:
part = types.Part.from_function_call(name=tc.function.name, args=dict(tc.function.arguments))
if part.function_call:
part.function_call.id = str(uuid.uuid4())
return part


class KAgentOllamaLlm(KAgentTLSMixin, BaseLlm):
"""Ollama model via the native Ollama SDK.

Expand Down Expand Up @@ -190,6 +197,7 @@ async def generate_content_async(
try:
if stream:
aggregated_text = ""
tool_calls = []
response: AsyncIterator[ollama_sdk.ChatResponse] = await self._client.chat(
model=llm_request.model or self.model,
messages=messages,
Expand All @@ -198,6 +206,7 @@ async def generate_content_async(
stream=True,
)
async for chunk in response:
tool_calls.extend(chunk.message.tool_calls or [])
if chunk.message.content:
aggregated_text += chunk.message.content
yield LlmResponse(
Expand All @@ -211,13 +220,7 @@ async def generate_content_async(
final_parts = []
if aggregated_text:
final_parts.append(types.Part.from_text(text=aggregated_text))
for tc in chunk.message.tool_calls or []:
part = types.Part.from_function_call(
name=tc.function.name, args=dict(tc.function.arguments)
)
if part.function_call:
part.function_call.id = str(uuid.uuid4())
final_parts.append(part)
final_parts.extend(_convert_tool_call_to_part(tc) for tc in tool_calls)
finish_reason = _done_reason_to_finish_reason(chunk.done_reason) if chunk.done_reason else None
usage_metadata = None
if chunk.prompt_eval_count is not None or chunk.eval_count is not None:
Expand Down Expand Up @@ -245,10 +248,7 @@ async def generate_content_async(
if response.message.content:
parts.append(types.Part.from_text(text=response.message.content))
for tc in response.message.tool_calls or []:
part = types.Part.from_function_call(name=tc.function.name, args=dict(tc.function.arguments))
if part.function_call:
part.function_call.id = str(uuid.uuid4())
parts.append(part)
parts.append(_convert_tool_call_to_part(tc))
finish_reason = _done_reason_to_finish_reason(response.done_reason) if response.done_reason else None
usage_metadata = None
if response.prompt_eval_count is not None or response.eval_count is not None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,51 @@ async def test_generate_content_forwards_ollama_options(self):

assert mock_client.chat.call_args.kwargs["options"] == opts

@pytest.mark.asyncio
async def test_generate_content_streaming_accumulates_tool_calls_before_done_chunk(self):
llm = KAgentOllamaLlm(model="llama3.2:latest")

tool_call = mock.MagicMock()
tool_call.function.name = "get_weather"
tool_call.function.arguments = {"city": "Tokyo"}

tool_chunk = mock.MagicMock()
tool_chunk.message.content = ""
tool_chunk.message.tool_calls = [tool_call]
tool_chunk.done = False

done_chunk = mock.MagicMock()
done_chunk.message.content = ""
done_chunk.message.tool_calls = None
done_chunk.done = True
done_chunk.done_reason = "stop"
done_chunk.prompt_eval_count = 10
done_chunk.eval_count = 0

async def chunks():
yield tool_chunk
yield done_chunk

mock_client = mock.AsyncMock()
mock_client.chat = mock.AsyncMock(return_value=chunks())

request = mock.MagicMock()
request.model = "llama3.2:latest"
request.contents = []
request.config = None

with mock.patch.object(type(llm), "_client", new_callable=lambda: property(lambda self: mock_client)):
responses = [r async for r in llm.generate_content_async(request, stream=True)]

assert len(responses) == 1
final_response = responses[0]
assert final_response.partial is False
assert final_response.turn_complete is True
assert len(final_response.content.parts) == 1
function_call = final_response.content.parts[0].function_call
assert function_call.name == "get_weather"
assert dict(function_call.args) == {"city": "Tokyo"}


class TestConvertContentToOllamaMessages:
def test_image_inline_data_included(self):
Expand Down
Loading