Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 26 additions & 15 deletions src/google/adk/tools/mcp_tool/mcp_toolset.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,9 @@ def __init__(
self._auth_scheme = auth_scheme
self._auth_credential = auth_credential

# Cache for tools to avoid repeated session.list_tools() calls
self._cached_tools: Optional[List[BaseTool]] = None

@retry_on_closed_resource
async def get_tools(
self,
Expand All @@ -151,24 +154,30 @@ async def get_tools(
Returns:
List[BaseTool]: A list of tools available under the specified context.
"""
# Get session from session manager
session = await self._mcp_session_manager.create_session()

# Fetch available tools from the MCP server
tools_response: ListToolsResult = await session.list_tools()
# Use cached tools if available, otherwise fetch and cache them
if self._cached_tools is None:
# Get session from session manager
session = await self._mcp_session_manager.create_session()

# Fetch available tools from the MCP server
tools_response: ListToolsResult = await session.list_tools()

# Create MCPTool instances for all tools and cache them
self._cached_tools = []
for tool in tools_response.tools:
mcp_tool = MCPTool(
mcp_tool=tool,
mcp_session_manager=self._mcp_session_manager,
auth_scheme=self._auth_scheme,
auth_credential=self._auth_credential,
)
self._cached_tools.append(mcp_tool)

# Apply filtering based on context and tool_filter
tools = []
for tool in tools_response.tools:
mcp_tool = MCPTool(
mcp_tool=tool,
mcp_session_manager=self._mcp_session_manager,
auth_scheme=self._auth_scheme,
auth_credential=self._auth_credential,
)

if self._is_tool_selected(mcp_tool, readonly_context):
tools.append(mcp_tool)
for tool in self._cached_tools:
if self._is_tool_selected(tool, readonly_context):
tools.append(tool)
return tools

async def close(self) -> None:
Expand All @@ -180,6 +189,8 @@ async def close(self) -> None:
"""
try:
await self._mcp_session_manager.close()
# Clear the cached tools when closing
self._cached_tools = None
except Exception as e:
# Log the error but don't re-raise to avoid blocking shutdown
print(f"Warning: Error during MCPToolset cleanup: {e}", file=self._errlog)
Expand Down
160 changes: 160 additions & 0 deletions tests/unittests/tools/mcp_tool/test_mcp_toolset.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,3 +284,163 @@ async def test_get_tools_retry_decorator(self):

# Check that the method has the retry decorator
assert hasattr(toolset.get_tools, "__wrapped__")

@pytest.mark.asyncio
async def test_tools_caching_behavior(self):
"""Test that tools are cached and session.list_tools() is only called once."""
# Mock tools from MCP server
mock_tools = [
MockMCPTool("tool1"),
MockMCPTool("tool2"),
]
self.mock_session.list_tools = AsyncMock(
return_value=MockListToolsResult(mock_tools)
)

toolset = MCPToolset(connection_params=self.mock_stdio_params)
toolset._mcp_session_manager = self.mock_session_manager

# First call should fetch tools from server
tools1 = await toolset.get_tools()
assert len(tools1) == 2
assert self.mock_session.list_tools.call_count == 1

# Second call should use cached tools, not call server again
tools2 = await toolset.get_tools()
assert len(tools2) == 2
assert (
self.mock_session.list_tools.call_count == 1
) # Still only called once

# Third call should also use cache
tools3 = await toolset.get_tools()
assert len(tools3) == 2
assert (
self.mock_session.list_tools.call_count == 1
) # Still only called once

# Verify all returned tool instances are the same (cached)
for i in range(len(tools1)):
assert tools1[i] is tools2[i] # Same object instance
assert tools2[i] is tools3[i] # Same object instance

def test_cache_initialization(self):
"""Test that cache starts as None and gets populated correctly."""
toolset = MCPToolset(connection_params=self.mock_stdio_params)

# Cache should start as None
assert toolset._cached_tools is None

@pytest.mark.asyncio
async def test_cache_populated_after_first_call(self):
"""Test that cache gets populated after first get_tools() call."""
# Mock tools from MCP server
mock_tools = [MockMCPTool("tool1"), MockMCPTool("tool2")]
self.mock_session.list_tools = AsyncMock(
return_value=MockListToolsResult(mock_tools)
)

toolset = MCPToolset(connection_params=self.mock_stdio_params)
toolset._mcp_session_manager = self.mock_session_manager

# Cache should start as None
assert toolset._cached_tools is None

# After first call, cache should be populated
await toolset.get_tools()
assert toolset._cached_tools is not None
assert len(toolset._cached_tools) == 2
assert all(isinstance(tool, MCPTool) for tool in toolset._cached_tools)

@pytest.mark.asyncio
async def test_cache_cleared_on_close(self):
"""Test that cache is cleared when toolset is closed."""
# Mock tools from MCP server
mock_tools = [MockMCPTool("tool1")]
self.mock_session.list_tools = AsyncMock(
return_value=MockListToolsResult(mock_tools)
)

toolset = MCPToolset(connection_params=self.mock_stdio_params)
toolset._mcp_session_manager = self.mock_session_manager

# Populate cache
await toolset.get_tools()
assert toolset._cached_tools is not None

# Close should clear the cache
await toolset.close()
assert toolset._cached_tools is None

@pytest.mark.asyncio
async def test_cache_filtering_works_correctly(self):
"""Test that filtering works correctly with cached tools."""
# Mock tools from MCP server
mock_tools = [
MockMCPTool("read_file"),
MockMCPTool("write_file"),
MockMCPTool("list_directory"),
]
self.mock_session.list_tools = AsyncMock(
return_value=MockListToolsResult(mock_tools)
)

# Create toolset with filter
tool_filter = ["read_file", "write_file"]
toolset = MCPToolset(
connection_params=self.mock_stdio_params, tool_filter=tool_filter
)
toolset._mcp_session_manager = self.mock_session_manager

# First call - should fetch and cache all tools, return filtered ones
tools1 = await toolset.get_tools()
assert len(tools1) == 2 # Only filtered tools returned
assert tools1[0].name == "read_file"
assert tools1[1].name == "write_file"
assert self.mock_session.list_tools.call_count == 1

# Cache should contain all tools (before filtering)
assert len(toolset._cached_tools) == 3

# Second call - should use cache, apply filtering again
tools2 = await toolset.get_tools()
assert len(tools2) == 2 # Only filtered tools returned
assert tools2[0].name == "read_file"
assert tools2[1].name == "write_file"
assert (
self.mock_session.list_tools.call_count == 1
) # Still only called once

# Returned tool instances should be the same (from cache)
assert tools1[0] is tools2[0]
assert tools1[1] is tools2[1]

@pytest.mark.asyncio
async def test_cache_with_different_readonly_contexts(self):
"""Test that cache works with different readonly contexts."""
# Mock tools from MCP server
mock_tools = [MockMCPTool("tool1"), MockMCPTool("tool2")]
self.mock_session.list_tools = AsyncMock(
return_value=MockListToolsResult(mock_tools)
)

toolset = MCPToolset(connection_params=self.mock_stdio_params)
toolset._mcp_session_manager = self.mock_session_manager

# First call with None context
tools1 = await toolset.get_tools(readonly_context=None)
assert len(tools1) == 2
assert self.mock_session.list_tools.call_count == 1

# Second call with different context - should still use cache
# Since the caching behavior doesn't depend on context (cache is at toolset level),
# we can test with None context again to verify cache works
tools2 = await toolset.get_tools(readonly_context=None)
assert len(tools2) == 2
assert (
self.mock_session.list_tools.call_count == 1
) # Still only called once

# Tools should be from the same cache
assert tools1[0] is tools2[0]
assert tools1[1] is tools2[1]