Skip to content

Commit e1745f8

Browse files
committed
remove taskgroup fixture due to premature closing
1 parent 5cc10fa commit e1745f8

File tree

1 file changed

+71
-74
lines changed

1 file changed

+71
-74
lines changed

tests/shared/test_sse.py

Lines changed: 71 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -116,28 +116,24 @@ def server_app() -> Starlette:
116116

117117

118118
@pytest.fixture()
119-
async def tg() -> AsyncGenerator[TaskGroup, None]:
120-
async with anyio.create_task_group() as tg:
121-
yield tg
122-
123-
124-
@pytest.fixture()
125-
async def http_client(tg: TaskGroup, server_app: Starlette) -> AsyncGenerator[httpx.AsyncClient, None]:
119+
async def http_client(server_app: Starlette) -> AsyncGenerator[httpx.AsyncClient, None]:
126120
"""Create test client using StreamingASGITransport"""
127-
transport = StreamingASGITransport(app=server_app, task_group=tg)
128-
async with httpx.AsyncClient(transport=transport, base_url=TEST_SERVER_BASE_URL) as client:
129-
yield client
121+
async with anyio.create_task_group() as tg:
122+
transport = StreamingASGITransport(app=server_app, task_group=tg)
123+
async with httpx.AsyncClient(transport=transport, base_url=TEST_SERVER_BASE_URL) as client:
124+
yield client
130125

131126

132127
@pytest.fixture()
133-
async def sse_client_session(tg: TaskGroup, server_app: Starlette) -> AsyncGenerator[ClientSession, None]:
134-
asgi_client_factory = create_asgi_client_factory(server_app, tg)
128+
async def sse_client_session(server_app: Starlette) -> AsyncGenerator[ClientSession, None]:
129+
async with anyio.create_task_group() as tg:
130+
asgi_client_factory = create_asgi_client_factory(server_app, tg)
135131

136-
async with sse_client(
137-
f"{TEST_SERVER_BASE_URL}/sse", sse_read_timeout=0.5, httpx_client_factory=asgi_client_factory
138-
) as streams:
139-
async with ClientSession(*streams) as session:
140-
yield session
132+
async with sse_client(
133+
f"{TEST_SERVER_BASE_URL}/sse", sse_read_timeout=0.5, httpx_client_factory=asgi_client_factory
134+
) as streams:
135+
async with ClientSession(*streams) as session:
136+
yield session
141137

142138

143139
# Tests
@@ -232,16 +228,15 @@ async def mounted_server_app(server_app: Starlette) -> Starlette:
232228

233229

234230
@pytest.fixture()
235-
async def sse_client_mounted_server_app_session(
236-
tg: TaskGroup, mounted_server_app: Starlette
237-
) -> AsyncGenerator[ClientSession, None]:
238-
asgi_client_factory = create_asgi_client_factory(mounted_server_app, tg)
231+
async def sse_client_mounted_server_app_session(mounted_server_app: Starlette) -> AsyncGenerator[ClientSession, None]:
232+
async with anyio.create_task_group() as tg:
233+
asgi_client_factory = create_asgi_client_factory(mounted_server_app, tg)
239234

240-
async with sse_client(
241-
f"{TEST_SERVER_BASE_URL}/mounted_app/sse", sse_read_timeout=0.5, httpx_client_factory=asgi_client_factory
242-
) as streams:
243-
async with ClientSession(*streams) as session:
244-
yield session
235+
async with sse_client(
236+
f"{TEST_SERVER_BASE_URL}/mounted_app/sse", sse_read_timeout=0.5, httpx_client_factory=asgi_client_factory
237+
) as streams:
238+
async with ClientSession(*streams) as session:
239+
yield session
245240

246241

247242
@pytest.mark.anyio
@@ -308,7 +303,7 @@ async def context_server_app() -> Starlette:
308303

309304

310305
@pytest.mark.anyio
311-
async def test_request_context_propagation(tg: TaskGroup, context_server_app: Starlette) -> None:
306+
async def test_request_context_propagation(context_server_app: Starlette) -> None:
312307
"""Test that request context is properly propagated through SSE transport."""
313308
# Test with custom headers
314309
custom_headers = {
@@ -317,61 +312,63 @@ async def test_request_context_propagation(tg: TaskGroup, context_server_app: St
317312
"X-Trace-Id": "trace-123",
318313
}
319314

320-
asgi_client_factory = create_asgi_client_factory(context_server_app, tg)
321-
322-
async with sse_client(
323-
f"{TEST_SERVER_BASE_URL}/sse",
324-
headers=custom_headers,
325-
httpx_client_factory=asgi_client_factory,
326-
sse_read_timeout=0.5,
327-
) as streams:
328-
async with ClientSession(*streams) as session:
329-
# Initialize the session
330-
result = await session.initialize()
331-
assert isinstance(result, InitializeResult)
332-
333-
# Call the tool that echoes headers back
334-
tool_result = await session.call_tool("echo_headers", {})
315+
async with anyio.create_task_group() as tg:
316+
asgi_client_factory = create_asgi_client_factory(context_server_app, tg)
335317

336-
# Parse the JSON response
337-
assert len(tool_result.content) == 1
338-
content_item = tool_result.content[0]
339-
headers_data = json.loads(content_item.text if content_item.type == "text" else "{}")
318+
async with sse_client(
319+
f"{TEST_SERVER_BASE_URL}/sse",
320+
headers=custom_headers,
321+
httpx_client_factory=asgi_client_factory,
322+
sse_read_timeout=0.5,
323+
) as streams:
324+
async with ClientSession(*streams) as session:
325+
# Initialize the session
326+
result = await session.initialize()
327+
assert isinstance(result, InitializeResult)
328+
329+
# Call the tool that echoes headers back
330+
tool_result = await session.call_tool("echo_headers", {})
331+
332+
# Parse the JSON response
333+
assert len(tool_result.content) == 1
334+
content_item = tool_result.content[0]
335+
headers_data = json.loads(content_item.text if content_item.type == "text" else "{}")
340336

341-
# Verify headers were propagated
342-
assert headers_data.get("authorization") == "Bearer test-token"
343-
assert headers_data.get("x-custom-header") == "test-value"
344-
assert headers_data.get("x-trace-id") == "trace-123"
337+
# Verify headers were propagated
338+
assert headers_data.get("authorization") == "Bearer test-token"
339+
assert headers_data.get("x-custom-header") == "test-value"
340+
assert headers_data.get("x-trace-id") == "trace-123"
345341

346342

347343
@pytest.mark.anyio
348-
async def test_request_context_isolation(tg: TaskGroup, context_server_app: Starlette) -> None:
344+
async def test_request_context_isolation(context_server_app: Starlette) -> None:
349345
"""Test that request contexts are isolated between different SSE clients."""
350346
contexts: list[dict[str, Any]] = []
351347

352-
asgi_client_factory = create_asgi_client_factory(context_server_app, tg)
353-
354-
# Create multiple clients with different headers
355-
for i in range(3):
356-
headers = {"X-Request-Id": f"request-{i}", "X-Custom-Value": f"value-{i}"}
357-
358-
async with sse_client(
359-
f"{TEST_SERVER_BASE_URL}/sse", headers=headers, httpx_client_factory=asgi_client_factory
360-
) as (
361-
read_stream,
362-
write_stream,
363-
):
364-
async with ClientSession(read_stream, write_stream) as session:
365-
await session.initialize()
366-
367-
# Call the tool that echoes context
368-
tool_result = await session.call_tool("echo_context", {"request_id": f"request-{i}"})
369-
370-
assert len(tool_result.content) == 1
371-
context_data = json.loads(
372-
tool_result.content[0].text if tool_result.content[0].type == "text" else "{}"
373-
)
374-
contexts.append(context_data)
348+
async with anyio.create_task_group() as tg:
349+
asgi_client_factory = create_asgi_client_factory(context_server_app, tg)
350+
351+
# Create multiple clients with different headers
352+
for i in range(3):
353+
headers = {"X-Request-Id": f"request-{i}", "X-Custom-Value": f"value-{i}"}
354+
355+
async with sse_client(
356+
f"{TEST_SERVER_BASE_URL}/sse", headers=headers, httpx_client_factory=asgi_client_factory
357+
) as (
358+
read_stream,
359+
write_stream,
360+
):
361+
async with ClientSession(read_stream, write_stream) as session:
362+
await session.initialize()
363+
364+
# Call the tool that echoes context
365+
tool_result = await session.call_tool("echo_context", {"request_id": f"request-{i}"})
366+
367+
assert len(tool_result.content) == 1
368+
context_data = json.loads(
369+
tool_result.content[0].text if tool_result.content[0].type == "text" else "{}"
370+
)
371+
contexts.append(context_data)
375372

376373
# Verify each request had its own context
377374
assert len(contexts) == 3

0 commit comments

Comments
 (0)