@@ -64,6 +64,7 @@ class RequestContext:
6464
6565 client : httpx .AsyncClient
6666 headers : dict [str , str ]
67+ extensions : dict [str , str ] | None
6768 session_id : str | None
6869 session_message : SessionMessage
6970 metadata : ClientMessageMetadata | None
@@ -78,6 +79,7 @@ def __init__(
7879 self ,
7980 url : str ,
8081 headers : dict [str , str ] | None = None ,
82+ extensions : dict [str , str ] | None = None ,
8183 timeout : float | timedelta = 30 ,
8284 sse_read_timeout : float | timedelta = 60 * 5 ,
8385 auth : httpx .Auth | None = None ,
@@ -87,12 +89,14 @@ def __init__(
8789 Args:
8890 url: The endpoint URL.
8991 headers: Optional headers to include in requests.
92+ extensions: Optional extensions to include in requests.
9093 timeout: HTTP timeout for regular operations.
9194 sse_read_timeout: Timeout for SSE read operations.
9295 auth: Optional HTTPX authentication handler.
9396 """
9497 self .url = url
9598 self .headers = headers or {}
99+ self .extensions = extensions or {}
96100 self .timeout = timeout .total_seconds () if isinstance (timeout , timedelta ) else timeout
97101 self .sse_read_timeout = (
98102 sse_read_timeout .total_seconds () if isinstance (sse_read_timeout , timedelta ) else sse_read_timeout
@@ -115,6 +119,12 @@ def _prepare_request_headers(self, base_headers: dict[str, str]) -> dict[str, st
115119 headers [MCP_PROTOCOL_VERSION ] = self .protocol_version
116120 return headers
117121
122+ def _prepare_request_extensions (self , base_extensions : dict [str , str ] | None ) -> dict [str , str ]:
123+ """Update extensions with session-specific data if available."""
124+ extensions = base_extensions .copy () if base_extensions else {}
125+ # Add any session-specific extensions here if needed
126+ return extensions
127+
118128 def _is_initialization_request (self , message : JSONRPCMessage ) -> bool :
119129 """Check if the message is an initialization request."""
120130 return isinstance (message .root , JSONRPCRequest ) and message .root .method == "initialize"
@@ -254,6 +264,7 @@ async def _handle_resumption_request(self, ctx: RequestContext) -> None:
254264 async def _handle_post_request (self , ctx : RequestContext ) -> None :
255265 """Handle a POST request with response processing."""
256266 headers = self ._prepare_request_headers (ctx .headers )
267+ extensions = self ._prepare_request_extensions (ctx .extensions )
257268 message = ctx .session_message .message
258269 is_initialization = self ._is_initialization_request (message )
259270
@@ -262,6 +273,7 @@ async def _handle_post_request(self, ctx: RequestContext) -> None:
262273 self .url ,
263274 json = message .model_dump (by_alias = True , mode = "json" , exclude_none = True ),
264275 headers = headers ,
276+ extensions = extensions ,
265277 ) as response :
266278 if response .status_code == 202 :
267279 logger .debug ("Received 202 Accepted" )
@@ -395,6 +407,7 @@ async def post_writer(
395407 ctx = RequestContext (
396408 client = client ,
397409 headers = self .request_headers ,
410+ extensions = self .extensions ,
398411 session_id = self .session_id ,
399412 session_message = session_message ,
400413 metadata = metadata ,
@@ -445,6 +458,7 @@ def get_session_id(self) -> str | None:
445458async def streamablehttp_client (
446459 url : str ,
447460 headers : dict [str , str ] | None = None ,
461+ extensions : dict [str , str ] | None = None ,
448462 timeout : float | timedelta = 30 ,
449463 sse_read_timeout : float | timedelta = 60 * 5 ,
450464 terminate_on_close : bool = True ,
@@ -470,7 +484,14 @@ async def streamablehttp_client(
470484 - write_stream: Stream for sending messages to the server
471485 - get_session_id_callback: Function to retrieve the current session ID
472486 """
473- transport = StreamableHTTPTransport (url , headers , timeout , sse_read_timeout , auth )
487+ transport = StreamableHTTPTransport (
488+ url = url ,
489+ headers = headers ,
490+ extensions = extensions ,
491+ timeout = timeout ,
492+ sse_read_timeout = sse_read_timeout ,
493+ auth = auth ,
494+ )
474495
475496 read_stream_writer , read_stream = anyio .create_memory_object_stream [SessionMessage | Exception ](0 )
476497 write_stream , write_stream_reader = anyio .create_memory_object_stream [SessionMessage ](0 )
0 commit comments