Skip to content

Commit 2ed88b8

Browse files
committed
feat: add ClientSessionParameters to enhance ClientSessionGroup.connect_to_server method
1 parent 47bc4ca commit 2ed88b8

File tree

1 file changed

+40
-3
lines changed

1 file changed

+40
-3
lines changed

src/mcp/client/session_group.py

Lines changed: 40 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import contextlib
1212
import logging
1313
from collections.abc import Callable
14+
from dataclasses import dataclass
1415
from datetime import timedelta
1516
from types import TracebackType
1617
from typing import Any, TypeAlias, overload
@@ -27,6 +28,8 @@
2728
from mcp.shared.exceptions import McpError
2829
from mcp.shared.session import ProgressFnT
2930

31+
from .session import ElicitationFnT, ListRootsFnT, LoggingFnT, MessageHandlerFnT, SamplingFnT
32+
3033

3134
class SseServerParameters(BaseModel):
3235
"""Parameters for intializing a sse_client."""
@@ -66,6 +69,21 @@ class StreamableHttpParameters(BaseModel):
6669
ServerParameters: TypeAlias = StdioServerParameters | SseServerParameters | StreamableHttpParameters
6770

6871

72+
# Use dataclass instead of pydantic BaseModel
73+
# because pydantic BaseModel cannot handle Protocol fields.
74+
@dataclass
75+
class ClientSessionParameters:
76+
"""Parameters for establishing a client session to an MCP server."""
77+
78+
read_timeout_seconds: timedelta | None = None
79+
sampling_callback: SamplingFnT | None = None
80+
elicitation_callback: ElicitationFnT | None = None
81+
list_roots_callback: ListRootsFnT | None = None
82+
logging_callback: LoggingFnT | None = None
83+
message_handler: MessageHandlerFnT | None = None
84+
client_info: types.Implementation | None = None
85+
86+
6987
class ClientSessionGroup:
7088
"""Client for managing connections to multiple MCP servers.
7189
@@ -264,13 +282,16 @@ async def connect_with_session(
264282
async def connect_to_server(
265283
self,
266284
server_params: ServerParameters,
285+
session_params: ClientSessionParameters | None = None,
267286
) -> mcp.ClientSession:
268287
"""Connects to a single MCP server."""
269-
server_info, session = await self._establish_session(server_params)
288+
server_info, session = await self._establish_session(server_params, session_params)
270289
return await self.connect_with_session(server_info, session)
271290

272291
async def _establish_session(
273-
self, server_params: ServerParameters
292+
self,
293+
server_params: ServerParameters,
294+
session_params: ClientSessionParameters | None = None,
274295
) -> tuple[types.Implementation, mcp.ClientSession]:
275296
"""Establish a client session to an MCP server."""
276297

@@ -298,7 +319,23 @@ async def _establish_session(
298319
)
299320
read, write, _ = await session_stack.enter_async_context(client)
300321

301-
session = await session_stack.enter_async_context(mcp.ClientSession(read, write))
322+
if session_params is None:
323+
session = await session_stack.enter_async_context(mcp.ClientSession(read, write))
324+
else:
325+
session = await session_stack.enter_async_context(
326+
mcp.ClientSession(
327+
read,
328+
write,
329+
read_timeout_seconds=session_params.read_timeout_seconds,
330+
sampling_callback=session_params.sampling_callback,
331+
elicitation_callback=session_params.elicitation_callback,
332+
list_roots_callback=session_params.list_roots_callback,
333+
logging_callback=session_params.logging_callback,
334+
message_handler=session_params.message_handler,
335+
client_info=session_params.client_info,
336+
)
337+
)
338+
302339
result = await session.initialize()
303340

304341
# Session successfully initialized.

0 commit comments

Comments
 (0)