11import logging
2+ from collections .abc import Callable
23from contextlib import asynccontextmanager
34from typing import Any
4- from urllib .parse import urljoin , urlparse
5+ from urllib .parse import parse_qs , urljoin , urlparse
56
67import anyio
78import httpx
@@ -29,6 +30,7 @@ async def sse_client(
2930 sse_read_timeout : float = 60 * 5 ,
3031 httpx_client_factory : McpHttpClientFactory = create_mcp_http_client ,
3132 auth : httpx .Auth | None = None ,
33+ on_session_created : Callable [[str ], None ] | None = None ,
3234):
3335 """
3436 Client transport for SSE.
@@ -42,6 +44,7 @@ async def sse_client(
4244 timeout: HTTP timeout for regular operations.
4345 sse_read_timeout: Timeout for SSE read operations.
4446 auth: Optional HTTPX authentication handler.
47+ on_session_created: Optional callback invoked with the session ID when received.
4548 """
4649 read_stream : MemoryObjectReceiveStream [SessionMessage | Exception ]
4750 read_stream_writer : MemoryObjectSendStream [SessionMessage | Exception ]
@@ -89,6 +92,13 @@ async def sse_reader(
8992 logger .error (error_msg ) # pragma: no cover
9093 raise ValueError (error_msg ) # pragma: no cover
9194
95+ if on_session_created :
96+ session_id = parse_qs (endpoint_parsed .query ).get (
97+ "sessionId" , [None ]
98+ )[0 ]
99+ if session_id :
100+ on_session_created (session_id )
101+
92102 task_status .started (endpoint_url )
93103
94104 case "message" :
0 commit comments