Skip to content
27 changes: 18 additions & 9 deletions src/google/adk/tools/mcp_tool/mcp_session_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@
from mcp.client.stdio import stdio_client
from mcp.client.streamable_http import streamablehttp_client
except ImportError as e:

if sys.version_info < (3, 10):
raise ImportError(
'MCP Tool requires Python 3.10 or above. Please upgrade your Python'
Expand All @@ -62,12 +61,15 @@ class StdioConnectionParams(BaseModel):
server_params: StdioServerParameters
timeout: float = 5.0

class Config:
arbitrary_types_allowed = True


class SseConnectionParams(BaseModel):
"""Parameters for the MCP SSE connection.

See MCP SSE Client documentation for more details.
https://github.com/modelcontextprotocol/python-sdk/blob/main/src/mcp/client/sse.py
[https://github.com/modelcontextprotocol/python-sdk/blob/main/src/mcp/client/sse.py](https://github.com/modelcontextprotocol/python-sdk/blob/main/src/mcp/client/sse.py)

Attributes:
url: URL for the MCP SSE server.
Expand All @@ -88,7 +90,7 @@ class StreamableHTTPConnectionParams(BaseModel):
"""Parameters for the MCP Streamable HTTP connection.

See MCP Streamable HTTP Client documentation for more details.
https://github.com/modelcontextprotocol/python-sdk/blob/main/src/mcp/client/streamable_http.py
[https://github.com/modelcontextprotocol/python-sdk/blob/main/src/mcp/client/streamable_http.py](https://github.com/modelcontextprotocol/python-sdk/blob/main/src/mcp/client/streamable_http.py)

Attributes:
url: URL for the MCP Streamable HTTP server.
Expand All @@ -111,12 +113,21 @@ class StreamableHTTPConnectionParams(BaseModel):
def retry_on_closed_resource(func):
"""Decorator to automatically retry action when MCP session is closed.

CRITICAL WARNING: This decorator is UNSAFE for non-idempotent operations.
Do NOT use with tool calls that create, update, or delete resources as
retrying can cause duplicate operations or data corruption.

Only use with read-only, idempotent operations like list_tools,
list_resources, or read_resource.

Do NOT apply to generic tool execution methods like _run_async_impl.

When MCP session was closed, the decorator will automatically retry the
action once. The create_session method will handle creating a new session
if the old one was disconnected.

Args:
func: The function to decorate.
func: The function to decorate. Must be idempotent and safe to retry.

Returns:
The decorated function.
Expand Down Expand Up @@ -177,11 +188,10 @@ def __init__(
)
else:
self._connection_params = connection_params
self._errlog = errlog

self._errlog = errlog
# Session pool: maps session keys to (session, exit_stack) tuples
self._sessions: Dict[str, tuple[ClientSession, AsyncExitStack]] = {}

# Lock to prevent race conditions in session creation
self._session_lock = asyncio.Lock()

Expand Down Expand Up @@ -293,6 +303,7 @@ def _create_client(self, merged_headers: Optional[Dict[str, str]] = None):
' StdioServerParameters or SseServerParams, but got'
f' {self._connection_params}'
)

return client

async def create_session(
Expand All @@ -314,7 +325,6 @@ async def create_session(
"""
# Merge headers once at the beginning
merged_headers = self._merge_headers(headers)

# Generate session key using merged headers
session_key = self._generate_session_key(merged_headers)

Expand All @@ -323,7 +333,6 @@ async def create_session(
# Check if we have an existing session
if session_key in self._sessions:
session, exit_stack = self._sessions[session_key]

# Check if the existing session is still connected
if not self._is_session_disconnected(session):
# Session is still good, return it
Expand Down Expand Up @@ -369,6 +378,7 @@ async def create_session(
)
await asyncio.wait_for(session.initialize(), timeout=timeout_in_seconds)

await session.initialize()
# Store session and exit stack in the pool
self._sessions[session_key] = (session, exit_stack)
logger.debug('Created new session: %s', session_key)
Expand Down Expand Up @@ -404,5 +414,4 @@ async def close(self):


SseServerParams = SseConnectionParams

StreamableHTTPServerParams = StreamableHTTPConnectionParams