Skip to content

Commit 60fb7e9

Browse files
committed
fix(server): opt-in drain on read EOF
1 parent 184a84c commit 60fb7e9

5 files changed

Lines changed: 43 additions & 19 deletions

File tree

src/mcp/server/lowlevel/server.py

Lines changed: 29 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -371,6 +371,10 @@ async def run(
371371
# the initialization lifecycle, but can do so with any available node
372372
# rather than requiring initialization for each connection.
373373
stateless: bool = False,
374+
# When True, treat read EOF as a half-close and allow in-flight handlers
375+
# to drain their responses via the still-open write stream (e.g. stdio
376+
# with bash-redirected stdin).
377+
drain_on_read_close: bool = False,
374378
):
375379
async with AsyncExitStack() as stack:
376380
lifespan_context = await stack.enter_async_context(self.lifespan(self))
@@ -380,6 +384,7 @@ async def run(
380384
write_stream,
381385
initialization_options,
382386
stateless=stateless,
387+
close_write_stream_on_read_close=not drain_on_read_close,
383388
)
384389
)
385390

@@ -390,22 +395,30 @@ async def run(
390395
await stack.enter_async_context(task_support.run())
391396

392397
async with anyio.create_task_group() as tg:
393-
async for message in session.incoming_messages:
394-
logger.debug("Received message: %s", message)
395-
396-
if isinstance(message, RequestResponder) and message.context is not None:
397-
context = message.context
398-
else:
399-
context = contextvars.copy_context()
400-
401-
context.run(
402-
tg.start_soon,
403-
self._handle_message,
404-
message,
405-
session,
406-
lifespan_context,
407-
raise_exceptions,
408-
)
398+
try:
399+
async for message in session.incoming_messages:
400+
logger.debug("Received message: %s", message)
401+
402+
if isinstance(message, RequestResponder) and message.context is not None:
403+
context = message.context
404+
else:
405+
context = contextvars.copy_context()
406+
407+
context.run(
408+
tg.start_soon,
409+
self._handle_message,
410+
message,
411+
session,
412+
lifespan_context,
413+
raise_exceptions,
414+
)
415+
finally:
416+
if not drain_on_read_close:
417+
# Transport closed: cancel in-flight handlers. Without this the
418+
# TG join waits for them, and when they eventually try to
419+
# respond they hit a closed write stream (the session's
420+
# _receive_loop closed it when the read stream ended).
421+
tg.cancel_scope.cancel()
409422

410423
async def _handle_message(
411424
self,

src/mcp/server/mcpserver/server.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -852,6 +852,7 @@ async def run_stdio_async(self) -> None:
852852
read_stream,
853853
write_stream,
854854
self._lowlevel_server.create_initialization_options(),
855+
drain_on_read_close=True,
855856
)
856857

857858
async def run_sse_async( # pragma: no cover

src/mcp/server/session.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,8 +84,9 @@ def __init__(
8484
write_stream: WriteStream[SessionMessage],
8585
init_options: InitializationOptions,
8686
stateless: bool = False,
87+
close_write_stream_on_read_close: bool = True,
8788
) -> None:
88-
super().__init__(read_stream, write_stream, close_write_stream_on_read_close=False)
89+
super().__init__(read_stream, write_stream, close_write_stream_on_read_close=close_write_stream_on_read_close)
8990
self._stateless = stateless
9091
self._initialization_state = (
9192
InitializationState.Initialized if stateless else InitializationState.NotInitialized

tests/server/test_cancel_handling.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ async def handle_call_tool(ctx: ServerRequestContext, params: CallToolRequestPar
120120
server_write, from_server = anyio.create_memory_object_stream[SessionMessage](10)
121121

122122
async def run_server():
123-
await server.run(server_read, server_write, server.create_initialization_options())
123+
await server.run(server_read, server_write, server.create_initialization_options(), drain_on_read_close=True)
124124
server_run_returned.set()
125125

126126
init_req = JSONRPCRequest(

tests/server/test_stdio.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,16 @@ async def handle_call_tool(ctx: ServerRequestContext, params: CallToolRequestPar
170170
):
171171
with anyio.fail_after(5):
172172
async with anyio.create_task_group() as tg: # pragma: no branch
173-
tg.start_soon(server.run, read_stream, write_stream, server.create_initialization_options())
173+
174+
async def run_server() -> None:
175+
await server.run(
176+
read_stream,
177+
write_stream,
178+
server.create_initialization_options(),
179+
drain_on_read_close=True,
180+
)
181+
182+
tg.start_soon(run_server)
174183
await both_tools_started.wait()
175184
allow_tools_to_finish.set()
176185

0 commit comments

Comments
 (0)