Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 56 additions & 2 deletions python/packages/core/agent_framework/_mcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import asyncio
import base64
import contextvars
import json
import logging
import re
Expand All @@ -22,6 +23,7 @@
from mcp.client.stdio import StdioServerParameters, stdio_client
from mcp.client.streamable_http import streamable_http_client
from mcp.client.websocket import websocket_client
from mcp.shared._httpx_utils import MCP_DEFAULT_SSE_READ_TIMEOUT, MCP_DEFAULT_TIMEOUT
from mcp.shared.context import RequestContext
from mcp.shared.exceptions import McpError
from mcp.shared.session import RequestResponder
Expand Down Expand Up @@ -60,6 +62,7 @@ class MCPSpecificApproval(TypedDict, total=False):
logger = logging.getLogger(__name__)
_MCP_REMOTE_NAME_KEY = "_mcp_remote_name"
_MCP_NORMALIZED_NAME_KEY = "_mcp_normalized_name"
_mcp_call_headers: contextvars.ContextVar[dict[str, str]] = contextvars.ContextVar("_mcp_call_headers")

# region: Helpers

Expand Down Expand Up @@ -1334,6 +1337,7 @@ def __init__(
client: SupportsChatGetResponse | None = None,
additional_properties: dict[str, Any] | None = None,
http_client: httpx.AsyncClient | None = None,
header_provider: Callable[[dict[str, Any]], dict[str, str]] | None = None,
**kwargs: Any,
) -> None:
"""Initialize the MCP streamable HTTP tool.
Expand Down Expand Up @@ -1381,6 +1385,11 @@ def __init__(
``streamable_http_client`` API will create and manage a default client.
To configure headers, timeouts, or other HTTP client settings, create
and pass your own ``httpx.AsyncClient`` instance.
header_provider: Optional callable that receives the runtime keyword arguments
(from ``FunctionInvocationContext.kwargs``) and returns a ``dict[str, str]``
of HTTP headers to inject into every outbound request to the MCP server.
Use this to forward per-request context (e.g. authentication tokens set in
agent middleware) without creating a separate ``httpx.AsyncClient``.
kwargs: Additional keyword arguments (accepted for backward compatibility but not used).
"""
super().__init__(
Expand All @@ -1401,20 +1410,65 @@ def __init__(
self.url = url
self.terminate_on_close = terminate_on_close
self._httpx_client: httpx.AsyncClient | None = http_client
self._header_provider = header_provider

def get_mcp_client(self) -> _AsyncGeneratorContextManager[Any, None]:
"""Get an MCP streamable HTTP client.

Returns:
An async context manager for the streamable HTTP client transport.
"""
# Pass the http_client (which may be None) to streamable_http_client
http_client = self._httpx_client
if self._header_provider is not None:
if http_client is None:
http_client = httpx.AsyncClient(
follow_redirects=True,
timeout=httpx.Timeout(MCP_DEFAULT_TIMEOUT, read=MCP_DEFAULT_SSE_READ_TIMEOUT),
)
self._httpx_client = http_client

if not hasattr(self, "_inject_headers_hook"):

async def _inject_headers(request: httpx.Request) -> None: # noqa: RUF029
headers = _mcp_call_headers.get({})
for key, value in headers.items():
request.headers[key] = value

self._inject_headers_hook = _inject_headers # type: ignore[attr-defined]
http_client.event_hooks["request"].append(self._inject_headers_hook) # type: ignore[attr-defined]

return streamable_http_client(
url=self.url,
http_client=self._httpx_client,
http_client=http_client,
terminate_on_close=self.terminate_on_close if self.terminate_on_close is not None else True,
)

async def call_tool(self, tool_name: str, **kwargs: Any) -> str | list[Content]:
"""Call a tool, injecting headers from the header_provider if configured.

When a ``header_provider`` was supplied at construction time, the runtime
*kwargs* (originating from ``FunctionInvocationContext.kwargs``) are passed
to the provider. The returned headers are attached to every HTTP request
made during this tool call via a ``contextvars.ContextVar``.

Args:
tool_name: The name of the tool to call.

Keyword Args:
kwargs: Arguments to pass to the tool.

Returns:
A list of Content items representing the tool output.
"""
if self._header_provider is not None:
headers = self._header_provider(kwargs)
token = _mcp_call_headers.set(headers)
try:
return await super().call_tool(tool_name, **kwargs)
finally:
_mcp_call_headers.reset(token)
return await super().call_tool(tool_name, **kwargs)


class MCPWebsocketTool(MCPTool):
"""MCP tool for connecting to WebSocket-based MCP servers.
Expand Down
Loading
Loading