Skip to content

Commit 86f793b

Browse files
committed
feat: Dependency injection for MCP Tools
1 parent dcc68ce commit 86f793b

File tree

4 files changed

+287
-7
lines changed

4 files changed

+287
-7
lines changed

src/mcp/server/fastmcp/tools/base.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,9 @@
1010

1111
from mcp.server.fastmcp.exceptions import ToolError
1212
from mcp.server.fastmcp.utilities.context_injection import find_context_parameter
13+
from mcp.server.fastmcp.utilities.dependencies import DependencyResolver, find_dependencies
1314
from mcp.server.fastmcp.utilities.func_metadata import FuncMetadata, func_metadata
14-
from mcp.types import Icon, ToolAnnotations
15+
from mcp.types import Depends, Icon, ToolAnnotations
1516

1617
if TYPE_CHECKING:
1718
from mcp.server.fastmcp.server import Context
@@ -32,6 +33,7 @@ class Tool(BaseModel):
3233
)
3334
is_async: bool = Field(description="Whether the tool is async")
3435
context_kwarg: str | None = Field(None, description="Name of the kwarg that should receive context")
36+
dependencies: dict[str, Depends] | None = Field(None, description="Tool dependencies")
3537
annotations: ToolAnnotations | None = Field(None, description="Optional annotations for the tool")
3638
icons: list[Icon] | None = Field(default=None, description="Optional list of icons for this tool")
3739

@@ -47,6 +49,7 @@ def from_function(
4749
title: str | None = None,
4850
description: str | None = None,
4951
context_kwarg: str | None = None,
52+
dependencies: dict[str, Depends] | None = None,
5053
annotations: ToolAnnotations | None = None,
5154
icons: list[Icon] | None = None,
5255
structured_output: bool | None = None,
@@ -63,9 +66,16 @@ def from_function(
6366
if context_kwarg is None:
6467
context_kwarg = find_context_parameter(fn)
6568

69+
if dependencies is None:
70+
dependencies = find_dependencies(fn)
71+
72+
skip_names = [context_kwarg] if context_kwarg is not None else []
73+
if dependencies:
74+
skip_names.extend(dependencies.keys())
75+
6676
func_arg_metadata = func_metadata(
6777
fn,
68-
skip_names=[context_kwarg] if context_kwarg is not None else [],
78+
skip_names=skip_names,
6979
structured_output=structured_output,
7080
)
7181
parameters = func_arg_metadata.arg_model.model_json_schema(by_alias=True)
@@ -79,6 +89,7 @@ def from_function(
7989
fn_metadata=func_arg_metadata,
8090
is_async=is_async,
8191
context_kwarg=context_kwarg,
92+
dependencies=dependencies,
8293
annotations=annotations,
8394
icons=icons,
8495
)
@@ -90,12 +101,20 @@ async def run(
90101
convert_result: bool = False,
91102
) -> Any:
92103
"""Run the tool with arguments."""
104+
dependency_resolver = DependencyResolver()
93105
try:
106+
# Resolve dependencies
107+
resolved_dependencies = await dependency_resolver.resolve_dependencies(self.dependencies or {})
108+
109+
# Prepare arguments to pass directly to the function
110+
arguments_to_pass_directly = {self.context_kwarg: context} if self.context_kwarg is not None else {}
111+
arguments_to_pass_directly.update(resolved_dependencies)
112+
94113
result = await self.fn_metadata.call_fn_with_arg_validation(
95114
self.fn,
96115
self.is_async,
97116
arguments,
98-
{self.context_kwarg: context} if self.context_kwarg is not None else None,
117+
arguments_to_pass_directly,
99118
)
100119

101120
if convert_result:
Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
import inspect
2+
from collections.abc import AsyncGenerator, Callable, Generator
3+
from typing import Annotated, Any, get_args, get_origin, get_type_hints
4+
5+
from mcp.types import Depends
6+
7+
8+
def find_dependencies(fn: Callable[..., Any]) -> dict[str, Depends]:
9+
"""Find all dependencies in a function's parameters."""
10+
# Get type hints to properly resolve string annotations
11+
try:
12+
hints = get_type_hints(fn, include_extras=True)
13+
except Exception:
14+
# If we can't resolve type hints, we can't find dependencies
15+
hints = {}
16+
17+
dependencies: dict[str, Depends] = {}
18+
19+
# Get function signature to access parameter defaults
20+
sig = inspect.signature(fn)
21+
22+
# Check each parameter's type hint and default value
23+
for param_name, param in sig.parameters.items():
24+
# Check if it's in Annotated form
25+
if param_name in hints:
26+
annotation = hints[param_name]
27+
if get_origin(annotation) is Annotated:
28+
_, *extras = get_args(annotation)
29+
dep = next((x for x in extras if isinstance(x, Depends)), None)
30+
if dep is not None:
31+
dependencies[param_name] = dep
32+
continue
33+
34+
# Check if default value is a Depends instance
35+
if param.default is not inspect.Parameter.empty and isinstance(param.default, Depends):
36+
dependencies[param_name] = param.default
37+
38+
return dependencies
39+
40+
41+
def _is_async_callable(obj: Any) -> bool:
42+
"""Check if a callable is async."""
43+
return inspect.iscoroutinefunction(obj) or (
44+
callable(obj) and inspect.iscoroutinefunction(getattr(obj, "__call__", None))
45+
)
46+
47+
48+
def _is_generator_function(obj: Any) -> bool:
49+
"""Check if a callable is a generator function."""
50+
return inspect.isgeneratorfunction(obj)
51+
52+
53+
def _is_async_generator_function(obj: Any) -> bool:
54+
"""Check if a callable is an async generator function."""
55+
return inspect.isasyncgenfunction(obj)
56+
57+
58+
class DependencyResolver:
59+
"""Resolve dependencies and clean up properly when errors occur."""
60+
61+
def __init__(self):
62+
self._generators: list[Generator[Any, None, None]] = []
63+
self._async_generators: list[AsyncGenerator[Any, None]] = []
64+
65+
async def resolve_dependencies(self, dependencies: dict[str, Depends]) -> dict[str, Any]:
66+
"""Resolve all dependencies and return their values."""
67+
if not dependencies:
68+
return {}
69+
70+
resolved: dict[str, Any] = {}
71+
72+
for param_name, depends in dependencies.items():
73+
try:
74+
resolved[param_name] = await self._resolve_single_dependency(depends)
75+
except Exception as e:
76+
# Cleanup any generators and async generators that were already created
77+
await self.cleanup()
78+
raise RuntimeError(f"Failed to resolve dependency '{param_name}': {e}") from e
79+
80+
return resolved
81+
82+
async def _resolve_single_dependency(self, depends: Depends) -> Any:
83+
"""Resolve a single dependency."""
84+
dependency_fn = depends.dependency
85+
86+
if _is_async_generator_function(dependency_fn):
87+
gen = dependency_fn()
88+
self._async_generators.append(gen)
89+
try:
90+
value = await gen.__anext__()
91+
return value
92+
except StopAsyncIteration:
93+
raise RuntimeError(f"Async generator dependency {dependency_fn.__name__} didn't yield a value")
94+
95+
elif _is_generator_function(dependency_fn):
96+
gen = dependency_fn()
97+
self._generators.append(gen)
98+
try:
99+
value = next(gen)
100+
return value
101+
except StopIteration:
102+
raise RuntimeError(f"Generator dependency {dependency_fn.__name__} didn't yield a value")
103+
104+
elif _is_async_callable(dependency_fn):
105+
return await dependency_fn()
106+
107+
else:
108+
return dependency_fn()
109+
110+
async def cleanup(self):
111+
"""Cleanup all generator dependencies."""
112+
for gen in self._async_generators:
113+
try:
114+
await gen.aclose()
115+
except Exception:
116+
pass
117+
118+
for gen in self._generators:
119+
try:
120+
gen.close()
121+
except Exception:
122+
pass
123+
124+
self._generators.clear()
125+
self._async_generators.clear()

src/mcp/types.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,12 @@
4040
AnyFunction: TypeAlias = Callable[..., Any]
4141

4242

43+
class Depends(BaseModel):
44+
"""Dependency injection for tool parameters."""
45+
46+
dependency: Callable[..., Any]
47+
48+
4349
class RequestParams(BaseModel):
4450
class Meta(BaseModel):
4551
progressToken: ProgressToken | None = None

tests/server/fastmcp/test_server.py

Lines changed: 134 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1+
import asyncio
12
import base64
3+
from collections.abc import AsyncGenerator, Generator
24
from pathlib import Path
3-
from typing import TYPE_CHECKING, Any
5+
from typing import TYPE_CHECKING, Annotated, Any
46
from unittest.mock import patch
57

68
import pytest
@@ -10,16 +12,16 @@
1012
from mcp.server.fastmcp import Context, FastMCP
1113
from mcp.server.fastmcp.prompts.base import Message, UserMessage
1214
from mcp.server.fastmcp.resources import FileResource, FunctionResource
15+
from mcp.server.fastmcp.utilities.dependencies import DependencyResolver
1316
from mcp.server.fastmcp.utilities.types import Audio, Image
1417
from mcp.server.session import ServerSession
1518
from mcp.shared.exceptions import McpError
16-
from mcp.shared.memory import (
17-
create_connected_server_and_client_session as client_session,
18-
)
19+
from mcp.shared.memory import create_connected_server_and_client_session as client_session
1920
from mcp.types import (
2021
AudioContent,
2122
BlobResourceContents,
2223
ContentBlock,
24+
Depends,
2325
EmbeddedResource,
2426
ImageContent,
2527
TextContent,
@@ -906,6 +908,134 @@ def get_csv(user: str) -> str:
906908
assert result.contents[0].text == "csv for bob"
907909

908910

911+
class TestDependenciesInjection:
912+
"""Test dependency injection functionality."""
913+
914+
@pytest.mark.anyio
915+
async def test_tool_with_regular_dependency(self):
916+
"""Test tool with regular function dependency."""
917+
mcp = FastMCP()
918+
919+
def load_resource() -> int:
920+
return 42
921+
922+
@mcp.tool()
923+
def add_numbers(a: int, b: int, resource: Annotated[int, Depends(dependency=load_resource)]) -> int:
924+
return a + b + resource
925+
926+
async with client_session(mcp._mcp_server) as client:
927+
result = await client.call_tool("add_numbers", {"a": 1, "b": 2})
928+
assert len(result.content) == 1
929+
content = result.content[0]
930+
assert isinstance(content, TextContent)
931+
assert content.text == "45" # 1 + 2 + 42
932+
933+
@pytest.mark.anyio
934+
async def test_tool_with_async_dependency(self):
935+
"""Test tool with async function dependency."""
936+
mcp = FastMCP()
937+
938+
async def load_async_resource() -> str:
939+
await asyncio.sleep(0.01)
940+
return "async_data"
941+
942+
@mcp.tool()
943+
async def process_text(text: str, data: Annotated[str, Depends(dependency=load_async_resource)]) -> str:
944+
return f"Processed '{text}' with {data}"
945+
946+
async with client_session(mcp._mcp_server) as client:
947+
result = await client.call_tool("process_text", {"text": "hello"})
948+
assert len(result.content) == 1
949+
content = result.content[0]
950+
assert isinstance(content, TextContent)
951+
assert content.text == "Processed 'hello' with async_data"
952+
953+
@pytest.mark.anyio
954+
async def test_tool_with_generator_dependency_cleanup(self):
955+
"""Test tool with generator dependency and proper cleanup."""
956+
mcp = FastMCP()
957+
cleanup_called = False
958+
959+
def database_connection() -> Generator[str, None, None]:
960+
nonlocal cleanup_called
961+
try:
962+
yield "db_conn_123"
963+
finally:
964+
cleanup_called = True
965+
966+
@mcp.tool()
967+
def query_database(query: str, db_conn: Annotated[str, Depends(dependency=database_connection)]) -> str:
968+
return f"Executed '{query}' on {db_conn}"
969+
970+
async with client_session(mcp._mcp_server) as client:
971+
result = await client.call_tool("query_database", {"query": "SELECT * FROM users"})
972+
assert len(result.content) == 1
973+
content = result.content[0]
974+
assert isinstance(content, TextContent)
975+
assert content.text == "Executed 'SELECT * FROM users' on db_conn_123"
976+
977+
# Cleanup should have been called after tool execution
978+
assert cleanup_called
979+
980+
@pytest.mark.anyio
981+
async def test_tool_with_async_generator_dependency_cleanup(self):
982+
"""Test tool with async generator dependency and proper cleanup."""
983+
mcp = FastMCP()
984+
cleanup_called = False
985+
986+
async def async_file_handler() -> AsyncGenerator[str, None]:
987+
nonlocal cleanup_called
988+
try:
989+
yield "file_123"
990+
finally:
991+
cleanup_called = True
992+
993+
@mcp.tool()
994+
async def process_file(
995+
content: str, file_handler: Annotated[str, Depends(dependency=async_file_handler)]
996+
) -> str:
997+
await asyncio.sleep(0.01)
998+
return f"Processed '{content}' with {file_handler}"
999+
1000+
async with client_session(mcp._mcp_server) as client:
1001+
result = await client.call_tool("process_file", {"content": "data"})
1002+
assert len(result.content) == 1
1003+
content = result.content[0]
1004+
assert isinstance(content, TextContent)
1005+
assert content.text == "Processed 'data' with file_123"
1006+
1007+
# Cleanup should have been called after tool execution
1008+
assert cleanup_called
1009+
1010+
@pytest.mark.anyio
1011+
async def test_generator_no_yield_error(self):
1012+
"""Test error when generator doesn't yield a value."""
1013+
1014+
def empty_generator() -> Generator[str, None, None]:
1015+
return
1016+
yield # This line is never reached
1017+
1018+
resolver = DependencyResolver()
1019+
dependencies = {"dep": Depends(dependency=empty_generator)}
1020+
1021+
with pytest.raises(RuntimeError):
1022+
await resolver.resolve_dependencies(dependencies)
1023+
1024+
@pytest.mark.anyio
1025+
async def test_async_generator_no_yield_error(self):
1026+
"""Test error when async generator doesn't yield a value."""
1027+
1028+
async def empty_async_generator() -> AsyncGenerator[str, None]:
1029+
return
1030+
yield # This line is never reached
1031+
1032+
resolver = DependencyResolver()
1033+
dependencies = {"dep": Depends(dependency=empty_async_generator)}
1034+
1035+
with pytest.raises(RuntimeError):
1036+
await resolver.resolve_dependencies(dependencies)
1037+
1038+
9091039
class TestContextInjection:
9101040
"""Test context injection in tools, resources, and prompts."""
9111041

0 commit comments

Comments
 (0)