Skip to content

Commit b485b6a

Browse files
committed
fix(stdio): drain responses after stdin EOF
1 parent e8e6484 commit b485b6a

5 files changed

Lines changed: 140 additions & 46 deletions

File tree

src/mcp/server/lowlevel/server.py

Lines changed: 16 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -390,29 +390,22 @@ async def run(
390390
await stack.enter_async_context(task_support.run())
391391

392392
async with anyio.create_task_group() as tg:
393-
try:
394-
async for message in session.incoming_messages:
395-
logger.debug("Received message: %s", message)
396-
397-
if isinstance(message, RequestResponder) and message.context is not None:
398-
context = message.context
399-
else:
400-
context = contextvars.copy_context()
401-
402-
context.run(
403-
tg.start_soon,
404-
self._handle_message,
405-
message,
406-
session,
407-
lifespan_context,
408-
raise_exceptions,
409-
)
410-
finally:
411-
# Transport closed: cancel in-flight handlers. Without this the
412-
# TG join waits for them, and when they eventually try to
413-
# respond they hit a closed write stream (the session's
414-
# _receive_loop closed it when the read stream ended).
415-
tg.cancel_scope.cancel()
393+
async for message in session.incoming_messages:
394+
logger.debug("Received message: %s", message)
395+
396+
if isinstance(message, RequestResponder) and message.context is not None:
397+
context = message.context
398+
else:
399+
context = contextvars.copy_context()
400+
401+
context.run(
402+
tg.start_soon,
403+
self._handle_message,
404+
message,
405+
session,
406+
lifespan_context,
407+
raise_exceptions,
408+
)
416409

417410
async def _handle_message(
418411
self,

src/mcp/server/session.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ def __init__(
8585
init_options: InitializationOptions,
8686
stateless: bool = False,
8787
) -> None:
88-
super().__init__(read_stream, write_stream)
88+
super().__init__(read_stream, write_stream, close_write_stream_on_read_close=False)
8989
self._stateless = stateless
9090
self._initialization_state = (
9191
InitializationState.Initialized if stateless else InitializationState.NotInitialized

src/mcp/shared/session.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,16 +191,26 @@ def __init__(
191191
write_stream: WriteStream[SessionMessage],
192192
# If none, reading will never time out
193193
read_timeout_seconds: float | None = None,
194+
# When True, closing/EOF on the read stream closes the write stream too.
195+
#
196+
# For full-duplex transports (e.g., stdio), an input EOF can be a
197+
# half-close: the peer is done sending requests but still expects
198+
# responses on the output stream. In that case, callers may opt out so
199+
# in-flight handlers can drain their responses before shutdown.
200+
close_write_stream_on_read_close: bool = True,
194201
) -> None:
195202
self._read_stream = read_stream
196203
self._write_stream = write_stream
197204
self._response_streams = {}
198205
self._request_id = 0
199206
self._session_read_timeout_seconds = read_timeout_seconds
207+
self._close_write_stream_on_read_close = close_write_stream_on_read_close
200208
self._in_flight = {}
201209
self._progress_callbacks = {}
202210
self._response_routers = []
203211
self._exit_stack = AsyncExitStack()
212+
self._exit_stack.push_async_callback(self._read_stream.aclose)
213+
self._exit_stack.push_async_callback(self._write_stream.aclose)
204214

205215
def add_response_router(self, router: ResponseRouter) -> None:
206216
"""Register a response router to handle responses for non-standard requests.
@@ -349,7 +359,10 @@ def _receive_notification_adapter(self) -> TypeAdapter[ReceiveNotificationT]:
349359
raise NotImplementedError
350360

351361
async def _receive_loop(self) -> None:
352-
async with self._read_stream, self._write_stream:
362+
async with AsyncExitStack() as stack:
363+
await stack.enter_async_context(self._read_stream)
364+
if self._close_write_stream_on_read_close:
365+
await stack.enter_async_context(self._write_stream)
353366
try:
354367

355368
async def _handle_session_message(message: SessionMessage) -> None:

tests/server/test_cancel_handling.py

Lines changed: 14 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
InitializeRequestParams,
2020
JSONRPCNotification,
2121
JSONRPCRequest,
22+
JSONRPCResponse,
2223
ListToolsResult,
2324
PaginatedRequestParams,
2425
TextContent,
@@ -100,29 +101,18 @@ async def first_request():
100101

101102

102103
@pytest.mark.anyio
103-
async def test_server_cancels_in_flight_handlers_on_transport_close():
104-
"""When the transport closes mid-request, server.run() must cancel in-flight
105-
handlers rather than join on them.
106-
107-
Without the cancel, the task group waits for the handler, which then tries
108-
to respond through a write stream that _receive_loop already closed,
109-
raising ClosedResourceError and crashing server.run() with exit code 1.
110-
111-
This drives server.run() with raw memory streams because InMemoryTransport
112-
wraps it in its own finally-cancel (_memory.py) which masks the bug.
113-
"""
104+
async def test_server_drains_in_flight_handlers_on_transport_read_eof():
105+
"""When the transport's read side hits EOF (e.g., stdio stdin closes), the
106+
server must drain already-started handlers so their responses reach the
107+
peer via the still-open write side."""
114108
handler_started = anyio.Event()
115-
handler_cancelled = anyio.Event()
109+
handler_allowed_to_finish = anyio.Event()
116110
server_run_returned = anyio.Event()
117111

118112
async def handle_call_tool(ctx: ServerRequestContext, params: CallToolRequestParams) -> CallToolResult:
119113
handler_started.set()
120-
try:
121-
await anyio.sleep_forever()
122-
finally:
123-
handler_cancelled.set()
124-
# unreachable: sleep_forever only exits via cancellation
125-
raise AssertionError # pragma: no cover
114+
await handler_allowed_to_finish.wait()
115+
return CallToolResult(content=[TextContent(type="text", text="ok")])
126116

127117
server = Server("test", on_call_tool=handle_call_tool)
128118

@@ -167,9 +157,13 @@ async def run_server():
167157
# handler gets CancelledError, server.run() returns.
168158
await to_server.aclose()
169159

170-
await server_run_returned.wait()
160+
handler_allowed_to_finish.set()
161+
162+
response = await from_server.receive()
163+
assert isinstance(response.message, JSONRPCResponse)
164+
assert response.message.id == 2
171165

172-
assert handler_cancelled.is_set()
166+
await server_run_returned.wait()
173167

174168

175169
@pytest.mark.anyio

tests/server/test_stdio.py

Lines changed: 95 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,27 @@
55
import anyio
66
import pytest
77

8+
from mcp.server import Server, ServerRequestContext
89
from mcp.server.stdio import stdio_server
910
from mcp.shared.message import SessionMessage
10-
from mcp.types import JSONRPCMessage, JSONRPCRequest, JSONRPCResponse, jsonrpc_message_adapter
11+
from mcp.types import (
12+
LATEST_PROTOCOL_VERSION,
13+
CallToolRequestParams,
14+
CallToolResult,
15+
ClientCapabilities,
16+
Implementation,
17+
InitializeRequestParams,
18+
JSONRPCError,
19+
JSONRPCMessage,
20+
JSONRPCNotification,
21+
JSONRPCRequest,
22+
JSONRPCResponse,
23+
ListToolsResult,
24+
PaginatedRequestParams,
25+
TextContent,
26+
Tool,
27+
jsonrpc_message_adapter,
28+
)
1129

1230

1331
@pytest.mark.anyio
@@ -92,3 +110,79 @@ async def test_stdio_server_invalid_utf8(monkeypatch: pytest.MonkeyPatch):
92110
second = await read_stream.receive()
93111
assert isinstance(second, SessionMessage)
94112
assert second.message == valid
113+
114+
115+
@pytest.mark.anyio
116+
async def test_stdio_server_drains_in_flight_responses_on_stdin_eof():
117+
"""When stdin reaches EOF (e.g., bash-redirected input), already-received
118+
requests must still be able to emit their responses on stdout."""
119+
stdin = io.StringIO()
120+
stdout = io.StringIO()
121+
122+
tool_started_count = 0
123+
both_tools_started = anyio.Event()
124+
allow_tools_to_finish = anyio.Event()
125+
126+
async def handle_list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult:
127+
return ListToolsResult(tools=[Tool(name="slow", description="test", input_schema={})])
128+
129+
async def handle_call_tool(ctx: ServerRequestContext, params: CallToolRequestParams) -> CallToolResult:
130+
nonlocal tool_started_count
131+
tool_started_count += 1
132+
if tool_started_count == 2:
133+
both_tools_started.set()
134+
await allow_tools_to_finish.wait()
135+
return CallToolResult(content=[TextContent(type="text", text="ok")])
136+
137+
server = Server("test", on_list_tools=handle_list_tools, on_call_tool=handle_call_tool)
138+
139+
init_req = JSONRPCRequest(
140+
jsonrpc="2.0",
141+
id=0,
142+
method="initialize",
143+
params=InitializeRequestParams(
144+
protocol_version=LATEST_PROTOCOL_VERSION,
145+
capabilities=ClientCapabilities(),
146+
client_info=Implementation(name="test", version="1.0"),
147+
).model_dump(by_alias=True, mode="json", exclude_none=True),
148+
)
149+
initialized = JSONRPCNotification(jsonrpc="2.0", method="notifications/initialized")
150+
call_1 = JSONRPCRequest(
151+
jsonrpc="2.0",
152+
id=1,
153+
method="tools/call",
154+
params=CallToolRequestParams(name="slow", arguments={}).model_dump(by_alias=True, mode="json"),
155+
)
156+
call_2 = JSONRPCRequest(
157+
jsonrpc="2.0",
158+
id=2,
159+
method="tools/call",
160+
params=CallToolRequestParams(name="slow", arguments={}).model_dump(by_alias=True, mode="json"),
161+
)
162+
163+
for message in (init_req, initialized, call_1, call_2):
164+
stdin.write(message.model_dump_json(by_alias=True, exclude_none=True) + "\n")
165+
stdin.seek(0)
166+
167+
async with stdio_server(stdin=anyio.AsyncFile(stdin), stdout=anyio.AsyncFile(stdout)) as (
168+
read_stream,
169+
write_stream,
170+
):
171+
with anyio.fail_after(5):
172+
async with anyio.create_task_group() as tg:
173+
tg.start_soon(server.run, read_stream, write_stream, server.create_initialization_options())
174+
await both_tools_started.wait()
175+
allow_tools_to_finish.set()
176+
177+
stdout.seek(0)
178+
ids: set[int | str] = set()
179+
for line in stdout.readlines():
180+
line = line.strip()
181+
if not line:
182+
continue
183+
message = jsonrpc_message_adapter.validate_json(line)
184+
if isinstance(message, JSONRPCResponse | JSONRPCError):
185+
assert message.id is not None
186+
ids.add(message.id)
187+
assert 1 in ids
188+
assert 2 in ids

0 commit comments

Comments
 (0)