Skip to content

Commit 34ad089

Browse files
committed
Refactor task architecture: separate pure state from server integration
This refactor eliminates circular imports and simplifies the task API: Architecture changes: - Pure task state (TaskContext, TaskStore, helpers) stays in shared/experimental/tasks/ - Server integration (ServerTaskContext, TaskResultHandler, TaskSupport, Experimental) moves to server/experimental/ - Empty __init__.py files with absolute imports only New simplified API: - server.experimental.enable_tasks() - one-line setup, auto-registers handlers - ctx.experimental.run_task(work) - spawns work, auto-completes/fails - ServerTaskContext.elicit()/create_message() - queues requests properly Key improvements: - No TYPE_CHECKING hacks or circular import workarounds - ServerTaskContext reuses session._build_*_request() helpers (no duplication) - TaskSupport manages task_group lifecycle - run_task() handles task creation, spawning, and completion automatically Test changes: - Removed tests for old internals (test_response_routing, test_elicitation_flow, etc.) - Added test_run_task_flow.py for new user flow - Fixed remaining tests to use new API (removed notify= params, updated imports)
1 parent 3c4f262 commit 34ad089

29 files changed

+1391
-3210
lines changed

examples/servers/simple-task-interactive/mcp_simple_task_interactive/server.py

Lines changed: 63 additions & 156 deletions
Original file line numberDiff line numberDiff line change
@@ -1,65 +1,28 @@
1-
"""Simple interactive task server demonstrating elicitation and sampling."""
1+
"""Simple interactive task server demonstrating elicitation and sampling.
2+
3+
This example shows the simplified task API where:
4+
- server.experimental.enable_tasks() sets up all infrastructure
5+
- ctx.experimental.run_task() handles task lifecycle automatically
6+
- ServerTaskContext.elicit() and ServerTaskContext.create_message() queue requests properly
7+
"""
28

39
from collections.abc import AsyncIterator
410
from contextlib import asynccontextmanager
5-
from dataclasses import dataclass
611
from typing import Any
712

8-
import anyio
913
import click
1014
import mcp.types as types
1115
import uvicorn
12-
from anyio.abc import TaskGroup
16+
from mcp.server.experimental.task_context import ServerTaskContext
1317
from mcp.server.lowlevel import Server
14-
from mcp.server.session import ServerSession
1518
from mcp.server.streamable_http_manager import StreamableHTTPSessionManager
16-
from mcp.shared.experimental.tasks import (
17-
InMemoryTaskMessageQueue,
18-
InMemoryTaskStore,
19-
TaskResultHandler,
20-
TaskSession,
21-
task_execution,
22-
)
2319
from starlette.applications import Starlette
2420
from starlette.routing import Mount
2521

22+
server = Server("simple-task-interactive")
2623

27-
@dataclass
28-
class AppContext:
29-
task_group: TaskGroup
30-
store: InMemoryTaskStore
31-
queue: InMemoryTaskMessageQueue
32-
handler: TaskResultHandler
33-
# Track sessions that have been configured (session ID -> bool)
34-
configured_sessions: dict[int, bool]
35-
36-
37-
@asynccontextmanager
38-
async def lifespan(server: Server[AppContext, Any]) -> AsyncIterator[AppContext]:
39-
store = InMemoryTaskStore()
40-
queue = InMemoryTaskMessageQueue()
41-
handler = TaskResultHandler(store, queue)
42-
async with anyio.create_task_group() as tg:
43-
yield AppContext(
44-
task_group=tg,
45-
store=store,
46-
queue=queue,
47-
handler=handler,
48-
configured_sessions={},
49-
)
50-
store.cleanup()
51-
queue.cleanup()
52-
53-
54-
server: Server[AppContext, Any] = Server("simple-task-interactive", lifespan=lifespan)
55-
56-
57-
def ensure_handler_configured(session: ServerSession, app: AppContext) -> None:
58-
"""Ensure the task result handler is configured for this session (once)."""
59-
session_id = id(session)
60-
if session_id not in app.configured_sessions:
61-
session.add_response_router(app.handler)
62-
app.configured_sessions[session_id] = True
24+
# Enable task support - this auto-registers all handlers
25+
server.experimental.enable_tasks()
6326

6427

6528
@server.list_tools()
@@ -84,129 +47,73 @@ async def list_tools() -> list[types.Tool]:
8447

8548

8649
@server.call_tool()
87-
async def handle_call_tool(name: str, arguments: dict[str, Any]) -> list[types.TextContent] | types.CreateTaskResult:
50+
async def handle_call_tool(name: str, arguments: dict[str, Any]) -> types.CallToolResult | types.CreateTaskResult:
8851
ctx = server.request_context
89-
app = ctx.lifespan_context
9052

91-
# Validate task mode
53+
# Validate task mode - this tool requires task augmentation
9254
ctx.experimental.validate_task_mode(types.TASK_REQUIRED)
9355

94-
# Ensure handler is configured for response routing
95-
ensure_handler_configured(ctx.session, app)
96-
97-
# Create task
98-
metadata = ctx.experimental.task_metadata
99-
assert metadata is not None
100-
task = await app.store.create_task(metadata)
101-
10256
if name == "confirm_delete":
10357
filename = arguments.get("filename", "unknown.txt")
10458
print(f"\n[Server] confirm_delete called for '{filename}'")
105-
print(f"[Server] Task created: {task.taskId}")
106-
107-
async def do_confirm() -> None:
108-
async with task_execution(task.taskId, app.store) as task_ctx:
109-
task_session = TaskSession(
110-
session=ctx.session,
111-
task_id=task.taskId,
112-
store=app.store,
113-
queue=app.queue,
114-
)
115-
116-
print("[Server] Sending elicitation request to client...")
117-
result = await task_session.elicit(
118-
message=f"Are you sure you want to delete '{filename}'?",
119-
requestedSchema={
120-
"type": "object",
121-
"properties": {"confirm": {"type": "boolean"}},
122-
"required": ["confirm"],
123-
},
124-
)
125-
126-
print(f"[Server] Received elicitation response: action={result.action}, content={result.content}")
127-
if result.action == "accept" and result.content:
128-
confirmed = result.content.get("confirm", False)
129-
text = f"Deleted '{filename}'" if confirmed else "Deletion cancelled"
130-
else:
131-
text = "Deletion cancelled"
132-
133-
print(f"[Server] Completing task with result: {text}")
134-
await task_ctx.complete(
135-
types.CallToolResult(content=[types.TextContent(type="text", text=text)]),
136-
notify=True,
137-
)
138-
139-
app.task_group.start_soon(do_confirm)
59+
60+
async def do_confirm(task: ServerTaskContext) -> types.CallToolResult:
61+
print(f"[Server] Task {task.task_id} starting elicitation...")
62+
63+
result = await task.elicit(
64+
message=f"Are you sure you want to delete '{filename}'?",
65+
requestedSchema={
66+
"type": "object",
67+
"properties": {"confirm": {"type": "boolean"}},
68+
"required": ["confirm"],
69+
},
70+
)
71+
72+
print(f"[Server] Received elicitation response: action={result.action}, content={result.content}")
73+
74+
if result.action == "accept" and result.content:
75+
confirmed = result.content.get("confirm", False)
76+
text = f"Deleted '{filename}'" if confirmed else "Deletion cancelled"
77+
else:
78+
text = "Deletion cancelled"
79+
80+
print(f"[Server] Completing task with result: {text}")
81+
return types.CallToolResult(content=[types.TextContent(type="text", text=text)])
82+
83+
# run_task creates the task, spawns work, returns CreateTaskResult immediately
84+
return await ctx.experimental.run_task(do_confirm)
14085

14186
elif name == "write_haiku":
14287
topic = arguments.get("topic", "nature")
14388
print(f"\n[Server] write_haiku called for topic '{topic}'")
144-
print(f"[Server] Task created: {task.taskId}")
145-
146-
async def do_haiku() -> None:
147-
async with task_execution(task.taskId, app.store) as task_ctx:
148-
task_session = TaskSession(
149-
session=ctx.session,
150-
task_id=task.taskId,
151-
store=app.store,
152-
queue=app.queue,
153-
)
154-
155-
print("[Server] Sending sampling request to client...")
156-
result = await task_session.create_message(
157-
messages=[
158-
types.SamplingMessage(
159-
role="user",
160-
content=types.TextContent(type="text", text=f"Write a haiku about {topic}"),
161-
)
162-
],
163-
max_tokens=50,
164-
)
165-
166-
haiku = "No response"
167-
if isinstance(result.content, types.TextContent):
168-
haiku = result.content.text
169-
170-
print(f"[Server] Received sampling response: {haiku[:50]}...")
171-
print("[Server] Completing task with haiku")
172-
await task_ctx.complete(
173-
types.CallToolResult(content=[types.TextContent(type="text", text=f"Haiku:\n{haiku}")]),
174-
notify=True,
175-
)
176-
177-
app.task_group.start_soon(do_haiku)
178-
179-
return types.CreateTaskResult(task=task)
180-
181-
182-
@server.experimental.get_task()
183-
async def handle_get_task(request: types.GetTaskRequest) -> types.GetTaskResult:
184-
app = server.request_context.lifespan_context
185-
task = await app.store.get_task(request.params.taskId)
186-
if task is None:
187-
raise ValueError(f"Task {request.params.taskId} not found")
188-
return types.GetTaskResult(
189-
taskId=task.taskId,
190-
status=task.status,
191-
statusMessage=task.statusMessage,
192-
createdAt=task.createdAt,
193-
lastUpdatedAt=task.lastUpdatedAt,
194-
ttl=task.ttl,
195-
pollInterval=task.pollInterval,
196-
)
19789

90+
async def do_haiku(task: ServerTaskContext) -> types.CallToolResult:
91+
print(f"[Server] Task {task.task_id} starting sampling...")
19892

199-
@server.experimental.get_task_result()
200-
async def handle_get_task_result(
201-
request: types.GetTaskPayloadRequest,
202-
) -> types.GetTaskPayloadResult:
203-
ctx = server.request_context
204-
app = ctx.lifespan_context
93+
result = await task.create_message(
94+
messages=[
95+
types.SamplingMessage(
96+
role="user",
97+
content=types.TextContent(type="text", text=f"Write a haiku about {topic}"),
98+
)
99+
],
100+
max_tokens=50,
101+
)
102+
103+
haiku = "No response"
104+
if isinstance(result.content, types.TextContent):
105+
haiku = result.content.text
205106

206-
# Ensure handler is configured for this session
207-
ensure_handler_configured(ctx.session, app)
107+
print(f"[Server] Received sampling response: {haiku[:50]}...")
108+
return types.CallToolResult(content=[types.TextContent(type="text", text=f"Haiku:\n{haiku}")])
208109

209-
return await app.handler.handle(request, ctx.session, ctx.request_id)
110+
return await ctx.experimental.run_task(do_haiku)
111+
112+
else:
113+
return types.CallToolResult(
114+
content=[types.TextContent(type="text", text=f"Unknown tool: {name}")],
115+
isError=True,
116+
)
210117

211118

212119
def create_app(session_manager: StreamableHTTPSessionManager) -> Starlette:

examples/servers/simple-task/mcp_simple_task/server.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@
1212
from anyio.abc import TaskGroup
1313
from mcp.server.lowlevel import Server
1414
from mcp.server.streamable_http_manager import StreamableHTTPSessionManager
15-
from mcp.shared.experimental.tasks import InMemoryTaskStore, task_execution
15+
from mcp.shared.experimental.tasks.helpers import task_execution
16+
from mcp.shared.experimental.tasks.in_memory_task_store import InMemoryTaskStore
1617
from starlette.applications import Starlette
1718
from starlette.routing import Mount
1819

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
"""
2+
Server-side experimental features.
3+
4+
WARNING: These APIs are experimental and may change without notice.
5+
6+
Import directly from submodules:
7+
- mcp.server.experimental.task_context.ServerTaskContext
8+
- mcp.server.experimental.task_support.TaskSupport
9+
- mcp.server.experimental.task_result_handler.TaskResultHandler
10+
- mcp.server.experimental.request_context.Experimental
11+
"""

0 commit comments

Comments
 (0)