@@ -893,6 +893,19 @@ async def test_streamablehttp_client_tool_invocation(initialized_client_session:
893893 assert result .content [0 ].text == "Called test_tool"
894894
895895
896+ @pytest .mark .anyio
897+ async def test_streamablehttp_client_tool_invocation_with_extra_headers (initialized_client_session : ClientSession ):
898+ """Test HTTP POST request with extra headers."""
899+ result = await initialized_client_session .call_tool (
900+ "test_tool" ,
901+ {},
902+ extra_headers = {"X-Custom-Header" : "test-value" },
903+ )
904+ assert len (result .content ) == 1
905+ assert result .content [0 ].type == "text"
906+ assert result .content [0 ].text == "Called test_tool"
907+
908+
896909@pytest .mark .anyio
897910async def test_streamablehttp_client_error_handling (initialized_client_session : ClientSession ):
898911 """Test error handling in client."""
@@ -1106,12 +1119,14 @@ async def mock_delete(self: httpx.AsyncClient, *args: Any, **kwargs: Any) -> htt
11061119 await session .list_tools ()
11071120
11081121
1109- @ pytest . mark . anyio
1110- async def test_streamablehttp_client_resumption ( event_server : tuple [ SimpleEventStore , str ]):
1111- """Test client session resumption using sync primitives for reliable coordination."""
1112- _ , server_url = event_server
1122+ async def _setup_resumption_test (
1123+ server_url : str ,
1124+ ) -> tuple [ str | None , str | None , str | int | None , list [ types . ServerNotification ]]:
1125+ """Helper function to set up a resumption test by starting a session and capturing resumption state.
11131126
1114- # Variables to track the state
1127+ Returns:
1128+ Tuple of (session_id, resumption_token, protocol_version, notifications)
1129+ """
11151130 captured_session_id = None
11161131 captured_resumption_token = None
11171132 captured_notifications : list [types .ServerNotification ] = []
@@ -1123,7 +1138,6 @@ async def message_handler( # pragma: no branch
11231138 ) -> None :
11241139 if isinstance (message , types .ServerNotification ): # pragma: no branch
11251140 captured_notifications .append (message )
1126- # Look for our first notification
11271141 if isinstance (message .root , types .LoggingMessageNotification ): # pragma: no branch
11281142 if message .root .params .data == "First notification before lock" :
11291143 nonlocal first_notification_received
@@ -1181,8 +1195,90 @@ async def run_tool():
11811195 assert isinstance (captured_notifications [0 ].root , types .LoggingMessageNotification ) # pragma: no cover
11821196 assert captured_notifications [0 ].root .params .data == "First notification before lock" # pragma: no cover
11831197
1184- # Clear notifications for the second phase
1185- captured_notifications = [] # pragma: no cover
1198+ return captured_session_id , captured_resumption_token , captured_protocol_version , captured_notifications
1199+
1200+
1201+ @pytest .mark .anyio
1202+ async def test_streamablehttp_client_resumption (event_server : tuple [SimpleEventStore , str ]):
1203+ """Test client session resumption using sync primitives for reliable coordination."""
1204+ _ , server_url = event_server
1205+
1206+ # Set up the initial session and capture resumption state
1207+ captured_session_id , captured_resumption_token , captured_protocol_version , _ = await _setup_resumption_test (
1208+ server_url
1209+ )
1210+
1211+ # Track notifications for the resumed session
1212+ captured_notifications : list [types .ServerNotification ] = []
1213+
1214+ async def message_handler ( # pragma: no branch
1215+ message : RequestResponder [types .ServerRequest , types .ClientResult ] | types .ServerNotification | Exception ,
1216+ ) -> None :
1217+ if isinstance (message , types .ServerNotification ): # pragma: no branch
1218+ captured_notifications .append (message )
1219+
1220+ # Now resume the session with the same mcp-session-id and protocol version
1221+ headers : dict [str , Any ] = {} # pragma: no cover
1222+ if captured_session_id : # pragma: no cover
1223+ headers [MCP_SESSION_ID_HEADER ] = captured_session_id
1224+ if captured_protocol_version : # pragma: no cover
1225+ headers [MCP_PROTOCOL_VERSION_HEADER ] = captured_protocol_version
1226+ async with streamablehttp_client (f"{ server_url } /mcp" , headers = headers ) as (
1227+ read_stream ,
1228+ write_stream ,
1229+ _ ,
1230+ ):
1231+ async with ClientSession (read_stream , write_stream , message_handler = message_handler ) as session :
1232+ result = await session .send_request (
1233+ types .ClientRequest (
1234+ types .CallToolRequest (
1235+ params = types .CallToolRequestParams (name = "release_lock" , arguments = {}),
1236+ )
1237+ ),
1238+ types .CallToolResult ,
1239+ )
1240+ metadata = ClientMessageMetadata (
1241+ resumption_token = captured_resumption_token ,
1242+ )
1243+
1244+ result = await session .send_request (
1245+ types .ClientRequest (
1246+ types .CallToolRequest (
1247+ params = types .CallToolRequestParams (name = "wait_for_lock_with_notification" , arguments = {}),
1248+ )
1249+ ),
1250+ types .CallToolResult ,
1251+ metadata = metadata ,
1252+ )
1253+ assert len (result .content ) == 1
1254+ assert result .content [0 ].type == "text"
1255+ assert result .content [0 ].text == "Completed"
1256+
1257+ # We should have received the remaining notifications
1258+ assert len (captured_notifications ) == 1
1259+
1260+ assert isinstance (captured_notifications [0 ].root , types .LoggingMessageNotification )
1261+ assert captured_notifications [0 ].root .params .data == "Second notification after lock"
1262+
1263+
1264+ @pytest .mark .anyio
1265+ async def test_streamablehttp_client_resumption_with_extra_headers (event_server : tuple [SimpleEventStore , str ]):
1266+ """Test client session resumption with extra headers."""
1267+ _ , server_url = event_server
1268+
1269+ # Set up the initial session and capture resumption state
1270+ captured_session_id , captured_resumption_token , captured_protocol_version , _ = await _setup_resumption_test (
1271+ server_url
1272+ )
1273+
1274+ # Track notifications for the resumed session
1275+ captured_notifications : list [types .ServerNotification ] = []
1276+
1277+ async def message_handler ( # pragma: no branch
1278+ message : RequestResponder [types .ServerRequest , types .ClientResult ] | types .ServerNotification | Exception ,
1279+ ) -> None :
1280+ if isinstance (message , types .ServerNotification ): # pragma: no branch
1281+ captured_notifications .append (message )
11861282
11871283 # Now resume the session with the same mcp-session-id and protocol version
11881284 headers : dict [str , Any ] = {} # pragma: no cover
@@ -1204,8 +1300,10 @@ async def run_tool():
12041300 ),
12051301 types .CallToolResult ,
12061302 )
1303+ # Test resumption WITH extra_headers
12071304 metadata = ClientMessageMetadata (
12081305 resumption_token = captured_resumption_token ,
1306+ extra_headers = {"X-Resumption-Test" : "test-value" },
12091307 )
12101308
12111309 result = await session .send_request (
0 commit comments