Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 32 additions & 9 deletions src/google/adk/flows/llm_flows/base_llm_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,6 +423,15 @@ async def _process_agent_tools(
instances, and calls ``process_llm_request`` on each to register
tool declarations in the request.

Tool-union resolution is dispatched concurrently via ``asyncio.gather``
to overlap I/O-bound listings (e.g. MCP ``list_tools`` over the
network). The subsequent ``process_llm_request`` calls are kept
serial in the original ``agent.tools`` order: some tools read/write
``llm_request`` state (e.g. ``GoogleSearchTool`` writes
``llm_request.model``; ``ComputerUseToolset`` performs an idempotency
check on ``llm_request.config.tools``) and rely on observing the
post-state of earlier tools.

After this function returns, ``llm_request.tools_dict`` maps tool
names to ``BaseTool`` instances ready for function call dispatch.

Expand All @@ -437,7 +446,29 @@ async def _process_agent_tools(

multiple_tools = len(agent.tools) > 1
model = agent.canonical_model
for tool_union in agent.tools:

from ...agents.llm_agent import _convert_tool_union_to_tools

# Resolve tool_unions in parallel. ``asyncio.gather`` preserves
# input order in the returned list, so the serial commit phase below
# still observes ``agent.tools`` order. If any resolution raises,
# gather cancels the siblings and propagates -- same observable
# behavior as the previous serial loop, which would propagate the
# first exception and abandon the rest.
resolved_tools_per_union = await asyncio.gather(*(
_convert_tool_union_to_tools(
tool_union,
ReadonlyContext(invocation_context),
model,
multiple_tools,
)
for tool_union in agent.tools
))

# Serial commit phase, in original ``agent.tools`` order. Mutations
# to ``llm_request`` and reads of its state (model, config.tools,
# tools_dict) preserve today's ordering semantics exactly.
for tool_union, tools in zip(agent.tools, resolved_tools_per_union):
tool_context = ToolContext(invocation_context)

# If it's a toolset, process it first
Expand All @@ -446,15 +477,7 @@ async def _process_agent_tools(
tool_context=tool_context, llm_request=llm_request
)

from ...agents.llm_agent import _convert_tool_union_to_tools

# Then process all tools from this tool union
tools = await _convert_tool_union_to_tools(
tool_union,
ReadonlyContext(invocation_context),
model,
multiple_tools,
)
for tool in tools:
await tool.process_llm_request(
tool_context=tool_context, llm_request=llm_request
Expand Down
125 changes: 125 additions & 0 deletions tests/unittests/flows/llm_flows/test_base_llm_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

"""Unit tests for BaseLlmFlow toolset integration."""

import asyncio
from unittest import mock
from unittest.mock import AsyncMock

Expand Down Expand Up @@ -243,6 +244,130 @@ def _my_tool(sides: int) -> int:
)


@pytest.mark.asyncio
async def test_process_agent_tools_resolves_unions_in_parallel():
"""``_convert_tool_union_to_tools`` is dispatched for every tool_union concurrently.

Each mocked resolution blocks until ``all_started`` is set; the event
is only set once every call has been entered. If
``_process_agent_tools`` were still serial, the first call would
block forever waiting for the event the second call hasn't yet
entered to set.
"""
num_tools = 5
started_count = 0
all_started = asyncio.Event()
release = asyncio.Event()

async def blocking_convert(tool_union, *args, **kwargs):
del args, kwargs
nonlocal started_count
started_count += 1
if started_count == num_tools:
all_started.set()
await release.wait()
return [_AsyncProcessLlmRequestTool(name=tool_union.__name__)]

def _make_func(i):
def _f():
"""Test function."""
return i

_f.__name__ = f'fn_{i}'
return _f

funcs = [_make_func(i) for i in range(num_tools)]

with mock.patch(
'google.adk.agents.llm_agent._convert_tool_union_to_tools',
side_effect=blocking_convert,
):
agent = Agent(name='test_agent', tools=funcs)
invocation_context = await testing_utils.create_invocation_context(
agent=agent, user_content='test message'
)
flow = BaseLlmFlowForTesting()
llm_request = LlmRequest()

async def drive():
async for _ in flow._preprocess_async(invocation_context, llm_request):
pass

drive_task = asyncio.create_task(drive())
try:
# If resolution were serial this would hang; release the gate as
# soon as every coroutine has entered.
await asyncio.wait_for(all_started.wait(), timeout=5.0)
finally:
release.set()
await asyncio.wait_for(drive_task, timeout=5.0)

assert started_count == num_tools


@pytest.mark.asyncio
async def test_process_agent_tools_preserves_order_when_later_unions_resolve_first():
"""``process_llm_request`` is called in original ``agent.tools`` order even when later unions resolve first."""

resolution_started_evt = [asyncio.Event(), asyncio.Event()]
process_call_order: list[str] = []

async def staggered_convert(tool_union, *args, **kwargs):
del args, kwargs
if tool_union.__name__ == 'fn_slow':
# Resolve only after fn_fast's resolution has completed.
await resolution_started_evt[1].wait()
tool_name = 'slow_tool'
else:
tool_name = 'fast_tool'
resolution_started_evt[1].set()
return [
_AsyncProcessLlmRequestTool(
name=tool_name, on_process=process_call_order.append
)
]

def fn_slow():
"""Slow-resolving function."""
return 0

def fn_fast():
"""Fast-resolving function."""
return 0

with mock.patch(
'google.adk.agents.llm_agent._convert_tool_union_to_tools',
side_effect=staggered_convert,
):
# agent.tools order is [slow, fast]; resolution completes [fast, slow].
agent = Agent(name='test_agent', tools=[fn_slow, fn_fast])
invocation_context = await testing_utils.create_invocation_context(
agent=agent, user_content='test message'
)
flow = BaseLlmFlowForTesting()
llm_request = LlmRequest()

async for _ in flow._preprocess_async(invocation_context, llm_request):
pass

# Even though fast_tool was resolved first, process_llm_request must
# be invoked in agent.tools order (slow_tool first).
assert process_call_order == ['slow_tool', 'fast_tool']


class _AsyncProcessLlmRequestTool:
"""Minimal stand-in for a BaseTool that records process_llm_request calls."""

def __init__(self, name: str, on_process=None):
self.name = name
self._on_process = on_process

async def process_llm_request(self, *, tool_context, llm_request):
del tool_context, llm_request
if self._on_process is not None:
self._on_process(self.name)


# TODO(b/448114567): Remove the following
# test_handle_after_model_callback_grounding tests once the workaround
# is no longer needed.
Expand Down
Loading