Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,14 @@ classifiers = [
requires-python = ">=3.12"
dependencies = [
"httpx>=0.28.1",
"mcp>=1.18.0",
"modelsdotdev==0.*",
"pydantic>=2.12.5",
"typing-extensions>=4.15.0",
]

[project.optional-dependencies]
anthropic = ["anthropic>=0.83.0"]
mcp = ["mcp>=1.18.0"]
openai = ["openai>=2.14.0"]

[build-system]
Expand All @@ -54,6 +54,7 @@ bump = true
[dependency-groups]
dev = [
"anthropic>=0.83.0",
"mcp>=1.18.0",
"python-dotenv>=1.2.1",
"pytest>=8.0",
"pytest-asyncio>=0.24",
Expand Down
63 changes: 45 additions & 18 deletions src/ai/agents/mcp/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,17 @@
import contextlib
import contextvars
import dataclasses
import importlib
import json
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, cast

if TYPE_CHECKING:
from collections.abc import AsyncIterator, Awaitable, Callable

import mcp.client.session
import mcp.types

from ... import types
from ... import errors, types
from ..agent import AgentTool, Tool

__all__ = [
Expand All @@ -39,6 +40,21 @@ class _Connection:
_pool_lock = asyncio.Lock()


def _import_mcp_module(module_name: str) -> Any:
"""Import an MCP module or raise the public optional dependency error."""
try:
return importlib.import_module(module_name)
except ModuleNotFoundError as exc:
root_module = module_name.partition(".")[0]
if exc.name not in {module_name, root_module}:
raise
raise errors.InstallationError(
"could not import `mcp`, which is required to use MCP tools, "
'you can install it with `pip install "ai[mcp]"` or '
'`uv add "ai[mcp]"`'
) from exc


@contextlib.asynccontextmanager
async def ensure_connection_pool() -> AsyncIterator[dict[str, _Connection]]:
pool = orig_pool = _pool.get()
Expand All @@ -60,7 +76,10 @@ async def _get_or_create_connection(
],
) -> mcp.client.session.ClientSession:
"""Get an existing connection or create a new one."""
import mcp.client.session as _mcp_session # noqa: PLC0415
mcp_session = _import_mcp_module("mcp.client.session")
client_session = cast(
"type[mcp.client.session.ClientSession]", mcp_session.ClientSession
)

pool = _pool.get()

Expand All @@ -80,7 +99,7 @@ async def _get_or_create_connection(
streams = await exit_stack.enter_async_context(transport_factory())
read_stream, write_stream = streams[0], streams[1]

client = _mcp_session.ClientSession(
client = client_session(
read_stream=read_stream,
write_stream=write_stream,
)
Expand All @@ -105,7 +124,10 @@ def _make_tool_fn(
"""Create a tool function that manages its own connection."""

async def call_tool(**kwargs: Any) -> Any:
import mcp.types as _mcp_types # noqa: PLC0415
mcp_types = _import_mcp_module("mcp.types")
text_content = cast(
"type[mcp.types.TextContent]", mcp_types.TextContent
)

client = await _get_or_create_connection(
connection_key, transport_factory
Expand All @@ -124,7 +146,7 @@ async def call_tool(**kwargs: Any) -> Any:
error_text = " ".join(
part.text
for part in result.content
if isinstance(part, _mcp_types.TextContent)
if isinstance(part, text_content)
)
raise RuntimeError(
f"MCP tool error: {error_text or 'Unknown error'}"
Expand All @@ -134,7 +156,7 @@ async def call_tool(**kwargs: Any) -> Any:
return result.structuredContent

for part in result.content:
if isinstance(part, _mcp_types.TextContent):
if isinstance(part, text_content):
text = part.text
if text.startswith(("{", "[")):
try:
Expand Down Expand Up @@ -177,18 +199,21 @@ async def get_stdio_tools(
)

"""
import mcp.client.stdio as _mcp_stdio # noqa: PLC0415
mcp_stdio = _import_mcp_module("mcp.client.stdio")

connection_key = f"stdio:{command}:{':'.join(args)}"

def transport_factory() -> contextlib.AbstractAsyncContextManager[Any]:
return _mcp_stdio.stdio_client(
_mcp_stdio.StdioServerParameters(
command=command,
args=list(args),
env=env,
cwd=cwd,
)
return cast(
"contextlib.AbstractAsyncContextManager[Any]",
mcp_stdio.stdio_client(
mcp_stdio.StdioServerParameters(
command=command,
args=list(args),
env=env,
cwd=cwd,
)
),
)

client = await _get_or_create_connection(connection_key, transport_factory)
Expand Down Expand Up @@ -230,14 +255,16 @@ async def get_http_tools(

"""
import httpx as _httpx # noqa: PLC0415
import mcp.client.streamable_http as _mcp_http # noqa: PLC0415

mcp_http = _import_mcp_module("mcp.client.streamable_http")

connection_key = f"http:{url}"

def transport_factory() -> contextlib.AbstractAsyncContextManager[Any]:
http_client = _httpx.AsyncClient(headers=headers) if headers else None
return _mcp_http.streamable_http_client(
url=url, http_client=http_client
return cast(
"contextlib.AbstractAsyncContextManager[Any]",
mcp_http.streamable_http_client(url=url, http_client=http_client),
)

async with ensure_connection_pool():
Expand Down
20 changes: 20 additions & 0 deletions tests/agents/mcp/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@

import contextlib
import dataclasses
import importlib
from typing import Any

import mcp.types
import pytest

import ai
from ai.agents.mcp.client import _mcp_tool_to_native
Expand Down Expand Up @@ -75,6 +77,24 @@ def test_mcp_tool_to_native_schema_preserved() -> None:
assert _function_args(native).description == "Echo input"


async def test_get_http_tools_raises_installation_error_without_mcp(
monkeypatch: pytest.MonkeyPatch,
) -> None:
real_import_module = importlib.import_module

def fake_import_module(name: str, package: str | None = None) -> Any:
if name == "mcp.client.streamable_http":
raise ModuleNotFoundError("No module named 'mcp'", name="mcp")
return real_import_module(name, package)

monkeypatch.setattr(importlib, "import_module", fake_import_module)

with pytest.raises(ai.InstallationError) as exc_info:
await ai.mcp.get_http_tools("https://mcp.example.com/mcp")

assert "ai[mcp]" in str(exc_info.value)


# -- End-to-end: MCP tool executes through Agent default loop ---------------


Expand Down
16 changes: 10 additions & 6 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading