Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 22 additions & 3 deletions src/mcp/server/fastmcp/tools/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@

from mcp.server.fastmcp.exceptions import ToolError
from mcp.server.fastmcp.utilities.context_injection import find_context_parameter
from mcp.server.fastmcp.utilities.dependencies import DependencyResolver, find_dependencies
from mcp.server.fastmcp.utilities.func_metadata import FuncMetadata, func_metadata
from mcp.types import Icon, ToolAnnotations
from mcp.types import Depends, Icon, ToolAnnotations

if TYPE_CHECKING:
from mcp.server.fastmcp.server import Context
Expand All @@ -32,6 +33,7 @@ class Tool(BaseModel):
)
is_async: bool = Field(description="Whether the tool is async")
context_kwarg: str | None = Field(None, description="Name of the kwarg that should receive context")
dependencies: dict[str, Depends] | None = Field(None, description="Tool dependencies")
annotations: ToolAnnotations | None = Field(None, description="Optional annotations for the tool")
icons: list[Icon] | None = Field(default=None, description="Optional list of icons for this tool")

Expand All @@ -47,6 +49,7 @@ def from_function(
title: str | None = None,
description: str | None = None,
context_kwarg: str | None = None,
dependencies: dict[str, Depends] | None = None,
annotations: ToolAnnotations | None = None,
icons: list[Icon] | None = None,
structured_output: bool | None = None,
Expand All @@ -63,9 +66,16 @@ def from_function(
if context_kwarg is None:
context_kwarg = find_context_parameter(fn)

if dependencies is None:
dependencies = find_dependencies(fn)

skip_names = [context_kwarg] if context_kwarg is not None else []
if dependencies:
skip_names.extend(dependencies.keys())

func_arg_metadata = func_metadata(
fn,
skip_names=[context_kwarg] if context_kwarg is not None else [],
skip_names=skip_names,
structured_output=structured_output,
)
parameters = func_arg_metadata.arg_model.model_json_schema(by_alias=True)
Expand All @@ -79,6 +89,7 @@ def from_function(
fn_metadata=func_arg_metadata,
is_async=is_async,
context_kwarg=context_kwarg,
dependencies=dependencies,
annotations=annotations,
icons=icons,
)
Expand All @@ -90,12 +101,20 @@ async def run(
convert_result: bool = False,
) -> Any:
"""Run the tool with arguments."""
dependency_resolver = DependencyResolver()
try:
# Resolve dependencies
resolved_dependencies = await dependency_resolver.resolve_dependencies(self.dependencies or {})

# Prepare arguments to pass directly to the function
arguments_to_pass_directly = {self.context_kwarg: context} if self.context_kwarg is not None else {}
arguments_to_pass_directly.update(resolved_dependencies)

result = await self.fn_metadata.call_fn_with_arg_validation(
self.fn,
self.is_async,
arguments,
{self.context_kwarg: context} if self.context_kwarg is not None else None,
arguments_to_pass_directly,
)

if convert_result:
Expand Down
125 changes: 125 additions & 0 deletions src/mcp/server/fastmcp/utilities/dependencies.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
import inspect
from collections.abc import AsyncGenerator, Callable, Generator
from typing import Annotated, Any, get_args, get_origin, get_type_hints

from mcp.types import Depends


def find_dependencies(fn: Callable[..., Any]) -> dict[str, Depends]:
"""Find all dependencies in a function's parameters."""
# Get type hints to properly resolve string annotations
try:
hints = get_type_hints(fn, include_extras=True)
except Exception:
# If we can't resolve type hints, we can't find dependencies
hints = {}

dependencies: dict[str, Depends] = {}

# Get function signature to access parameter defaults
sig = inspect.signature(fn)

# Check each parameter's type hint and default value
for param_name, param in sig.parameters.items():
# Check if it's in Annotated form
if param_name in hints:
annotation = hints[param_name]
if get_origin(annotation) is Annotated:
_, *extras = get_args(annotation)
dep = next((x for x in extras if isinstance(x, Depends)), None)
if dep is not None:
dependencies[param_name] = dep
continue

# Check if default value is a Depends instance
if param.default is not inspect.Parameter.empty and isinstance(param.default, Depends):
dependencies[param_name] = param.default

return dependencies


def _is_async_callable(obj: Any) -> bool:
"""Check if a callable is async."""
return inspect.iscoroutinefunction(obj) or (
callable(obj) and inspect.iscoroutinefunction(getattr(obj, "__call__", None))
)


def _is_generator_function(obj: Any) -> bool:
"""Check if a callable is a generator function."""
return inspect.isgeneratorfunction(obj)


def _is_async_generator_function(obj: Any) -> bool:
"""Check if a callable is an async generator function."""
return inspect.isasyncgenfunction(obj)


class DependencyResolver:
"""Resolve dependencies and clean up properly when errors occur."""

def __init__(self):
self._generators: list[Generator[Any, None, None]] = []
self._async_generators: list[AsyncGenerator[Any, None]] = []

async def resolve_dependencies(self, dependencies: dict[str, Depends]) -> dict[str, Any]:
"""Resolve all dependencies and return their values."""
if not dependencies:
return {}

resolved: dict[str, Any] = {}

for param_name, depends in dependencies.items():
try:
resolved[param_name] = await self._resolve_single_dependency(depends)
except Exception as e:
# Cleanup any generators and async generators that were already created
await self.cleanup()
raise RuntimeError(f"Failed to resolve dependency '{param_name}': {e}") from e

return resolved

async def _resolve_single_dependency(self, depends: Depends) -> Any:
"""Resolve a single dependency."""
dependency_fn = depends.dependency

if _is_async_generator_function(dependency_fn):
gen = dependency_fn()
self._async_generators.append(gen)
try:
value = await gen.__anext__()
return value
except StopAsyncIteration:
raise RuntimeError(f"Async generator dependency {dependency_fn.__name__} didn't yield a value")

elif _is_generator_function(dependency_fn):
gen = dependency_fn()
self._generators.append(gen)
try:
value = next(gen)
return value
except StopIteration:
raise RuntimeError(f"Generator dependency {dependency_fn.__name__} didn't yield a value")

elif _is_async_callable(dependency_fn):
return await dependency_fn()

else:
return dependency_fn()

async def cleanup(self):
"""Cleanup all generator dependencies."""
for gen in self._async_generators:
try:
await gen.aclose()
except Exception:
pass

for gen in self._generators:
try:
gen.close()
except Exception:
pass

self._generators.clear()
self._async_generators.clear()
6 changes: 6 additions & 0 deletions src/mcp/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,12 @@
AnyFunction: TypeAlias = Callable[..., Any]


class Depends(BaseModel):
"""Dependency injection for tool parameters."""

dependency: Callable[..., Any]


class RequestParams(BaseModel):
class Meta(BaseModel):
progressToken: ProgressToken | None = None
Expand Down
138 changes: 134 additions & 4 deletions tests/server/fastmcp/test_server.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import asyncio
import base64
from collections.abc import AsyncGenerator, Generator
from pathlib import Path
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Annotated, Any
from unittest.mock import patch

import pytest
Expand All @@ -10,16 +12,16 @@
from mcp.server.fastmcp import Context, FastMCP
from mcp.server.fastmcp.prompts.base import Message, UserMessage
from mcp.server.fastmcp.resources import FileResource, FunctionResource
from mcp.server.fastmcp.utilities.dependencies import DependencyResolver
from mcp.server.fastmcp.utilities.types import Audio, Image
from mcp.server.session import ServerSession
from mcp.shared.exceptions import McpError
from mcp.shared.memory import (
create_connected_server_and_client_session as client_session,
)
from mcp.shared.memory import create_connected_server_and_client_session as client_session
from mcp.types import (
AudioContent,
BlobResourceContents,
ContentBlock,
Depends,
EmbeddedResource,
ImageContent,
TextContent,
Expand Down Expand Up @@ -906,6 +908,134 @@ def get_csv(user: str) -> str:
assert result.contents[0].text == "csv for bob"


class TestDependenciesInjection:
"""Test dependency injection functionality."""

@pytest.mark.anyio
async def test_tool_with_regular_dependency(self):
"""Test tool with regular function dependency."""
mcp = FastMCP()

def load_resource() -> int:
return 42

@mcp.tool()
def add_numbers(a: int, b: int, resource: Annotated[int, Depends(dependency=load_resource)]) -> int:
return a + b + resource

async with client_session(mcp._mcp_server) as client:
result = await client.call_tool("add_numbers", {"a": 1, "b": 2})
assert len(result.content) == 1
content = result.content[0]
assert isinstance(content, TextContent)
assert content.text == "45" # 1 + 2 + 42

@pytest.mark.anyio
async def test_tool_with_async_dependency(self):
"""Test tool with async function dependency."""
mcp = FastMCP()

async def load_async_resource() -> str:
await asyncio.sleep(0.01)
return "async_data"

@mcp.tool()
async def process_text(text: str, data: Annotated[str, Depends(dependency=load_async_resource)]) -> str:
return f"Processed '{text}' with {data}"

async with client_session(mcp._mcp_server) as client:
result = await client.call_tool("process_text", {"text": "hello"})
assert len(result.content) == 1
content = result.content[0]
assert isinstance(content, TextContent)
assert content.text == "Processed 'hello' with async_data"

@pytest.mark.anyio
async def test_tool_with_generator_dependency_cleanup(self):
"""Test tool with generator dependency and proper cleanup."""
mcp = FastMCP()
cleanup_called = False

def database_connection() -> Generator[str, None, None]:
nonlocal cleanup_called
try:
yield "db_conn_123"
finally:
cleanup_called = True

@mcp.tool()
def query_database(query: str, db_conn: Annotated[str, Depends(dependency=database_connection)]) -> str:
return f"Executed '{query}' on {db_conn}"

async with client_session(mcp._mcp_server) as client:
result = await client.call_tool("query_database", {"query": "SELECT * FROM users"})
assert len(result.content) == 1
content = result.content[0]
assert isinstance(content, TextContent)
assert content.text == "Executed 'SELECT * FROM users' on db_conn_123"

# Cleanup should have been called after tool execution
assert cleanup_called

@pytest.mark.anyio
async def test_tool_with_async_generator_dependency_cleanup(self):
"""Test tool with async generator dependency and proper cleanup."""
mcp = FastMCP()
cleanup_called = False

async def async_file_handler() -> AsyncGenerator[str, None]:
nonlocal cleanup_called
try:
yield "file_123"
finally:
cleanup_called = True

@mcp.tool()
async def process_file(
content: str, file_handler: Annotated[str, Depends(dependency=async_file_handler)]
) -> str:
await asyncio.sleep(0.01)
return f"Processed '{content}' with {file_handler}"

async with client_session(mcp._mcp_server) as client:
result = await client.call_tool("process_file", {"content": "data"})
assert len(result.content) == 1
content = result.content[0]
assert isinstance(content, TextContent)
assert content.text == "Processed 'data' with file_123"

# Cleanup should have been called after tool execution
assert cleanup_called

@pytest.mark.anyio
async def test_generator_no_yield_error(self):
"""Test error when generator doesn't yield a value."""

def empty_generator() -> Generator[str, None, None]:
return
yield # This line is never reached

resolver = DependencyResolver()
dependencies = {"dep": Depends(dependency=empty_generator)}

with pytest.raises(RuntimeError):
await resolver.resolve_dependencies(dependencies)

@pytest.mark.anyio
async def test_async_generator_no_yield_error(self):
"""Test error when async generator doesn't yield a value."""

async def empty_async_generator() -> AsyncGenerator[str, None]:
return
yield # This line is never reached

resolver = DependencyResolver()
dependencies = {"dep": Depends(dependency=empty_async_generator)}

with pytest.raises(RuntimeError):
await resolver.resolve_dependencies(dependencies)


class TestContextInjection:
"""Test context injection in tools, resources, and prompts."""

Expand Down
1 change: 1 addition & 0 deletions tests/server/fastmcp/test_tool_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ class AddArguments(ArgModelBase):
is_async=False,
parameters=AddArguments.model_json_schema(),
context_kwarg=None,
dependencies=None,
annotations=None,
)
manager = ToolManager(tools=[original_tool])
Expand Down
Loading