Skip to content

Commit f1c0e22

Browse files
committed
Drain stdio responses after redirected stdin EOF
1 parent e8e6484 commit f1c0e22

5 files changed

Lines changed: 121 additions & 12 deletions

File tree

src/mcp/server/lowlevel/server.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ async def main():
3939
import contextvars
4040
import logging
4141
import warnings
42-
from collections.abc import AsyncIterator, Awaitable, Callable
42+
from collections.abc import AsyncGenerator, Awaitable, Callable
4343
from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager
4444
from importlib.metadata import version as importlib_version
4545
from typing import Any, Generic, cast
@@ -85,7 +85,7 @@ def __init__(self, prompts_changed: bool = False, resources_changed: bool = Fals
8585

8686

8787
@asynccontextmanager
88-
async def lifespan(_: Server[LifespanResultT]) -> AsyncIterator[dict[str, Any]]:
88+
async def lifespan(_: Server[LifespanResultT]) -> AsyncGenerator[dict[str, Any]]:
8989
"""Default lifespan context manager that does nothing.
9090
9191
Returns:
@@ -371,6 +371,10 @@ async def run(
371371
# the initialization lifecycle, but can do so with any available node
372372
# rather than requiring initialization for each connection.
373373
stateless: bool = False,
374+
# When True, stdin/file-style EOF is treated as "no more inbound messages";
375+
# accepted request handlers are allowed to finish and flush their responses.
376+
drain_in_flight_on_read_eof: bool = False,
377+
drain_in_flight_on_read_eof_timeout_seconds: float = 5.0,
374378
):
375379
async with AsyncExitStack() as stack:
376380
lifespan_context = await stack.enter_async_context(self.lifespan(self))
@@ -380,6 +384,7 @@ async def run(
380384
write_stream,
381385
initialization_options,
382386
stateless=stateless,
387+
close_write_stream_on_read_end=not drain_in_flight_on_read_eof,
383388
)
384389
)
385390

@@ -408,11 +413,14 @@ async def run(
408413
raise_exceptions,
409414
)
410415
finally:
411-
# Transport closed: cancel in-flight handlers. Without this the
412-
# TG join waits for them, and when they eventually try to
413-
# respond they hit a closed write stream (the session's
414-
# _receive_loop closed it when the read stream ended).
415-
tg.cancel_scope.cancel()
416+
if not drain_in_flight_on_read_eof:
417+
# Transport closed: cancel in-flight handlers. Without this the
418+
# TG join waits for them, and when they eventually try to
419+
# respond they hit a closed write stream (the session's
420+
# _receive_loop closed it when the read stream ended).
421+
tg.cancel_scope.cancel()
422+
else:
423+
tg.cancel_scope.deadline = anyio.current_time() + drain_in_flight_on_read_eof_timeout_seconds
416424

417425
async def _handle_message(
418426
self,

src/mcp/server/mcpserver/server.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import inspect
77
import json
88
import re
9-
from collections.abc import AsyncIterator, Awaitable, Callable, Iterable, Sequence
9+
from collections.abc import AsyncGenerator, Awaitable, Callable, Iterable, Sequence
1010
from contextlib import AbstractAsyncContextManager, asynccontextmanager
1111
from typing import Any, Generic, Literal, TypeVar, overload
1212

@@ -74,6 +74,8 @@
7474

7575
logger = get_logger(__name__)
7676

77+
STDIO_EOF_DRAIN_TIMEOUT_SECONDS = 5.0
78+
7779
_CallableT = TypeVar("_CallableT", bound=Callable[..., Any])
7880

7981

@@ -119,7 +121,7 @@ def lifespan_wrapper(
119121
lifespan: Callable[[MCPServer[LifespanResultT]], AbstractAsyncContextManager[LifespanResultT]],
120122
) -> Callable[[Server[LifespanResultT]], AbstractAsyncContextManager[LifespanResultT]]:
121123
@asynccontextmanager
122-
async def wrap(_: Server[LifespanResultT]) -> AsyncIterator[LifespanResultT]:
124+
async def wrap(_: Server[LifespanResultT]) -> AsyncGenerator[LifespanResultT]:
123125
async with lifespan(app) as context:
124126
yield context
125127

@@ -852,6 +854,8 @@ async def run_stdio_async(self) -> None:
852854
read_stream,
853855
write_stream,
854856
self._lowlevel_server.create_initialization_options(),
857+
drain_in_flight_on_read_eof=True,
858+
drain_in_flight_on_read_eof_timeout_seconds=STDIO_EOF_DRAIN_TIMEOUT_SECONDS,
855859
)
856860

857861
async def run_sse_async( # pragma: no cover

src/mcp/server/session.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,8 +84,13 @@ def __init__(
8484
write_stream: WriteStream[SessionMessage],
8585
init_options: InitializationOptions,
8686
stateless: bool = False,
87+
close_write_stream_on_read_end: bool = True,
8788
) -> None:
88-
super().__init__(read_stream, write_stream)
89+
super().__init__(
90+
read_stream,
91+
write_stream,
92+
close_write_stream_on_read_end=close_write_stream_on_read_end,
93+
)
8994
self._stateless = stateless
9095
self._initialization_state = (
9196
InitializationState.Initialized if stateless else InitializationState.NotInitialized

src/mcp/shared/session.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -191,9 +191,11 @@ def __init__(
191191
write_stream: WriteStream[SessionMessage],
192192
# If none, reading will never time out
193193
read_timeout_seconds: float | None = None,
194+
close_write_stream_on_read_end: bool = True,
194195
) -> None:
195196
self._read_stream = read_stream
196197
self._write_stream = write_stream
198+
self._close_write_stream_on_read_end = close_write_stream_on_read_end
197199
self._response_streams = {}
198200
self._request_id = 0
199201
self._session_read_timeout_seconds = read_timeout_seconds
@@ -234,7 +236,11 @@ async def __aexit__(
234236
# would be very surprising behavior), so make sure to cancel the tasks
235237
# in the task group.
236238
self._task_group.cancel_scope.cancel()
237-
return await self._task_group.__aexit__(exc_type, exc_val, exc_tb)
239+
try:
240+
return await self._task_group.__aexit__(exc_type, exc_val, exc_tb)
241+
finally:
242+
if not self._close_write_stream_on_read_end:
243+
await self._write_stream.aclose()
238244

239245
async def send_request(
240246
self,
@@ -349,7 +355,10 @@ def _receive_notification_adapter(self) -> TypeAdapter[ReceiveNotificationT]:
349355
raise NotImplementedError
350356

351357
async def _receive_loop(self) -> None:
352-
async with self._read_stream, self._write_stream:
358+
async with AsyncExitStack() as stack:
359+
await stack.enter_async_context(self._read_stream)
360+
if self._close_write_stream_on_read_end:
361+
await stack.enter_async_context(self._write_stream)
353362
try:
354363

355364
async def _handle_session_message(message: SessionMessage) -> None:
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
import json
2+
import subprocess
3+
import sys
4+
import textwrap
5+
from pathlib import Path
6+
7+
8+
def test_stdio_redirected_stdin_eof_drains_accepted_tool_responses(tmp_path: Path) -> None:
9+
server_py = tmp_path / "server.py"
10+
payload_jsonl = tmp_path / "payload.jsonl"
11+
response_jsonl = tmp_path / "response.jsonl"
12+
13+
server_py.write_text(
14+
textwrap.dedent(
15+
"""
16+
import asyncio
17+
18+
from mcp.server.mcpserver import MCPServer
19+
20+
mcp = MCPServer("repro")
21+
22+
@mcp.tool()
23+
async def slow_echo(text: str) -> str:
24+
await asyncio.sleep(0.05)
25+
return text
26+
27+
if __name__ == "__main__":
28+
mcp.run(transport="stdio")
29+
"""
30+
),
31+
encoding="utf-8",
32+
)
33+
payload_jsonl.write_text(
34+
"\n".join(
35+
[
36+
json.dumps(
37+
{
38+
"jsonrpc": "2.0",
39+
"id": 0,
40+
"method": "initialize",
41+
"params": {
42+
"protocolVersion": "2024-11-05",
43+
"capabilities": {},
44+
"clientInfo": {"name": "repro", "version": "0.1"},
45+
},
46+
}
47+
),
48+
json.dumps({"jsonrpc": "2.0", "method": "notifications/initialized", "params": {}}),
49+
json.dumps(
50+
{
51+
"jsonrpc": "2.0",
52+
"id": 1,
53+
"method": "tools/call",
54+
"params": {"name": "slow_echo", "arguments": {"text": "first"}},
55+
}
56+
),
57+
json.dumps(
58+
{
59+
"jsonrpc": "2.0",
60+
"id": 2,
61+
"method": "tools/call",
62+
"params": {"name": "slow_echo", "arguments": {"text": "second"}},
63+
}
64+
),
65+
]
66+
)
67+
+ "\n",
68+
encoding="utf-8",
69+
)
70+
71+
with payload_jsonl.open("rb") as stdin, response_jsonl.open("wb") as stdout:
72+
completed = subprocess.run(
73+
[sys.executable, str(server_py)],
74+
stdin=stdin,
75+
stdout=stdout,
76+
stderr=subprocess.PIPE,
77+
timeout=10,
78+
check=False,
79+
)
80+
81+
assert completed.returncode == 0, completed.stderr.decode("utf-8", errors="replace")
82+
response_ids = {json.loads(line)["id"] for line in response_jsonl.read_text(encoding="utf-8").splitlines()}
83+
assert {0, 1, 2}.issubset(response_ids)

0 commit comments

Comments
 (0)