Skip to content

Commit c9f31d5

Browse files
committed
Add support to custom extensions
1 parent 61399b3 commit c9f31d5

File tree

1 file changed

+22
-1
lines changed

1 file changed

+22
-1
lines changed

src/mcp/client/streamable_http.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ class RequestContext:
6464

6565
client: httpx.AsyncClient
6666
headers: dict[str, str]
67+
extensions: dict[str, str] | None
6768
session_id: str | None
6869
session_message: SessionMessage
6970
metadata: ClientMessageMetadata | None
@@ -78,6 +79,7 @@ def __init__(
7879
self,
7980
url: str,
8081
headers: dict[str, str] | None = None,
82+
extensions: dict[str, str] | None = None,
8183
timeout: float | timedelta = 30,
8284
sse_read_timeout: float | timedelta = 60 * 5,
8385
auth: httpx.Auth | None = None,
@@ -87,12 +89,14 @@ def __init__(
8789
Args:
8890
url: The endpoint URL.
8991
headers: Optional headers to include in requests.
92+
extensions: Optional extensions to include in requests.
9093
timeout: HTTP timeout for regular operations.
9194
sse_read_timeout: Timeout for SSE read operations.
9295
auth: Optional HTTPX authentication handler.
9396
"""
9497
self.url = url
9598
self.headers = headers or {}
99+
self.extensions = extensions or {}
96100
self.timeout = timeout.total_seconds() if isinstance(timeout, timedelta) else timeout
97101
self.sse_read_timeout = (
98102
sse_read_timeout.total_seconds() if isinstance(sse_read_timeout, timedelta) else sse_read_timeout
@@ -115,6 +119,12 @@ def _prepare_request_headers(self, base_headers: dict[str, str]) -> dict[str, st
115119
headers[MCP_PROTOCOL_VERSION] = self.protocol_version
116120
return headers
117121

122+
def _prepare_request_extensions(self, base_extensions: dict[str, str] | None) -> dict[str, str]:
123+
"""Update extensions with session-specific data if available."""
124+
extensions = base_extensions.copy() if base_extensions else {}
125+
# Add any session-specific extensions here if needed
126+
return extensions
127+
118128
def _is_initialization_request(self, message: JSONRPCMessage) -> bool:
119129
"""Check if the message is an initialization request."""
120130
return isinstance(message.root, JSONRPCRequest) and message.root.method == "initialize"
@@ -254,6 +264,7 @@ async def _handle_resumption_request(self, ctx: RequestContext) -> None:
254264
async def _handle_post_request(self, ctx: RequestContext) -> None:
255265
"""Handle a POST request with response processing."""
256266
headers = self._prepare_request_headers(ctx.headers)
267+
extensions = self._prepare_request_extensions(ctx.extensions)
257268
message = ctx.session_message.message
258269
is_initialization = self._is_initialization_request(message)
259270

@@ -262,6 +273,7 @@ async def _handle_post_request(self, ctx: RequestContext) -> None:
262273
self.url,
263274
json=message.model_dump(by_alias=True, mode="json", exclude_none=True),
264275
headers=headers,
276+
extensions=extensions,
265277
) as response:
266278
if response.status_code == 202:
267279
logger.debug("Received 202 Accepted")
@@ -395,6 +407,7 @@ async def post_writer(
395407
ctx = RequestContext(
396408
client=client,
397409
headers=self.request_headers,
410+
extensions=self.extensions,
398411
session_id=self.session_id,
399412
session_message=session_message,
400413
metadata=metadata,
@@ -445,6 +458,7 @@ def get_session_id(self) -> str | None:
445458
async def streamablehttp_client(
446459
url: str,
447460
headers: dict[str, str] | None = None,
461+
extensions: dict[str, str] | None = None,
448462
timeout: float | timedelta = 30,
449463
sse_read_timeout: float | timedelta = 60 * 5,
450464
terminate_on_close: bool = True,
@@ -470,7 +484,14 @@ async def streamablehttp_client(
470484
- write_stream: Stream for sending messages to the server
471485
- get_session_id_callback: Function to retrieve the current session ID
472486
"""
473-
transport = StreamableHTTPTransport(url, headers, timeout, sse_read_timeout, auth)
487+
transport = StreamableHTTPTransport(
488+
url=url,
489+
headers=headers,
490+
extensions=extensions,
491+
timeout=timeout,
492+
sse_read_timeout=sse_read_timeout,
493+
auth=auth,
494+
)
474495

475496
read_stream_writer, read_stream = anyio.create_memory_object_stream[SessionMessage | Exception](0)
476497
write_stream, write_stream_reader = anyio.create_memory_object_stream[SessionMessage](0)

0 commit comments

Comments
 (0)