Skip to content

Commit 4120cc5

Browse files
committed
fix(client): respect negotiated capabilities in ClientSessionGroup
1 parent 3eb5799 commit 4120cc5

2 files changed

Lines changed: 75 additions & 25 deletions

File tree

src/mcp/client/session_group.py

Lines changed: 30 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -332,6 +332,8 @@ async def _establish_session(
332332
async def _aggregate_components(self, server_info: types.Implementation, session: mcp.ClientSession) -> None:
333333
"""Aggregates prompts, resources, and tools from a given session."""
334334

335+
capabilities = session.initialize_result.capabilities if session.initialize_result else None
336+
335337
# Create a reverse index so we can find all prompts, resources, and
336338
# tools belonging to this session. Used for removing components from
337339
# the session group via self.disconnect_from_server.
@@ -345,35 +347,38 @@ async def _aggregate_components(self, server_info: types.Implementation, session
345347
tool_to_session_temp: dict[str, mcp.ClientSession] = {}
346348

347349
# Query the server for its prompts and aggregate to list.
348-
try:
349-
prompts = (await session.list_prompts()).prompts
350-
for prompt in prompts:
351-
name = self._component_name(prompt.name, server_info)
352-
prompts_temp[name] = prompt
353-
component_names.prompts.add(name)
354-
except MCPError as err: # pragma: no cover
355-
logging.warning(f"Could not fetch prompts: {err}")
350+
if capabilities is None or capabilities.prompts is not None:
351+
try:
352+
prompts = (await session.list_prompts()).prompts
353+
for prompt in prompts:
354+
name = self._component_name(prompt.name, server_info)
355+
prompts_temp[name] = prompt
356+
component_names.prompts.add(name)
357+
except MCPError as err: # pragma: no cover
358+
logging.warning(f"Could not fetch prompts: {err}")
356359

357360
# Query the server for its resources and aggregate to list.
358-
try:
359-
resources = (await session.list_resources()).resources
360-
for resource in resources:
361-
name = self._component_name(resource.name, server_info)
362-
resources_temp[name] = resource
363-
component_names.resources.add(name)
364-
except MCPError as err: # pragma: no cover
365-
logging.warning(f"Could not fetch resources: {err}")
361+
if capabilities is None or capabilities.resources is not None:
362+
try:
363+
resources = (await session.list_resources()).resources
364+
for resource in resources:
365+
name = self._component_name(resource.name, server_info)
366+
resources_temp[name] = resource
367+
component_names.resources.add(name)
368+
except MCPError as err: # pragma: no cover
369+
logging.warning(f"Could not fetch resources: {err}")
366370

367371
# Query the server for its tools and aggregate to list.
368-
try:
369-
tools = (await session.list_tools()).tools
370-
for tool in tools:
371-
name = self._component_name(tool.name, server_info)
372-
tools_temp[name] = tool
373-
tool_to_session_temp[name] = session
374-
component_names.tools.add(name)
375-
except MCPError as err: # pragma: no cover
376-
logging.warning(f"Could not fetch tools: {err}")
372+
if capabilities is None or capabilities.tools is not None:
373+
try:
374+
tools = (await session.list_tools()).tools
375+
for tool in tools:
376+
name = self._component_name(tool.name, server_info)
377+
tools_temp[name] = tool
378+
tool_to_session_temp[name] = session
379+
component_names.tools.add(name)
380+
except MCPError as err: # pragma: no cover
381+
logging.warning(f"Could not fetch tools: {err}")
377382

378383
# Clean up exit stack for session if we couldn't retrieve anything
379384
# from the server.

tests/client/test_session_group.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import contextlib
2+
import logging
23
from unittest import mock
34

45
import httpx
@@ -125,6 +126,50 @@ async def test_client_session_group_connect_to_server(mock_exit_stack: contextli
125126
mock_session.list_prompts.assert_awaited_once()
126127

127128

129+
@pytest.mark.anyio
130+
async def test_client_session_group_connect_with_session_respects_negotiated_capabilities(
131+
caplog: pytest.LogCaptureFixture,
132+
):
133+
from mcp import Client
134+
from mcp.server import Server, ServerRequestContext
135+
136+
async def handle_list_tools(
137+
ctx: ServerRequestContext, params: types.PaginatedRequestParams | None
138+
) -> types.ListToolsResult:
139+
return types.ListToolsResult(
140+
tools=[
141+
types.Tool(
142+
name="ping",
143+
description="Ping",
144+
input_schema={"type": "object", "properties": {}},
145+
)
146+
]
147+
)
148+
149+
async def handle_call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> types.CallToolResult:
150+
return types.CallToolResult(content=[types.TextContent(type="text", text="pong")])
151+
152+
server = Server(
153+
"tools-only-server",
154+
on_list_tools=handle_list_tools,
155+
on_call_tool=handle_call_tool,
156+
)
157+
158+
group = ClientSessionGroup()
159+
160+
with caplog.at_level(logging.WARNING):
161+
async with Client(server) as client:
162+
assert client.initialize_result.capabilities.prompts is None
163+
assert client.initialize_result.capabilities.resources is None
164+
165+
client.session.list_prompts = mock.AsyncMock(side_effect=AssertionError("list_prompts() was called"))
166+
client.session.list_resources = mock.AsyncMock(side_effect=AssertionError("list_resources() was called"))
167+
168+
await group.connect_with_session(client.initialize_result.server_info, client.session)
169+
170+
assert not caplog.records
171+
172+
128173
@pytest.mark.anyio
129174
async def test_client_session_group_connect_to_server_with_name_hook(mock_exit_stack: contextlib.AsyncExitStack):
130175
"""Test connecting with a component name hook."""

0 commit comments

Comments
 (0)