|
11 | 11 | import contextlib |
12 | 12 | import logging |
13 | 13 | from collections.abc import Callable |
| 14 | +from dataclasses import dataclass |
14 | 15 | from datetime import timedelta |
15 | 16 | from types import TracebackType |
16 | 17 | from typing import Any, TypeAlias, overload |
|
27 | 28 | from mcp.shared.exceptions import McpError |
28 | 29 | from mcp.shared.session import ProgressFnT |
29 | 30 |
|
| 31 | +from .session import ElicitationFnT, ListRootsFnT, LoggingFnT, MessageHandlerFnT, SamplingFnT |
| 32 | + |
30 | 33 |
|
31 | 34 | class SseServerParameters(BaseModel): |
32 | 35 | """Parameters for intializing a sse_client.""" |
@@ -66,6 +69,21 @@ class StreamableHttpParameters(BaseModel): |
66 | 69 | ServerParameters: TypeAlias = StdioServerParameters | SseServerParameters | StreamableHttpParameters |
67 | 70 |
|
68 | 71 |
|
| 72 | +# Use dataclass instead of pydantic BaseModel |
| 73 | +# because pydantic BaseModel cannot handle Protocol fields. |
| 74 | +@dataclass |
| 75 | +class ClientSessionParameters: |
| 76 | + """Parameters for establishing a client session to an MCP server.""" |
| 77 | + |
| 78 | + read_timeout_seconds: timedelta | None = None |
| 79 | + sampling_callback: SamplingFnT | None = None |
| 80 | + elicitation_callback: ElicitationFnT | None = None |
| 81 | + list_roots_callback: ListRootsFnT | None = None |
| 82 | + logging_callback: LoggingFnT | None = None |
| 83 | + message_handler: MessageHandlerFnT | None = None |
| 84 | + client_info: types.Implementation | None = None |
| 85 | + |
| 86 | + |
69 | 87 | class ClientSessionGroup: |
70 | 88 | """Client for managing connections to multiple MCP servers. |
71 | 89 |
|
@@ -264,13 +282,16 @@ async def connect_with_session( |
264 | 282 | async def connect_to_server( |
265 | 283 | self, |
266 | 284 | server_params: ServerParameters, |
| 285 | + session_params: ClientSessionParameters | None = None, |
267 | 286 | ) -> mcp.ClientSession: |
268 | 287 | """Connects to a single MCP server.""" |
269 | | - server_info, session = await self._establish_session(server_params) |
| 288 | + server_info, session = await self._establish_session(server_params, session_params) |
270 | 289 | return await self.connect_with_session(server_info, session) |
271 | 290 |
|
272 | 291 | async def _establish_session( |
273 | | - self, server_params: ServerParameters |
| 292 | + self, |
| 293 | + server_params: ServerParameters, |
| 294 | + session_params: ClientSessionParameters | None = None, |
274 | 295 | ) -> tuple[types.Implementation, mcp.ClientSession]: |
275 | 296 | """Establish a client session to an MCP server.""" |
276 | 297 |
|
@@ -298,7 +319,23 @@ async def _establish_session( |
298 | 319 | ) |
299 | 320 | read, write, _ = await session_stack.enter_async_context(client) |
300 | 321 |
|
301 | | - session = await session_stack.enter_async_context(mcp.ClientSession(read, write)) |
| 322 | + if session_params is None: |
| 323 | + session = await session_stack.enter_async_context(mcp.ClientSession(read, write)) |
| 324 | + else: |
| 325 | + session = await session_stack.enter_async_context( |
| 326 | + mcp.ClientSession( |
| 327 | + read, |
| 328 | + write, |
| 329 | + read_timeout_seconds=session_params.read_timeout_seconds, |
| 330 | + sampling_callback=session_params.sampling_callback, |
| 331 | + elicitation_callback=session_params.elicitation_callback, |
| 332 | + list_roots_callback=session_params.list_roots_callback, |
| 333 | + logging_callback=session_params.logging_callback, |
| 334 | + message_handler=session_params.message_handler, |
| 335 | + client_info=session_params.client_info, |
| 336 | + ) |
| 337 | + ) |
| 338 | + |
302 | 339 | result = await session.initialize() |
303 | 340 |
|
304 | 341 | # Session successfully initialized. |
|
0 commit comments