Skip to content

Commit 3d56170

Browse files
committed
Add unit tests
1 parent 45632d2 commit 3d56170

File tree

3 files changed

+366
-11
lines changed

3 files changed

+366
-11
lines changed

examples/clients/simple-streamable-private-gateway/README.md

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -85,13 +85,4 @@ mcp> quit
8585
## Configuration
8686

8787
- `MCP_SERVER_PORT` - Server port (default: 8000)
88-
- `MCP_SERVER_HOSTNAME` - Server hostname (default: 8000)
89-
90-
## Compatible Servers
91-
92-
This client works with any MCP server that doesn't require authentication, including:
93-
94-
- `examples/servers/simple-tool` - Basic tool server
95-
- `examples/servers/simple-resource` - Resource server
96-
- `examples/servers/simple-prompt` - Prompt server
97-
- Any custom MCP server without auth requirements
88+
- `MCP_SERVER_HOSTNAME` - Server hostname (default: localhost)

src/mcp/client/streamable_http.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ def __init__(
9696
"""
9797
self.url = url
9898
self.headers = headers or {}
99-
self.extensions = extensions or {}
99+
self.extensions = extensions.copy() if extensions else {}
100100
self.timeout = timeout.total_seconds() if isinstance(timeout, timedelta) else timeout
101101
self.sse_read_timeout = (
102102
sse_read_timeout.total_seconds() if isinstance(sse_read_timeout, timedelta) else sse_read_timeout

tests/shared/test_streamable_http.py

Lines changed: 364 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1597,3 +1597,367 @@ async def bad_client():
15971597
assert isinstance(result, InitializeResult)
15981598
tools = await session.list_tools()
15991599
assert tools.tools
1600+
1601+
1602+
# Extensions Tests
1603+
class TestStreamableHTTPExtensions:
1604+
"""Test class for StreamableHTTP extensions functionality."""
1605+
1606+
def test_extensions_initialization_none(self):
1607+
"""Test that extensions are properly initialized when None."""
1608+
from mcp.client.streamable_http import StreamableHTTPTransport
1609+
1610+
transport = StreamableHTTPTransport("http://test.example.com")
1611+
assert transport.extensions == {}
1612+
1613+
def test_extensions_initialization_empty_dict(self):
1614+
"""Test that extensions are properly initialized with empty dict."""
1615+
from mcp.client.streamable_http import StreamableHTTPTransport
1616+
1617+
transport = StreamableHTTPTransport("http://test.example.com", extensions={})
1618+
assert transport.extensions == {}
1619+
1620+
def test_extensions_initialization_with_data(self):
1621+
"""Test that extensions are properly initialized with provided data."""
1622+
from mcp.client.streamable_http import StreamableHTTPTransport
1623+
1624+
extensions = {"custom_extension": "test_value", "trace_id": "123456"}
1625+
transport = StreamableHTTPTransport("http://test.example.com", extensions=extensions)
1626+
assert transport.extensions == extensions
1627+
# Ensure it's a copy, not the same reference
1628+
assert transport.extensions is not extensions
1629+
1630+
def test_extensions_preparation_none_base(self):
1631+
"""Test that _prepare_request_extensions works with None base extensions."""
1632+
from mcp.client.streamable_http import StreamableHTTPTransport
1633+
1634+
transport = StreamableHTTPTransport("http://test.example.com")
1635+
result = transport._prepare_request_extensions(None)
1636+
assert result == {}
1637+
1638+
def test_extensions_preparation_empty_base(self):
1639+
"""Test that _prepare_request_extensions works with empty base extensions."""
1640+
from mcp.client.streamable_http import StreamableHTTPTransport
1641+
1642+
transport = StreamableHTTPTransport("http://test.example.com")
1643+
result = transport._prepare_request_extensions({})
1644+
assert result == {}
1645+
1646+
def test_extensions_preparation_with_base(self):
1647+
"""Test that _prepare_request_extensions works with base extensions."""
1648+
from mcp.client.streamable_http import StreamableHTTPTransport
1649+
1650+
transport = StreamableHTTPTransport("http://test.example.com")
1651+
base_extensions = {"request_id": "req_123", "custom": "value"}
1652+
result = transport._prepare_request_extensions(base_extensions)
1653+
assert result == base_extensions
1654+
# Ensure it's a copy, not the same reference
1655+
assert result is not base_extensions
1656+
1657+
def test_extensions_preparation_preserves_original(self):
1658+
"""Test that _prepare_request_extensions doesn't modify the original."""
1659+
from mcp.client.streamable_http import StreamableHTTPTransport
1660+
1661+
transport = StreamableHTTPTransport("http://test.example.com")
1662+
base_extensions = {"request_id": "req_123"}
1663+
original_extensions = base_extensions.copy()
1664+
1665+
result = transport._prepare_request_extensions(base_extensions)
1666+
1667+
# Original should be unchanged
1668+
assert base_extensions == original_extensions
1669+
# Result should be a copy
1670+
assert result == base_extensions
1671+
assert result is not base_extensions
1672+
1673+
@pytest.mark.anyio
1674+
async def test_extensions_passed_to_streamablehttp_client(self, basic_server: None, basic_server_url: str):
1675+
"""Test that extensions are properly passed through streamablehttp_client."""
1676+
test_extensions = {
1677+
"test_extension": "test_value",
1678+
"trace_id": "ext_trace_123",
1679+
"custom_metadata": "custom_data"
1680+
}
1681+
1682+
async with streamablehttp_client(
1683+
f"{basic_server_url}/mcp",
1684+
extensions=test_extensions
1685+
) as (read_stream, write_stream, _):
1686+
async with ClientSession(read_stream, write_stream) as session:
1687+
# Test initialization with extensions
1688+
result = await session.initialize()
1689+
assert isinstance(result, InitializeResult)
1690+
assert result.serverInfo.name == SERVER_NAME
1691+
1692+
# Test that session works with extensions
1693+
tools = await session.list_tools()
1694+
assert len(tools.tools) == 6
1695+
1696+
@pytest.mark.anyio
1697+
async def test_extensions_with_empty_dict(self, basic_server: None, basic_server_url: str):
1698+
"""Test streamablehttp_client with empty extensions dict."""
1699+
async with streamablehttp_client(
1700+
f"{basic_server_url}/mcp",
1701+
extensions={}
1702+
) as (read_stream, write_stream, _):
1703+
async with ClientSession(read_stream, write_stream) as session:
1704+
result = await session.initialize()
1705+
assert isinstance(result, InitializeResult)
1706+
1707+
@pytest.mark.anyio
1708+
async def test_extensions_with_none(self, basic_server: None, basic_server_url: str):
1709+
"""Test streamablehttp_client with None extensions."""
1710+
async with streamablehttp_client(
1711+
f"{basic_server_url}/mcp",
1712+
extensions=None
1713+
) as (read_stream, write_stream, _):
1714+
async with ClientSession(read_stream, write_stream) as session:
1715+
result = await session.initialize()
1716+
assert isinstance(result, InitializeResult)
1717+
1718+
def test_extensions_request_context_creation(self):
1719+
"""Test that RequestContext includes extensions correctly."""
1720+
from mcp.client.streamable_http import RequestContext, StreamableHTTPTransport
1721+
from mcp.shared.message import SessionMessage
1722+
from mcp.types import JSONRPCMessage, JSONRPCRequest
1723+
import httpx
1724+
import anyio
1725+
import asyncio
1726+
1727+
# Create transport with extensions
1728+
test_extensions = {"custom": "data", "trace": "123"}
1729+
transport = StreamableHTTPTransport(
1730+
"http://test.example.com",
1731+
extensions=test_extensions
1732+
)
1733+
1734+
async def run_test():
1735+
# Create mock objects for the context
1736+
client = httpx.AsyncClient()
1737+
read_stream_writer, read_stream_reader = anyio.create_memory_object_stream[SessionMessage | Exception](0)
1738+
1739+
try:
1740+
message = JSONRPCMessage(JSONRPCRequest(
1741+
jsonrpc="2.0",
1742+
method="test_method",
1743+
id="test_id"
1744+
))
1745+
session_message = SessionMessage(message)
1746+
1747+
# Create RequestContext
1748+
ctx = RequestContext(
1749+
client=client,
1750+
headers={},
1751+
extensions=transport.extensions,
1752+
session_id=None,
1753+
session_message=session_message,
1754+
metadata=None,
1755+
read_stream_writer=read_stream_writer,
1756+
sse_read_timeout=60
1757+
)
1758+
1759+
assert ctx.extensions == test_extensions
1760+
# RequestContext uses the same reference to extensions, which is acceptable
1761+
assert ctx.extensions is transport.extensions
1762+
finally:
1763+
# Clean up resources
1764+
await read_stream_writer.aclose()
1765+
await read_stream_reader.aclose()
1766+
await client.aclose()
1767+
1768+
# Run the async test
1769+
asyncio.run(run_test())
1770+
1771+
@pytest.mark.anyio
1772+
async def test_extensions_isolation_between_clients(self, basic_server: None, basic_server_url: str):
1773+
"""Test that extensions are isolated between different client instances."""
1774+
extensions_1 = {"client": "1", "session": "session_1"}
1775+
extensions_2 = {"client": "2", "session": "session_2"}
1776+
1777+
# Create two clients with different extensions
1778+
results: list[tuple[str, str]] = []
1779+
1780+
async with streamablehttp_client(
1781+
f"{basic_server_url}/mcp",
1782+
extensions=extensions_1
1783+
) as (read_stream1, write_stream1, _):
1784+
async with ClientSession(read_stream1, write_stream1) as session1:
1785+
result1 = await session1.initialize()
1786+
results.append(("client1", result1.serverInfo.name))
1787+
1788+
async with streamablehttp_client(
1789+
f"{basic_server_url}/mcp",
1790+
extensions=extensions_2
1791+
) as (read_stream2, write_stream2, _):
1792+
async with ClientSession(read_stream2, write_stream2) as session2:
1793+
result2 = await session2.initialize()
1794+
results.append(("client2", result2.serverInfo.name))
1795+
1796+
# Both clients should work independently
1797+
assert len(results) == 2
1798+
assert all(name == SERVER_NAME for _, name in results)
1799+
1800+
def test_extensions_immutability(self):
1801+
"""Test that modifying extensions after transport creation doesn't affect the transport."""
1802+
from mcp.client.streamable_http import StreamableHTTPTransport
1803+
1804+
original_extensions = {"mutable": "original"}
1805+
transport = StreamableHTTPTransport(
1806+
"http://test.example.com",
1807+
extensions=original_extensions
1808+
)
1809+
1810+
# Modify the original extensions dict
1811+
original_extensions["mutable"] = "modified"
1812+
original_extensions["new_key"] = "new_value"
1813+
1814+
# Transport should still have the original values
1815+
assert transport.extensions == {"mutable": "original"}
1816+
assert "new_key" not in transport.extensions
1817+
1818+
@pytest.mark.anyio
1819+
async def test_extensions_passed_to_httpx_requests(self, basic_server: None, basic_server_url: str):
1820+
"""Test that extensions are actually passed to httpx client requests."""
1821+
import httpx
1822+
from contextlib import asynccontextmanager
1823+
from typing import Any
1824+
1825+
test_extensions = {
1826+
"test_key": "test_value",
1827+
"trace_id": "httpx_trace_123"
1828+
}
1829+
1830+
captured_extensions: list[dict[str, str]] = []
1831+
1832+
# Create a mock httpx client that captures extensions
1833+
class ExtensionCapturingClient(httpx.AsyncClient):
1834+
def __init__(self, *args: Any, **kwargs: Any):
1835+
super().__init__(*args, **kwargs)
1836+
1837+
@asynccontextmanager
1838+
async def stream(self, *args: Any, **kwargs: Any):
1839+
# Capture extensions when stream is called
1840+
if 'extensions' in kwargs:
1841+
captured_extensions.append(kwargs['extensions'])
1842+
# Call the real stream method
1843+
async with super().stream(*args, **kwargs) as response:
1844+
yield response
1845+
1846+
# Custom client factory that returns our capturing client
1847+
def custom_client_factory(
1848+
headers: dict[str, str] | None = None,
1849+
timeout: httpx.Timeout | None = None,
1850+
auth: httpx.Auth | None = None
1851+
) -> httpx.AsyncClient:
1852+
return ExtensionCapturingClient(
1853+
headers=headers,
1854+
timeout=timeout,
1855+
auth=auth,
1856+
)
1857+
1858+
async with streamablehttp_client(
1859+
f"{basic_server_url}/mcp/",
1860+
extensions=test_extensions,
1861+
httpx_client_factory=custom_client_factory
1862+
) as (read_stream, write_stream, _):
1863+
async with ClientSession(read_stream, write_stream) as session:
1864+
# Initialize - this should make a POST request with extensions
1865+
await session.initialize()
1866+
1867+
# Make another request to capture more extensions usage
1868+
await session.list_tools()
1869+
1870+
# Verify extensions were captured in requests
1871+
assert len(captured_extensions) > 0
1872+
1873+
# Check that our test extensions were included
1874+
for captured in captured_extensions:
1875+
assert "test_key" in captured
1876+
assert captured["test_key"] == "test_value"
1877+
assert "trace_id" in captured
1878+
assert captured["trace_id"] == "httpx_trace_123"
1879+
1880+
@pytest.mark.anyio
1881+
async def test_extensions_with_json_and_sse_responses(self, basic_server: None, basic_server_url: str):
1882+
"""Test that extensions work with both JSON and SSE response types."""
1883+
test_extensions = {
1884+
"response_test": "json_sse_test",
1885+
"format": "both"
1886+
}
1887+
1888+
# Test with regular SSE response (default behavior)
1889+
async with streamablehttp_client(
1890+
f"{basic_server_url}/mcp",
1891+
extensions=test_extensions
1892+
) as (read_stream, write_stream, _):
1893+
async with ClientSession(read_stream, write_stream) as session:
1894+
result = await session.initialize()
1895+
assert isinstance(result, InitializeResult)
1896+
1897+
# Call tool which should work with SSE
1898+
tool_result = await session.call_tool("test_tool", {})
1899+
assert len(tool_result.content) == 1
1900+
content = tool_result.content[0]
1901+
assert content.type == "text"
1902+
from mcp.types import TextContent
1903+
assert isinstance(content, TextContent)
1904+
assert content.text == "Called test_tool"
1905+
1906+
@pytest.mark.anyio
1907+
async def test_extensions_with_json_response_server(self, json_response_server: None, json_server_url: str):
1908+
"""Test extensions work with JSON response mode."""
1909+
test_extensions = {
1910+
"response_mode": "json_only",
1911+
"test_id": "json_test_123"
1912+
}
1913+
1914+
async with streamablehttp_client(
1915+
f"{json_server_url}/mcp",
1916+
extensions=test_extensions
1917+
) as (read_stream, write_stream, _):
1918+
async with ClientSession(read_stream, write_stream) as session:
1919+
result = await session.initialize()
1920+
assert isinstance(result, InitializeResult)
1921+
1922+
tools = await session.list_tools()
1923+
assert len(tools.tools) == 6
1924+
1925+
def test_extensions_type_validation(self):
1926+
"""Test that extensions parameter accepts proper types."""
1927+
from mcp.client.streamable_http import StreamableHTTPTransport
1928+
1929+
# Test with valid dict[str, str]
1930+
valid_extensions = {"key1": "value1", "key2": "value2"}
1931+
transport = StreamableHTTPTransport("http://test.com", extensions=valid_extensions)
1932+
assert transport.extensions == valid_extensions
1933+
1934+
# Test with None (should default to empty dict)
1935+
transport_none = StreamableHTTPTransport("http://test.com", extensions=None)
1936+
assert transport_none.extensions == {}
1937+
1938+
# Test with empty dict
1939+
transport_empty = StreamableHTTPTransport("http://test.com", extensions={})
1940+
assert transport_empty.extensions == {}
1941+
1942+
@pytest.mark.anyio
1943+
async def test_extensions_with_special_characters(self, basic_server: None, basic_server_url: str):
1944+
"""Test that extensions work with special characters in values."""
1945+
test_extensions = {
1946+
"special_chars": "test-value_with.special@chars#123!",
1947+
"unicode": "test_测试_🔧",
1948+
"json_like": '{"nested": "value"}',
1949+
"url_like": "https://example.com/path?param=value",
1950+
}
1951+
1952+
async with streamablehttp_client(
1953+
f"{basic_server_url}/mcp",
1954+
extensions=test_extensions
1955+
) as (read_stream, write_stream, _):
1956+
async with ClientSession(read_stream, write_stream) as session:
1957+
# Should not throw any errors with special characters
1958+
result = await session.initialize()
1959+
assert isinstance(result, InitializeResult)
1960+
1961+
# Should work normally with tools
1962+
tools = await session.list_tools()
1963+
assert len(tools.tools) == 6

0 commit comments

Comments
 (0)