Skip to content

Commit 785d23c

Browse files
committed
feat(session): propagate callback exceptions to the awaiter
1 parent e8e6484 commit 785d23c

2 files changed

Lines changed: 87 additions & 1 deletion

File tree

src/mcp/shared/session.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from mcp.shared.response_router import ResponseRouter
2121
from mcp.types import (
2222
CONNECTION_CLOSED,
23+
INTERNAL_ERROR,
2324
INVALID_PARAMS,
2425
REQUEST_TIMEOUT,
2526
CancelledNotification,
@@ -184,6 +185,7 @@ class BaseSession(
184185
_in_flight: dict[RequestId, RequestResponder[ReceiveRequestT, SendResultT]]
185186
_progress_callbacks: dict[RequestId, ProgressFnT]
186187
_response_routers: list[ResponseRouter]
188+
_propagate_errors: dict[RequestId, BaseException]
187189

188190
def __init__(
189191
self,
@@ -201,6 +203,7 @@ def __init__(
201203
self._progress_callbacks = {}
202204
self._response_routers = []
203205
self._exit_stack = AsyncExitStack()
206+
self._propagate_errors = {}
204207

205208
def add_response_router(self, router: ResponseRouter) -> None:
206209
"""Register a response router to handle responses for non-standard requests.
@@ -295,6 +298,11 @@ async def send_request(
295298
class_name = request.__class__.__name__
296299
message = f"Timed out while waiting for response to {class_name}. Waited {timeout} seconds."
297300
raise MCPError(code=REQUEST_TIMEOUT, message=message)
301+
except anyio.EndOfStream:
302+
propagate = self._propagate_errors.pop(request_id, None)
303+
if propagate is not None:
304+
raise propagate from None
305+
raise
298306

299307
if isinstance(response_or_error, JSONRPCError):
300308
raise MCPError.from_jsonrpc_error(response_or_error)
@@ -374,7 +382,20 @@ async def _handle_session_message(message: SessionMessage) -> None:
374382

375383
if not responder._completed: # type: ignore[reportPrivateUsage]
376384
await self._handle_incoming(responder)
377-
except Exception:
385+
except Exception as e:
386+
if getattr(e, "__mcp_propagate__", False):
387+
error_response = JSONRPCError(
388+
jsonrpc="2.0",
389+
id=message.message.id,
390+
error=ErrorData(code=INTERNAL_ERROR, message="Handler raised", data=""),
391+
)
392+
await self._write_stream.send(SessionMessage(message=error_response))
393+
self._in_flight.pop(message.message.id, None)
394+
for in_flight_id, stream in list(self._response_streams.items()):
395+
self._propagate_errors[in_flight_id] = e
396+
await stream.aclose()
397+
return
398+
378399
# For request validation errors, send a proper JSON-RPC error
379400
# response instead of crashing the server
380401
logging.warning("Failed to validate request", exc_info=True)

tests/shared/test_session.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from mcp.shared.message import SessionMessage
1010
from mcp.shared.session import RequestResponder
1111
from mcp.types import (
12+
INTERNAL_ERROR,
1213
PARSE_ERROR,
1314
CancelledNotification,
1415
CancelledNotificationParams,
@@ -416,3 +417,67 @@ async def make_request(client_session: ClientSession):
416417
# Pending request completed successfully
417418
assert len(result_holder) == 1
418419
assert isinstance(result_holder[0], EmptyResult)
420+
421+
422+
@pytest.mark.anyio
423+
async def test_callback_exception_propagation():
424+
"""Verify that exceptions raised in callbacks with __mcp_propagate__ = True
425+
are propagated to the awaiter of send_request, and result in INTERNAL_ERROR to peer.
426+
"""
427+
class CustomPropagatedException(Exception):
428+
__mcp_propagate__ = True
429+
430+
ev_server_received_error = anyio.Event()
431+
server_error_holder: list[JSONRPCError] = []
432+
433+
async with create_client_server_memory_streams() as (client_streams, server_streams):
434+
client_read, client_write = client_streams
435+
server_read, server_write = server_streams
436+
437+
async def mock_server():
438+
# Wait for client's ping request
439+
msg = await server_read.receive()
440+
assert isinstance(msg, SessionMessage)
441+
assert isinstance(msg.message, JSONRPCRequest)
442+
443+
# Trigger list_roots callback on client by sending roots/list request
444+
roots_request = JSONRPCRequest(
445+
jsonrpc="2.0",
446+
id=1,
447+
method="roots/list",
448+
)
449+
await server_write.send(SessionMessage(message=roots_request))
450+
451+
# Receive the client's response (which should be an error due to propagated exception)
452+
response_msg = await server_read.receive()
453+
assert isinstance(response_msg, SessionMessage)
454+
assert isinstance(response_msg.message, JSONRPCError)
455+
server_error_holder.append(response_msg.message)
456+
ev_server_received_error.set()
457+
458+
async def mock_list_roots(ctx):
459+
raise CustomPropagatedException("Callback error that should propagate")
460+
461+
async def make_request(client_session: ClientSession):
462+
# Send a ping request and assert that CustomPropagatedException propagates to it
463+
with pytest.raises(CustomPropagatedException) as exc_info:
464+
await client_session.send_ping()
465+
assert "Callback error that should propagate" in str(exc_info.value)
466+
467+
async with (
468+
anyio.create_task_group() as tg,
469+
ClientSession(
470+
read_stream=client_read,
471+
write_stream=client_write,
472+
list_roots_callback=mock_list_roots,
473+
) as client_session,
474+
):
475+
tg.start_soon(mock_server)
476+
tg.start_soon(make_request, client_session)
477+
478+
with anyio.fail_after(2):
479+
await ev_server_received_error.wait()
480+
481+
assert len(server_error_holder) == 1
482+
assert server_error_holder[0].error.code == INTERNAL_ERROR
483+

0 commit comments

Comments
 (0)