Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 66 additions & 28 deletions src/google/adk/agents/llm_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,38 +138,39 @@ async def _convert_tool_union_to_tools(
model: Union[str, BaseLlm],
multiple_tools: bool = False,
) -> list[BaseTool]:
from ..tools.enterprise_search_tool import EnterpriseWebSearchTool
from ..tools.google_search_tool import GoogleSearchTool
from ..tools.vertex_ai_search_tool import VertexAiSearchTool

# Wrap google_search tool with AgentTool if there are multiple tools because
# the built-in tools cannot be used together with other tools.
# Handle built-in tool workarounds when multiple tools are present.
# Built-in tools cannot be used together with other tools, so we wrap or
# replace them with compatible alternatives.
# TODO(b/448114567): Remove once the workaround is no longer needed.
if multiple_tools and isinstance(tool_union, GoogleSearchTool):
from ..tools.google_search_agent_tool import create_google_search_agent
from ..tools.google_search_agent_tool import GoogleSearchAgentTool

search_tool = cast(GoogleSearchTool, tool_union)
if search_tool.bypass_multi_tools_limit:
return [GoogleSearchAgentTool(create_google_search_agent(model))]

# Replace VertexAiSearchTool with DiscoveryEngineSearchTool if there are
# multiple tools because the built-in tools cannot be used together with
# other tools.
# TODO(b/448114567): Remove once the workaround is no longer needed.
if multiple_tools and isinstance(tool_union, VertexAiSearchTool):
from ..tools.discovery_engine_search_tool import DiscoveryEngineSearchTool

vais_tool = cast(VertexAiSearchTool, tool_union)
if vais_tool.bypass_multi_tools_limit:
return [
DiscoveryEngineSearchTool(
data_store_id=vais_tool.data_store_id,
data_store_specs=vais_tool.data_store_specs,
search_engine_id=vais_tool.search_engine_id,
filter=vais_tool.filter,
max_results=vais_tool.max_results,
)
]
if multiple_tools:
tool_workarounds = [
# GoogleSearchTool: wrap with AgentTool
{
'tool_class': GoogleSearchTool,
'handler': lambda: _handle_google_search_tool(tool_union, model),
},
# VertexAiSearchTool: replace with DiscoveryEngineSearchTool
{
'tool_class': VertexAiSearchTool,
'handler': lambda: _handle_vertex_ai_search_tool(tool_union),
},
# EnterpriseWebSearchTool: wrap with AgentTool
{
'tool_class': EnterpriseWebSearchTool,
'handler': lambda: _handle_enterprise_search_tool(
tool_union, model
),
},
]

for workaround in tool_workarounds:
if isinstance(tool_union, workaround['tool_class']):
if tool_union.bypass_multi_tools_limit:
return workaround['handler']()

if isinstance(tool_union, BaseTool):
return [tool_union]
Expand All @@ -180,6 +181,43 @@ async def _convert_tool_union_to_tools(
return await tool_union.get_tools_with_prefix(ctx)


def _handle_google_search_tool(
tool_union: ToolUnion, model: Union[str, BaseLlm]
) -> list[BaseTool]:
"""Handle GoogleSearchTool workaround by wrapping with AgentTool."""
from ..tools.google_search_agent_tool import create_google_search_agent
from ..tools.google_search_agent_tool import GoogleSearchAgentTool

return [GoogleSearchAgentTool(create_google_search_agent(model))]


def _handle_vertex_ai_search_tool(tool_union: ToolUnion) -> list[BaseTool]:
"""Handle VertexAiSearchTool workaround by replacing with DiscoveryEngineSearchTool."""
from ..tools.discovery_engine_search_tool import DiscoveryEngineSearchTool
from ..tools.vertex_ai_search_tool import VertexAiSearchTool

vais_tool = cast(VertexAiSearchTool, tool_union)
return [
DiscoveryEngineSearchTool(
data_store_id=vais_tool.data_store_id,
data_store_specs=vais_tool.data_store_specs,
search_engine_id=vais_tool.search_engine_id,
filter=vais_tool.filter,
max_results=vais_tool.max_results,
)
]


def _handle_enterprise_search_tool(
tool_union: ToolUnion, model: Union[str, BaseLlm]
) -> list[BaseTool]:
"""Handle EnterpriseWebSearchTool workaround by wrapping with AgentTool."""
from ..tools.enterprise_search_agent_tool import create_enterprise_search_agent
from ..tools.enterprise_search_agent_tool import EnterpriseSearchAgentTool

return [EnterpriseSearchAgentTool(create_enterprise_search_agent(model))]


class LlmAgent(BaseAgent):
"""LLM-based Agent."""

Expand Down
5 changes: 4 additions & 1 deletion src/google/adk/flows/llm_flows/base_llm_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -846,7 +846,10 @@ async def _maybe_add_grounding_metadata(
tools = await agent.canonical_tools(readonly_context)
invocation_context.canonical_tools_cache = tools

if not any(tool.name == 'google_search_agent' for tool in tools):
if not any(
tool.name in {'google_search_agent', 'enterprise_search_agent'}
for tool in tools
):
return response
ground_metadata = invocation_context.session.state.get(
'temp:_adk_grounding_metadata', None
Expand Down
112 changes: 112 additions & 0 deletions src/google/adk/tools/_search_agent_tool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from typing import Any

from google.genai import types
from typing_extensions import override

from ..agents.llm_agent import LlmAgent
from ..memory.in_memory_memory_service import InMemoryMemoryService
from ..runners import Runner
from ..sessions.in_memory_session_service import InMemorySessionService
from ..utils.context_utils import Aclosing
from ._forwarding_artifact_service import ForwardingArtifactService
from .agent_tool import AgentTool
from .tool_context import ToolContext


class _SearchAgentTool(AgentTool):
"""A base class for search agent tools."""

@override
async def run_async(
self,
*,
args: dict[str, Any],
tool_context: ToolContext,
) -> Any:

if isinstance(self.agent, LlmAgent) and self.agent.input_schema:
input_value = self.agent.input_schema.model_validate(args)
content = types.Content(
role='user',
parts=[
types.Part.from_text(
text=input_value.model_dump_json(exclude_none=True)
)
],
)
else:
content = types.Content(
role='user',
parts=[types.Part.from_text(text=args['request'])],
)
runner = Runner(
app_name=self.agent.name,
agent=self.agent,
artifact_service=ForwardingArtifactService(tool_context),
session_service=InMemorySessionService(),
memory_service=InMemoryMemoryService(),
credential_service=tool_context._invocation_context.credential_service,
plugins=list(tool_context._invocation_context.plugin_manager.plugins),
)
try:
state_dict = {
k: v
for k, v in tool_context.state.to_dict().items()
if not k.startswith('_adk') and not k.startswith('temp:')
}
session = await runner.session_service.create_session(
app_name=self.agent.name,
user_id=tool_context._invocation_context.user_id,
state=state_dict,
)

last_content = None
last_grounding_metadata = None
async with Aclosing(
runner.run_async(
user_id=session.user_id,
session_id=session.id,
new_message=content,
)
) as agen:
async for event in agen:
# Forward state delta to parent session.
if event.actions.state_delta:
tool_context.state.update(event.actions.state_delta)
if event.content:
last_content = event.content
last_grounding_metadata = event.grounding_metadata

if not last_content:
return ''
merged_text = '\n'.join(p.text for p in last_content.parts if p.text)
if isinstance(self.agent, LlmAgent) and self.agent.output_schema:
tool_result = self.agent.output_schema.model_validate_json(
merged_text
).model_dump(exclude_none=True)
else:
tool_result = merged_text

if last_grounding_metadata:
tool_context.state['temp:_adk_grounding_metadata'] = (
last_grounding_metadata
)
return tool_result
finally:
await runner.close()
55 changes: 55 additions & 0 deletions src/google/adk/tools/enterprise_search_agent_tool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from typing import Union

from ..agents.llm_agent import LlmAgent
from ..models.base_llm import BaseLlm
from ._search_agent_tool import _SearchAgentTool
from .enterprise_search_tool import enterprise_web_search_tool


def create_enterprise_search_agent(model: Union[str, BaseLlm]) -> LlmAgent:
"""Create a sub-agent that only uses enterprise_web_search tool."""
return LlmAgent(
name='enterprise_search_agent',
model=model,
description=(
'An agent for performing Enterprise search using the'
' `enterprise_web_search` tool'
),
instruction="""
You are a specialized Enterprise search agent.

When given a search query, use the `enterprise_web_search` tool to find the related information.
""",
tools=[enterprise_web_search_tool],
)


class EnterpriseSearchAgentTool(_SearchAgentTool):
"""A tool that wraps a sub-agent that only uses enterprise_web_search tool.

This is a workaround to support using enterprise_web_search tool with other tools.
TODO(b/448114567): Remove once the workaround is no longer needed.

Attributes:
agent: The sub-agent that this tool wraps.
"""

def __init__(self, agent: LlmAgent):
self.agent = agent
super().__init__(agent=self.agent)
10 changes: 8 additions & 2 deletions src/google/adk/tools/enterprise_search_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,18 @@ class EnterpriseWebSearchTool(BaseTool):
https://cloud.google.com/vertex-ai/generative-ai/docs/grounding/web-grounding-enterprise.
"""

def __init__(self):
"""Initializes the Vertex AI Search tool."""
def __init__(self, *, bypass_multi_tools_limit: bool = False):
"""Initializes the Enterprise web search tool.

Args:
bypass_multi_tools_limit: Whether to bypass the multi tools limitation,
so that the tool can be used with other tools in the same agent.
"""
# Name and description are not used because this is a model built-in tool.
super().__init__(
name='enterprise_web_search', description='enterprise_web_search'
)
self.bypass_multi_tools_limit = bypass_multi_tools_limit

@override
async def process_llm_request(
Expand Down
Loading