Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions src/google/adk/tools/mcp_tool/mcp_tool.py
Copy link
Collaborator

@seanzhou1023 seanzhou1023 Sep 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel the changes in this file will cause the prefix to be added twice, given we already handled prefix prepending in base toolset, did you test it ?

I think the only changes needed in this file is

response = await session.call_tool(self._mcp_tool.name, arguments=args)

could you try and verify ?

Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def __init__(
mcp_session_manager: MCPSessionManager,
auth_scheme: Optional[AuthScheme] = None,
auth_credential: Optional[AuthCredential] = None,
tool_name_prefix: str = "",
):
"""Initializes an MCPTool.

Expand All @@ -78,12 +79,20 @@ def __init__(
mcp_session_manager: The MCP session manager to use for communication.
auth_scheme: The authentication scheme to use.
auth_credential: The authentication credential to use.
tool_name_prefix: string to add to the start of the tool name. For example,
`prefix="ns_"` would name `my_tool` as `ns_my_tool`.

Raises:
ValueError: If mcp_tool or mcp_session_manager is None.
"""
if mcp_tool is None:
raise ValueError("mcp_tool cannot be None")
if mcp_session_manager is None:
raise ValueError("mcp_session_manager cannot be None")
raw_name = mcp_tool.name
name = tool_name_prefix + raw_name
super().__init__(
name=mcp_tool.name,
name=name,
description=mcp_tool.description if mcp_tool.description else "",
auth_config=AuthConfig(
auth_scheme=auth_scheme, raw_auth_credential=auth_credential
Expand All @@ -93,6 +102,8 @@ def __init__(
)
self._mcp_tool = mcp_tool
self._mcp_session_manager = mcp_session_manager
self._tool_name_prefix = tool_name_prefix
self._raw_name = raw_name

@override
def _get_declaration(self) -> FunctionDeclaration:
Expand Down Expand Up @@ -128,7 +139,7 @@ async def _run_async_impl(
# Get the session from the session manager
session = await self._mcp_session_manager.create_session(headers=headers)

response = await session.call_tool(self.name, arguments=args)
response = await session.call_tool(self._mcp_tool.name, arguments=args)
Copy link
Collaborator

@seanzhou1023 seanzhou1023 Sep 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

besides this line, I feel all the other changes in this file are not necessary and will cause problem as I mentioned in the comment of the whole file, would you please try and test it ?

return response

async def _get_headers(
Expand Down
19 changes: 14 additions & 5 deletions src/google/adk/tools/mcp_tool/mcp_toolset.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@
# Attempt to import MCP Tool from the MCP library, and hints user to upgrade
# their Python version to 3.10 if it fails.
try:
from mcp import StdioServerParameters
from mcp.types import ListToolsResult
from mcp import StdioServerParameters
except ImportError as e:
import sys

Expand Down Expand Up @@ -68,7 +68,8 @@ class MCPToolset(BaseToolset):
command='npx',
args=["-y", "@modelcontextprotocol/server-filesystem"],
),
tool_filter=['read_file', 'list_directory'] # Optional: filter specific tools
tool_filter=['read_file', 'list_directory'], # Optional: filter specific tools
tool_name_prefix="sfs_", # Optional: add_name_prefix
)

# Use in an agent
Expand Down Expand Up @@ -98,6 +99,7 @@ def __init__(
errlog: TextIO = sys.stderr,
auth_scheme: Optional[AuthScheme] = None,
auth_credential: Optional[AuthCredential] = None,
tool_name_prefix: str = "",
):
"""Initializes the MCPToolset.

Expand All @@ -110,12 +112,17 @@ def __init__(
mcp server (e.g. using `npx` or `python3` ), but it does not support
timeout, and we recommend to use `StdioConnectionParams` instead when
timeout is needed.
tool_filter: Optional filter to select specific tools. Can be either: - A
list of tool names to include - A ToolPredicate function for custom
filtering logic
tool_filter: Optional filter to select specific tools. Can be either:
- A list of tool names to include
- A ToolPredicate function for custom filtering logic
In both cases, the tool name WILL include the `tool_name_prefix` when
matching.
errlog: TextIO stream for error logging.
auth_scheme: The auth scheme of the tool for tool calling
auth_credential: The auth credential of the tool for tool calling
tool_name_prefix: string to add to the start of the name of all return tools.
For example, `prefix="ns_"` would change a returned tool name from
`my_tool` to `ns_my_tool`.
"""
super().__init__(tool_filter=tool_filter)

Expand All @@ -124,6 +131,7 @@ def __init__(

self._connection_params = connection_params
self._errlog = errlog
self._tool_name_prefix = tool_name_prefix

# Create the session manager that will handle the MCP connection
self._mcp_session_manager = MCPSessionManager(
Expand Down Expand Up @@ -161,6 +169,7 @@ async def get_tools(
mcp_session_manager=self._mcp_session_manager,
auth_scheme=self._auth_scheme,
auth_credential=self._auth_credential,
tool_name_prefix=self._tool_name_prefix,
)

if self._is_tool_selected(mcp_tool, readonly_context):
Expand Down
86 changes: 83 additions & 3 deletions tests/unittests/tools/mcp_tool/test_mcp_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,21 @@
# limitations under the License.

import sys
from typing import Any
from typing import Dict
from google.genai.types import Part
from unittest.mock import AsyncMock
from unittest.mock import Mock
from unittest.mock import patch

from google.adk import Agent
from google.adk.tools import FunctionTool
from google.adk.auth.auth_credential import AuthCredential
from google.adk.auth.auth_credential import AuthCredentialTypes
from google.adk.auth.auth_credential import HttpAuth
from google.adk.auth.auth_credential import HttpCredentials
from google.adk.auth.auth_credential import OAuth2Auth
from google.adk.auth.auth_credential import ServiceAccount
import pytest
from tests.unittests import testing_utils


# Skip all tests in this module if Python version is less than 3.10
pytestmark = pytest.mark.skipif(
Expand All @@ -38,6 +40,7 @@
from google.adk.tools.mcp_tool.mcp_tool import MCPTool
from google.adk.tools.tool_context import ToolContext
from google.genai.types import FunctionDeclaration
from mcp.types import Tool as McpBaseTool
except ImportError as e:
if sys.version_info < (3, 10):
# Create dummy classes to prevent NameError during test collection
Expand All @@ -49,6 +52,7 @@ class DummyClass:
MCPTool = DummyClass
ToolContext = DummyClass
FunctionDeclaration = DummyClass
McpBaseTool = DummyClass
else:
raise e

Expand Down Expand Up @@ -358,3 +362,79 @@ def test_init_validation(self):

with pytest.raises(TypeError):
MCPTool(mcp_tool=self.mock_mcp_tool) # Missing session manager


class TestMCPSession(object):

def __init__(self, function_tool: FunctionTool):
self._function_tool = function_tool

async def call_tool(self, name, arguments):
return self._function_tool.func(**arguments)


class TestMCPSessionManager(object):

def __init__(self, function_tool: FunctionTool):
self._function_tool = function_tool

async def create_session(self, headers=None):
return TestMCPSession(self._function_tool)

async def close(self):
pass


def mcp_tool(function_tool: FunctionTool, prefix=""):
return MCPTool(
mcp_tool=McpBaseTool(
name=function_tool.name,
description=function_tool.description,
inputSchema=function_tool._get_declaration().parameters.json_schema.model_dump(
exclude_none=True
),
),
mcp_session_manager=TestMCPSessionManager(function_tool),
tool_name_prefix=prefix,
)


def test_mcp_tool():
@FunctionTool
def add(a: int, b: int):
"""Add a and b and retuirn the result"""
return a + b

mcp_add = mcp_tool(add, "mcp_")

add_call = Part.from_function_call(name="add", args={"a": 1, "b": 2})
add_response = Part.from_function_response(name="add", response={"result": 3})

mcp_add_call = Part.from_function_call(name="mcp_add", args={"a": 5, "b": 10})
mcp_add_response = Part.from_function_response(
name="mcp_add", response={"result": 15}
)

mock_model = testing_utils.MockModel.create(
responses=[
add_call,
mcp_add_call,
"response1",
]
)

root_agent = Agent(
name="root_agent",
model=mock_model,
tools=[add, mcp_add],
)

runner = testing_utils.InMemoryRunner(root_agent)

assert testing_utils.simplify_events(runner.run("test1")) == [
("root_agent", add_call),
("root_agent", add_response),
("root_agent", mcp_add_call),
("root_agent", mcp_add_response),
("root_agent", "response1"),
]