Skip to content

Commit 35579f3

Browse files
committed
test(client): cover unadvertised capability branches
1 parent 4120cc5 commit 35579f3

2 files changed

Lines changed: 38 additions & 5 deletions

File tree

src/mcp/client/session_group.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from collections.abc import Callable
1212
from dataclasses import dataclass
1313
from types import TracebackType
14-
from typing import Any, TypeAlias
14+
from typing import Any, TypeAlias, cast
1515

1616
import anyio
1717
import httpx
@@ -332,7 +332,7 @@ async def _establish_session(
332332
async def _aggregate_components(self, server_info: types.Implementation, session: mcp.ClientSession) -> None:
333333
"""Aggregates prompts, resources, and tools from a given session."""
334334

335-
capabilities = session.initialize_result.capabilities if session.initialize_result else None
335+
capabilities = cast(types.InitializeResult, session.initialize_result).capabilities
336336

337337
# Create a reverse index so we can find all prompts, resources, and
338338
# tools belonging to this session. Used for removing components from
@@ -347,7 +347,7 @@ async def _aggregate_components(self, server_info: types.Implementation, session
347347
tool_to_session_temp: dict[str, mcp.ClientSession] = {}
348348

349349
# Query the server for its prompts and aggregate to list.
350-
if capabilities is None or capabilities.prompts is not None:
350+
if capabilities.prompts is not None:
351351
try:
352352
prompts = (await session.list_prompts()).prompts
353353
for prompt in prompts:
@@ -358,7 +358,7 @@ async def _aggregate_components(self, server_info: types.Implementation, session
358358
logging.warning(f"Could not fetch prompts: {err}")
359359

360360
# Query the server for its resources and aggregate to list.
361-
if capabilities is None or capabilities.resources is not None:
361+
if capabilities.resources is not None:
362362
try:
363363
resources = (await session.list_resources()).resources
364364
for resource in resources:
@@ -369,7 +369,7 @@ async def _aggregate_components(self, server_info: types.Implementation, session
369369
logging.warning(f"Could not fetch resources: {err}")
370370

371371
# Query the server for its tools and aggregate to list.
372-
if capabilities is None or capabilities.tools is not None:
372+
if capabilities.tools is not None:
373373
try:
374374
tools = (await session.list_tools()).tools
375375
for tool in tools:

tests/client/test_session_group.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,39 @@ async def handle_call_tool(ctx: ServerRequestContext, params: types.CallToolRequ
166166
client.session.list_resources = mock.AsyncMock(side_effect=AssertionError("list_resources() was called"))
167167

168168
await group.connect_with_session(client.initialize_result.server_info, client.session)
169+
await group.call_tool("ping")
170+
171+
assert not caplog.records
172+
173+
174+
@pytest.mark.anyio
175+
async def test_client_session_group_skips_unadvertised_tools_and_resources(
176+
caplog: pytest.LogCaptureFixture,
177+
):
178+
from mcp import Client
179+
from mcp.server import Server, ServerRequestContext
180+
181+
async def handle_list_prompts(
182+
ctx: ServerRequestContext, params: types.PaginatedRequestParams | None
183+
) -> types.ListPromptsResult:
184+
return types.ListPromptsResult(prompts=[types.Prompt(name="hello", description="Hello", arguments=[])])
185+
186+
server = Server(
187+
"prompts-only-server",
188+
on_list_prompts=handle_list_prompts,
189+
)
190+
191+
group = ClientSessionGroup()
192+
193+
with caplog.at_level(logging.WARNING):
194+
async with Client(server) as client:
195+
assert client.initialize_result.capabilities.tools is None
196+
assert client.initialize_result.capabilities.resources is None
197+
198+
client.session.list_tools = mock.AsyncMock(side_effect=AssertionError("list_tools() was called"))
199+
client.session.list_resources = mock.AsyncMock(side_effect=AssertionError("list_resources() was called"))
200+
201+
await group.connect_with_session(client.initialize_result.server_info, client.session)
169202

170203
assert not caplog.records
171204

0 commit comments

Comments
 (0)