-
Notifications
You must be signed in to change notification settings - Fork 1.4k
Python: fix(python): prevent MCP message_handler deadlock on notification reload #4866
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -9,7 +9,7 @@ | |||||||||||||||||||||||
| import re | ||||||||||||||||||||||||
| import sys | ||||||||||||||||||||||||
| from abc import abstractmethod | ||||||||||||||||||||||||
| from collections.abc import Callable, Collection, Sequence | ||||||||||||||||||||||||
| from collections.abc import Callable, Collection, Coroutine, Sequence | ||||||||||||||||||||||||
| from contextlib import AsyncExitStack, _AsyncGeneratorContextManager # type: ignore | ||||||||||||||||||||||||
| from datetime import timedelta | ||||||||||||||||||||||||
| from functools import partial | ||||||||||||||||||||||||
|
|
@@ -491,6 +491,7 @@ def __init__( | |||||||||||||||||||||||
| self.is_connected: bool = False | ||||||||||||||||||||||||
| self._tools_loaded: bool = False | ||||||||||||||||||||||||
| self._prompts_loaded: bool = False | ||||||||||||||||||||||||
| self._pending_reload_tasks: set[asyncio.Task[None]] = set() | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| def __str__(self) -> str: | ||||||||||||||||||||||||
| return f"MCPTool(name={self.name}, description={self.description})" | ||||||||||||||||||||||||
|
|
@@ -799,17 +800,51 @@ async def message_handler( | |||||||||||||||||||||||
| message: The message from the MCP server (request responder, notification, or exception). | ||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||
| if isinstance(message, Exception): | ||||||||||||||||||||||||
| logger.error("Error from MCP server: %s", message, exc_info=message) | ||||||||||||||||||||||||
| logger.error("Error from MCP server: %s", message, exc_info=True) | ||||||||||||||||||||||||
| return | ||||||||||||||||||||||||
| if isinstance(message, types.ServerNotification): | ||||||||||||||||||||||||
| match message.root.method: | ||||||||||||||||||||||||
| case "notifications/tools/list_changed": | ||||||||||||||||||||||||
| await self.load_tools() | ||||||||||||||||||||||||
| self._schedule_reload(self.load_tools()) | ||||||||||||||||||||||||
| case "notifications/prompts/list_changed": | ||||||||||||||||||||||||
| await self.load_prompts() | ||||||||||||||||||||||||
| self._schedule_reload(self.load_prompts()) | ||||||||||||||||||||||||
| case _: | ||||||||||||||||||||||||
| logger.debug("Unhandled notification: %s", message.root.method) | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| def _schedule_reload(self, coro: Coroutine[Any, Any, None]) -> None: | ||||||||||||||||||||||||
| """Schedule a reload coroutine as a background task. | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| Reloads (load_tools / load_prompts) triggered by MCP server | ||||||||||||||||||||||||
| notifications must NOT be awaited inside the message handler because | ||||||||||||||||||||||||
| the handler runs on the MCP SDK's single-threaded receive loop. | ||||||||||||||||||||||||
| Awaiting a session request (e.g. ``list_tools``) from within that loop | ||||||||||||||||||||||||
| deadlocks: the receive loop cannot read the response while it is | ||||||||||||||||||||||||
| blocked waiting for the handler to return. | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| Instead we fire the reload as an independent ``asyncio.Task`` and keep | ||||||||||||||||||||||||
| a strong reference in ``_pending_reload_tasks`` so it is not garbage- | ||||||||||||||||||||||||
| collected before completion. Only one reload per kind (tools / prompts) | ||||||||||||||||||||||||
| is kept in flight; a new notification cancels the previous pending task | ||||||||||||||||||||||||
| for the same coroutine name to avoid unbounded growth. | ||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||
| # Cancel-and-replace: only one reload per kind should be in flight. | ||||||||||||||||||||||||
| reload_name = f"mcp-reload:{self.name}:{coro.__qualname__}" | ||||||||||||||||||||||||
| for existing in list(self._pending_reload_tasks): | ||||||||||||||||||||||||
| if existing.get_name() == reload_name and not existing.done(): | ||||||||||||||||||||||||
| existing.cancel() | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| async def _safe_reload() -> None: | ||||||||||||||||||||||||
| try: | ||||||||||||||||||||||||
| await coro | ||||||||||||||||||||||||
| except asyncio.CancelledError: | ||||||||||||||||||||||||
| raise | ||||||||||||||||||||||||
| except Exception: | ||||||||||||||||||||||||
| logger.warning("Background MCP reload failed", exc_info=True) | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| task = asyncio.create_task(_safe_reload(), name=reload_name) | ||||||||||||||||||||||||
| self._pending_reload_tasks.add(task) | ||||||||||||||||||||||||
| task.add_done_callback(self._pending_reload_tasks.discard) | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| def _determine_approval_mode( | ||||||||||||||||||||||||
| self, | ||||||||||||||||||||||||
| *candidate_names: str, | ||||||||||||||||||||||||
|
|
@@ -931,6 +966,11 @@ async def load_tools(self) -> None: | |||||||||||||||||||||||
| params = types.PaginatedRequestParams(cursor=tool_list.nextCursor) | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| async def _close_on_owner(self) -> None: | ||||||||||||||||||||||||
| # Cancel any pending reload tasks before tearing down the session. | ||||||||||||||||||||||||
| for task in list(self._pending_reload_tasks): | ||||||||||||||||||||||||
| task.cancel() | ||||||||||||||||||||||||
| self._pending_reload_tasks.clear() | ||||||||||||||||||||||||
|
Comment on lines
968
to
+972
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Reliability concern:
Suggested change
|
||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| await self._safe_close_exit_stack() | ||||||||||||||||||||||||
| self._exit_stack = AsyncExitStack() | ||||||||||||||||||||||||
| self.session = None | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,5 +1,6 @@ | ||
| # Copyright (c) Microsoft. All rights reserved. | ||
| # type: ignore[reportPrivateUsage] | ||
| import asyncio | ||
| import logging | ||
| import os | ||
| from contextlib import _AsyncGeneratorContextManager # type: ignore | ||
|
|
@@ -1478,7 +1479,7 @@ async def call_tool_with_error(*args, **kwargs): | |
|
|
||
| async def test_mcp_tool_message_handler_notification(): | ||
| """Test that message_handler correctly processes tools/list_changed and prompts/list_changed | ||
| notifications.""" | ||
| notifications by scheduling reloads as background tasks.""" | ||
| tool = MCPStdioTool(name="test_tool", command="python") | ||
|
|
||
| # Mock the load_tools and load_prompts methods | ||
|
|
@@ -1492,6 +1493,8 @@ async def test_mcp_tool_message_handler_notification(): | |
|
|
||
| result = await tool.message_handler(tools_notification) | ||
| assert result is None | ||
| # The reload is scheduled as a background task; let it run. | ||
| await asyncio.sleep(0) | ||
| tool.load_tools.assert_called_once() | ||
|
|
||
| # Reset mock | ||
|
|
@@ -1504,6 +1507,7 @@ async def test_mcp_tool_message_handler_notification(): | |
|
|
||
| result = await tool.message_handler(prompts_notification) | ||
| assert result is None | ||
| await asyncio.sleep(0) | ||
| tool.load_prompts.assert_called_once() | ||
|
|
||
| # Test unhandled notification | ||
|
|
@@ -1527,6 +1531,71 @@ async def test_mcp_tool_message_handler_error(): | |
| assert result is None | ||
|
|
||
|
|
||
| async def test_mcp_tool_message_handler_does_not_block_receive_loop(): | ||
| """Test that message_handler does not deadlock the MCP receive loop. | ||
|
|
||
| Regression test for https://github.com/microsoft/agent-framework/issues/4828. | ||
| When the MCP server sends a ``notifications/tools/list_changed`` | ||
| notification, the handler must NOT await ``load_tools()`` synchronously | ||
| because that would block the single-threaded MCP receive loop, preventing | ||
| it from delivering the ``list_tools`` response — a classic deadlock. | ||
| """ | ||
| tool = MCPStdioTool(name="test_tool", command="python") | ||
|
|
||
| # Use an event to make load_tools block until we release it. | ||
| # This simulates load_tools waiting for a session response that the | ||
| # receive loop would need to deliver. | ||
| release = asyncio.Event() | ||
|
|
||
| async def slow_load_tools(): | ||
| await release.wait() | ||
|
|
||
| tool.load_tools = slow_load_tools # type: ignore[assignment] | ||
|
|
||
| tools_notification = Mock(spec=types.ServerNotification) | ||
| tools_notification.root = Mock() | ||
| tools_notification.root.method = "notifications/tools/list_changed" | ||
|
|
||
| # message_handler must return immediately even though load_tools blocks. | ||
| await tool.message_handler(tools_notification) | ||
|
|
||
| # If the handler had awaited load_tools synchronously, we would never | ||
| # reach this line (deadlock). Verify the reload task is pending. | ||
| assert len(tool._pending_reload_tasks) == 1 | ||
|
|
||
| # Unblock the reload so the background task finishes cleanly. | ||
| release.set() | ||
| # Wait for the pending reload task(s) to complete so their done-callbacks | ||
| # have a chance to remove them from _pending_reload_tasks. | ||
| await asyncio.wait_for(asyncio.gather(*tool._pending_reload_tasks), timeout=1) | ||
| assert len(tool._pending_reload_tasks) == 0 | ||
|
|
||
|
|
||
| async def test_mcp_tool_message_handler_reload_failure_is_logged(caplog: pytest.LogCaptureFixture): | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This test verifies the exception doesn't propagate, but (in its current on-disk form without |
||
| """Background reload errors are logged, not raised into the receive loop.""" | ||
| tool = MCPStdioTool(name="test_tool", command="python") | ||
| tool.load_tools = AsyncMock(side_effect=RuntimeError("connection lost")) | ||
|
|
||
| tools_notification = Mock(spec=types.ServerNotification) | ||
| tools_notification.root = Mock() | ||
| tools_notification.root.method = "notifications/tools/list_changed" | ||
|
|
||
| await tool.message_handler(tools_notification) | ||
| # Let the background task run — it should not propagate the exception. | ||
| # Snapshot tasks and await them to ensure done-callbacks fire. | ||
| pending = list(tool._pending_reload_tasks) | ||
| if pending: | ||
| await asyncio.wait_for(asyncio.gather(*pending, return_exceptions=True), timeout=1) | ||
| tool.load_tools.assert_called_once() | ||
giles17 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| assert len(tool._pending_reload_tasks) == 0 | ||
|
|
||
| # Verify the warning was actually logged with exception info. | ||
| reload_warnings = [r for r in caplog.records if "Background MCP reload failed" in r.message] | ||
| assert len(reload_warnings) == 1 | ||
| assert reload_warnings[0].levelname == "WARNING" | ||
| assert reload_warnings[0].exc_info is not None | ||
|
|
||
|
|
||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Consider adding a companion test for the cancel-and-replace logic: send two |
||
| async def test_mcp_tool_sampling_callback_no_client(): | ||
| """Test sampling callback error path when no chat client is available.""" | ||
| tool = MCPStdioTool(name="test_tool", command="python") | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Bug:
message_handleris called by the MCP SDK with the exception as a parameter—NOT from within anexceptblock (seemcp/shared/session.py:359and:516). Withexc_info=True,sys.exc_info()returns(None, None, None), so the exception type and traceback are silently lost. The originalexc_info=messagecorrectly passes the exception instance, letting Python's logging extract the full traceback directly.