Skip to content

Commit e8a3b0e

Browse files
committed
revert sse_client and add cleanup to outer task group in tests
1 parent fabf6c5 commit e8a3b0e

File tree

2 files changed

+8
-15
lines changed

2 files changed

+8
-15
lines changed

src/mcp/client/sse.py

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
import logging
2-
from contextlib import AsyncExitStack, asynccontextmanager
2+
from contextlib import asynccontextmanager
33
from typing import Any
44
from urllib.parse import urljoin, urlparse
55

66
import anyio
77
import httpx
8-
from anyio.abc import TaskGroup, TaskStatus
8+
from anyio.abc import TaskStatus
99
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
1010
from httpx_sse import aconnect_sse
1111
from httpx_sse._exceptions import SSEError
@@ -29,7 +29,6 @@ async def sse_client(
2929
sse_read_timeout: float = 60 * 5,
3030
httpx_client_factory: McpHttpClientFactory = create_mcp_http_client,
3131
auth: httpx.Auth | None = None,
32-
maybe_task_group: TaskGroup | None = None,
3332
):
3433
"""
3534
Client transport for SSE.
@@ -53,15 +52,7 @@ async def sse_client(
5352
read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
5453
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
5554

56-
async with AsyncExitStack() as stack:
57-
# Only create a task group if one wasn't provided
58-
if maybe_task_group is None:
59-
tg = await stack.enter_async_context(anyio.create_task_group())
60-
else:
61-
tg = maybe_task_group
62-
63-
owns_task_group = maybe_task_group is None
64-
55+
async with anyio.create_task_group() as tg:
6556
try:
6657
logger.debug(f"Connecting to SSE endpoint: {remove_request_params(url)}")
6758
async with httpx_client_factory(
@@ -151,8 +142,7 @@ async def post_writer(endpoint_url: str):
151142
try:
152143
yield read_stream, write_stream
153144
finally:
154-
if owns_task_group:
155-
tg.cancel_scope.cancel()
145+
tg.cancel_scope.cancel()
156146
finally:
157147
await read_stream_writer.aclose()
158148
await write_stream.aclose()

tests/shared/test_sse.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,10 @@ def server_app() -> Starlette:
118118
@pytest.fixture()
119119
async def tg() -> AsyncGenerator[TaskGroup, None]:
120120
async with anyio.create_task_group() as tg:
121-
yield tg
121+
try:
122+
yield tg
123+
finally:
124+
tg.cancel_scope.cancel()
122125

123126

124127
@pytest.fixture()

0 commit comments

Comments
 (0)