Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 104 additions & 1 deletion src/agents/mcp/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import inspect
import sys
from collections.abc import Awaitable
from contextlib import AbstractAsyncContextManager, AsyncExitStack
from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager
from datetime import timedelta
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Literal, TypeVar, Union, cast
Expand Down Expand Up @@ -67,6 +67,12 @@ class _UnsetType:

_UNSET = _UnsetType()


class _InnerMCPRequestCancelled(Exception):
"""Raised when a shared MCP request is cancelled from inside the transport."""

pass

if TYPE_CHECKING:
from ..agent import AgentBase

Expand Down Expand Up @@ -1108,6 +1114,103 @@ def create_streams(
terminate_on_close=self.params.get("terminate_on_close", True),
)

@asynccontextmanager
async def _isolated_client_session(self):
async with AsyncExitStack() as exit_stack:
transport = await exit_stack.enter_async_context(self.create_streams())
read, write, *_ = transport
session = await exit_stack.enter_async_context(
ClientSession(
read,
write,
timedelta(seconds=self.client_session_timeout_seconds)
if self.client_session_timeout_seconds
else None,
message_handler=self.message_handler,
)
)
await session.initialize()
yield session

async def _call_tool_with_session(
self,
session: ClientSession,
tool_name: str,
arguments: dict[str, Any] | None,
meta: dict[str, Any] | None = None,
) -> CallToolResult:
if meta is None:
return await session.call_tool(tool_name, arguments)
return await session.call_tool(tool_name, arguments, meta=meta)

async def _call_tool_with_shared_session(
self,
tool_name: str,
arguments: dict[str, Any] | None,
meta: dict[str, Any] | None = None,
) -> CallToolResult:
session = self.session
assert session is not None
try:
return await self._call_tool_with_session(session, tool_name, arguments, meta)
except asyncio.CancelledError as exc:
raise _InnerMCPRequestCancelled from exc

async def _call_tool_with_isolated_retry(
self,
tool_name: str,
arguments: dict[str, Any] | None,
meta: dict[str, Any] | None = None,
) -> CallToolResult:
request_task = asyncio.create_task(
self._call_tool_with_shared_session(tool_name, arguments, meta)
)
try:
return await asyncio.shield(request_task)
except _InnerMCPRequestCancelled:
logger.warning(
"Retrying streamable-http MCP tool '%s' on isolated session after shared "
"request cancellation.",
tool_name,
)
async with self._isolated_client_session() as session:
return await self._call_tool_with_session(session, tool_name, arguments, meta)
except asyncio.CancelledError:
if not request_task.done():
request_task.cancel()
try:
await request_task
except (asyncio.CancelledError, Exception):
pass
raise

async def call_tool(
self,
tool_name: str,
arguments: dict[str, Any] | None,
meta: dict[str, Any] | None = None,
) -> CallToolResult:
"""Invoke a tool on the server."""
if not self.session:
raise UserError("Server not initialized. Make sure you call `connect()` first.")

try:
self._validate_required_parameters(tool_name=tool_name, arguments=arguments)
return await self._run_with_retries(
lambda: self._call_tool_with_isolated_retry(tool_name, arguments, meta)
)
except httpx.HTTPStatusError as e:
status_code = e.response.status_code
raise UserError(
f"Failed to call tool '{tool_name}' on MCP server '{self.name}': "
f"HTTP error {status_code}"
) from e
except httpx.ConnectError as e:
raise UserError(
f"Failed to call tool '{tool_name}' on MCP server '{self.name}': Connection lost. "
f"The server may have disconnected."
) from e

@property
def name(self) -> str:
"""A readable name for the server."""
Expand Down
99 changes: 98 additions & 1 deletion tests/mcp/test_client_session_retries.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import asyncio
from contextlib import asynccontextmanager
from typing import cast

import pytest
from mcp import ClientSession, Tool as MCPTool
from mcp.types import CallToolResult, ListToolsResult

from agents.exceptions import UserError
from agents.mcp.server import _MCPServerWithClientSession
from agents.mcp.server import MCPServerStreamableHttp, _MCPServerWithClientSession


class DummySession:
Expand Down Expand Up @@ -148,3 +150,98 @@ async def test_call_tool_rejects_non_object_arguments_before_remote_call():
await server.call_tool("tool", cast(dict[str, object] | None, ["bad"]))

assert session.call_tool_attempts == 0


class ConcurrentCancellationSession:
def __init__(self):
self._slow_task: asyncio.Task[CallToolResult] | None = None
self._slow_started = asyncio.Event()

async def call_tool(self, tool_name, arguments, meta=None):
if tool_name == "slow":
self._slow_task = cast(asyncio.Task[CallToolResult], asyncio.current_task())
self._slow_started.set()
await asyncio.sleep(0.1)
return CallToolResult(content=[])

await self._slow_started.wait()
assert self._slow_task is not None
self._slow_task.cancel()
raise RuntimeError("synthetic request failure")


class IsolatedRetrySession:
def __init__(self):
self.call_tool_attempts = 0

async def call_tool(self, tool_name, arguments, meta=None):
self.call_tool_attempts += 1
if tool_name == "slow":
return CallToolResult(content=[])
raise RuntimeError("synthetic request failure")


class HangingSession:
async def call_tool(self, tool_name, arguments, meta=None):
await asyncio.sleep(10)


class DummyStreamableHttpServer(MCPServerStreamableHttp):
def __init__(
self,
shared_session: ConcurrentCancellationSession,
isolated_session: IsolatedRetrySession,
):
super().__init__(
params={"url": "https://example.test/mcp"},
client_session_timeout_seconds=None,
max_retry_attempts=0,
)
self.session = cast(ClientSession, shared_session)
self._isolated_session = cast(ClientSession, isolated_session)

def create_streams(self):
raise NotImplementedError

@asynccontextmanager
async def _isolated_client_session(self):
yield self._isolated_session


@pytest.mark.asyncio
async def test_streamable_http_retries_cancelled_request_on_isolated_session():
shared_session = ConcurrentCancellationSession()
isolated_session = IsolatedRetrySession()
server = DummyStreamableHttpServer(
shared_session=shared_session,
isolated_session=isolated_session,
)

results = await asyncio.gather(
server.call_tool("slow", None),
server.call_tool("fail", None),
return_exceptions=True,
)

assert isinstance(results[0], CallToolResult)
assert isinstance(results[1], RuntimeError)
assert shared_session._slow_task is not None
assert isolated_session.call_tool_attempts == 1


@pytest.mark.asyncio
async def test_streamable_http_preserves_outer_cancellation():
isolated_session = IsolatedRetrySession()
server = DummyStreamableHttpServer(
shared_session=cast(ConcurrentCancellationSession, HangingSession()),
isolated_session=isolated_session,
)

task = asyncio.create_task(server.call_tool("slow", None))
await asyncio.sleep(0)
task.cancel()

with pytest.raises(asyncio.CancelledError):
await task

assert isolated_session.call_tool_attempts == 0
Loading