@@ -116,28 +116,24 @@ def server_app() -> Starlette:
116116
117117
118118@pytest .fixture ()
119- async def tg () -> AsyncGenerator [TaskGroup , None ]:
120- async with anyio .create_task_group () as tg :
121- yield tg
122-
123-
124- @pytest .fixture ()
125- async def http_client (tg : TaskGroup , server_app : Starlette ) -> AsyncGenerator [httpx .AsyncClient , None ]:
119+ async def http_client (server_app : Starlette ) -> AsyncGenerator [httpx .AsyncClient , None ]:
126120 """Create test client using StreamingASGITransport"""
127- transport = StreamingASGITransport (app = server_app , task_group = tg )
128- async with httpx .AsyncClient (transport = transport , base_url = TEST_SERVER_BASE_URL ) as client :
129- yield client
121+ async with anyio .create_task_group () as tg :
122+ transport = StreamingASGITransport (app = server_app , task_group = tg )
123+ async with httpx .AsyncClient (transport = transport , base_url = TEST_SERVER_BASE_URL ) as client :
124+ yield client
130125
131126
132127@pytest .fixture ()
133- async def sse_client_session (tg : TaskGroup , server_app : Starlette ) -> AsyncGenerator [ClientSession , None ]:
134- asgi_client_factory = create_asgi_client_factory (server_app , tg )
128+ async def sse_client_session (server_app : Starlette ) -> AsyncGenerator [ClientSession , None ]:
129+ async with anyio .create_task_group () as tg :
130+ asgi_client_factory = create_asgi_client_factory (server_app , tg )
135131
136- async with sse_client (
137- f"{ TEST_SERVER_BASE_URL } /sse" , sse_read_timeout = 0.5 , httpx_client_factory = asgi_client_factory
138- ) as streams :
139- async with ClientSession (* streams ) as session :
140- yield session
132+ async with sse_client (
133+ f"{ TEST_SERVER_BASE_URL } /sse" , sse_read_timeout = 0.5 , httpx_client_factory = asgi_client_factory
134+ ) as streams :
135+ async with ClientSession (* streams ) as session :
136+ yield session
141137
142138
143139# Tests
@@ -232,16 +228,15 @@ async def mounted_server_app(server_app: Starlette) -> Starlette:
232228
233229
234230@pytest .fixture ()
235- async def sse_client_mounted_server_app_session (
236- tg : TaskGroup , mounted_server_app : Starlette
237- ) -> AsyncGenerator [ClientSession , None ]:
238- asgi_client_factory = create_asgi_client_factory (mounted_server_app , tg )
231+ async def sse_client_mounted_server_app_session (mounted_server_app : Starlette ) -> AsyncGenerator [ClientSession , None ]:
232+ async with anyio .create_task_group () as tg :
233+ asgi_client_factory = create_asgi_client_factory (mounted_server_app , tg )
239234
240- async with sse_client (
241- f"{ TEST_SERVER_BASE_URL } /mounted_app/sse" , sse_read_timeout = 0.5 , httpx_client_factory = asgi_client_factory
242- ) as streams :
243- async with ClientSession (* streams ) as session :
244- yield session
235+ async with sse_client (
236+ f"{ TEST_SERVER_BASE_URL } /mounted_app/sse" , sse_read_timeout = 0.5 , httpx_client_factory = asgi_client_factory
237+ ) as streams :
238+ async with ClientSession (* streams ) as session :
239+ yield session
245240
246241
247242@pytest .mark .anyio
@@ -308,7 +303,7 @@ async def context_server_app() -> Starlette:
308303
309304
310305@pytest .mark .anyio
311- async def test_request_context_propagation (tg : TaskGroup , context_server_app : Starlette ) -> None :
306+ async def test_request_context_propagation (context_server_app : Starlette ) -> None :
312307 """Test that request context is properly propagated through SSE transport."""
313308 # Test with custom headers
314309 custom_headers = {
@@ -317,61 +312,63 @@ async def test_request_context_propagation(tg: TaskGroup, context_server_app: St
317312 "X-Trace-Id" : "trace-123" ,
318313 }
319314
320- asgi_client_factory = create_asgi_client_factory (context_server_app , tg )
321-
322- async with sse_client (
323- f"{ TEST_SERVER_BASE_URL } /sse" ,
324- headers = custom_headers ,
325- httpx_client_factory = asgi_client_factory ,
326- sse_read_timeout = 0.5 ,
327- ) as streams :
328- async with ClientSession (* streams ) as session :
329- # Initialize the session
330- result = await session .initialize ()
331- assert isinstance (result , InitializeResult )
332-
333- # Call the tool that echoes headers back
334- tool_result = await session .call_tool ("echo_headers" , {})
315+ async with anyio .create_task_group () as tg :
316+ asgi_client_factory = create_asgi_client_factory (context_server_app , tg )
335317
336- # Parse the JSON response
337- assert len (tool_result .content ) == 1
338- content_item = tool_result .content [0 ]
339- headers_data = json .loads (content_item .text if content_item .type == "text" else "{}" )
318+ async with sse_client (
319+ f"{ TEST_SERVER_BASE_URL } /sse" ,
320+ headers = custom_headers ,
321+ httpx_client_factory = asgi_client_factory ,
322+ sse_read_timeout = 0.5 ,
323+ ) as streams :
324+ async with ClientSession (* streams ) as session :
325+ # Initialize the session
326+ result = await session .initialize ()
327+ assert isinstance (result , InitializeResult )
328+
329+ # Call the tool that echoes headers back
330+ tool_result = await session .call_tool ("echo_headers" , {})
331+
332+ # Parse the JSON response
333+ assert len (tool_result .content ) == 1
334+ content_item = tool_result .content [0 ]
335+ headers_data = json .loads (content_item .text if content_item .type == "text" else "{}" )
340336
341- # Verify headers were propagated
342- assert headers_data .get ("authorization" ) == "Bearer test-token"
343- assert headers_data .get ("x-custom-header" ) == "test-value"
344- assert headers_data .get ("x-trace-id" ) == "trace-123"
337+ # Verify headers were propagated
338+ assert headers_data .get ("authorization" ) == "Bearer test-token"
339+ assert headers_data .get ("x-custom-header" ) == "test-value"
340+ assert headers_data .get ("x-trace-id" ) == "trace-123"
345341
346342
347343@pytest .mark .anyio
348- async def test_request_context_isolation (tg : TaskGroup , context_server_app : Starlette ) -> None :
344+ async def test_request_context_isolation (context_server_app : Starlette ) -> None :
349345 """Test that request contexts are isolated between different SSE clients."""
350346 contexts : list [dict [str , Any ]] = []
351347
352- asgi_client_factory = create_asgi_client_factory (context_server_app , tg )
353-
354- # Create multiple clients with different headers
355- for i in range (3 ):
356- headers = {"X-Request-Id" : f"request-{ i } " , "X-Custom-Value" : f"value-{ i } " }
357-
358- async with sse_client (
359- f"{ TEST_SERVER_BASE_URL } /sse" , headers = headers , httpx_client_factory = asgi_client_factory
360- ) as (
361- read_stream ,
362- write_stream ,
363- ):
364- async with ClientSession (read_stream , write_stream ) as session :
365- await session .initialize ()
366-
367- # Call the tool that echoes context
368- tool_result = await session .call_tool ("echo_context" , {"request_id" : f"request-{ i } " })
369-
370- assert len (tool_result .content ) == 1
371- context_data = json .loads (
372- tool_result .content [0 ].text if tool_result .content [0 ].type == "text" else "{}"
373- )
374- contexts .append (context_data )
348+ async with anyio .create_task_group () as tg :
349+ asgi_client_factory = create_asgi_client_factory (context_server_app , tg )
350+
351+ # Create multiple clients with different headers
352+ for i in range (3 ):
353+ headers = {"X-Request-Id" : f"request-{ i } " , "X-Custom-Value" : f"value-{ i } " }
354+
355+ async with sse_client (
356+ f"{ TEST_SERVER_BASE_URL } /sse" , headers = headers , httpx_client_factory = asgi_client_factory
357+ ) as (
358+ read_stream ,
359+ write_stream ,
360+ ):
361+ async with ClientSession (read_stream , write_stream ) as session :
362+ await session .initialize ()
363+
364+ # Call the tool that echoes context
365+ tool_result = await session .call_tool ("echo_context" , {"request_id" : f"request-{ i } " })
366+
367+ assert len (tool_result .content ) == 1
368+ context_data = json .loads (
369+ tool_result .content [0 ].text if tool_result .content [0 ].type == "text" else "{}"
370+ )
371+ contexts .append (context_data )
375372
376373 # Verify each request had its own context
377374 assert len (contexts ) == 3
0 commit comments