|
| 1 | +import asyncio |
1 | 2 | import base64 |
| 3 | +from collections.abc import AsyncGenerator, Generator |
2 | 4 | from pathlib import Path |
3 | | -from typing import TYPE_CHECKING, Any |
| 5 | +from typing import TYPE_CHECKING, Annotated, Any |
4 | 6 | from unittest.mock import patch |
5 | 7 |
|
6 | 8 | import pytest |
|
10 | 12 | from mcp.server.fastmcp import Context, FastMCP |
11 | 13 | from mcp.server.fastmcp.prompts.base import Message, UserMessage |
12 | 14 | from mcp.server.fastmcp.resources import FileResource, FunctionResource |
| 15 | +from mcp.server.fastmcp.utilities.dependencies import DependencyResolver |
13 | 16 | from mcp.server.fastmcp.utilities.types import Audio, Image |
14 | 17 | from mcp.server.session import ServerSession |
15 | 18 | from mcp.shared.exceptions import McpError |
16 | | -from mcp.shared.memory import ( |
17 | | - create_connected_server_and_client_session as client_session, |
18 | | -) |
| 19 | +from mcp.shared.memory import create_connected_server_and_client_session as client_session |
19 | 20 | from mcp.types import ( |
20 | 21 | AudioContent, |
21 | 22 | BlobResourceContents, |
22 | 23 | ContentBlock, |
| 24 | + Depends, |
23 | 25 | EmbeddedResource, |
24 | 26 | ImageContent, |
25 | 27 | TextContent, |
@@ -906,6 +908,134 @@ def get_csv(user: str) -> str: |
906 | 908 | assert result.contents[0].text == "csv for bob" |
907 | 909 |
|
908 | 910 |
|
| 911 | +class TestDependenciesInjection: |
| 912 | + """Test dependency injection functionality.""" |
| 913 | + |
| 914 | + @pytest.mark.anyio |
| 915 | + async def test_tool_with_regular_dependency(self): |
| 916 | + """Test tool with regular function dependency.""" |
| 917 | + mcp = FastMCP() |
| 918 | + |
| 919 | + def load_resource() -> int: |
| 920 | + return 42 |
| 921 | + |
| 922 | + @mcp.tool() |
| 923 | + def add_numbers(a: int, b: int, resource: Annotated[int, Depends(dependency=load_resource)]) -> int: |
| 924 | + return a + b + resource |
| 925 | + |
| 926 | + async with client_session(mcp._mcp_server) as client: |
| 927 | + result = await client.call_tool("add_numbers", {"a": 1, "b": 2}) |
| 928 | + assert len(result.content) == 1 |
| 929 | + content = result.content[0] |
| 930 | + assert isinstance(content, TextContent) |
| 931 | + assert content.text == "45" # 1 + 2 + 42 |
| 932 | + |
| 933 | + @pytest.mark.anyio |
| 934 | + async def test_tool_with_async_dependency(self): |
| 935 | + """Test tool with async function dependency.""" |
| 936 | + mcp = FastMCP() |
| 937 | + |
| 938 | + async def load_async_resource() -> str: |
| 939 | + await asyncio.sleep(0.01) |
| 940 | + return "async_data" |
| 941 | + |
| 942 | + @mcp.tool() |
| 943 | + async def process_text(text: str, data: Annotated[str, Depends(dependency=load_async_resource)]) -> str: |
| 944 | + return f"Processed '{text}' with {data}" |
| 945 | + |
| 946 | + async with client_session(mcp._mcp_server) as client: |
| 947 | + result = await client.call_tool("process_text", {"text": "hello"}) |
| 948 | + assert len(result.content) == 1 |
| 949 | + content = result.content[0] |
| 950 | + assert isinstance(content, TextContent) |
| 951 | + assert content.text == "Processed 'hello' with async_data" |
| 952 | + |
| 953 | + @pytest.mark.anyio |
| 954 | + async def test_tool_with_generator_dependency_cleanup(self): |
| 955 | + """Test tool with generator dependency and proper cleanup.""" |
| 956 | + mcp = FastMCP() |
| 957 | + cleanup_called = False |
| 958 | + |
| 959 | + def database_connection() -> Generator[str, None, None]: |
| 960 | + nonlocal cleanup_called |
| 961 | + try: |
| 962 | + yield "db_conn_123" |
| 963 | + finally: |
| 964 | + cleanup_called = True |
| 965 | + |
| 966 | + @mcp.tool() |
| 967 | + def query_database(query: str, db_conn: Annotated[str, Depends(dependency=database_connection)]) -> str: |
| 968 | + return f"Executed '{query}' on {db_conn}" |
| 969 | + |
| 970 | + async with client_session(mcp._mcp_server) as client: |
| 971 | + result = await client.call_tool("query_database", {"query": "SELECT * FROM users"}) |
| 972 | + assert len(result.content) == 1 |
| 973 | + content = result.content[0] |
| 974 | + assert isinstance(content, TextContent) |
| 975 | + assert content.text == "Executed 'SELECT * FROM users' on db_conn_123" |
| 976 | + |
| 977 | + # Cleanup should have been called after tool execution |
| 978 | + assert cleanup_called |
| 979 | + |
| 980 | + @pytest.mark.anyio |
| 981 | + async def test_tool_with_async_generator_dependency_cleanup(self): |
| 982 | + """Test tool with async generator dependency and proper cleanup.""" |
| 983 | + mcp = FastMCP() |
| 984 | + cleanup_called = False |
| 985 | + |
| 986 | + async def async_file_handler() -> AsyncGenerator[str, None]: |
| 987 | + nonlocal cleanup_called |
| 988 | + try: |
| 989 | + yield "file_123" |
| 990 | + finally: |
| 991 | + cleanup_called = True |
| 992 | + |
| 993 | + @mcp.tool() |
| 994 | + async def process_file( |
| 995 | + content: str, file_handler: Annotated[str, Depends(dependency=async_file_handler)] |
| 996 | + ) -> str: |
| 997 | + await asyncio.sleep(0.01) |
| 998 | + return f"Processed '{content}' with {file_handler}" |
| 999 | + |
| 1000 | + async with client_session(mcp._mcp_server) as client: |
| 1001 | + result = await client.call_tool("process_file", {"content": "data"}) |
| 1002 | + assert len(result.content) == 1 |
| 1003 | + content = result.content[0] |
| 1004 | + assert isinstance(content, TextContent) |
| 1005 | + assert content.text == "Processed 'data' with file_123" |
| 1006 | + |
| 1007 | + # Cleanup should have been called after tool execution |
| 1008 | + assert cleanup_called |
| 1009 | + |
| 1010 | + @pytest.mark.anyio |
| 1011 | + async def test_generator_no_yield_error(self): |
| 1012 | + """Test error when generator doesn't yield a value.""" |
| 1013 | + |
| 1014 | + def empty_generator() -> Generator[str, None, None]: |
| 1015 | + return |
| 1016 | + yield # This line is never reached |
| 1017 | + |
| 1018 | + resolver = DependencyResolver() |
| 1019 | + dependencies = {"dep": Depends(dependency=empty_generator)} |
| 1020 | + |
| 1021 | + with pytest.raises(RuntimeError): |
| 1022 | + await resolver.resolve_dependencies(dependencies) |
| 1023 | + |
| 1024 | + @pytest.mark.anyio |
| 1025 | + async def test_async_generator_no_yield_error(self): |
| 1026 | + """Test error when async generator doesn't yield a value.""" |
| 1027 | + |
| 1028 | + async def empty_async_generator() -> AsyncGenerator[str, None]: |
| 1029 | + return |
| 1030 | + yield # This line is never reached |
| 1031 | + |
| 1032 | + resolver = DependencyResolver() |
| 1033 | + dependencies = {"dep": Depends(dependency=empty_async_generator)} |
| 1034 | + |
| 1035 | + with pytest.raises(RuntimeError): |
| 1036 | + await resolver.resolve_dependencies(dependencies) |
| 1037 | + |
| 1038 | + |
909 | 1039 | class TestContextInjection: |
910 | 1040 | """Test context injection in tools, resources, and prompts.""" |
911 | 1041 |
|
|
0 commit comments