Skip to content

Commit f4c256f

Browse files
committed
fix: update tests to use new Server constructor kwargs pattern
Migrate test files from the old decorator-based handler registration to the new on_* constructor kwargs pattern. Key changes: - Replace @server.list_tools(), @server.call_tool(), etc. decorators with on_list_tools, on_call_tool, etc. kwargs on Server() - Replace server.request_context access with ctx parameter (first argument to all handlers) - Handlers now receive (ctx, params) and return full result types (e.g. ListToolsResult instead of list[Tool]) - Convert experimental task decorators to enable_tasks() kwargs - Add LifespanContextT default to ServerRequestContext - Widen on_call_tool return type to include CreateTaskResult - Delete redundant tests/shared/test_memory.py - Simplify tests to use Client where possible
1 parent ca8fd2e commit f4c256f

24 files changed

+904
-1268
lines changed

src/mcp/server/context.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from mcp.shared._context import RequestContext
1111
from mcp.shared.message import CloseSSEStreamCallback
1212

13-
LifespanContextT = TypeVar("LifespanContextT")
13+
LifespanContextT = TypeVar("LifespanContextT", default=dict[str, Any])
1414
RequestT = TypeVar("RequestT", default=Any)
1515

1616

src/mcp/server/lowlevel/server.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ def __init__(
124124
| None = None,
125125
on_call_tool: Callable[
126126
[ServerRequestContext[LifespanResultT], types.CallToolRequestParams],
127-
Awaitable[types.CallToolResult],
127+
Awaitable[types.CallToolResult | types.CreateTaskResult],
128128
]
129129
| None = None,
130130
on_list_resources: Callable[
@@ -643,5 +643,3 @@ def streamable_http_app(
643643
middleware=middleware,
644644
lifespan=lambda app: session_manager.run(),
645645
)
646-
647-

tests/client/test_client.py

Lines changed: 30 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from mcp import types
1212
from mcp.client._memory import InMemoryTransport
1313
from mcp.client.client import Client
14-
from mcp.server import Server
14+
from mcp.server import Server, ServerRequestContext
1515
from mcp.server.mcpserver import MCPServer
1616
from mcp.types import (
1717
CallToolResult,
@@ -41,33 +41,36 @@
4141
@pytest.fixture
4242
def simple_server() -> Server:
4343
"""Create a simple MCP server for testing."""
44-
server = Server(name="test_server")
4544

46-
@server.list_resources()
47-
async def handle_list_resources():
48-
return [Resource(uri="memory://test", name="Test Resource", description="A test resource")]
45+
async def handle_list_resources(
46+
ctx: ServerRequestContext, params: types.PaginatedRequestParams | None
47+
) -> ListResourcesResult:
48+
return ListResourcesResult(
49+
resources=[Resource(uri="memory://test", name="Test Resource", description="A test resource")]
50+
)
4951

50-
@server.subscribe_resource()
51-
async def handle_subscribe_resource(uri: str):
52-
pass
52+
async def handle_subscribe_resource(ctx: ServerRequestContext, params: types.SubscribeRequestParams) -> EmptyResult:
53+
return EmptyResult()
5354

54-
@server.unsubscribe_resource()
55-
async def handle_unsubscribe_resource(uri: str):
56-
pass
55+
async def handle_unsubscribe_resource(
56+
ctx: ServerRequestContext, params: types.UnsubscribeRequestParams
57+
) -> EmptyResult:
58+
return EmptyResult()
5759

58-
@server.set_logging_level()
59-
async def handle_set_logging_level(level: str):
60-
pass
60+
async def handle_set_logging_level(ctx: ServerRequestContext, params: types.SetLevelRequestParams) -> EmptyResult:
61+
return EmptyResult()
6162

62-
@server.completion()
63-
async def handle_completion(
64-
ref: types.PromptReference | types.ResourceTemplateReference,
65-
argument: types.CompletionArgument,
66-
context: types.CompletionContext | None,
67-
) -> types.Completion | None:
68-
return types.Completion(values=[])
63+
async def handle_completion(ctx: ServerRequestContext, params: types.CompleteRequestParams) -> types.CompleteResult:
64+
return types.CompleteResult(completion=types.Completion(values=[]))
6965

70-
return server
66+
return Server(
67+
name="test_server",
68+
on_list_resources=handle_list_resources,
69+
on_subscribe_resource=handle_subscribe_resource,
70+
on_unsubscribe_resource=handle_unsubscribe_resource,
71+
on_set_logging_level=handle_set_logging_level,
72+
on_completion=handle_completion,
73+
)
7174

7275

7376
@pytest.fixture
@@ -202,19 +205,14 @@ async def test_client_send_progress_notification():
202205
"""Test sending progress notification."""
203206
received_from_client = None
204207
event = anyio.Event()
205-
server = Server(name="test_server")
206-
207-
@server.progress_notification()
208-
async def handle_progress_notification(
209-
progress_token: str | int,
210-
progress: float = 0.0,
211-
total: float | None = None,
212-
message: str | None = None,
213-
) -> None:
208+
209+
async def handle_progress(ctx: ServerRequestContext, params: types.ProgressNotificationParams) -> None:
214210
nonlocal received_from_client
215-
received_from_client = {"progress_token": progress_token, "progress": progress}
211+
received_from_client = {"progress_token": params.progress_token, "progress": params.progress}
216212
event.set()
217213

214+
server = Server(name="test_server", on_progress=handle_progress)
215+
218216
async with Client(server) as client:
219217
await client.send_progress_notification(progress_token="token123", progress=50.0)
220218
await event.wait()

tests/client/test_http_unicode.py

Lines changed: 56 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
import socket
99
from collections.abc import AsyncGenerator, Generator
1010
from contextlib import asynccontextmanager
11-
from typing import Any
1211

1312
import pytest
1413
from starlette.applications import Starlette
@@ -17,7 +16,7 @@
1716
from mcp import types
1817
from mcp.client.session import ClientSession
1918
from mcp.client.streamable_http import streamable_http_client
20-
from mcp.server import Server
19+
from mcp.server import Server, ServerRequestContext
2120
from mcp.server.streamable_http_manager import StreamableHTTPSessionManager
2221
from mcp.types import TextContent, Tool
2322
from tests.test_helpers import wait_for_server
@@ -47,54 +46,56 @@ def run_unicode_server(port: int) -> None: # pragma: no cover
4746
import uvicorn
4847

4948
# Need to recreate the server setup in this process
50-
server = Server(name="unicode_test_server")
51-
52-
@server.list_tools()
53-
async def list_tools() -> list[Tool]:
54-
"""List tools with Unicode descriptions."""
55-
return [
56-
Tool(
57-
name="echo_unicode",
58-
description="🔤 Echo Unicode text - Hello 👋 World 🌍 - Testing 🧪 Unicode ✨",
59-
input_schema={
60-
"type": "object",
61-
"properties": {
62-
"text": {"type": "string", "description": "Text to echo back"},
49+
async def handle_list_tools(
50+
ctx: ServerRequestContext, params: types.PaginatedRequestParams | None
51+
) -> types.ListToolsResult:
52+
return types.ListToolsResult(
53+
tools=[
54+
Tool(
55+
name="echo_unicode",
56+
description="🔤 Echo Unicode text - Hello 👋 World 🌍 - Testing 🧪 Unicode ✨",
57+
input_schema={
58+
"type": "object",
59+
"properties": {
60+
"text": {"type": "string", "description": "Text to echo back"},
61+
},
62+
"required": ["text"],
6363
},
64-
"required": ["text"],
65-
},
66-
),
67-
]
68-
69-
@server.call_tool()
70-
async def call_tool(name: str, arguments: dict[str, Any] | None) -> list[TextContent]:
71-
"""Handle tool calls with Unicode content."""
72-
if name == "echo_unicode":
73-
text = arguments.get("text", "") if arguments else ""
74-
return [
75-
TextContent(
76-
type="text",
77-
text=f"Echo: {text}",
78-
)
64+
),
7965
]
80-
else:
81-
raise ValueError(f"Unknown tool: {name}")
82-
83-
@server.list_prompts()
84-
async def list_prompts() -> list[types.Prompt]:
85-
"""List prompts with Unicode names and descriptions."""
86-
return [
87-
types.Prompt(
88-
name="unicode_prompt",
89-
description="Unicode prompt - Слой хранилища, где располагаются",
90-
arguments=[],
66+
)
67+
68+
async def handle_call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> types.CallToolResult:
69+
if params.name == "echo_unicode":
70+
text = params.arguments.get("text", "") if params.arguments else ""
71+
return types.CallToolResult(
72+
content=[
73+
TextContent(
74+
type="text",
75+
text=f"Echo: {text}",
76+
)
77+
]
9178
)
92-
]
79+
else:
80+
raise ValueError(f"Unknown tool: {params.name}")
81+
82+
async def handle_list_prompts(
83+
ctx: ServerRequestContext, params: types.PaginatedRequestParams | None
84+
) -> types.ListPromptsResult:
85+
return types.ListPromptsResult(
86+
prompts=[
87+
types.Prompt(
88+
name="unicode_prompt",
89+
description="Unicode prompt - Слой хранилища, где располагаются",
90+
arguments=[],
91+
)
92+
]
93+
)
9394

94-
@server.get_prompt()
95-
async def get_prompt(name: str, arguments: dict[str, Any] | None) -> types.GetPromptResult:
96-
"""Get a prompt with Unicode content."""
97-
if name == "unicode_prompt":
95+
async def handle_get_prompt(
96+
ctx: ServerRequestContext, params: types.GetPromptRequestParams
97+
) -> types.GetPromptResult:
98+
if params.name == "unicode_prompt":
9899
return types.GetPromptResult(
99100
messages=[
100101
types.PromptMessage(
@@ -106,7 +107,15 @@ async def get_prompt(name: str, arguments: dict[str, Any] | None) -> types.GetPr
106107
)
107108
]
108109
)
109-
raise ValueError(f"Unknown prompt: {name}")
110+
raise ValueError(f"Unknown prompt: {params.name}")
111+
112+
server = Server(
113+
name="unicode_test_server",
114+
on_list_tools=handle_list_tools,
115+
on_call_tool=handle_call_tool,
116+
on_list_prompts=handle_list_prompts,
117+
on_get_prompt=handle_get_prompt,
118+
)
110119

111120
# Create the session manager
112121
session_manager = StreamableHTTPSessionManager(

tests/client/test_list_methods_cursor.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@
33
import pytest
44

55
from mcp import Client, types
6-
from mcp.server import Server
6+
from mcp.server import Server, ServerRequestContext
77
from mcp.server.mcpserver import MCPServer
8-
from mcp.types import ListToolsRequest, ListToolsResult
8+
from mcp.types import ListToolsResult
99

1010
from .conftest import StreamSpyCollection
1111

@@ -105,14 +105,16 @@ async def test_list_tools_with_strict_server_validation(
105105

106106
async def test_list_tools_with_lowlevel_server():
107107
"""Test that list_tools works with a lowlevel Server using params."""
108-
server = Server("test-lowlevel")
109108

110-
@server.list_tools()
111-
async def handle_list_tools(request: ListToolsRequest) -> ListToolsResult:
109+
async def handle_list_tools(
110+
ctx: ServerRequestContext, params: types.PaginatedRequestParams | None
111+
) -> ListToolsResult:
112112
# Echo back what cursor we received in the tool description
113-
cursor = request.params.cursor if request.params else None
113+
cursor = params.cursor if params else None
114114
return ListToolsResult(tools=[types.Tool(name="test_tool", description=f"cursor={cursor}", input_schema={})])
115115

116+
server = Server("test-lowlevel", on_list_tools=handle_list_tools)
117+
116118
async with Client(server) as client:
117119
result = await client.list_tools()
118120
assert result.tools[0].description == "cursor=None"

tests/client/transports/test_memory.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2,31 +2,31 @@
22

33
import pytest
44

5-
from mcp import Client
5+
from mcp import Client, types
66
from mcp.client._memory import InMemoryTransport
7-
from mcp.server import Server
7+
from mcp.server import Server, ServerRequestContext
88
from mcp.server.mcpserver import MCPServer
9-
from mcp.types import Resource
9+
from mcp.types import ListResourcesResult, Resource
1010

1111

1212
@pytest.fixture
1313
def simple_server() -> Server:
1414
"""Create a simple MCP server for testing."""
15-
server = Server(name="test_server")
16-
17-
# pragma: no cover - handler exists only to register a resource capability.
18-
# Transport tests verify stream creation, not handler invocation.
19-
@server.list_resources()
20-
async def handle_list_resources(): # pragma: no cover
21-
return [
22-
Resource(
23-
uri="memory://test",
24-
name="Test Resource",
25-
description="A test resource",
26-
)
27-
]
2815

29-
return server
16+
async def handle_list_resources(
17+
ctx: ServerRequestContext, params: types.PaginatedRequestParams | None
18+
) -> ListResourcesResult: # pragma: no cover
19+
return ListResourcesResult(
20+
resources=[
21+
Resource(
22+
uri="memory://test",
23+
name="Test Resource",
24+
description="A test resource",
25+
)
26+
]
27+
)
28+
29+
return Server(name="test_server", on_list_resources=handle_list_resources)
3030

3131

3232
@pytest.fixture

0 commit comments

Comments
 (0)