Skip to content

Commit a358aa4

Browse files
committed
test: add wire-level invariant tests via a recording transport
A RecordingTransport wrapper tees every message crossing the client's transport boundary so the suite can assert properties that are invisible to API callers: request ids are unique and never null, notifications are never answered, and exactly one initialized notification is sent between the initialize response and the first feature request.
1 parent d6c9b63 commit a358aa4

5 files changed

Lines changed: 248 additions & 14 deletions

File tree

tests/interaction/_helpers.py

Lines changed: 95 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,107 @@
1-
"""Shared type aliases for the interaction suite.
1+
"""Shared helpers for the interaction suite.
22
3-
Keep this module small: it exists only for types that every test would otherwise have to
4-
assemble from the SDK's internals to annotate a client callback. Server fixtures and assertion
5-
helpers belong in the test that uses them.
3+
Keep this module small: it exists only for (a) types that every test would otherwise have to
4+
assemble from the SDK's internals to annotate a client callback, and (b) the recording transport
5+
used by the wire-level tests. Server fixtures and assertion helpers belong in the test that uses
6+
them.
67
"""
78

9+
from types import TracebackType
10+
11+
import anyio
12+
from typing_extensions import Self
13+
14+
from mcp.client._transport import ReadStream, Transport, TransportStreams, WriteStream
15+
from mcp.shared.message import SessionMessage
816
from mcp.shared.session import RequestResponder
917
from mcp.types import ClientResult, ServerNotification, ServerRequest
1018

1119
# TODO: this union is the parameter type of every client message handler (MessageHandlerFnT),
1220
# but the SDK does not export a name for it -- writing a correctly-typed handler requires
1321
# importing RequestResponder from mcp.shared.session and assembling the union by hand. It
1422
# should be a named, exported alias next to MessageHandlerFnT (like ClientRequestContext is
15-
# for the request callbacks), at which point this module can be deleted.
23+
# for the request callbacks), at which point this alias can be deleted.
1624
IncomingMessage = RequestResponder[ServerRequest, ClientResult] | ServerNotification | Exception
1725
"""Everything a client message handler can receive."""
26+
27+
28+
class _RecordingReadStream:
29+
"""Delegates to a read stream, appending every received message to a log."""
30+
31+
def __init__(self, inner: ReadStream[SessionMessage | Exception], log: list[SessionMessage | Exception]) -> None:
32+
self._inner = inner
33+
self._log = log
34+
35+
async def receive(self) -> SessionMessage | Exception:
36+
item = await self._inner.receive()
37+
self._log.append(item)
38+
return item
39+
40+
async def aclose(self) -> None:
41+
await self._inner.aclose()
42+
43+
def __aiter__(self) -> Self:
44+
return self
45+
46+
async def __anext__(self) -> SessionMessage | Exception:
47+
try:
48+
return await self.receive()
49+
except anyio.EndOfStream:
50+
raise StopAsyncIteration from None
51+
52+
async def __aenter__(self) -> Self:
53+
return self
54+
55+
async def __aexit__(
56+
self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None
57+
) -> bool | None:
58+
await self.aclose()
59+
return None
60+
61+
62+
class _RecordingWriteStream:
63+
"""Delegates to a write stream, appending every sent message to a log."""
64+
65+
def __init__(self, inner: WriteStream[SessionMessage], log: list[SessionMessage]) -> None:
66+
self._inner = inner
67+
self._log = log
68+
69+
async def send(self, item: SessionMessage, /) -> None:
70+
self._log.append(item)
71+
await self._inner.send(item)
72+
73+
async def aclose(self) -> None:
74+
await self._inner.aclose()
75+
76+
async def __aenter__(self) -> Self:
77+
return self
78+
79+
async def __aexit__(
80+
self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None
81+
) -> bool | None:
82+
await self.aclose()
83+
return None
84+
85+
86+
class RecordingTransport:
87+
"""Wraps a Transport and records every message crossing the client's transport boundary.
88+
89+
`sent` holds everything the client wrote towards the server; `received` holds everything the
90+
server delivered to the client. The recording sits at the transport seam -- the exact payloads
91+
a real transport would serialise -- and never touches the session, so wire-level assertions
92+
written against it survive changes to the receive path.
93+
"""
94+
95+
def __init__(self, inner: Transport) -> None:
96+
self.inner = inner
97+
self.sent: list[SessionMessage] = []
98+
self.received: list[SessionMessage | Exception] = []
99+
100+
async def __aenter__(self) -> TransportStreams:
101+
read_stream, write_stream = await self.inner.__aenter__()
102+
return _RecordingReadStream(read_stream, self.received), _RecordingWriteStream(write_stream, self.sent)
103+
104+
async def __aexit__(
105+
self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None
106+
) -> bool | None:
107+
return await self.inner.__aexit__(exc_type, exc_val, exc_tb)

tests/interaction/_requirements.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,20 @@ class Requirement:
4949
# ═══════════════════════════════════════════════════════════════════════════
5050
# Protocol primitives
5151
# ═══════════════════════════════════════════════════════════════════════════
52+
"protocol:request-id:unique": Requirement(
53+
source=f"{SPEC_BASE_URL}/basic#requests",
54+
behavior=(
55+
"Every request sent on a session carries a unique, non-null integer id; ids are never reused "
56+
"within the session."
57+
),
58+
),
59+
"protocol:notifications:no-response": Requirement(
60+
source=f"{SPEC_BASE_URL}/basic#notifications",
61+
behavior=(
62+
"Notifications are never answered: every message the server delivers is either the response "
63+
"to a request the client sent or a notification carrying no id."
64+
),
65+
),
5266
"protocol:error:internal-error": Requirement(
5367
source=f"{SPEC_BASE_URL}/basic#responses",
5468
behavior="An unhandled exception in a request handler is returned to the caller as a JSON-RPC error.",
@@ -106,6 +120,13 @@ class Requirement:
106120
source=f"{SPEC_BASE_URL}/basic/lifecycle#initialization",
107121
behavior="A request sent before the initialization handshake completes is rejected with an error.",
108122
),
123+
"lifecycle:initialized-notification": Requirement(
124+
source=f"{SPEC_BASE_URL}/basic/lifecycle#initialization",
125+
behavior=(
126+
"The client sends exactly one initialized notification, after the initialize response and "
127+
"before its first feature request."
128+
),
129+
),
109130
# ═══════════════════════════════════════════════════════════════════════════
110131
# Cancellation
111132
# ═══════════════════════════════════════════════════════════════════════════

tests/interaction/lowlevel/test_timeouts.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -86,12 +86,7 @@ async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestPara
8686

8787
@requirement("timeouts:session-default")
8888
async def test_session_level_timeout_applies_to_every_request() -> None:
89-
"""A read timeout configured on the client applies to requests that do not set their own.
90-
91-
The session default also governs the initialize handshake, so this is the one test in the
92-
suite that needs a real (50ms) timeout: it must be long enough for the in-process handshake
93-
to complete and is then waited out in full by the blocked tool call.
94-
"""
89+
"""A read timeout configured on the client applies to requests that do not set their own."""
9590

9691
async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult:
9792
assert params.name == "block"
@@ -100,6 +95,12 @@ async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestPara
10095

10196
server = Server("blocker", on_call_tool=call_tool)
10297

98+
# The one real wall-clock wait in the suite, and it cannot be made effectively zero like the
99+
# per-request timeouts: a session-level timeout also governs the initialize handshake, so the
100+
# value must be long enough for the in-process handshake to complete before the blocked tool
101+
# call waits it out in full. 50ms buys a ~50x safety margin over the handshake's actual
102+
# latency; lowering it only erodes the margin against CI scheduler jitter without saving
103+
# anything perceptible.
103104
async with Client(server, read_timeout_seconds=0.05) as client:
104105
with pytest.raises(MCPError) as exc_info:
105106
await client.call_tool("block", {})
Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
"""Wire-level invariants observed at the client's transport boundary.
2+
3+
These behaviours are invisible to API callers -- they are properties of the raw JSON-RPC frames.
4+
The tests wrap the in-memory transport in a RecordingTransport, which tees every message crossing
5+
the transport seam into a list without touching the session, so the assertions hold for whatever
6+
the session implementation sends rather than for what its API returns.
7+
"""
8+
9+
import anyio
10+
import pytest
11+
from inline_snapshot import snapshot
12+
13+
from mcp import types
14+
from mcp.client._memory import InMemoryTransport
15+
from mcp.client.client import Client
16+
from mcp.server import Server, ServerRequestContext
17+
from mcp.shared.message import SessionMessage
18+
from mcp.types import CallToolResult, JSONRPCNotification, JSONRPCRequest, JSONRPCResponse, TextContent
19+
from tests.interaction._helpers import RecordingTransport, _RecordingReadStream
20+
from tests.interaction._requirements import requirement
21+
22+
pytestmark = pytest.mark.anyio
23+
24+
25+
def _echo_server() -> Server:
26+
"""A server with one echo tool, used by every test in this module."""
27+
28+
async def list_tools(
29+
ctx: ServerRequestContext, params: types.PaginatedRequestParams | None
30+
) -> types.ListToolsResult:
31+
return types.ListToolsResult(tools=[types.Tool(name="echo", input_schema={"type": "object"})])
32+
33+
async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult:
34+
assert params.name == "echo"
35+
return CallToolResult(content=[TextContent(text="ok")])
36+
37+
return Server("wire", on_list_tools=list_tools, on_call_tool=call_tool)
38+
39+
40+
@requirement("protocol:request-id:unique")
41+
async def test_request_ids_are_unique_and_never_null() -> None:
42+
"""Every request the client sends carries a distinct, non-null id.
43+
44+
The id sequence is pinned: sequential integers from zero, in send order, including the
45+
schema-cache refresh the client performs after the first successful tool call.
46+
"""
47+
recording = RecordingTransport(InMemoryTransport(_echo_server()))
48+
49+
async with Client(recording) as client:
50+
await client.list_tools()
51+
await client.call_tool("echo", {})
52+
await client.call_tool("echo", {})
53+
await client.send_ping()
54+
55+
sent = [message.message for message in recording.sent]
56+
request_ids = [message.id for message in sent if isinstance(message, JSONRPCRequest)]
57+
assert all(request_id is not None for request_id in request_ids)
58+
assert len(request_ids) == len(set(request_ids))
59+
# initialize, tools/list, tools/call, tools/call, ping -- the client does not issue a
60+
# schema-cache refresh here because the explicit tools/list already populated the cache.
61+
assert request_ids == snapshot([0, 1, 2, 3, 4])
62+
63+
64+
@requirement("protocol:notifications:no-response")
65+
async def test_notifications_are_never_answered() -> None:
66+
"""A notification produces no response: everything the server sends back answers a request.
67+
68+
The client sends two notifications (initialized and roots/list_changed) and several requests;
69+
the messages received from the server must be exactly one response per request, each carrying
70+
the id of the request it answers, and nothing else.
71+
"""
72+
recording = RecordingTransport(InMemoryTransport(_echo_server()))
73+
74+
async with Client(recording) as client:
75+
await client.send_roots_list_changed()
76+
await client.send_ping()
77+
78+
sent = [message.message for message in recording.sent]
79+
sent_request_ids = [message.id for message in sent if isinstance(message, JSONRPCRequest)]
80+
sent_notifications = [message for message in sent if isinstance(message, JSONRPCNotification)]
81+
received = [message.message for message in recording.received if isinstance(message, SessionMessage)]
82+
received_responses = [message for message in received if isinstance(message, JSONRPCResponse)]
83+
84+
assert len(sent_notifications) == 2 # notifications/initialized and notifications/roots/list_changed
85+
assert len(received_responses) == len(received) # nothing the server sent was anything but a response
86+
assert [message.id for message in received_responses] == sent_request_ids
87+
88+
89+
async def test_recording_read_stream_ends_iteration_when_the_sender_closes() -> None:
90+
"""The recording wrapper preserves the end-of-stream behaviour of the stream it wraps.
91+
92+
This exercises the helper itself rather than an interaction-model behaviour: a transport whose
93+
far end closes must end the client's receive loop cleanly, and the wrapper must not swallow or
94+
mistranslate that.
95+
"""
96+
send_stream, receive_stream = anyio.create_memory_object_stream[SessionMessage | Exception](1)
97+
log: list[SessionMessage | Exception] = []
98+
async with send_stream, _RecordingReadStream(receive_stream, log) as wrapped:
99+
await send_stream.aclose()
100+
items = [item async for item in wrapped]
101+
assert items == []
102+
assert log == []
103+
104+
105+
@requirement("lifecycle:initialized-notification")
106+
async def test_exactly_one_initialized_notification_is_sent_after_the_handshake() -> None:
107+
"""The client sends initialized exactly once, between the initialize response and its first request.
108+
109+
The full method sequence the client puts on the wire is pinned in send order.
110+
"""
111+
recording = RecordingTransport(InMemoryTransport(_echo_server()))
112+
113+
async with Client(recording) as client:
114+
await client.list_tools()
115+
116+
sent_methods = [
117+
message.message.method
118+
for message in recording.sent
119+
if isinstance(message.message, JSONRPCRequest | JSONRPCNotification)
120+
]
121+
assert sent_methods.count("notifications/initialized") == 1
122+
assert sent_methods == snapshot(["initialize", "notifications/initialized", "tools/list"])

tests/interaction/mcpserver/test_tools.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -173,9 +173,6 @@ def primes() -> list[int]:
173173
async def test_call_tool_invalid_arguments_become_error_result() -> None:
174174
"""Arguments that fail validation against the tool's signature are reported as an is_error
175175
result describing the failure, not as a protocol error.
176-
177-
The description is raw pydantic output (version-dependent and leaking the internal argument
178-
model name), so only the stable prefix is asserted rather than the full text.
179176
"""
180177
mcp = MCPServer("adder")
181178

@@ -187,6 +184,9 @@ def add(a: int, b: int) -> str:
187184
async with Client(mcp) as client:
188185
result = await client.call_tool("add", {"b": 3})
189186

187+
# The description is raw pydantic output -- it embeds a pydantic-version-specific
188+
# errors.pydantic.dev URL and the internal `addArguments` model name -- so only the stable
189+
# prefix is asserted; a full snapshot would break on every pydantic upgrade.
190190
assert result.is_error is True
191191
assert isinstance(result.content[0], TextContent)
192192
assert result.content[0].text.startswith("Error executing tool add: 1 validation error")

0 commit comments

Comments
 (0)