|
9 | 9 | from mcp.shared.message import SessionMessage |
10 | 10 | from mcp.shared.session import RequestResponder |
11 | 11 | from mcp.types import ( |
| 12 | + INTERNAL_ERROR, |
12 | 13 | PARSE_ERROR, |
13 | 14 | CancelledNotification, |
14 | 15 | CancelledNotificationParams, |
@@ -416,3 +417,67 @@ async def make_request(client_session: ClientSession): |
416 | 417 | # Pending request completed successfully |
417 | 418 | assert len(result_holder) == 1 |
418 | 419 | assert isinstance(result_holder[0], EmptyResult) |
| 420 | + |
| 421 | + |
| 422 | +@pytest.mark.anyio |
| 423 | +async def test_callback_exception_propagation(): |
| 424 | + """Verify that exceptions raised in callbacks with __mcp_propagate__ = True |
| 425 | + are propagated to the awaiter of send_request, and result in INTERNAL_ERROR to peer. |
| 426 | + """ |
| 427 | + class CustomPropagatedException(Exception): |
| 428 | + __mcp_propagate__ = True |
| 429 | + |
| 430 | + ev_server_received_error = anyio.Event() |
| 431 | + server_error_holder: list[JSONRPCError] = [] |
| 432 | + |
| 433 | + async with create_client_server_memory_streams() as (client_streams, server_streams): |
| 434 | + client_read, client_write = client_streams |
| 435 | + server_read, server_write = server_streams |
| 436 | + |
| 437 | + async def mock_server(): |
| 438 | + # Wait for client's ping request |
| 439 | + msg = await server_read.receive() |
| 440 | + assert isinstance(msg, SessionMessage) |
| 441 | + assert isinstance(msg.message, JSONRPCRequest) |
| 442 | + |
| 443 | + # Trigger list_roots callback on client by sending roots/list request |
| 444 | + roots_request = JSONRPCRequest( |
| 445 | + jsonrpc="2.0", |
| 446 | + id=1, |
| 447 | + method="roots/list", |
| 448 | + ) |
| 449 | + await server_write.send(SessionMessage(message=roots_request)) |
| 450 | + |
| 451 | + # Receive the client's response (which should be an error due to propagated exception) |
| 452 | + response_msg = await server_read.receive() |
| 453 | + assert isinstance(response_msg, SessionMessage) |
| 454 | + assert isinstance(response_msg.message, JSONRPCError) |
| 455 | + server_error_holder.append(response_msg.message) |
| 456 | + ev_server_received_error.set() |
| 457 | + |
| 458 | + async def mock_list_roots(ctx): |
| 459 | + raise CustomPropagatedException("Callback error that should propagate") |
| 460 | + |
| 461 | + async def make_request(client_session: ClientSession): |
| 462 | + # Send a ping request and assert that CustomPropagatedException propagates to it |
| 463 | + with pytest.raises(CustomPropagatedException) as exc_info: |
| 464 | + await client_session.send_ping() |
| 465 | + assert "Callback error that should propagate" in str(exc_info.value) |
| 466 | + |
| 467 | + async with ( |
| 468 | + anyio.create_task_group() as tg, |
| 469 | + ClientSession( |
| 470 | + read_stream=client_read, |
| 471 | + write_stream=client_write, |
| 472 | + list_roots_callback=mock_list_roots, |
| 473 | + ) as client_session, |
| 474 | + ): |
| 475 | + tg.start_soon(mock_server) |
| 476 | + tg.start_soon(make_request, client_session) |
| 477 | + |
| 478 | + with anyio.fail_after(2): |
| 479 | + await ev_server_received_error.wait() |
| 480 | + |
| 481 | + assert len(server_error_holder) == 1 |
| 482 | + assert server_error_holder[0].error.code == INTERNAL_ERROR |
| 483 | + |
0 commit comments