Skip to content

Commit 8edcf43

Browse files
committed
PR 2460
1 parent 0433ef5 commit 8edcf43

9 files changed

Lines changed: 1230 additions & 0 deletions

File tree

src/mcp/server/_typed_request.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
"""Typed ``send_request`` for server-to-client requests.
2+
3+
`TypedServerRequestMixin` provides a typed `send_request(req) -> Result` over
4+
the host's raw `Outbound.send_raw_request`. Spec server-to-client request types
5+
have their result type inferred via per-type overloads; custom requests pass
6+
``result_type=`` explicitly.
7+
8+
If the spec's request set grows substantially, consider declaring the result
9+
mapping on the request types themselves (a ``__mcp_result__`` ClassVar read via
10+
a structural protocol) so this overload ladder doesn't need maintaining
11+
per-host-class.
12+
"""
13+
14+
from typing import Any, TypeVar, overload
15+
16+
from pydantic import BaseModel
17+
18+
from mcp.shared.dispatcher import CallOptions, Outbound
19+
from mcp.shared.peer import dump_params
20+
from mcp.types import (
21+
CreateMessageRequest,
22+
CreateMessageResult,
23+
ElicitRequest,
24+
ElicitResult,
25+
EmptyResult,
26+
ListRootsRequest,
27+
ListRootsResult,
28+
PingRequest,
29+
Request,
30+
)
31+
32+
__all__ = ["TypedServerRequestMixin"]
33+
34+
ResultT = TypeVar("ResultT", bound=BaseModel)
35+
36+
_RESULT_FOR: dict[type[Request[Any, Any]], type[BaseModel]] = {
37+
CreateMessageRequest: CreateMessageResult,
38+
ElicitRequest: ElicitResult,
39+
ListRootsRequest: ListRootsResult,
40+
PingRequest: EmptyResult,
41+
}
42+
43+
44+
class TypedServerRequestMixin:
45+
"""Typed ``send_request`` for the server-to-client request set.
46+
47+
Mixed into `Connection` and the server `Context`. Each method constrains
48+
``self`` to `Outbound` so any host with ``send_raw_request`` works.
49+
"""
50+
51+
@overload
52+
async def send_request(
53+
self: Outbound, req: CreateMessageRequest, *, opts: CallOptions | None = None
54+
) -> CreateMessageResult: ...
55+
@overload
56+
async def send_request(self: Outbound, req: ElicitRequest, *, opts: CallOptions | None = None) -> ElicitResult: ...
57+
@overload
58+
async def send_request(
59+
self: Outbound, req: ListRootsRequest, *, opts: CallOptions | None = None
60+
) -> ListRootsResult: ...
61+
@overload
62+
async def send_request(self: Outbound, req: PingRequest, *, opts: CallOptions | None = None) -> EmptyResult: ...
63+
@overload
64+
async def send_request(
65+
self: Outbound, req: Request[Any, Any], *, result_type: type[ResultT], opts: CallOptions | None = None
66+
) -> ResultT: ...
67+
async def send_request(
68+
self: Outbound,
69+
req: Request[Any, Any],
70+
*,
71+
result_type: type[BaseModel] | None = None,
72+
opts: CallOptions | None = None,
73+
) -> BaseModel:
74+
"""Send a typed server-to-client request and return its typed result.
75+
76+
For spec request types the result type is inferred. For custom requests
77+
pass ``result_type=`` explicitly.
78+
79+
Raises:
80+
MCPError: The peer responded with an error.
81+
NoBackChannelError: No back-channel for server-initiated requests.
82+
KeyError: ``result_type`` omitted for a non-spec request type.
83+
"""
84+
raw = await self.send_raw_request(req.method, dump_params(req.params), opts)
85+
cls = result_type if result_type is not None else _RESULT_FOR[type(req)]
86+
return cls.model_validate(raw)

src/mcp/server/connection.py

Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
"""`Connection` — per-client connection state and the standalone outbound channel.
2+
3+
Always present on `Context` (never ``None``), even in stateless deployments.
4+
Holds peer info populated at ``initialize`` time, the per-connection lifespan
5+
output, and an `Outbound` for the standalone stream (the SSE GET stream in
6+
streamable HTTP, or the single duplex stream in stdio).
7+
8+
`notify` is best-effort: it never raises. If there's no standalone channel
9+
(stateless HTTP) or the stream has been dropped, the notification is
10+
debug-logged and silently discarded — server-initiated notifications are
11+
inherently advisory. `send_raw_request` *does* raise `NoBackChannelError` when
12+
there's no channel; `ping` is the only spec-sanctioned standalone request.
13+
"""
14+
15+
import logging
16+
from collections.abc import Mapping
17+
from typing import Any
18+
19+
import anyio
20+
21+
from mcp.server._typed_request import TypedServerRequestMixin
22+
from mcp.shared.dispatcher import CallOptions, Outbound
23+
from mcp.shared.exceptions import NoBackChannelError
24+
from mcp.shared.peer import Meta, dump_params
25+
from mcp.types import ClientCapabilities, Implementation, LoggingLevel
26+
27+
__all__ = ["Connection"]
28+
29+
logger = logging.getLogger(__name__)
30+
31+
32+
def _notification_params(payload: dict[str, Any] | None, meta: Meta | None) -> dict[str, Any] | None:
33+
if not meta:
34+
return payload
35+
out = dict(payload or {})
36+
out["_meta"] = meta
37+
return out
38+
39+
40+
class Connection(TypedServerRequestMixin):
41+
"""Per-client connection state and standalone-stream `Outbound`.
42+
43+
Constructed by `ServerRunner` once per connection. The peer-info fields are
44+
``None`` until ``initialize`` completes; ``initialized`` is set then.
45+
"""
46+
47+
def __init__(self, outbound: Outbound, *, has_standalone_channel: bool) -> None:
48+
self._outbound = outbound
49+
self.has_standalone_channel = has_standalone_channel
50+
51+
self.client_info: Implementation | None = None
52+
self.client_capabilities: ClientCapabilities | None = None
53+
self.protocol_version: str | None = None
54+
self.initialized: anyio.Event = anyio.Event()
55+
# TODO: make this generic (Connection[StateT]) once connection_lifespan
56+
# wiring lands in ServerRunner.
57+
self.state: Any = None
58+
59+
async def send_raw_request(
60+
self,
61+
method: str,
62+
params: Mapping[str, Any] | None,
63+
opts: CallOptions | None = None,
64+
) -> dict[str, Any]:
65+
"""Send a raw request on the standalone stream.
66+
67+
Low-level `Outbound` channel. Prefer the typed ``send_request`` (from
68+
`TypedServerRequestMixin`) or the convenience methods below; use this
69+
directly only for off-spec messages.
70+
71+
Raises:
72+
MCPError: The peer responded with an error.
73+
NoBackChannelError: ``has_standalone_channel`` is ``False``.
74+
"""
75+
if not self.has_standalone_channel:
76+
raise NoBackChannelError(method)
77+
return await self._outbound.send_raw_request(method, params, opts)
78+
79+
async def notify(self, method: str, params: Mapping[str, Any] | None) -> None:
80+
"""Send a best-effort notification on the standalone stream.
81+
82+
Never raises. If there's no standalone channel or the stream is broken,
83+
the notification is dropped and debug-logged.
84+
"""
85+
if not self.has_standalone_channel:
86+
logger.debug("dropped %s: no standalone channel", method)
87+
return
88+
try:
89+
await self._outbound.notify(method, params)
90+
except (anyio.BrokenResourceError, anyio.ClosedResourceError):
91+
logger.debug("dropped %s: standalone stream closed", method)
92+
93+
async def ping(self, *, meta: Meta | None = None, opts: CallOptions | None = None) -> None:
94+
"""Send a ``ping`` request on the standalone stream.
95+
96+
Raises:
97+
MCPError: The peer responded with an error.
98+
NoBackChannelError: ``has_standalone_channel`` is ``False``.
99+
"""
100+
await self.send_raw_request("ping", dump_params(None, meta), opts)
101+
102+
async def log(self, level: LoggingLevel, data: Any, logger: str | None = None, *, meta: Meta | None = None) -> None:
103+
"""Send a ``notifications/message`` log entry on the standalone stream. Best-effort."""
104+
params: dict[str, Any] = {"level": level, "data": data}
105+
if logger is not None:
106+
params["logger"] = logger
107+
await self.notify("notifications/message", _notification_params(params, meta))
108+
109+
async def send_tool_list_changed(self, *, meta: Meta | None = None) -> None:
110+
await self.notify("notifications/tools/list_changed", _notification_params(None, meta))
111+
112+
async def send_prompt_list_changed(self, *, meta: Meta | None = None) -> None:
113+
await self.notify("notifications/prompts/list_changed", _notification_params(None, meta))
114+
115+
async def send_resource_list_changed(self, *, meta: Meta | None = None) -> None:
116+
await self.notify("notifications/resources/list_changed", _notification_params(None, meta))
117+
118+
async def send_resource_updated(self, uri: str, *, meta: Meta | None = None) -> None:
119+
await self.notify("notifications/resources/updated", _notification_params({"uri": uri}, meta))
120+
121+
def check_capability(self, capability: ClientCapabilities) -> bool:
122+
"""Return whether the connected client declared the given capability.
123+
124+
Returns ``False`` if ``initialize`` hasn't completed yet.
125+
"""
126+
# TODO: redesign — mirrors v1 ServerSession.check_client_capability
127+
# verbatim for parity.
128+
if self.client_capabilities is None:
129+
return False
130+
have = self.client_capabilities
131+
if capability.roots is not None:
132+
if have.roots is None:
133+
return False
134+
if capability.roots.list_changed and not have.roots.list_changed:
135+
return False
136+
if capability.sampling is not None and have.sampling is None:
137+
return False
138+
if capability.elicitation is not None and have.elicitation is None:
139+
return False
140+
if capability.experimental is not None:
141+
if have.experimental is None:
142+
return False
143+
for k in capability.experimental:
144+
if k not in have.experimental:
145+
return False
146+
return True

src/mcp/server/context.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,17 @@
55

66
from typing_extensions import TypeVar
77

8+
from mcp.server._typed_request import TypedServerRequestMixin
9+
from mcp.server.connection import Connection
810
from mcp.server.experimental.request_context import Experimental
911
from mcp.server.session import ServerSession
1012
from mcp.shared._context import RequestContext
13+
from mcp.shared.context import BaseContext
14+
from mcp.shared.dispatcher import DispatchContext
1115
from mcp.shared.message import CloseSSEStreamCallback
16+
from mcp.shared.peer import Meta, PeerMixin
17+
from mcp.shared.transport_context import TransportContext
18+
from mcp.types import LoggingLevel, RequestParamsMeta
1219

1320
LifespanContextT = TypeVar("LifespanContextT", default=dict[str, Any])
1421
RequestT = TypeVar("RequestT", default=Any)
@@ -21,3 +28,56 @@ class ServerRequestContext(RequestContext[ServerSession], Generic[LifespanContex
2128
request: RequestT | None = None
2229
close_sse_stream: CloseSSEStreamCallback | None = None
2330
close_standalone_sse_stream: CloseSSEStreamCallback | None = None
31+
32+
33+
LifespanT = TypeVar("LifespanT", default=Any, covariant=True)
34+
TransportT = TypeVar("TransportT", bound=TransportContext, default=TransportContext, covariant=True)
35+
36+
37+
class Context(BaseContext[TransportT], PeerMixin, TypedServerRequestMixin, Generic[LifespanT, TransportT]):
38+
"""Server-side per-request context.
39+
40+
Composes `BaseContext` (forwards to `DispatchContext`, satisfies `Outbound`),
41+
`PeerMixin` (kwarg-style ``sample``/``elicit_*``/``list_roots``/``ping``),
42+
and `TypedServerRequestMixin` (typed ``send_request(req) -> Result``). Adds
43+
``lifespan`` and ``connection``.
44+
45+
Constructed by `ServerRunner` per inbound request and handed to the user's
46+
handler.
47+
"""
48+
49+
def __init__(
50+
self,
51+
dctx: DispatchContext[TransportT],
52+
*,
53+
lifespan: LifespanT,
54+
connection: Connection,
55+
meta: RequestParamsMeta | None = None,
56+
) -> None:
57+
super().__init__(dctx, meta=meta)
58+
self._lifespan = lifespan
59+
self._connection = connection
60+
61+
@property
62+
def lifespan(self) -> LifespanT:
63+
"""The server-wide lifespan output (what `Server(..., lifespan=...)` yielded)."""
64+
return self._lifespan
65+
66+
@property
67+
def connection(self) -> Connection:
68+
"""The per-client `Connection` for this request's connection."""
69+
return self._connection
70+
71+
async def log(self, level: LoggingLevel, data: Any, logger: str | None = None, *, meta: Meta | None = None) -> None:
72+
"""Send a request-scoped ``notifications/message`` log entry.
73+
74+
Uses this request's back-channel (so the entry rides the request's SSE
75+
stream in streamable HTTP), not the standalone stream — use
76+
``ctx.connection.log(...)`` for that.
77+
"""
78+
params: dict[str, Any] = {"level": level, "data": data}
79+
if logger is not None:
80+
params["logger"] = logger
81+
if meta:
82+
params["_meta"] = meta
83+
await self.notify("notifications/message", params)

src/mcp/shared/context.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
"""`BaseContext` — the user-facing per-request context.
2+
3+
Composition over a `DispatchContext`: forwards the transport metadata, the
4+
back-channel (`send_raw_request`/`notify`), progress reporting, and the cancel
5+
event. Adds `meta` (the inbound request's `_meta` field).
6+
7+
Satisfies `Outbound`, so `PeerMixin` works on it (the server-side `Context`
8+
mixes that in directly). Shared between client and server: the server's
9+
`Context` extends this with `lifespan`/`connection`; `ClientContext` is just an
10+
alias.
11+
"""
12+
13+
from collections.abc import Mapping
14+
from typing import Any, Generic
15+
16+
import anyio
17+
from typing_extensions import TypeVar
18+
19+
from mcp.shared.dispatcher import CallOptions, DispatchContext
20+
from mcp.shared.transport_context import TransportContext
21+
from mcp.types import RequestParamsMeta
22+
23+
__all__ = ["BaseContext"]
24+
25+
TransportT = TypeVar("TransportT", bound=TransportContext, default=TransportContext, covariant=True)
26+
27+
28+
class BaseContext(Generic[TransportT]):
29+
"""Per-request context wrapping a `DispatchContext`.
30+
31+
`ServerRunner` constructs one per inbound request and passes it to the
32+
user's handler.
33+
"""
34+
35+
def __init__(self, dctx: DispatchContext[TransportT], meta: RequestParamsMeta | None = None) -> None:
36+
self._dctx = dctx
37+
self._meta = meta
38+
39+
@property
40+
def transport(self) -> TransportT:
41+
"""Transport-specific metadata for this inbound request."""
42+
return self._dctx.transport
43+
44+
@property
45+
def cancel_requested(self) -> anyio.Event:
46+
"""Set when the peer sends ``notifications/cancelled`` for this request."""
47+
return self._dctx.cancel_requested
48+
49+
@property
50+
def can_send_request(self) -> bool:
51+
"""Whether the back-channel can deliver server-initiated requests."""
52+
return self._dctx.transport.can_send_request
53+
54+
@property
55+
def meta(self) -> RequestParamsMeta | None:
56+
"""The inbound request's ``_meta`` field, if present."""
57+
return self._meta
58+
59+
async def send_raw_request(
60+
self,
61+
method: str,
62+
params: Mapping[str, Any] | None,
63+
opts: CallOptions | None = None,
64+
) -> dict[str, Any]:
65+
"""Send a request to the peer on the back-channel.
66+
67+
Raises:
68+
MCPError: The peer responded with an error.
69+
NoBackChannelError: ``can_send_request`` is ``False``.
70+
"""
71+
return await self._dctx.send_raw_request(method, params, opts)
72+
73+
async def notify(self, method: str, params: Mapping[str, Any] | None) -> None:
74+
"""Send a notification to the peer on the back-channel."""
75+
await self._dctx.notify(method, params)
76+
77+
async def report_progress(self, progress: float, total: float | None = None, message: str | None = None) -> None:
78+
"""Report progress for this request, if the peer supplied a progress token.
79+
80+
A no-op when no token was supplied.
81+
"""
82+
await self._dctx.progress(progress, total, message)

0 commit comments

Comments
 (0)