|
14 | 14 | try: |
15 | 15 | from builtins import ExceptionGroup |
16 | 16 | except ImportError: |
17 | | - from exceptiongroup import ExceptionGroup |
| 17 | + from exceptiongroup import ExceptionGroup # type: ignore |
18 | 18 |
|
19 | 19 | import anyio |
20 | 20 | import httpx |
@@ -1615,62 +1615,64 @@ async def test_client_unexpected_content_type_raises_mcp_error(): |
1615 | 1615 | async def html_endpoint(request: Request): |
1616 | 1616 | return HTMLResponse("<html><body>Not an MCP server</body></html>") |
1617 | 1617 |
|
1618 | | - app = Starlette(routes=[ |
1619 | | - Route("/mcp", html_endpoint, methods=["GET", "POST"]), |
1620 | | - ]) |
1621 | | - |
| 1618 | + app = Starlette( |
| 1619 | + routes=[ |
| 1620 | + Route("/mcp", html_endpoint, methods=["GET", "POST"]), |
| 1621 | + ] |
| 1622 | + ) |
| 1623 | + |
1622 | 1624 | # Start server on a random port using a simpler approach |
1623 | 1625 | with socket.socket() as s: |
1624 | 1626 | s.bind(("127.0.0.1", 0)) |
1625 | 1627 | port = s.getsockname()[1] |
1626 | | - |
| 1628 | + |
1627 | 1629 | # Use a thread instead of multiprocessing to avoid pickle issues |
1628 | 1630 | import asyncio |
1629 | 1631 | import threading |
1630 | | - |
| 1632 | + |
1631 | 1633 | def run_server(): |
1632 | 1634 | loop = asyncio.new_event_loop() |
1633 | 1635 | asyncio.set_event_loop(loop) |
1634 | 1636 | uvicorn.run(app, host="127.0.0.1", port=port, log_level="error") |
1635 | | - |
| 1637 | + |
1636 | 1638 | server_thread = threading.Thread(target=run_server, daemon=True) |
1637 | 1639 | server_thread.start() |
1638 | | - |
| 1640 | + |
1639 | 1641 | try: |
1640 | 1642 | # Give server time to start |
1641 | 1643 | await asyncio.sleep(0.5) |
1642 | | - |
| 1644 | + |
1643 | 1645 | server_url = f"http://127.0.0.1:{port}" |
1644 | | - |
| 1646 | + |
1645 | 1647 | # Test that the client raises McpError when server returns HTML |
1646 | | - with pytest.raises(ExceptionGroup) as exc_info: |
| 1648 | + with pytest.raises(ExceptionGroup) as exc_info: # type: ignore |
1647 | 1649 | async with streamablehttp_client(f"{server_url}/mcp") as ( |
1648 | 1650 | read_stream, |
1649 | 1651 | write_stream, |
1650 | 1652 | _, |
1651 | 1653 | ): |
1652 | 1654 | async with ClientSession(read_stream, write_stream) as session: |
1653 | 1655 | await session.initialize() |
1654 | | - |
| 1656 | + |
1655 | 1657 | # Extract the McpError from the ExceptionGroup (handle nested groups) |
1656 | 1658 | mcp_error = None |
1657 | | - |
1658 | | - def find_mcp_error(exc_group): |
1659 | | - for exc in exc_group.exceptions: |
| 1659 | + |
| 1660 | + def find_mcp_error(exc_group: ExceptionGroup) -> McpError | None: # type: ignore |
| 1661 | + for exc in exc_group.exceptions: # type: ignore |
1660 | 1662 | if isinstance(exc, McpError): |
1661 | 1663 | return exc |
1662 | | - elif isinstance(exc, ExceptionGroup): |
| 1664 | + elif isinstance(exc, ExceptionGroup): # type: ignore |
1663 | 1665 | result = find_mcp_error(exc) |
1664 | 1666 | if result: |
1665 | 1667 | return result |
1666 | 1668 | return None |
1667 | | - |
| 1669 | + |
1668 | 1670 | mcp_error = find_mcp_error(exc_info.value) |
1669 | | - |
| 1671 | + |
1670 | 1672 | assert mcp_error is not None, "Expected McpError in ExceptionGroup hierarchy" |
1671 | 1673 | assert "Unexpected content type" in str(mcp_error) |
1672 | 1674 | assert "text/html" in str(mcp_error) |
1673 | | - |
| 1675 | + |
1674 | 1676 | finally: |
1675 | 1677 | # Server thread will be cleaned up automatically as daemon |
1676 | 1678 | pass |
0 commit comments