11import logging
2- from contextlib import AsyncExitStack , asynccontextmanager
2+ from contextlib import asynccontextmanager
33from typing import Any
44from urllib .parse import urljoin , urlparse
55
66import anyio
77import httpx
8- from anyio .abc import TaskGroup , TaskStatus
8+ from anyio .abc import TaskStatus
99from anyio .streams .memory import MemoryObjectReceiveStream , MemoryObjectSendStream
1010from httpx_sse import aconnect_sse
1111from httpx_sse ._exceptions import SSEError
@@ -29,7 +29,6 @@ async def sse_client(
2929 sse_read_timeout : float = 60 * 5 ,
3030 httpx_client_factory : McpHttpClientFactory = create_mcp_http_client ,
3131 auth : httpx .Auth | None = None ,
32- maybe_task_group : TaskGroup | None = None ,
3332):
3433 """
3534 Client transport for SSE.
@@ -53,15 +52,7 @@ async def sse_client(
5352 read_stream_writer , read_stream = anyio .create_memory_object_stream (0 )
5453 write_stream , write_stream_reader = anyio .create_memory_object_stream (0 )
5554
56- async with AsyncExitStack () as stack :
57- # Only create a task group if one wasn't provided
58- if maybe_task_group is None :
59- tg = await stack .enter_async_context (anyio .create_task_group ())
60- else :
61- tg = maybe_task_group
62-
63- owns_task_group = maybe_task_group is None
64-
55+ async with anyio .create_task_group () as tg :
6556 try :
6657 logger .debug (f"Connecting to SSE endpoint: { remove_request_params (url )} " )
6758 async with httpx_client_factory (
@@ -151,8 +142,7 @@ async def post_writer(endpoint_url: str):
151142 try :
152143 yield read_stream , write_stream
153144 finally :
154- if owns_task_group :
155- tg .cancel_scope .cancel ()
145+ tg .cancel_scope .cancel ()
156146 finally :
157147 await read_stream_writer .aclose ()
158148 await write_stream .aclose ()
0 commit comments