Skip to content

Commit 7f09e87

Browse files
committed
Refactor code
1 parent 8714c53 commit 7f09e87

File tree

2 files changed

+124
-81
lines changed

2 files changed

+124
-81
lines changed

src/mcp/client/streamable_http.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -464,7 +464,6 @@ def get_session_id(self) -> str | None:
464464
@asynccontextmanager
465465
async def streamable_http_client(
466466
url: str,
467-
extensions: dict[str, str] | None = None,
468467
*,
469468
http_client: httpx.AsyncClient | None = None,
470469
terminate_on_close: bool = True,
@@ -481,10 +480,11 @@ async def streamable_http_client(
481480
482481
Args:
483482
url: The MCP server endpoint URL.
484-
extensions: Optional extensions to include in requests.
485483
http_client: Optional pre-configured httpx.AsyncClient. If None, a default
486484
client with recommended MCP timeouts will be created. To configure headers,
487485
authentication, or other HTTP settings, create an httpx.AsyncClient and pass it here.
486+
To include custom extensions in requests, set a `custom_extensions` attribute on the
487+
client: `client.custom_extensions = {"key": "value"}`.
488488
terminate_on_close: If True, send a DELETE request to terminate the session
489489
when the context exits.
490490
@@ -515,9 +515,12 @@ async def streamable_http_client(
515515
client.timeout.read if (client.timeout and client.timeout.read is not None) else MCP_DEFAULT_SSE_READ_TIMEOUT
516516
)
517517
auth = client.auth
518+
519+
# Extract custom extensions from the client if available
520+
custom_extensions = getattr(client, "custom_extensions", None)
518521

519522
# Create transport with extracted configuration
520-
transport = StreamableHTTPTransport(url, headers_dict, extensions, timeout, sse_read_timeout, auth)
523+
transport = StreamableHTTPTransport(url, headers_dict, custom_extensions, timeout, sse_read_timeout, auth)
521524

522525
async with anyio.create_task_group() as tg:
523526
try:

tests/shared/test_streamable_http.py

Lines changed: 118 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -1799,36 +1799,50 @@ async def test_extensions_passed_to_streamablehttp_client(self, basic_server: No
17991799
"custom_metadata": "custom_data",
18001800
}
18011801

1802-
async with streamable_http_client(f"{basic_server_url}/mcp", extensions=test_extensions) as (
1803-
read_stream,
1804-
write_stream,
1805-
_,
1806-
):
1807-
async with ClientSession(read_stream, write_stream) as session:
1808-
# Test initialization with extensions
1809-
result = await session.initialize()
1810-
assert isinstance(result, InitializeResult)
1811-
assert result.serverInfo.name == SERVER_NAME
1802+
# Create httpx client with extensions
1803+
custom_client = create_mcp_http_client()
1804+
setattr(custom_client, "custom_extensions", test_extensions)
1805+
1806+
async with custom_client:
1807+
async with streamable_http_client(f"{basic_server_url}/mcp", http_client=custom_client) as (
1808+
read_stream,
1809+
write_stream,
1810+
_,
1811+
):
1812+
async with ClientSession(read_stream, write_stream) as session:
1813+
# Test initialization with extensions
1814+
result = await session.initialize()
1815+
assert isinstance(result, InitializeResult)
1816+
assert result.serverInfo.name == SERVER_NAME
18121817

1813-
# Test that session works with extensions
1814-
tools = await session.list_tools()
1815-
assert len(tools.tools) == 6
1818+
# Test that session works with extensions
1819+
tools = await session.list_tools()
1820+
assert len(tools.tools) == 6
18161821

18171822
@pytest.mark.anyio
18181823
async def test_extensions_with_empty_dict(self, basic_server: None, basic_server_url: str):
18191824
"""Test streamablehttp_client with empty extensions dict."""
1820-
async with streamable_http_client(f"{basic_server_url}/mcp", extensions={}) as (read_stream, write_stream, _):
1821-
async with ClientSession(read_stream, write_stream) as session:
1822-
result = await session.initialize()
1823-
assert isinstance(result, InitializeResult)
1825+
# Create httpx client with empty extensions
1826+
custom_client = create_mcp_http_client()
1827+
setattr(custom_client, "custom_extensions", {})
1828+
1829+
async with custom_client:
1830+
async with streamable_http_client(f"{basic_server_url}/mcp", http_client=custom_client) as (read_stream, write_stream, _):
1831+
async with ClientSession(read_stream, write_stream) as session:
1832+
result = await session.initialize()
1833+
assert isinstance(result, InitializeResult)
18241834

18251835
@pytest.mark.anyio
18261836
async def test_extensions_with_none(self, basic_server: None, basic_server_url: str):
1827-
"""Test streamablehttp_client with None extensions."""
1828-
async with streamable_http_client(f"{basic_server_url}/mcp", extensions=None) as (read_stream, write_stream, _):
1829-
async with ClientSession(read_stream, write_stream) as session:
1830-
result = await session.initialize()
1831-
assert isinstance(result, InitializeResult)
1837+
"""Test streamablehttp_client with None extensions (no custom_extensions attribute)."""
1838+
# Create httpx client without setting custom_extensions
1839+
custom_client = create_mcp_http_client()
1840+
1841+
async with custom_client:
1842+
async with streamable_http_client(f"{basic_server_url}/mcp", http_client=custom_client) as (read_stream, write_stream, _):
1843+
async with ClientSession(read_stream, write_stream) as session:
1844+
result = await session.initialize()
1845+
assert isinstance(result, InitializeResult)
18321846

18331847
def test_extensions_request_context_creation(self):
18341848
"""Test that RequestContext includes extensions correctly."""
@@ -1887,23 +1901,33 @@ async def test_extensions_isolation_between_clients(self, basic_server: None, ba
18871901
# Create two clients with different extensions
18881902
results: list[tuple[str, str]] = []
18891903

1890-
async with streamable_http_client(f"{basic_server_url}/mcp", extensions=extensions_1) as (
1891-
read_stream1,
1892-
write_stream1,
1893-
_,
1894-
):
1895-
async with ClientSession(read_stream1, write_stream1) as session1:
1896-
result1 = await session1.initialize()
1897-
results.append(("client1", result1.serverInfo.name))
1898-
1899-
async with streamable_http_client(f"{basic_server_url}/mcp", extensions=extensions_2) as (
1900-
read_stream2,
1901-
write_stream2,
1902-
_,
1903-
):
1904-
async with ClientSession(read_stream2, write_stream2) as session2:
1905-
result2 = await session2.initialize()
1906-
results.append(("client2", result2.serverInfo.name))
1904+
# First client with extensions_1
1905+
custom_client1 = create_mcp_http_client()
1906+
setattr(custom_client1, "custom_extensions", extensions_1)
1907+
1908+
async with custom_client1:
1909+
async with streamable_http_client(f"{basic_server_url}/mcp", http_client=custom_client1) as (
1910+
read_stream1,
1911+
write_stream1,
1912+
_,
1913+
):
1914+
async with ClientSession(read_stream1, write_stream1) as session1:
1915+
result1 = await session1.initialize()
1916+
results.append(("client1", result1.serverInfo.name))
1917+
1918+
# Second client with extensions_2
1919+
custom_client2 = create_mcp_http_client()
1920+
setattr(custom_client2, "custom_extensions", extensions_2)
1921+
1922+
async with custom_client2:
1923+
async with streamable_http_client(f"{basic_server_url}/mcp", http_client=custom_client2) as (
1924+
read_stream2,
1925+
write_stream2,
1926+
_,
1927+
):
1928+
async with ClientSession(read_stream2, write_stream2) as session2:
1929+
result2 = await session2.initialize()
1930+
results.append(("client2", result2.serverInfo.name))
19071931

19081932
# Both clients should work independently
19091933
assert len(results) == 2
@@ -1952,9 +1976,10 @@ async def stream(self, *args: Any, **kwargs: Any):
19521976

19531977
# Create the custom client that will capture extensions
19541978
custom_client = ExtensionCapturingClient()
1979+
setattr(custom_client, "custom_extensions", test_extensions)
19551980

19561981
async with streamable_http_client(
1957-
f"{basic_server_url}/mcp/", extensions=test_extensions, http_client=custom_client
1982+
f"{basic_server_url}/mcp/", http_client=custom_client
19581983
) as (read_stream, write_stream, _):
19591984
async with ClientSession(read_stream, write_stream) as session:
19601985
# Initialize - this should make a POST request with extensions
@@ -1981,42 +2006,52 @@ async def test_extensions_with_json_and_sse_responses(self, basic_server: None,
19812006
"""Test that extensions work with both JSON and SSE response types."""
19822007
test_extensions = {"response_test": "json_sse_test", "format": "both"}
19832008

2009+
# Create httpx client with extensions
2010+
custom_client = create_mcp_http_client()
2011+
setattr(custom_client, "custom_extensions", test_extensions)
2012+
19842013
# Test with regular SSE response (default behavior)
1985-
async with streamable_http_client(f"{basic_server_url}/mcp", extensions=test_extensions) as (
1986-
read_stream,
1987-
write_stream,
1988-
_,
1989-
):
1990-
async with ClientSession(read_stream, write_stream) as session:
1991-
result = await session.initialize()
1992-
assert isinstance(result, InitializeResult)
2014+
async with custom_client:
2015+
async with streamable_http_client(f"{basic_server_url}/mcp", http_client=custom_client) as (
2016+
read_stream,
2017+
write_stream,
2018+
_,
2019+
):
2020+
async with ClientSession(read_stream, write_stream) as session:
2021+
result = await session.initialize()
2022+
assert isinstance(result, InitializeResult)
19932023

1994-
# Call tool which should work with SSE
1995-
tool_result = await session.call_tool("test_tool", {})
1996-
assert len(tool_result.content) == 1
1997-
content = tool_result.content[0]
1998-
assert content.type == "text"
1999-
from mcp.types import TextContent
2024+
# Call tool which should work with SSE
2025+
tool_result = await session.call_tool("test_tool", {})
2026+
assert len(tool_result.content) == 1
2027+
content = tool_result.content[0]
2028+
assert content.type == "text"
2029+
from mcp.types import TextContent
20002030

2001-
assert isinstance(content, TextContent)
2002-
assert content.text == "Called test_tool"
2031+
assert isinstance(content, TextContent)
2032+
assert content.text == "Called test_tool"
20032033

20042034
@pytest.mark.anyio
20052035
async def test_extensions_with_json_response_server(self, json_response_server: None, json_server_url: str):
20062036
"""Test extensions work with JSON response mode."""
20072037
test_extensions = {"response_mode": "json_only", "test_id": "json_test_123"}
20082038

2009-
async with streamable_http_client(f"{json_server_url}/mcp", extensions=test_extensions) as (
2010-
read_stream,
2011-
write_stream,
2012-
_,
2013-
):
2014-
async with ClientSession(read_stream, write_stream) as session:
2015-
result = await session.initialize()
2016-
assert isinstance(result, InitializeResult)
2039+
# Create httpx client with extensions
2040+
custom_client = create_mcp_http_client()
2041+
setattr(custom_client, "custom_extensions", test_extensions)
2042+
2043+
async with custom_client:
2044+
async with streamable_http_client(f"{json_server_url}/mcp", http_client=custom_client) as (
2045+
read_stream,
2046+
write_stream,
2047+
_,
2048+
):
2049+
async with ClientSession(read_stream, write_stream) as session:
2050+
result = await session.initialize()
2051+
assert isinstance(result, InitializeResult)
20172052

2018-
tools = await session.list_tools()
2019-
assert len(tools.tools) == 6
2053+
tools = await session.list_tools()
2054+
assert len(tools.tools) == 6
20202055

20212056
def test_extensions_type_validation(self):
20222057
"""Test that extensions parameter accepts proper types."""
@@ -2045,16 +2080,21 @@ async def test_extensions_with_special_characters(self, basic_server: None, basi
20452080
"url_like": "https://example.com/path?param=value",
20462081
}
20472082

2048-
async with streamable_http_client(f"{basic_server_url}/mcp", extensions=test_extensions) as (
2049-
read_stream,
2050-
write_stream,
2051-
_,
2052-
):
2053-
async with ClientSession(read_stream, write_stream) as session:
2054-
# Should not throw any errors with special characters
2055-
result = await session.initialize()
2056-
assert isinstance(result, InitializeResult)
2083+
# Create httpx client with extensions
2084+
custom_client = create_mcp_http_client()
2085+
setattr(custom_client, "custom_extensions", test_extensions)
2086+
2087+
async with custom_client:
2088+
async with streamable_http_client(f"{basic_server_url}/mcp", http_client=custom_client) as (
2089+
read_stream,
2090+
write_stream,
2091+
_,
2092+
):
2093+
async with ClientSession(read_stream, write_stream) as session:
2094+
# Should not throw any errors with special characters
2095+
result = await session.initialize()
2096+
assert isinstance(result, InitializeResult)
20572097

2058-
# Should work normally with tools
2059-
tools = await session.list_tools()
2060-
assert len(tools.tools) == 6
2098+
# Should work normally with tools
2099+
tools = await session.list_tools()
2100+
assert len(tools.tools) == 6

0 commit comments

Comments
 (0)