Skip to content

Commit dcca6e2

Browse files
add on_session_created callback option
1 parent c92bb2f commit dcca6e2

File tree

2 files changed

+30
-1
lines changed

2 files changed

+30
-1
lines changed

src/mcp/client/sse.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
import logging
2+
from collections.abc import Callable
23
from contextlib import asynccontextmanager
34
from typing import Any
4-
from urllib.parse import urljoin, urlparse
5+
from urllib.parse import parse_qs, urljoin, urlparse
56

67
import anyio
78
import httpx
@@ -29,6 +30,7 @@ async def sse_client(
2930
sse_read_timeout: float = 60 * 5,
3031
httpx_client_factory: McpHttpClientFactory = create_mcp_http_client,
3132
auth: httpx.Auth | None = None,
33+
on_session_created: Callable[[str], None] | None = None,
3234
):
3335
"""
3436
Client transport for SSE.
@@ -42,6 +44,7 @@ async def sse_client(
4244
timeout: HTTP timeout for regular operations.
4345
sse_read_timeout: Timeout for SSE read operations.
4446
auth: Optional HTTPX authentication handler.
47+
on_session_created: Optional callback invoked with the session ID when received.
4548
"""
4649
read_stream: MemoryObjectReceiveStream[SessionMessage | Exception]
4750
read_stream_writer: MemoryObjectSendStream[SessionMessage | Exception]
@@ -89,6 +92,13 @@ async def sse_reader(
8992
logger.error(error_msg) # pragma: no cover
9093
raise ValueError(error_msg) # pragma: no cover
9194

95+
if on_session_created:
96+
session_id = parse_qs(endpoint_parsed.query).get(
97+
"sessionId", [None]
98+
)[0]
99+
if session_id:
100+
on_session_created(session_id)
101+
92102
task_status.started(endpoint_url)
93103

94104
case "message":

tests/shared/test_sse.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,25 @@ async def test_sse_client_basic_connection(server: None, server_url: str) -> Non
184184
assert isinstance(ping_result, EmptyResult)
185185

186186

187+
@pytest.mark.anyio
188+
async def test_sse_client_on_session_created(server: None, server_url: str) -> None:
189+
captured_session_id: str | None = None
190+
191+
def on_session_created(session_id: str) -> None:
192+
nonlocal captured_session_id
193+
captured_session_id = session_id
194+
195+
async with sse_client(
196+
server_url + "/sse", on_session_created=on_session_created
197+
) as streams:
198+
async with ClientSession(*streams) as session:
199+
result = await session.initialize()
200+
assert isinstance(result, InitializeResult)
201+
202+
assert captured_session_id is not None
203+
assert len(captured_session_id) > 0
204+
205+
187206
@pytest.fixture
188207
async def initialized_sse_client_session(server: None, server_url: str) -> AsyncGenerator[ClientSession, None]:
189208
async with sse_client(server_url + "/sse", sse_read_timeout=0.5) as streams:

0 commit comments

Comments
 (0)