@@ -1597,3 +1597,367 @@ async def bad_client():
15971597 assert isinstance (result , InitializeResult )
15981598 tools = await session .list_tools ()
15991599 assert tools .tools
1600+
1601+
1602+ # Extensions Tests
1603+ class TestStreamableHTTPExtensions :
1604+ """Test class for StreamableHTTP extensions functionality."""
1605+
1606+ def test_extensions_initialization_none (self ):
1607+ """Test that extensions are properly initialized when None."""
1608+ from mcp .client .streamable_http import StreamableHTTPTransport
1609+
1610+ transport = StreamableHTTPTransport ("http://test.example.com" )
1611+ assert transport .extensions == {}
1612+
1613+ def test_extensions_initialization_empty_dict (self ):
1614+ """Test that extensions are properly initialized with empty dict."""
1615+ from mcp .client .streamable_http import StreamableHTTPTransport
1616+
1617+ transport = StreamableHTTPTransport ("http://test.example.com" , extensions = {})
1618+ assert transport .extensions == {}
1619+
1620+ def test_extensions_initialization_with_data (self ):
1621+ """Test that extensions are properly initialized with provided data."""
1622+ from mcp .client .streamable_http import StreamableHTTPTransport
1623+
1624+ extensions = {"custom_extension" : "test_value" , "trace_id" : "123456" }
1625+ transport = StreamableHTTPTransport ("http://test.example.com" , extensions = extensions )
1626+ assert transport .extensions == extensions
1627+ # Ensure it's a copy, not the same reference
1628+ assert transport .extensions is not extensions
1629+
1630+ def test_extensions_preparation_none_base (self ):
1631+ """Test that _prepare_request_extensions works with None base extensions."""
1632+ from mcp .client .streamable_http import StreamableHTTPTransport
1633+
1634+ transport = StreamableHTTPTransport ("http://test.example.com" )
1635+ result = transport ._prepare_request_extensions (None )
1636+ assert result == {}
1637+
1638+ def test_extensions_preparation_empty_base (self ):
1639+ """Test that _prepare_request_extensions works with empty base extensions."""
1640+ from mcp .client .streamable_http import StreamableHTTPTransport
1641+
1642+ transport = StreamableHTTPTransport ("http://test.example.com" )
1643+ result = transport ._prepare_request_extensions ({})
1644+ assert result == {}
1645+
1646+ def test_extensions_preparation_with_base (self ):
1647+ """Test that _prepare_request_extensions works with base extensions."""
1648+ from mcp .client .streamable_http import StreamableHTTPTransport
1649+
1650+ transport = StreamableHTTPTransport ("http://test.example.com" )
1651+ base_extensions = {"request_id" : "req_123" , "custom" : "value" }
1652+ result = transport ._prepare_request_extensions (base_extensions )
1653+ assert result == base_extensions
1654+ # Ensure it's a copy, not the same reference
1655+ assert result is not base_extensions
1656+
1657+ def test_extensions_preparation_preserves_original (self ):
1658+ """Test that _prepare_request_extensions doesn't modify the original."""
1659+ from mcp .client .streamable_http import StreamableHTTPTransport
1660+
1661+ transport = StreamableHTTPTransport ("http://test.example.com" )
1662+ base_extensions = {"request_id" : "req_123" }
1663+ original_extensions = base_extensions .copy ()
1664+
1665+ result = transport ._prepare_request_extensions (base_extensions )
1666+
1667+ # Original should be unchanged
1668+ assert base_extensions == original_extensions
1669+ # Result should be a copy
1670+ assert result == base_extensions
1671+ assert result is not base_extensions
1672+
1673+ @pytest .mark .anyio
1674+ async def test_extensions_passed_to_streamablehttp_client (self , basic_server : None , basic_server_url : str ):
1675+ """Test that extensions are properly passed through streamablehttp_client."""
1676+ test_extensions = {
1677+ "test_extension" : "test_value" ,
1678+ "trace_id" : "ext_trace_123" ,
1679+ "custom_metadata" : "custom_data"
1680+ }
1681+
1682+ async with streamablehttp_client (
1683+ f"{ basic_server_url } /mcp" ,
1684+ extensions = test_extensions
1685+ ) as (read_stream , write_stream , _ ):
1686+ async with ClientSession (read_stream , write_stream ) as session :
1687+ # Test initialization with extensions
1688+ result = await session .initialize ()
1689+ assert isinstance (result , InitializeResult )
1690+ assert result .serverInfo .name == SERVER_NAME
1691+
1692+ # Test that session works with extensions
1693+ tools = await session .list_tools ()
1694+ assert len (tools .tools ) == 6
1695+
1696+ @pytest .mark .anyio
1697+ async def test_extensions_with_empty_dict (self , basic_server : None , basic_server_url : str ):
1698+ """Test streamablehttp_client with empty extensions dict."""
1699+ async with streamablehttp_client (
1700+ f"{ basic_server_url } /mcp" ,
1701+ extensions = {}
1702+ ) as (read_stream , write_stream , _ ):
1703+ async with ClientSession (read_stream , write_stream ) as session :
1704+ result = await session .initialize ()
1705+ assert isinstance (result , InitializeResult )
1706+
1707+ @pytest .mark .anyio
1708+ async def test_extensions_with_none (self , basic_server : None , basic_server_url : str ):
1709+ """Test streamablehttp_client with None extensions."""
1710+ async with streamablehttp_client (
1711+ f"{ basic_server_url } /mcp" ,
1712+ extensions = None
1713+ ) as (read_stream , write_stream , _ ):
1714+ async with ClientSession (read_stream , write_stream ) as session :
1715+ result = await session .initialize ()
1716+ assert isinstance (result , InitializeResult )
1717+
1718+ def test_extensions_request_context_creation (self ):
1719+ """Test that RequestContext includes extensions correctly."""
1720+ from mcp .client .streamable_http import RequestContext , StreamableHTTPTransport
1721+ from mcp .shared .message import SessionMessage
1722+ from mcp .types import JSONRPCMessage , JSONRPCRequest
1723+ import httpx
1724+ import anyio
1725+ import asyncio
1726+
1727+ # Create transport with extensions
1728+ test_extensions = {"custom" : "data" , "trace" : "123" }
1729+ transport = StreamableHTTPTransport (
1730+ "http://test.example.com" ,
1731+ extensions = test_extensions
1732+ )
1733+
1734+ async def run_test ():
1735+ # Create mock objects for the context
1736+ client = httpx .AsyncClient ()
1737+ read_stream_writer , read_stream_reader = anyio .create_memory_object_stream [SessionMessage | Exception ](0 )
1738+
1739+ try :
1740+ message = JSONRPCMessage (JSONRPCRequest (
1741+ jsonrpc = "2.0" ,
1742+ method = "test_method" ,
1743+ id = "test_id"
1744+ ))
1745+ session_message = SessionMessage (message )
1746+
1747+ # Create RequestContext
1748+ ctx = RequestContext (
1749+ client = client ,
1750+ headers = {},
1751+ extensions = transport .extensions ,
1752+ session_id = None ,
1753+ session_message = session_message ,
1754+ metadata = None ,
1755+ read_stream_writer = read_stream_writer ,
1756+ sse_read_timeout = 60
1757+ )
1758+
1759+ assert ctx .extensions == test_extensions
1760+ # RequestContext uses the same reference to extensions, which is acceptable
1761+ assert ctx .extensions is transport .extensions
1762+ finally :
1763+ # Clean up resources
1764+ await read_stream_writer .aclose ()
1765+ await read_stream_reader .aclose ()
1766+ await client .aclose ()
1767+
1768+ # Run the async test
1769+ asyncio .run (run_test ())
1770+
1771+ @pytest .mark .anyio
1772+ async def test_extensions_isolation_between_clients (self , basic_server : None , basic_server_url : str ):
1773+ """Test that extensions are isolated between different client instances."""
1774+ extensions_1 = {"client" : "1" , "session" : "session_1" }
1775+ extensions_2 = {"client" : "2" , "session" : "session_2" }
1776+
1777+ # Create two clients with different extensions
1778+ results : list [tuple [str , str ]] = []
1779+
1780+ async with streamablehttp_client (
1781+ f"{ basic_server_url } /mcp" ,
1782+ extensions = extensions_1
1783+ ) as (read_stream1 , write_stream1 , _ ):
1784+ async with ClientSession (read_stream1 , write_stream1 ) as session1 :
1785+ result1 = await session1 .initialize ()
1786+ results .append (("client1" , result1 .serverInfo .name ))
1787+
1788+ async with streamablehttp_client (
1789+ f"{ basic_server_url } /mcp" ,
1790+ extensions = extensions_2
1791+ ) as (read_stream2 , write_stream2 , _ ):
1792+ async with ClientSession (read_stream2 , write_stream2 ) as session2 :
1793+ result2 = await session2 .initialize ()
1794+ results .append (("client2" , result2 .serverInfo .name ))
1795+
1796+ # Both clients should work independently
1797+ assert len (results ) == 2
1798+ assert all (name == SERVER_NAME for _ , name in results )
1799+
1800+ def test_extensions_immutability (self ):
1801+ """Test that modifying extensions after transport creation doesn't affect the transport."""
1802+ from mcp .client .streamable_http import StreamableHTTPTransport
1803+
1804+ original_extensions = {"mutable" : "original" }
1805+ transport = StreamableHTTPTransport (
1806+ "http://test.example.com" ,
1807+ extensions = original_extensions
1808+ )
1809+
1810+ # Modify the original extensions dict
1811+ original_extensions ["mutable" ] = "modified"
1812+ original_extensions ["new_key" ] = "new_value"
1813+
1814+ # Transport should still have the original values
1815+ assert transport .extensions == {"mutable" : "original" }
1816+ assert "new_key" not in transport .extensions
1817+
1818+ @pytest .mark .anyio
1819+ async def test_extensions_passed_to_httpx_requests (self , basic_server : None , basic_server_url : str ):
1820+ """Test that extensions are actually passed to httpx client requests."""
1821+ import httpx
1822+ from contextlib import asynccontextmanager
1823+ from typing import Any
1824+
1825+ test_extensions = {
1826+ "test_key" : "test_value" ,
1827+ "trace_id" : "httpx_trace_123"
1828+ }
1829+
1830+ captured_extensions : list [dict [str , str ]] = []
1831+
1832+ # Create a mock httpx client that captures extensions
1833+ class ExtensionCapturingClient (httpx .AsyncClient ):
1834+ def __init__ (self , * args : Any , ** kwargs : Any ):
1835+ super ().__init__ (* args , ** kwargs )
1836+
1837+ @asynccontextmanager
1838+ async def stream (self , * args : Any , ** kwargs : Any ):
1839+ # Capture extensions when stream is called
1840+ if 'extensions' in kwargs :
1841+ captured_extensions .append (kwargs ['extensions' ])
1842+ # Call the real stream method
1843+ async with super ().stream (* args , ** kwargs ) as response :
1844+ yield response
1845+
1846+ # Custom client factory that returns our capturing client
1847+ def custom_client_factory (
1848+ headers : dict [str , str ] | None = None ,
1849+ timeout : httpx .Timeout | None = None ,
1850+ auth : httpx .Auth | None = None
1851+ ) -> httpx .AsyncClient :
1852+ return ExtensionCapturingClient (
1853+ headers = headers ,
1854+ timeout = timeout ,
1855+ auth = auth ,
1856+ )
1857+
1858+ async with streamablehttp_client (
1859+ f"{ basic_server_url } /mcp/" ,
1860+ extensions = test_extensions ,
1861+ httpx_client_factory = custom_client_factory
1862+ ) as (read_stream , write_stream , _ ):
1863+ async with ClientSession (read_stream , write_stream ) as session :
1864+ # Initialize - this should make a POST request with extensions
1865+ await session .initialize ()
1866+
1867+ # Make another request to capture more extensions usage
1868+ await session .list_tools ()
1869+
1870+ # Verify extensions were captured in requests
1871+ assert len (captured_extensions ) > 0
1872+
1873+ # Check that our test extensions were included
1874+ for captured in captured_extensions :
1875+ assert "test_key" in captured
1876+ assert captured ["test_key" ] == "test_value"
1877+ assert "trace_id" in captured
1878+ assert captured ["trace_id" ] == "httpx_trace_123"
1879+
1880+ @pytest .mark .anyio
1881+ async def test_extensions_with_json_and_sse_responses (self , basic_server : None , basic_server_url : str ):
1882+ """Test that extensions work with both JSON and SSE response types."""
1883+ test_extensions = {
1884+ "response_test" : "json_sse_test" ,
1885+ "format" : "both"
1886+ }
1887+
1888+ # Test with regular SSE response (default behavior)
1889+ async with streamablehttp_client (
1890+ f"{ basic_server_url } /mcp" ,
1891+ extensions = test_extensions
1892+ ) as (read_stream , write_stream , _ ):
1893+ async with ClientSession (read_stream , write_stream ) as session :
1894+ result = await session .initialize ()
1895+ assert isinstance (result , InitializeResult )
1896+
1897+ # Call tool which should work with SSE
1898+ tool_result = await session .call_tool ("test_tool" , {})
1899+ assert len (tool_result .content ) == 1
1900+ content = tool_result .content [0 ]
1901+ assert content .type == "text"
1902+ from mcp .types import TextContent
1903+ assert isinstance (content , TextContent )
1904+ assert content .text == "Called test_tool"
1905+
1906+ @pytest .mark .anyio
1907+ async def test_extensions_with_json_response_server (self , json_response_server : None , json_server_url : str ):
1908+ """Test extensions work with JSON response mode."""
1909+ test_extensions = {
1910+ "response_mode" : "json_only" ,
1911+ "test_id" : "json_test_123"
1912+ }
1913+
1914+ async with streamablehttp_client (
1915+ f"{ json_server_url } /mcp" ,
1916+ extensions = test_extensions
1917+ ) as (read_stream , write_stream , _ ):
1918+ async with ClientSession (read_stream , write_stream ) as session :
1919+ result = await session .initialize ()
1920+ assert isinstance (result , InitializeResult )
1921+
1922+ tools = await session .list_tools ()
1923+ assert len (tools .tools ) == 6
1924+
1925+ def test_extensions_type_validation (self ):
1926+ """Test that extensions parameter accepts proper types."""
1927+ from mcp .client .streamable_http import StreamableHTTPTransport
1928+
1929+ # Test with valid dict[str, str]
1930+ valid_extensions = {"key1" : "value1" , "key2" : "value2" }
1931+ transport = StreamableHTTPTransport ("http://test.com" , extensions = valid_extensions )
1932+ assert transport .extensions == valid_extensions
1933+
1934+ # Test with None (should default to empty dict)
1935+ transport_none = StreamableHTTPTransport ("http://test.com" , extensions = None )
1936+ assert transport_none .extensions == {}
1937+
1938+ # Test with empty dict
1939+ transport_empty = StreamableHTTPTransport ("http://test.com" , extensions = {})
1940+ assert transport_empty .extensions == {}
1941+
1942+ @pytest .mark .anyio
1943+ async def test_extensions_with_special_characters (self , basic_server : None , basic_server_url : str ):
1944+ """Test that extensions work with special characters in values."""
1945+ test_extensions = {
1946+ "special_chars" : "test-value_with.special@chars#123!" ,
1947+ "unicode" : "test_测试_🔧" ,
1948+ "json_like" : '{"nested": "value"}' ,
1949+ "url_like" : "https://example.com/path?param=value" ,
1950+ }
1951+
1952+ async with streamablehttp_client (
1953+ f"{ basic_server_url } /mcp" ,
1954+ extensions = test_extensions
1955+ ) as (read_stream , write_stream , _ ):
1956+ async with ClientSession (read_stream , write_stream ) as session :
1957+ # Should not throw any errors with special characters
1958+ result = await session .initialize ()
1959+ assert isinstance (result , InitializeResult )
1960+
1961+ # Should work normally with tools
1962+ tools = await session .list_tools ()
1963+ assert len (tools .tools ) == 6
0 commit comments