Skip to content

Commit d64f525

Browse files
committed
test: run the interaction suite over both in-memory and streamable HTTP transports
1 parent c1eab9d commit d64f525

24 files changed

Lines changed: 395 additions & 351 deletions

src/mcp/server/streamable_http.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -374,7 +374,7 @@ async def handle_request(self, scope: Scope, receive: Receive, send: Send) -> No
374374
await error_response(scope, receive, send)
375375
return
376376

377-
if self._terminated: # pragma: no cover
377+
if self._terminated: # pragma: lax no cover
378378
# If the session has been terminated, return 404 Not Found
379379
response = self._create_error_response(
380380
"Not Found: Session has been terminated",
@@ -635,7 +635,7 @@ async def sse_writer(): # pragma: lax no cover
635635
finally:
636636
await sse_stream_reader.aclose()
637637

638-
except Exception as err: # pragma: no cover
638+
except Exception as err: # pragma: lax no cover
639639
logger.exception("Error handling POST request")
640640
response = self._create_error_response(
641641
f"Error handling POST request: {err}",

tests/interaction/README.md

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,17 +37,31 @@ The whole suite is in-memory and event-driven; it runs in about a second.
3737
tests/interaction/
3838
_requirements.py the requirements manifest (see below)
3939
_helpers.py shared type aliases + the wire-recording transport
40+
_connect.py the transport-parametrized connection factories
41+
conftest.py the connect fixture (the transport matrix)
4042
test_coverage.py enforces the manifest ↔ test contract
4143
lowlevel/ one file per feature area, against the low-level Server
4244
mcpserver/ the same feature areas in MCPServer's natural idiom
43-
transports/ a smoke subset over the streamable HTTP framing
45+
transports/ behaviour specific to one transport (modes, streams, framing)
4446
```
4547

4648
The two server APIs produce genuinely different wire output for the same conceptual feature
4749
(`MCPServer` generates schemas, converts exceptions to `isError` results, attaches structured
4850
content), so they get parallel directories with mirrored file names rather than one parametrized
4951
test body — each directory pins its flavour's true output exactly.
5052

53+
### The transport matrix
54+
55+
Transport-agnostic tests take the `connect` fixture instead of constructing `Client(server)`
56+
directly, and therefore run once per transport: over the in-memory transport and over the
57+
server's real streamable HTTP app driven in process through the streaming bridge. A test connects
58+
the same way in either case — `async with connect(server, ...) as client:` — and asserts the same
59+
output, because the transport is not supposed to change observable behaviour. Tests that are tied
60+
to one transport do not use the fixture: the wire-recording tests (their seam is the in-memory
61+
stream pair), the bare-`ClientSession` lifecycle tests, the real-clock timeout tests (the timeout
62+
machinery is transport-independent and must not race transport latency), and everything under
63+
`transports/`, which pins behaviour only observable on that transport.
64+
5165
## The requirements manifest
5266

5367
`_requirements.py` maps every behaviour the suite covers to the reason it must hold:

tests/interaction/_connect.py

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
"""Transport-parametrized connection factories for the interaction suite.
2+
3+
The `connect` fixture (see conftest.py) hands tests one of these factories so the same test body
4+
runs over the in-memory transport and over streamable HTTP without naming either: the factory is a
5+
drop-in replacement for constructing `Client(server, ...)` and yields the connected client. The
6+
streamable HTTP factory drives the server's real Starlette app through the in-process streaming
7+
bridge, so the full HTTP framing layer (session ids, SSE encoding, session management) runs with
8+
no sockets, threads, or subprocesses.
9+
"""
10+
11+
from collections.abc import AsyncIterator
12+
from contextlib import AbstractAsyncContextManager, asynccontextmanager
13+
from typing import Protocol
14+
15+
import httpx
16+
17+
from mcp.client.client import Client
18+
from mcp.client.session import ElicitationFnT, ListRootsFnT, LoggingFnT, MessageHandlerFnT, SamplingFnT
19+
from mcp.client.streamable_http import streamable_http_client
20+
from mcp.server import Server
21+
from mcp.server.mcpserver import MCPServer
22+
from mcp.server.transport_security import TransportSecuritySettings
23+
from mcp.types import Implementation
24+
from tests.interaction.transports._bridge import StreamingASGITransport
25+
26+
# The in-process app is mounted at this origin purely so URLs are well-formed; nothing listens here.
27+
_BASE_URL = "http://127.0.0.1:8000"
28+
29+
30+
class Connect(Protocol):
31+
"""Connect a Client to a server over the transport selected by the `connect` fixture.
32+
33+
Accepts the same keyword arguments as `Client` and yields the connected client.
34+
"""
35+
36+
def __call__(
37+
self,
38+
server: Server | MCPServer,
39+
*,
40+
read_timeout_seconds: float | None = None,
41+
sampling_callback: SamplingFnT | None = None,
42+
list_roots_callback: ListRootsFnT | None = None,
43+
logging_callback: LoggingFnT | None = None,
44+
message_handler: MessageHandlerFnT | None = None,
45+
client_info: Implementation | None = None,
46+
elicitation_callback: ElicitationFnT | None = None,
47+
) -> AbstractAsyncContextManager[Client]: ...
48+
49+
50+
@asynccontextmanager
51+
async def connect_in_memory(
52+
server: Server | MCPServer,
53+
*,
54+
read_timeout_seconds: float | None = None,
55+
sampling_callback: SamplingFnT | None = None,
56+
list_roots_callback: ListRootsFnT | None = None,
57+
logging_callback: LoggingFnT | None = None,
58+
message_handler: MessageHandlerFnT | None = None,
59+
client_info: Implementation | None = None,
60+
elicitation_callback: ElicitationFnT | None = None,
61+
) -> AsyncIterator[Client]:
62+
"""Yield a Client connected to the server over the in-memory transport."""
63+
async with Client(
64+
server,
65+
read_timeout_seconds=read_timeout_seconds,
66+
sampling_callback=sampling_callback,
67+
list_roots_callback=list_roots_callback,
68+
logging_callback=logging_callback,
69+
message_handler=message_handler,
70+
client_info=client_info,
71+
elicitation_callback=elicitation_callback,
72+
) as client:
73+
yield client
74+
75+
76+
@asynccontextmanager
77+
async def connect_over_streamable_http(
78+
server: Server | MCPServer,
79+
*,
80+
stateless_http: bool = False,
81+
json_response: bool = False,
82+
read_timeout_seconds: float | None = None,
83+
sampling_callback: SamplingFnT | None = None,
84+
list_roots_callback: ListRootsFnT | None = None,
85+
logging_callback: LoggingFnT | None = None,
86+
message_handler: MessageHandlerFnT | None = None,
87+
client_info: Implementation | None = None,
88+
elicitation_callback: ElicitationFnT | None = None,
89+
) -> AsyncIterator[Client]:
90+
"""Yield a Client connected to the server's streamable HTTP app, entirely in process.
91+
92+
With the defaults this is the matrix leg (stateful sessions, SSE responses); the
93+
transport-specific tests pass `stateless_http` or `json_response` to select the other
94+
server modes.
95+
"""
96+
# DNS-rebinding protection validates Host/Origin headers against a real network attack that
97+
# cannot exist for an in-process ASGI app; leaving it on would also pull the origin-validation
98+
# branch (deliberately uncovered in src) into coverage.
99+
app = server.streamable_http_app(
100+
stateless_http=stateless_http,
101+
json_response=json_response,
102+
transport_security=TransportSecuritySettings(enable_dns_rebinding_protection=False),
103+
)
104+
async with server.session_manager.run():
105+
async with httpx.AsyncClient(transport=StreamingASGITransport(app), base_url=_BASE_URL) as http_client:
106+
transport = streamable_http_client(f"{_BASE_URL}/mcp", http_client=http_client)
107+
async with Client(
108+
transport,
109+
read_timeout_seconds=read_timeout_seconds,
110+
sampling_callback=sampling_callback,
111+
list_roots_callback=list_roots_callback,
112+
logging_callback=logging_callback,
113+
message_handler=message_handler,
114+
client_info=client_info,
115+
elicitation_callback=elicitation_callback,
116+
) as client:
117+
yield client

tests/interaction/conftest.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
"""Shared fixtures for the interaction suite."""
2+
3+
import pytest
4+
5+
from tests.interaction._connect import Connect, connect_in_memory, connect_over_streamable_http
6+
7+
_FACTORIES: dict[str, Connect] = {
8+
"in-memory": connect_in_memory,
9+
"streamable-http": connect_over_streamable_http,
10+
}
11+
12+
13+
@pytest.fixture(params=sorted(_FACTORIES))
14+
def connect(request: pytest.FixtureRequest) -> Connect:
15+
"""The transport-parametrized connection factory: a test using it runs once per transport.
16+
17+
Tests that are tied to one transport (the wire-recording tests, the bare-ClientSession tests,
18+
the transport-specific tests under transports/) do not use this fixture and connect directly.
19+
"""
20+
transport_name = request.param
21+
assert isinstance(transport_name, str)
22+
return _FACTORIES[transport_name]

tests/interaction/lowlevel/test_cancellation.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,17 +11,17 @@
1111
from inline_snapshot import snapshot
1212

1313
from mcp import MCPError, types
14-
from mcp.client.client import Client
1514
from mcp.server import Server, ServerRequestContext
1615
from mcp.types import CallToolResult, ErrorData, TextContent
16+
from tests.interaction._connect import Connect
1717
from tests.interaction._requirements import requirement
1818

1919
pytestmark = pytest.mark.anyio
2020

2121

2222
@requirement("protocol:cancel:in-flight")
2323
@requirement("protocol:cancel:handler-abort-propagates")
24-
async def test_cancellation_stops_in_flight_handler() -> None:
24+
async def test_cancellation_stops_in_flight_handler(connect: Connect) -> None:
2525
"""Cancelling an in-flight request interrupts its handler and fails the pending call.
2626
2727
The server answers the cancelled request with an error response (the spec says it should
@@ -47,7 +47,7 @@ async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestPara
4747

4848
server = Server("blocker", on_call_tool=call_tool)
4949

50-
async with Client(server) as client:
50+
async with connect(server) as client:
5151
with anyio.fail_after(5):
5252
async with anyio.create_task_group() as task_group:
5353

@@ -70,7 +70,7 @@ async def call_and_capture_error() -> None:
7070

7171

7272
@requirement("protocol:cancel:server-survives")
73-
async def test_session_serves_requests_after_cancellation() -> None:
73+
async def test_session_serves_requests_after_cancellation(connect: Connect) -> None:
7474
"""A request cancelled mid-flight does not poison the session: the next request succeeds."""
7575
started = anyio.Event()
7676
request_ids: list[types.RequestId] = []
@@ -96,7 +96,7 @@ async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestPara
9696

9797
server = Server("blocker", on_list_tools=list_tools, on_call_tool=call_tool)
9898

99-
async with Client(server) as client:
99+
async with connect(server) as client:
100100
with anyio.fail_after(5):
101101
async with anyio.create_task_group() as task_group:
102102

@@ -116,7 +116,7 @@ async def call_and_swallow_cancellation_error() -> None:
116116

117117

118118
@requirement("protocol:cancel:unknown-id-ignored")
119-
async def test_cancellation_for_unknown_request_is_ignored() -> None:
119+
async def test_cancellation_for_unknown_request_is_ignored(connect: Connect) -> None:
120120
"""A cancellation referencing a request id that is not in flight is ignored without error."""
121121

122122
async def list_tools(
@@ -130,7 +130,7 @@ async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestPara
130130

131131
server = Server("calm", on_list_tools=list_tools, on_call_tool=call_tool)
132132

133-
async with Client(server) as client:
133+
async with connect(server) as client:
134134
await client.session.send_notification(
135135
types.CancelledNotification(params=types.CancelledNotificationParams(request_id=9999))
136136
)

tests/interaction/lowlevel/test_completion.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
from inline_snapshot import snapshot
55

66
from mcp import MCPError, types
7-
from mcp.client.client import Client
87
from mcp.server import Server, ServerRequestContext
98
from mcp.types import (
109
METHOD_NOT_FOUND,
@@ -14,14 +13,15 @@
1413
PromptReference,
1514
ResourceTemplateReference,
1615
)
16+
from tests.interaction._connect import Connect
1717
from tests.interaction._requirements import requirement
1818

1919
pytestmark = pytest.mark.anyio
2020

2121

2222
@requirement("completion:prompt-arg")
2323
@requirement("completion:result-shape")
24-
async def test_complete_prompt_argument() -> None:
24+
async def test_complete_prompt_argument(connect: Connect) -> None:
2525
"""Completing a prompt argument delivers the ref, argument name, and current value to the handler.
2626
2727
The returned values are filtered by the argument's value, proving the value reached the handler.
@@ -37,7 +37,7 @@ async def completion(ctx: ServerRequestContext, params: types.CompleteRequestPar
3737

3838
server = Server("completer", on_completion=completion)
3939

40-
async with Client(server) as client:
40+
async with connect(server) as client:
4141
result = await client.complete(
4242
PromptReference(name="code_review"), argument={"name": "language", "value": "py"}
4343
)
@@ -48,7 +48,7 @@ async def completion(ctx: ServerRequestContext, params: types.CompleteRequestPar
4848

4949

5050
@requirement("completion:resource-template-arg")
51-
async def test_complete_resource_template_variable() -> None:
51+
async def test_complete_resource_template_variable(connect: Connect) -> None:
5252
"""Completing a URI template variable delivers the template URI and variable name to the handler."""
5353

5454
async def completion(ctx: ServerRequestContext, params: types.CompleteRequestParams) -> CompleteResult:
@@ -59,7 +59,7 @@ async def completion(ctx: ServerRequestContext, params: types.CompleteRequestPar
5959

6060
server = Server("completer", on_completion=completion)
6161

62-
async with Client(server) as client:
62+
async with connect(server) as client:
6363
result = await client.complete(
6464
ResourceTemplateReference(uri="github://repos/{owner}/{repo}"),
6565
argument={"name": "owner", "value": "model"},
@@ -69,7 +69,7 @@ async def completion(ctx: ServerRequestContext, params: types.CompleteRequestPar
6969

7070

7171
@requirement("completion:context-arguments")
72-
async def test_complete_receives_context_arguments() -> None:
72+
async def test_complete_receives_context_arguments(connect: Connect) -> None:
7373
"""Previously-resolved arguments passed as completion context reach the handler.
7474
7575
The returned value is derived from the context, proving it arrived.
@@ -83,7 +83,7 @@ async def completion(ctx: ServerRequestContext, params: types.CompleteRequestPar
8383

8484
server = Server("completer", on_completion=completion)
8585

86-
async with Client(server) as client:
86+
async with connect(server) as client:
8787
result = await client.complete(
8888
ResourceTemplateReference(uri="github://repos/{owner}/{repo}"),
8989
argument={"name": "repo", "value": ""},
@@ -95,11 +95,11 @@ async def completion(ctx: ServerRequestContext, params: types.CompleteRequestPar
9595

9696
@requirement("completion:complete:not-supported")
9797
@requirement("protocol:error:method-not-found")
98-
async def test_complete_without_handler_is_method_not_found() -> None:
98+
async def test_complete_without_handler_is_method_not_found(connect: Connect) -> None:
9999
"""A server with no completion handler advertises no completions capability and rejects the request."""
100100
server = Server("incomplete")
101101

102-
async with Client(server) as client:
102+
async with connect(server) as client:
103103
assert client.initialize_result.capabilities.completions is None
104104

105105
with pytest.raises(MCPError) as exc_info:

0 commit comments

Comments
 (0)