|
1 | 1 | """Transport-parametrized connection factories for the interaction suite. |
2 | 2 |
|
3 | 3 | The `connect` fixture (see conftest.py) hands tests one of these factories so the same test body |
4 | | -runs over the in-memory transport and over streamable HTTP without naming either: the factory is a |
5 | | -drop-in replacement for constructing `Client(server, ...)` and yields the connected client. The |
6 | | -streamable HTTP factory drives the server's real Starlette app through the in-process streaming |
7 | | -bridge, so the full HTTP framing layer (session ids, SSE encoding, session management) runs with |
8 | | -no sockets, threads, or subprocesses. |
| 4 | +runs over each transport without naming any of them: the factory is a drop-in replacement for |
| 5 | +constructing `Client(server, ...)` and yields the connected client. The HTTP factories drive the |
| 6 | +server's real Starlette app through the in-process streaming bridge, so the full transport layer |
| 7 | +(session ids, SSE encoding, session management) runs with no sockets, threads, or subprocesses. |
9 | 8 | """ |
10 | 9 |
|
| 10 | +import gc |
| 11 | +import warnings |
11 | 12 | from collections.abc import AsyncIterator |
12 | 13 | from contextlib import AbstractAsyncContextManager, asynccontextmanager |
13 | 14 | from typing import Protocol |
14 | 15 |
|
15 | 16 | import httpx |
| 17 | +from starlette.applications import Starlette |
| 18 | +from starlette.requests import Request |
| 19 | +from starlette.responses import Response |
| 20 | +from starlette.routing import Mount, Route |
16 | 21 |
|
17 | 22 | from mcp.client.client import Client |
18 | 23 | from mcp.client.session import ElicitationFnT, ListRootsFnT, LoggingFnT, MessageHandlerFnT, SamplingFnT |
| 24 | +from mcp.client.sse import sse_client |
19 | 25 | from mcp.client.streamable_http import streamable_http_client |
20 | 26 | from mcp.server import Server |
21 | 27 | from mcp.server.mcpserver import MCPServer |
| 28 | +from mcp.server.sse import SseServerTransport |
22 | 29 | from mcp.server.transport_security import TransportSecuritySettings |
23 | 30 | from mcp.types import Implementation |
24 | 31 | from tests.interaction.transports._bridge import StreamingASGITransport |
@@ -115,3 +122,84 @@ async def connect_over_streamable_http( |
115 | 122 | elicitation_callback=elicitation_callback, |
116 | 123 | ) as client: |
117 | 124 | yield client |
| 125 | + |
| 126 | + |
| 127 | +def build_sse_app(server: Server | MCPServer) -> tuple[Starlette, SseServerTransport]: |
| 128 | + """Mount a server on a Starlette app exposing the legacy SSE transport at /sse and /messages/. |
| 129 | +
|
| 130 | + `MCPServer.sse_app()` exists but does not expose the underlying `SseServerTransport`, which |
| 131 | + the SSE-specific tests need; building the app explicitly here gives both server flavours the |
| 132 | + same routing while keeping that handle. |
| 133 | + """ |
| 134 | + sse = SseServerTransport( |
| 135 | + "/messages/", security_settings=TransportSecuritySettings(enable_dns_rebinding_protection=False) |
| 136 | + ) |
| 137 | + lowlevel = server._lowlevel_server if isinstance(server, MCPServer) else server |
| 138 | + |
| 139 | + async def handle_sse(request: Request) -> Response: |
| 140 | + async with sse.connect_sse(request.scope, request.receive, request._send) as (read, write): |
| 141 | + await lowlevel.run(read, write, lowlevel.create_initialization_options()) |
| 142 | + return Response() |
| 143 | + |
| 144 | + app = Starlette( |
| 145 | + routes=[ |
| 146 | + Route("/sse", endpoint=handle_sse, methods=["GET"]), |
| 147 | + Mount("/messages/", app=sse.handle_post_message), |
| 148 | + ], |
| 149 | + ) |
| 150 | + return app, sse |
| 151 | + |
| 152 | + |
| 153 | +@asynccontextmanager |
| 154 | +async def connect_over_sse( |
| 155 | + server: Server | MCPServer, |
| 156 | + *, |
| 157 | + read_timeout_seconds: float | None = None, |
| 158 | + sampling_callback: SamplingFnT | None = None, |
| 159 | + list_roots_callback: ListRootsFnT | None = None, |
| 160 | + logging_callback: LoggingFnT | None = None, |
| 161 | + message_handler: MessageHandlerFnT | None = None, |
| 162 | + client_info: Implementation | None = None, |
| 163 | + elicitation_callback: ElicitationFnT | None = None, |
| 164 | +) -> AsyncIterator[Client]: |
| 165 | + """Yield a Client connected to the server's legacy SSE transport, entirely in process.""" |
| 166 | + app, _ = build_sse_app(server) |
| 167 | + |
| 168 | + def httpx_client_factory( |
| 169 | + headers: dict[str, str] | None = None, |
| 170 | + timeout: httpx.Timeout | None = None, |
| 171 | + auth: httpx.Auth | None = None, |
| 172 | + ) -> httpx.AsyncClient: |
| 173 | + # The SSE server transport's connect_sse runs the entire MCP session inside the GET |
| 174 | + # request and only releases its streams after that request observes a disconnect, so the |
| 175 | + # bridge must let the application drain rather than cancelling at close. |
| 176 | + return httpx.AsyncClient( |
| 177 | + transport=StreamingASGITransport(app, cancel_on_close=False), |
| 178 | + base_url=_BASE_URL, |
| 179 | + headers=headers, |
| 180 | + timeout=timeout, |
| 181 | + auth=auth, |
| 182 | + ) |
| 183 | + |
| 184 | + transport = sse_client(f"{_BASE_URL}/sse", httpx_client_factory=httpx_client_factory) |
| 185 | + try: |
| 186 | + async with Client( |
| 187 | + transport, |
| 188 | + read_timeout_seconds=read_timeout_seconds, |
| 189 | + sampling_callback=sampling_callback, |
| 190 | + list_roots_callback=list_roots_callback, |
| 191 | + logging_callback=logging_callback, |
| 192 | + message_handler=message_handler, |
| 193 | + client_info=client_info, |
| 194 | + elicitation_callback=elicitation_callback, |
| 195 | + ) as client: |
| 196 | + yield client |
| 197 | + finally: |
| 198 | + # SseServerTransport.connect_sse hands its internal SSE-chunk receive stream to |
| 199 | + # sse_starlette's EventSourceResponse, which never closes it when its task group is |
| 200 | + # cancelled on disconnect (see notes/findings.md). Collect the orphan here so its |
| 201 | + # ResourceWarning fires deterministically inside this fixture instead of at an |
| 202 | + # arbitrary later GC. |
| 203 | + with warnings.catch_warnings(): |
| 204 | + warnings.simplefilter("ignore", ResourceWarning) |
| 205 | + gc.collect() |
0 commit comments