Skip to content

Commit 7068536

Browse files
committed
test: add coverage for extra_headers in HTTP transport
- Add test_streamablehttp_client_tool_invocation_with_extra_headers for POST requests - Add test_streamablehttp_client_resumption_with_extra_headers for resumption with extra headers - Refactor common resumption setup code into _setup_resumption_test helper - Achieve 100% coverage for streamable_http.py
1 parent 876e13a commit 7068536

File tree

1 file changed

+106
-8
lines changed

1 file changed

+106
-8
lines changed

tests/shared/test_streamable_http.py

Lines changed: 106 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -893,6 +893,19 @@ async def test_streamablehttp_client_tool_invocation(initialized_client_session:
893893
assert result.content[0].text == "Called test_tool"
894894

895895

896+
@pytest.mark.anyio
897+
async def test_streamablehttp_client_tool_invocation_with_extra_headers(initialized_client_session: ClientSession):
898+
"""Test HTTP POST request with extra headers."""
899+
result = await initialized_client_session.call_tool(
900+
"test_tool",
901+
{},
902+
extra_headers={"X-Custom-Header": "test-value"},
903+
)
904+
assert len(result.content) == 1
905+
assert result.content[0].type == "text"
906+
assert result.content[0].text == "Called test_tool"
907+
908+
896909
@pytest.mark.anyio
897910
async def test_streamablehttp_client_error_handling(initialized_client_session: ClientSession):
898911
"""Test error handling in client."""
@@ -1106,12 +1119,14 @@ async def mock_delete(self: httpx.AsyncClient, *args: Any, **kwargs: Any) -> htt
11061119
await session.list_tools()
11071120

11081121

1109-
@pytest.mark.anyio
1110-
async def test_streamablehttp_client_resumption(event_server: tuple[SimpleEventStore, str]):
1111-
"""Test client session resumption using sync primitives for reliable coordination."""
1112-
_, server_url = event_server
1122+
async def _setup_resumption_test(
1123+
server_url: str,
1124+
) -> tuple[str | None, str | None, str | int | None, list[types.ServerNotification]]:
1125+
"""Helper function to set up a resumption test by starting a session and capturing resumption state.
11131126
1114-
# Variables to track the state
1127+
Returns:
1128+
Tuple of (session_id, resumption_token, protocol_version, notifications)
1129+
"""
11151130
captured_session_id = None
11161131
captured_resumption_token = None
11171132
captured_notifications: list[types.ServerNotification] = []
@@ -1123,7 +1138,6 @@ async def message_handler( # pragma: no branch
11231138
) -> None:
11241139
if isinstance(message, types.ServerNotification): # pragma: no branch
11251140
captured_notifications.append(message)
1126-
# Look for our first notification
11271141
if isinstance(message.root, types.LoggingMessageNotification): # pragma: no branch
11281142
if message.root.params.data == "First notification before lock":
11291143
nonlocal first_notification_received
@@ -1181,8 +1195,90 @@ async def run_tool():
11811195
assert isinstance(captured_notifications[0].root, types.LoggingMessageNotification) # pragma: no cover
11821196
assert captured_notifications[0].root.params.data == "First notification before lock" # pragma: no cover
11831197

1184-
# Clear notifications for the second phase
1185-
captured_notifications = [] # pragma: no cover
1198+
return captured_session_id, captured_resumption_token, captured_protocol_version, captured_notifications
1199+
1200+
1201+
@pytest.mark.anyio
1202+
async def test_streamablehttp_client_resumption(event_server: tuple[SimpleEventStore, str]):
1203+
"""Test client session resumption using sync primitives for reliable coordination."""
1204+
_, server_url = event_server
1205+
1206+
# Set up the initial session and capture resumption state
1207+
captured_session_id, captured_resumption_token, captured_protocol_version, _ = await _setup_resumption_test(
1208+
server_url
1209+
)
1210+
1211+
# Track notifications for the resumed session
1212+
captured_notifications: list[types.ServerNotification] = []
1213+
1214+
async def message_handler( # pragma: no branch
1215+
message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
1216+
) -> None:
1217+
if isinstance(message, types.ServerNotification): # pragma: no branch
1218+
captured_notifications.append(message)
1219+
1220+
# Now resume the session with the same mcp-session-id and protocol version
1221+
headers: dict[str, Any] = {} # pragma: no cover
1222+
if captured_session_id: # pragma: no cover
1223+
headers[MCP_SESSION_ID_HEADER] = captured_session_id
1224+
if captured_protocol_version: # pragma: no cover
1225+
headers[MCP_PROTOCOL_VERSION_HEADER] = captured_protocol_version
1226+
async with streamablehttp_client(f"{server_url}/mcp", headers=headers) as (
1227+
read_stream,
1228+
write_stream,
1229+
_,
1230+
):
1231+
async with ClientSession(read_stream, write_stream, message_handler=message_handler) as session:
1232+
result = await session.send_request(
1233+
types.ClientRequest(
1234+
types.CallToolRequest(
1235+
params=types.CallToolRequestParams(name="release_lock", arguments={}),
1236+
)
1237+
),
1238+
types.CallToolResult,
1239+
)
1240+
metadata = ClientMessageMetadata(
1241+
resumption_token=captured_resumption_token,
1242+
)
1243+
1244+
result = await session.send_request(
1245+
types.ClientRequest(
1246+
types.CallToolRequest(
1247+
params=types.CallToolRequestParams(name="wait_for_lock_with_notification", arguments={}),
1248+
)
1249+
),
1250+
types.CallToolResult,
1251+
metadata=metadata,
1252+
)
1253+
assert len(result.content) == 1
1254+
assert result.content[0].type == "text"
1255+
assert result.content[0].text == "Completed"
1256+
1257+
# We should have received the remaining notifications
1258+
assert len(captured_notifications) == 1
1259+
1260+
assert isinstance(captured_notifications[0].root, types.LoggingMessageNotification)
1261+
assert captured_notifications[0].root.params.data == "Second notification after lock"
1262+
1263+
1264+
@pytest.mark.anyio
1265+
async def test_streamablehttp_client_resumption_with_extra_headers(event_server: tuple[SimpleEventStore, str]):
1266+
"""Test client session resumption with extra headers."""
1267+
_, server_url = event_server
1268+
1269+
# Set up the initial session and capture resumption state
1270+
captured_session_id, captured_resumption_token, captured_protocol_version, _ = await _setup_resumption_test(
1271+
server_url
1272+
)
1273+
1274+
# Track notifications for the resumed session
1275+
captured_notifications: list[types.ServerNotification] = []
1276+
1277+
async def message_handler( # pragma: no branch
1278+
message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
1279+
) -> None:
1280+
if isinstance(message, types.ServerNotification): # pragma: no branch
1281+
captured_notifications.append(message)
11861282

11871283
# Now resume the session with the same mcp-session-id and protocol version
11881284
headers: dict[str, Any] = {} # pragma: no cover
@@ -1204,8 +1300,10 @@ async def run_tool():
12041300
),
12051301
types.CallToolResult,
12061302
)
1303+
# Test resumption WITH extra_headers
12071304
metadata = ClientMessageMetadata(
12081305
resumption_token=captured_resumption_token,
1306+
extra_headers={"X-Resumption-Test": "test-value"},
12091307
)
12101308

12111309
result = await session.send_request(

0 commit comments

Comments
 (0)