Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 14 additions & 4 deletions src/mcp/server/fastmcp/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@
from mcp.server.streamable_http_manager import StreamableHTTPSessionManager
from mcp.server.transport_security import TransportSecuritySettings
from mcp.shared.context import LifespanContextT, RequestContext, RequestT
from mcp.types import Annotations, AnyFunction, ContentBlock, GetPromptResult, Icon, ToolAnnotations
from mcp.types import Annotations, AnyFunction, ContentBlock, GetPromptResult, Icon, ListToolsRequest, ToolAnnotations
from mcp.types import Prompt as MCPPrompt
from mcp.types import PromptArgument as MCPPromptArgument
from mcp.types import Resource as MCPResource
Expand Down Expand Up @@ -298,9 +298,19 @@ def _setup_handlers(self) -> None:
self._mcp_server.get_prompt()(self.get_prompt)
self._mcp_server.list_resource_templates()(self.list_resource_templates)

async def list_tools(self) -> list[MCPTool]:
"""List all available tools."""
tools = self._tool_manager.list_tools()
async def list_tools(
self,
request: ListToolsRequest | None = None,
) -> list[MCPTool]:
"""List all available tools, optionally filtered by include/exclude parameters."""
if request and request.params:
tools = self._tool_manager.list_tools(
include=request.params.include,
exclude=request.params.exclude,
)
else:
tools = self._tool_manager.list_tools()

return [
MCPTool(
name=info.name,
Expand Down
36 changes: 34 additions & 2 deletions src/mcp/server/fastmcp/tools/tool_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,40 @@ def get_tool(self, name: str) -> Tool | None:
"""Get tool by name."""
return self._tools.get(name)

def list_tools(self) -> list[Tool]:
"""List all registered tools."""
def _include_tools(self, tools: dict[str, Tool], include: list[str]) -> list[Tool]:
"""Filter tools to include only the specified tool names."""
filtered_tools: list[Tool] = []
for tool_name in include:
tool = tools.get(tool_name)
if tool is None:
raise ValueError(f"Tool '{tool_name}' not found in available tools, cannot be included.")
filtered_tools.append(tool)
return filtered_tools

def _exclude_tools(self, tools: dict[str, Tool], exclude: list[str]) -> list[Tool]:
"""Filter tools to exclude the specified tool names."""
exclude_set = set(exclude)

for tool_name in exclude:
if tool_name not in tools:
raise ValueError(f"Tool '{tool_name}' not found in available tools, cannot be excluded.")

return [tool for name, tool in tools.items() if name not in exclude_set]

def list_tools(
self,
*,
include: list[str] | None = None,
exclude: list[str] | None = None,
) -> list[Tool]:
"""List all registered tools, optionally filtered by include or exclude parameters."""
if include is not None and exclude is not None:
raise ValueError("Cannot specify both 'include' and 'exclude' parameters")
elif include is not None:
return self._include_tools(self._tools, include)
elif exclude is not None:
return self._exclude_tools(self._tools, exclude)

return list(self._tools.values())

def add_tool(
Expand Down
50 changes: 35 additions & 15 deletions src/mcp/server/lowlevel/func_inspection.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,48 @@
import inspect
from collections.abc import Callable
from typing import Any, TypeVar, get_type_hints
from typing import Any, TypeVar, Union, get_args, get_origin, get_type_hints

T = TypeVar("T")
R = TypeVar("R")


def _type_matches_request(param_type: Any, request_type: type[T]) -> bool:
"""
Check if a parameter type matches the request type.

This handles direct matches, Union types (e.g., RequestType | None),
and Optional types (e.g., Optional[RequestType]).
"""
if param_type == request_type:
return True

origin = get_origin(param_type)
args = get_args(param_type)

# Handle typing.Union and Python 3.10+ | syntax
if origin is Union:
return request_type in args

# Handle types.UnionType from Python 3.10+ | syntax
if hasattr(param_type, "__args__") and args:
return request_type in args

return False


def create_call_wrapper(func: Callable[..., R], request_type: type[T]) -> Callable[[T], R]:
"""
Create a wrapper function that knows how to call func with the request object.

Returns a wrapper function that takes the request and calls func appropriately.

The wrapper handles three calling patterns:
1. Positional-only parameter typed as request_type (no default): func(req)
2. Positional/keyword parameter typed as request_type (no default): func(**{param_name: req})
3. No request parameter or parameter with default: func()
1. Positional-only parameter typed as request_type or Union containing request_type: func(req)
2. Positional/keyword parameter typed as request_type or Union containing request_type: func(**{param_name: req})
3. No matching request parameter: func()

Union types like `RequestType | None` and `Optional[RequestType]` are supported,
allowing for optional request parameters with default values.
"""
try:
sig = inspect.signature(func)
Expand All @@ -27,23 +54,16 @@ def create_call_wrapper(func: Callable[..., R], request_type: type[T]) -> Callab
for param_name, param in sig.parameters.items():
if param.kind == inspect.Parameter.POSITIONAL_ONLY:
param_type = type_hints.get(param_name)
if param_type == request_type:
# Check if it has a default - if so, treat as old style
if param.default is not inspect.Parameter.empty:
return lambda _: func()
# Found positional-only parameter with correct type and no default
if _type_matches_request(param_type, request_type):
# Found positional-only parameter with correct type
return lambda req: func(req)

# Check for any positional/keyword parameter typed as request_type
for param_name, param in sig.parameters.items():
if param.kind in (inspect.Parameter.POSITIONAL_OR_KEYWORD, inspect.Parameter.KEYWORD_ONLY):
param_type = type_hints.get(param_name)
if param_type == request_type:
# Check if it has a default - if so, treat as old style
if param.default is not inspect.Parameter.empty:
return lambda _: func()

# Found keyword parameter with correct type and no default
if _type_matches_request(param_type, request_type):
# Found keyword parameter with correct type
# Need to capture param_name in closure properly
def make_keyword_wrapper(name: str) -> Callable[[Any], Any]:
return lambda req: func(**{name: req})
Expand Down
2 changes: 1 addition & 1 deletion src/mcp/server/lowlevel/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,7 +408,7 @@ async def handler(req: types.UnsubscribeRequest):

def list_tools(self):
def decorator(
func: Callable[[], Awaitable[list[types.Tool]]]
func: Callable[..., Awaitable[list[types.Tool]]]
| Callable[[types.ListToolsRequest], Awaitable[types.ListToolsResult]],
):
logger.debug("Registering handler for ListToolsRequest")
Expand Down
8 changes: 4 additions & 4 deletions src/mcp/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,11 @@ class Meta(BaseModel):


class PaginatedRequestParams(RequestParams):
"""Request parameters for paginated operations with optional filtering."""

cursor: Cursor | None = None
"""
An opaque token representing the current pagination position.
If provided, the server should return results starting after this cursor.
"""
include: list[str] | None = None
exclude: list[str] | None = None


class NotificationParams(BaseModel):
Expand Down
2 changes: 1 addition & 1 deletion tests/issues/test_100_tool_listing.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def dummy_tool_func():
globals()[f"dummy_tool_{i}"] = dummy_tool_func # Keep reference to avoid garbage collection

# Get all tools
tools = await mcp.list_tools()
tools = await mcp.list_tools(request=None)

# Verify we get all tools
assert len(tools) == num_tools, f"Expected {num_tools} tools, but got {len(tools)}"
Expand Down
6 changes: 3 additions & 3 deletions tests/issues/test_1338_icons_and_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def test_resource_template(city: str) -> str:
assert mcp.icons[0].sizes == test_icon.sizes

# Test tool includes icon
tools = await mcp.list_tools()
tools = await mcp.list_tools(request=None)
assert len(tools) == 1
tool = tools[0]
assert tool.name == "test_tool"
Expand Down Expand Up @@ -109,7 +109,7 @@ def multi_icon_tool() -> str:
return "success"

# Test tool has all icons
tools = await mcp.list_tools()
tools = await mcp.list_tools(request=None)
assert len(tools) == 1
tool = tools[0]
assert tool.icons is not None
Expand All @@ -135,7 +135,7 @@ def basic_tool() -> str:
assert mcp.icons is None

# Test tool has no icons
tools = await mcp.list_tools()
tools = await mcp.list_tools(request=None)
assert len(tools) == 1
tool = tools[0]
assert tool.name == "basic_tool"
Expand Down
2 changes: 1 addition & 1 deletion tests/server/fastmcp/test_parameter_descriptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def greet(
"""A greeting tool"""
return f"Hello {title} {name}"

tools = await mcp.list_tools()
tools = await mcp.list_tools(request=None)
assert len(tools) == 1
tool = tools[0]

Expand Down
Loading
Loading