Skip to content

Commit 4d4e863

Browse files
committed
test: cover stdio EOF drain and shutdown edges
1 parent b485b6a commit 4d4e863

2 files changed

Lines changed: 125 additions & 10 deletions

File tree

tests/server/test_cancel_handling.py

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,123 @@ async def run_server():
166166
await server_run_returned.wait()
167167

168168

169+
@pytest.mark.anyio
170+
async def test_server_reraises_handler_cancellation_when_server_is_cancelled():
171+
"""If the server task is cancelled (e.g. KeyboardInterrupt), in-flight
172+
request handlers will get cancelled too. Cancellation must be re-raised so
173+
the task group can unwind cleanly."""
174+
handler_started = anyio.Event()
175+
server_run_returned = anyio.Event()
176+
cancel_scope = anyio.CancelScope()
177+
178+
async def handle_call_tool(ctx: ServerRequestContext, params: CallToolRequestParams) -> CallToolResult:
179+
handler_started.set()
180+
await anyio.sleep_forever()
181+
raise AssertionError # pragma: no cover
182+
183+
server = Server("test", on_call_tool=handle_call_tool)
184+
185+
to_server, server_read = anyio.create_memory_object_stream[SessionMessage | Exception](10)
186+
server_write, from_server = anyio.create_memory_object_stream[SessionMessage](10)
187+
188+
async def run_server():
189+
try:
190+
with cancel_scope:
191+
await server.run(server_read, server_write, server.create_initialization_options())
192+
finally:
193+
server_run_returned.set()
194+
195+
init_req = JSONRPCRequest(
196+
jsonrpc="2.0",
197+
id=1,
198+
method="initialize",
199+
params=InitializeRequestParams(
200+
protocol_version=LATEST_PROTOCOL_VERSION,
201+
capabilities=ClientCapabilities(),
202+
client_info=Implementation(name="test", version="1.0"),
203+
).model_dump(by_alias=True, mode="json", exclude_none=True),
204+
)
205+
initialized = JSONRPCNotification(jsonrpc="2.0", method="notifications/initialized")
206+
call_req = JSONRPCRequest(
207+
jsonrpc="2.0",
208+
id=2,
209+
method="tools/call",
210+
params=CallToolRequestParams(name="slow", arguments={}).model_dump(by_alias=True, mode="json"),
211+
)
212+
213+
with anyio.fail_after(5):
214+
async with anyio.create_task_group() as tg, to_server, server_read, server_write, from_server:
215+
tg.start_soon(run_server)
216+
217+
await to_server.send(SessionMessage(init_req))
218+
await from_server.receive() # init response
219+
await to_server.send(SessionMessage(initialized))
220+
await to_server.send(SessionMessage(call_req))
221+
222+
await handler_started.wait()
223+
cancel_scope.cancel()
224+
await server_run_returned.wait()
225+
226+
227+
@pytest.mark.anyio
228+
async def test_server_drops_response_when_write_stream_closes_mid_request():
229+
"""If the write side closes while a handler is in-flight, responding may
230+
raise (ClosedResourceError/BrokenResourceError). The handler task should
231+
exit without crashing the server."""
232+
handler_started = anyio.Event()
233+
allow_finish = anyio.Event()
234+
server_run_returned = anyio.Event()
235+
236+
async def handle_call_tool(ctx: ServerRequestContext, params: CallToolRequestParams) -> CallToolResult:
237+
handler_started.set()
238+
await allow_finish.wait()
239+
return CallToolResult(content=[TextContent(type="text", text="ok")])
240+
241+
server = Server("test", on_call_tool=handle_call_tool)
242+
243+
to_server, server_read = anyio.create_memory_object_stream[SessionMessage | Exception](10)
244+
server_write, from_server = anyio.create_memory_object_stream[SessionMessage](10)
245+
246+
async def run_server():
247+
await server.run(server_read, server_write, server.create_initialization_options())
248+
server_run_returned.set()
249+
250+
init_req = JSONRPCRequest(
251+
jsonrpc="2.0",
252+
id=1,
253+
method="initialize",
254+
params=InitializeRequestParams(
255+
protocol_version=LATEST_PROTOCOL_VERSION,
256+
capabilities=ClientCapabilities(),
257+
client_info=Implementation(name="test", version="1.0"),
258+
).model_dump(by_alias=True, mode="json", exclude_none=True),
259+
)
260+
initialized = JSONRPCNotification(jsonrpc="2.0", method="notifications/initialized")
261+
call_req = JSONRPCRequest(
262+
jsonrpc="2.0",
263+
id=2,
264+
method="tools/call",
265+
params=CallToolRequestParams(name="slow", arguments={}).model_dump(by_alias=True, mode="json"),
266+
)
267+
268+
with anyio.fail_after(5):
269+
async with anyio.create_task_group() as tg, to_server, server_read, server_write, from_server:
270+
tg.start_soon(run_server)
271+
272+
await to_server.send(SessionMessage(init_req))
273+
await from_server.receive() # init response
274+
await to_server.send(SessionMessage(initialized))
275+
await to_server.send(SessionMessage(call_req))
276+
277+
await handler_started.wait()
278+
await server_write.aclose()
279+
280+
allow_finish.set()
281+
await to_server.aclose()
282+
283+
await server_run_returned.wait()
284+
285+
169286
@pytest.mark.anyio
170287
async def test_server_handles_transport_close_with_pending_server_to_client_requests():
171288
"""When the transport closes while handlers are blocked on server→client

tests/server/test_stdio.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
ClientCapabilities,
1616
Implementation,
1717
InitializeRequestParams,
18-
JSONRPCError,
1918
JSONRPCMessage,
2019
JSONRPCNotification,
2120
JSONRPCRequest,
@@ -147,6 +146,7 @@ async def handle_call_tool(ctx: ServerRequestContext, params: CallToolRequestPar
147146
).model_dump(by_alias=True, mode="json", exclude_none=True),
148147
)
149148
initialized = JSONRPCNotification(jsonrpc="2.0", method="notifications/initialized")
149+
list_tools = JSONRPCRequest(jsonrpc="2.0", id=10, method="tools/list")
150150
call_1 = JSONRPCRequest(
151151
jsonrpc="2.0",
152152
id=1,
@@ -160,7 +160,7 @@ async def handle_call_tool(ctx: ServerRequestContext, params: CallToolRequestPar
160160
params=CallToolRequestParams(name="slow", arguments={}).model_dump(by_alias=True, mode="json"),
161161
)
162162

163-
for message in (init_req, initialized, call_1, call_2):
163+
for message in (init_req, initialized, list_tools, call_1, call_2):
164164
stdin.write(message.model_dump_json(by_alias=True, exclude_none=True) + "\n")
165165
stdin.seek(0)
166166

@@ -175,14 +175,12 @@ async def handle_call_tool(ctx: ServerRequestContext, params: CallToolRequestPar
175175
allow_tools_to_finish.set()
176176

177177
stdout.seek(0)
178+
output_lines = [line.strip() for line in stdout.readlines()]
179+
messages = [jsonrpc_message_adapter.validate_json(line) for line in output_lines]
178180
ids: set[int | str] = set()
179-
for line in stdout.readlines():
180-
line = line.strip()
181-
if not line:
182-
continue
183-
message = jsonrpc_message_adapter.validate_json(line)
184-
if isinstance(message, JSONRPCResponse | JSONRPCError):
185-
assert message.id is not None
186-
ids.add(message.id)
181+
for message in messages:
182+
assert isinstance(message, JSONRPCResponse)
183+
ids.add(message.id)
184+
187185
assert 1 in ids
188186
assert 2 in ids

0 commit comments

Comments
 (0)