Skip to content

Commit 1e83583

Browse files
committed
refactor: use async with instead of try/finally for stream cleanup
Replace explicit try/finally + aclose() chains with async with on all 4 stream ends. Memory stream context managers are idempotent and have no checkpoints in __aexit__, so this is semantically identical to the try/finally form with the same teardown ordering — but the ownership is stated once at creation time and can't drift. For websocket_client, also move stream creation inside ws_connect so a connection failure never creates streams in the first place. Matches the existing pattern in shared/memory.py.
1 parent 6ed6261 commit 1e83583

File tree

3 files changed

+7
-22
lines changed

3 files changed

+7
-22
lines changed

src/mcp/client/sse.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ async def sse_client(
6161
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
6262

6363
async with anyio.create_task_group() as tg:
64-
try:
64+
async with read_stream_writer, read_stream, write_stream, write_stream_reader:
6565
logger.debug(f"Connecting to SSE endpoint: {remove_request_params(url)}")
6666
async with httpx_client_factory(
6767
headers=headers, auth=auth, timeout=httpx.Timeout(timeout, read=sse_read_timeout)
@@ -157,8 +157,3 @@ async def post_writer(endpoint_url: str):
157157
yield read_stream, write_stream
158158
finally:
159159
tg.cancel_scope.cancel()
160-
finally:
161-
await read_stream_writer.aclose()
162-
await write_stream.aclose()
163-
await read_stream.aclose()
164-
await write_stream_reader.aclose()

src/mcp/client/streamable_http.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -547,7 +547,7 @@ async def streamable_http_client(
547547
transport = StreamableHTTPTransport(url)
548548

549549
async with anyio.create_task_group() as tg:
550-
try:
550+
async with read_stream_writer, read_stream, write_stream, write_stream_reader:
551551
logger.debug(f"Connecting to StreamableHTTP endpoint: {url}")
552552

553553
async with contextlib.AsyncExitStack() as stack:
@@ -574,8 +574,3 @@ def start_get_stream() -> None:
574574
if transport.session_id and terminate_on_close:
575575
await transport.terminate_session(client)
576576
tg.cancel_scope.cancel()
577-
finally:
578-
await read_stream_writer.aclose()
579-
await write_stream.aclose()
580-
await read_stream.aclose()
581-
await write_stream_reader.aclose()

src/mcp/client/websocket.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -38,12 +38,12 @@ async def websocket_client(
3838
write_stream: MemoryObjectSendStream[SessionMessage]
3939
write_stream_reader: MemoryObjectReceiveStream[SessionMessage]
4040

41-
read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
42-
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
41+
# Connect using websockets, requesting the "mcp" subprotocol
42+
async with ws_connect(url, subprotocols=[Subprotocol("mcp")]) as ws:
43+
read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
44+
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
4345

44-
try:
45-
# Connect using websockets, requesting the "mcp" subprotocol
46-
async with ws_connect(url, subprotocols=[Subprotocol("mcp")]) as ws:
46+
async with read_stream_writer, read_stream, write_stream, write_stream_reader:
4747

4848
async def ws_reader():
4949
"""Reads text messages from the WebSocket, parses them as JSON-RPC messages,
@@ -79,8 +79,3 @@ async def ws_writer():
7979

8080
# Once the caller's 'async with' block exits, we shut down
8181
tg.cancel_scope.cancel()
82-
finally:
83-
await read_stream_writer.aclose()
84-
await write_stream.aclose()
85-
await read_stream.aclose()
86-
await write_stream_reader.aclose()

0 commit comments

Comments
 (0)