Skip to content

Commit 1995794

Browse files
committed
Use shared is_async_callable instead of inspect.iscoroutinefunction
Extract `_is_async_callable` from `tools/base.py` into `mcp.shared._callable_inspection` and replace all raw `inspect.iscoroutinefunction` calls across prompts, resources, and templates. The shared helper also handles `functools.partial` wrappers and callable objects with an async `__call__`.
1 parent d5b9155 commit 1995794

File tree

5 files changed

+49
-22
lines changed

5 files changed

+49
-22
lines changed

src/mcp/server/mcpserver/prompts/base.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
from __future__ import annotations
44

55
import functools
6-
import inspect
76
from collections.abc import Awaitable, Callable, Sequence
87
from typing import TYPE_CHECKING, Any, Literal
98

@@ -13,6 +12,7 @@
1312

1413
from mcp.server.mcpserver.utilities.context_injection import find_context_parameter, inject_context
1514
from mcp.server.mcpserver.utilities.func_metadata import func_metadata
15+
from mcp.shared._callable_inspection import is_async_callable
1616
from mcp.types import ContentBlock, Icon, TextContent
1717

1818
if TYPE_CHECKING:
@@ -157,8 +157,9 @@ async def render(
157157
# Add context to arguments if needed
158158
call_args = inject_context(self.fn, arguments or {}, context, self.context_kwarg)
159159

160-
if inspect.iscoroutinefunction(self.fn):
161-
result = await self.fn(**call_args)
160+
fn = self.fn
161+
if is_async_callable(fn):
162+
result = await fn(**call_args)
162163
else:
163164
result = await anyio.to_thread.run_sync(functools.partial(self.fn, **call_args))
164165

src/mcp/server/mcpserver/resources/templates.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
from __future__ import annotations
44

55
import functools
6-
import inspect
76
import re
87
from collections.abc import Callable
98
from typing import TYPE_CHECKING, Any
@@ -15,6 +14,7 @@
1514
from mcp.server.mcpserver.resources.types import FunctionResource, Resource
1615
from mcp.server.mcpserver.utilities.context_injection import find_context_parameter, inject_context
1716
from mcp.server.mcpserver.utilities.func_metadata import func_metadata
17+
from mcp.shared._callable_inspection import is_async_callable
1818
from mcp.types import Annotations, Icon
1919

2020
if TYPE_CHECKING:
@@ -112,8 +112,9 @@ async def create_resource(
112112
# Add context to params if needed
113113
params = inject_context(self.fn, params, context, self.context_kwarg)
114114

115-
if inspect.iscoroutinefunction(self.fn):
116-
result = await self.fn(**params)
115+
fn = self.fn
116+
if is_async_callable(fn):
117+
result = await fn(**params)
117118
else:
118119
result = await anyio.to_thread.run_sync(functools.partial(self.fn, **params))
119120

src/mcp/server/mcpserver/resources/types.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""Concrete resource implementations."""
22

3-
import inspect
3+
from __future__ import annotations
4+
45
import json
56
from collections.abc import Callable
67
from pathlib import Path
@@ -14,6 +15,7 @@
1415
from pydantic import Field, ValidationInfo, validate_call
1516

1617
from mcp.server.mcpserver.resources.base import Resource
18+
from mcp.shared._callable_inspection import is_async_callable
1719
from mcp.types import Annotations, Icon
1820

1921

@@ -55,8 +57,9 @@ class FunctionResource(Resource):
5557
async def read(self) -> str | bytes:
5658
"""Read the resource by calling the wrapped function."""
5759
try:
58-
if inspect.iscoroutinefunction(self.fn):
59-
result = await self.fn()
60+
fn = self.fn
61+
if is_async_callable(fn):
62+
result = await fn()
6063
else:
6164
result = await anyio.to_thread.run_sync(self.fn)
6265

@@ -83,7 +86,7 @@ def from_function(
8386
icons: list[Icon] | None = None,
8487
annotations: Annotations | None = None,
8588
meta: dict[str, Any] | None = None,
86-
) -> "FunctionResource":
89+
) -> FunctionResource:
8790
"""Create a FunctionResource from a function."""
8891
func_name = name or fn.__name__
8992
if func_name == "<lambda>": # pragma: no cover

src/mcp/server/mcpserver/tools/base.py

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
from __future__ import annotations
22

3-
import functools
4-
import inspect
53
from collections.abc import Callable
64
from functools import cached_property
75
from typing import TYPE_CHECKING, Any
@@ -11,6 +9,7 @@
119
from mcp.server.mcpserver.exceptions import ToolError
1210
from mcp.server.mcpserver.utilities.context_injection import find_context_parameter
1311
from mcp.server.mcpserver.utilities.func_metadata import FuncMetadata, func_metadata
12+
from mcp.shared._callable_inspection import is_async_callable
1413
from mcp.shared.exceptions import UrlElicitationRequiredError
1514
from mcp.shared.tool_name_validation import validate_and_warn_tool_name
1615
from mcp.types import Icon, ToolAnnotations
@@ -63,7 +62,7 @@ def from_function(
6362
raise ValueError("You must provide a name for lambda functions")
6463

6564
func_doc = description or fn.__doc__ or ""
66-
is_async = _is_async_callable(fn)
65+
is_async = is_async_callable(fn)
6766

6867
if context_kwarg is None: # pragma: no branch
6968
context_kwarg = find_context_parameter(fn)
@@ -118,12 +117,3 @@ async def run(
118117
raise
119118
except Exception as e:
120119
raise ToolError(f"Error executing tool {self.name}: {e}") from e
121-
122-
123-
def _is_async_callable(obj: Any) -> bool:
124-
while isinstance(obj, functools.partial): # pragma: lax no cover
125-
obj = obj.func
126-
127-
return inspect.iscoroutinefunction(obj) or (
128-
callable(obj) and inspect.iscoroutinefunction(getattr(obj, "__call__", None))
129-
)
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
"""Callable inspection utilities.
2+
3+
Adapted from Starlette's `is_async_callable` implementation.
4+
"""
5+
6+
from __future__ import annotations
7+
8+
import functools
9+
import inspect
10+
from collections.abc import Awaitable, Callable
11+
from typing import Any, TypeGuard, TypeVar, overload
12+
13+
T = TypeVar("T")
14+
15+
AwaitableCallable = Callable[..., Awaitable[T]]
16+
17+
18+
@overload
19+
def is_async_callable(obj: AwaitableCallable[T]) -> TypeGuard[AwaitableCallable[T]]: ...
20+
21+
22+
@overload
23+
def is_async_callable(obj: Any) -> TypeGuard[AwaitableCallable[Any]]: ...
24+
25+
26+
def is_async_callable(obj: Any) -> Any:
27+
while isinstance(obj, functools.partial): # pragma: lax no cover
28+
obj = obj.func
29+
30+
return inspect.iscoroutinefunction(obj) or (
31+
callable(obj) and inspect.iscoroutinefunction(getattr(obj, "__call__", None))
32+
)

0 commit comments

Comments
 (0)