Skip to content

Commit 9bfd6c9

Browse files
committed
PR 2452
1 parent e8e6484 commit 9bfd6c9

7 files changed

Lines changed: 623 additions & 1 deletion

File tree

Lines changed: 173 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,173 @@
1+
"""In-memory `Dispatcher` that wires two peers together with no transport.
2+
3+
`DirectDispatcher` is the simplest possible `Dispatcher` implementation: a
4+
request on one side directly invokes the other side's `on_request`. There is no
5+
serialization, no JSON-RPC framing, and no streams. It exists to:
6+
7+
* prove the `Dispatcher` Protocol is implementable without JSON-RPC
8+
* provide a fast substrate for testing the layers above the dispatcher
9+
(`ServerRunner`, `Context`, `Connection`) without wire-level moving parts
10+
* embed a server in-process when the JSON-RPC overhead is unnecessary
11+
12+
Unlike `JSONRPCDispatcher`, exceptions raised in a handler propagate directly
13+
to the caller — there is no exception-to-`ErrorData` boundary here.
14+
"""
15+
16+
from __future__ import annotations
17+
18+
from collections.abc import Awaitable, Callable, Mapping
19+
from dataclasses import dataclass, field
20+
from typing import Any
21+
22+
import anyio
23+
24+
from mcp.shared.dispatcher import CallOptions, OnNotify, OnRequest, ProgressFnT
25+
from mcp.shared.exceptions import MCPError, NoBackChannelError
26+
from mcp.shared.transport_context import TransportContext
27+
from mcp.types import INTERNAL_ERROR, REQUEST_TIMEOUT
28+
29+
__all__ = ["DirectDispatcher", "create_direct_dispatcher_pair"]
30+
31+
DIRECT_TRANSPORT_KIND = "direct"
32+
33+
34+
_Request = Callable[[str, Mapping[str, Any] | None, CallOptions | None], Awaitable[dict[str, Any]]]
35+
_Notify = Callable[[str, Mapping[str, Any] | None], Awaitable[None]]
36+
37+
38+
@dataclass
39+
class _DirectDispatchContext:
40+
"""`DispatchContext` for an inbound request on a `DirectDispatcher`.
41+
42+
The back-channel callables target the *originating* side, so a handler's
43+
`send_raw_request` reaches the peer that made the inbound request.
44+
"""
45+
46+
transport: TransportContext
47+
_back_request: _Request
48+
_back_notify: _Notify
49+
_on_progress: ProgressFnT | None = None
50+
cancel_requested: anyio.Event = field(default_factory=anyio.Event)
51+
52+
async def notify(self, method: str, params: Mapping[str, Any] | None) -> None:
53+
await self._back_notify(method, params)
54+
55+
async def send_raw_request(
56+
self,
57+
method: str,
58+
params: Mapping[str, Any] | None,
59+
opts: CallOptions | None = None,
60+
) -> dict[str, Any]:
61+
if not self.transport.can_send_request:
62+
raise NoBackChannelError(method)
63+
return await self._back_request(method, params, opts)
64+
65+
async def progress(self, progress: float, total: float | None = None, message: str | None = None) -> None:
66+
if self._on_progress is not None:
67+
await self._on_progress(progress, total, message)
68+
69+
70+
class DirectDispatcher:
71+
"""A `Dispatcher` that calls a peer's handlers directly, in-process.
72+
73+
Two instances are wired together with `create_direct_dispatcher_pair`; each
74+
holds a reference to the other. `send_raw_request` on one awaits the peer's
75+
`on_request`. `run` parks until `close` is called.
76+
"""
77+
78+
def __init__(self, transport_ctx: TransportContext):
79+
self._transport_ctx = transport_ctx
80+
self._peer: DirectDispatcher | None = None
81+
self._on_request: OnRequest | None = None
82+
self._on_notify: OnNotify | None = None
83+
self._ready = anyio.Event()
84+
self._closed = anyio.Event()
85+
86+
def connect_to(self, peer: DirectDispatcher) -> None:
87+
self._peer = peer
88+
89+
async def send_raw_request(
90+
self,
91+
method: str,
92+
params: Mapping[str, Any] | None,
93+
opts: CallOptions | None = None,
94+
) -> dict[str, Any]:
95+
if self._peer is None:
96+
raise RuntimeError("DirectDispatcher has no peer; use create_direct_dispatcher_pair()")
97+
return await self._peer._dispatch_request(method, params, opts)
98+
99+
async def notify(self, method: str, params: Mapping[str, Any] | None) -> None:
100+
if self._peer is None:
101+
raise RuntimeError("DirectDispatcher has no peer; use create_direct_dispatcher_pair()")
102+
await self._peer._dispatch_notify(method, params)
103+
104+
async def run(self, on_request: OnRequest, on_notify: OnNotify) -> None:
105+
self._on_request = on_request
106+
self._on_notify = on_notify
107+
self._ready.set()
108+
await self._closed.wait()
109+
110+
def close(self) -> None:
111+
self._closed.set()
112+
113+
def _make_context(self, on_progress: ProgressFnT | None = None) -> _DirectDispatchContext:
114+
assert self._peer is not None
115+
peer = self._peer
116+
return _DirectDispatchContext(
117+
transport=self._transport_ctx,
118+
_back_request=lambda m, p, o: peer._dispatch_request(m, p, o),
119+
_back_notify=lambda m, p: peer._dispatch_notify(m, p),
120+
_on_progress=on_progress,
121+
)
122+
123+
async def _dispatch_request(
124+
self,
125+
method: str,
126+
params: Mapping[str, Any] | None,
127+
opts: CallOptions | None,
128+
) -> dict[str, Any]:
129+
await self._ready.wait()
130+
assert self._on_request is not None
131+
opts = opts or {}
132+
dctx = self._make_context(on_progress=opts.get("on_progress"))
133+
try:
134+
with anyio.fail_after(opts.get("timeout")):
135+
try:
136+
return await self._on_request(dctx, method, params)
137+
except MCPError:
138+
raise
139+
except Exception as e:
140+
raise MCPError(code=INTERNAL_ERROR, message=str(e)) from e
141+
except TimeoutError:
142+
raise MCPError(
143+
code=REQUEST_TIMEOUT,
144+
message=f"Timed out after {opts.get('timeout')}s waiting for {method!r}",
145+
) from None
146+
147+
async def _dispatch_notify(self, method: str, params: Mapping[str, Any] | None) -> None:
148+
await self._ready.wait()
149+
assert self._on_notify is not None
150+
dctx = self._make_context()
151+
await self._on_notify(dctx, method, params)
152+
153+
154+
def create_direct_dispatcher_pair(
155+
*,
156+
can_send_request: bool = True,
157+
) -> tuple[DirectDispatcher, DirectDispatcher]:
158+
"""Create two `DirectDispatcher` instances wired to each other.
159+
160+
Args:
161+
can_send_request: Sets `TransportContext.can_send_request` on both
162+
sides. Pass ``False`` to simulate a transport with no back-channel.
163+
164+
Returns:
165+
A ``(left, right)`` pair. Conventionally ``left`` is the client side
166+
and ``right`` is the server side, but the wiring is symmetric.
167+
"""
168+
ctx = TransportContext(kind=DIRECT_TRANSPORT_KIND, can_send_request=can_send_request)
169+
left = DirectDispatcher(ctx)
170+
right = DirectDispatcher(ctx)
171+
left.connect_to(right)
172+
right.connect_to(left)
173+
return left, right

src/mcp/shared/dispatcher.py

Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
"""Dispatcher Protocol — the call/return boundary between transports and handlers.
2+
3+
A Dispatcher turns a duplex message channel into two things:
4+
5+
* an outbound API: ``send_raw_request(method, params)`` and ``notify(method, params)``
6+
* an inbound pump: ``run(on_request, on_notify)`` that drives the receive loop
7+
and invokes the supplied handlers for each incoming request/notification
8+
9+
It is deliberately *not* MCP-aware. Method names are strings, params and
10+
results are ``dict[str, Any]``. The MCP type layer (request/result models,
11+
capability negotiation, ``Context``) sits above this; the wire encoding
12+
(JSON-RPC, gRPC, in-process direct calls) sits below it.
13+
14+
See ``JSONRPCDispatcher`` for the production implementation and
15+
``DirectDispatcher`` for an in-memory implementation used in tests and for
16+
embedding a server in-process.
17+
"""
18+
19+
from collections.abc import Awaitable, Callable, Mapping
20+
from typing import Any, Protocol, TypedDict, TypeVar, runtime_checkable
21+
22+
import anyio
23+
24+
from mcp.shared.transport_context import TransportContext
25+
26+
__all__ = [
27+
"CallOptions",
28+
"DispatchContext",
29+
"DispatchMiddleware",
30+
"Dispatcher",
31+
"OnNotify",
32+
"OnRequest",
33+
"Outbound",
34+
"ProgressFnT",
35+
]
36+
37+
TransportT_co = TypeVar("TransportT_co", bound=TransportContext, covariant=True)
38+
39+
40+
class ProgressFnT(Protocol):
41+
"""Callback invoked when a progress notification arrives for a pending request."""
42+
43+
async def __call__(self, progress: float, total: float | None, message: str | None) -> None: ...
44+
45+
46+
class CallOptions(TypedDict, total=False):
47+
"""Per-call options for `Outbound.send_raw_request`.
48+
49+
All keys are optional. Dispatchers ignore keys they do not understand.
50+
"""
51+
52+
timeout: float
53+
"""Seconds to wait for a result before raising and sending ``notifications/cancelled``."""
54+
55+
on_progress: ProgressFnT
56+
"""Receive ``notifications/progress`` updates for this request."""
57+
58+
resumption_token: str
59+
"""Opaque token to resume a previously interrupted request (transport-dependent)."""
60+
61+
on_resumption_token: Callable[[str], Awaitable[None]]
62+
"""Receive a resumption token when the transport issues one."""
63+
64+
65+
@runtime_checkable
66+
class Outbound(Protocol):
67+
"""Anything that can send requests and notifications to the peer.
68+
69+
Both `Dispatcher` (top-level outbound) and `DispatchContext` (back-channel
70+
during an inbound request) extend this. The MCP type layer (`PeerMixin`,
71+
`Connection`, `Context`) builds typed ``send_request`` / convenience methods
72+
on top of this raw channel.
73+
"""
74+
75+
async def send_raw_request(
76+
self,
77+
method: str,
78+
params: Mapping[str, Any] | None,
79+
opts: CallOptions | None = None,
80+
) -> dict[str, Any]:
81+
"""Send a request and await its raw result dict.
82+
83+
Raises:
84+
MCPError: If the peer responded with an error, or the handler
85+
raised. Implementations normalize all handler exceptions to
86+
`MCPError` so callers see a single exception type.
87+
"""
88+
...
89+
90+
async def notify(self, method: str, params: Mapping[str, Any] | None) -> None:
91+
"""Send a fire-and-forget notification."""
92+
...
93+
94+
95+
class DispatchContext(Outbound, Protocol[TransportT_co]):
96+
"""Per-request context handed to ``on_request`` / ``on_notify``.
97+
98+
Carries the transport metadata for the inbound message and provides the
99+
back-channel for sending requests/notifications to the peer while handling
100+
it. `send_raw_request` raises `NoBackChannelError` if
101+
``transport.can_send_request`` is ``False``.
102+
"""
103+
104+
@property
105+
def transport(self) -> TransportT_co:
106+
"""Transport-specific metadata for this inbound message."""
107+
...
108+
109+
@property
110+
def cancel_requested(self) -> anyio.Event:
111+
"""Set when the peer sends ``notifications/cancelled`` for this request."""
112+
...
113+
114+
async def progress(self, progress: float, total: float | None = None, message: str | None = None) -> None:
115+
"""Report progress for the inbound request, if the peer supplied a progress token.
116+
117+
A no-op when no token was supplied.
118+
"""
119+
...
120+
121+
122+
OnRequest = Callable[[DispatchContext[TransportContext], str, Mapping[str, Any] | None], Awaitable[dict[str, Any]]]
123+
"""Handler for inbound requests: ``(ctx, method, params) -> result``. Raise ``MCPError`` to send an error response."""
124+
125+
OnNotify = Callable[[DispatchContext[TransportContext], str, Mapping[str, Any] | None], Awaitable[None]]
126+
"""Handler for inbound notifications: ``(ctx, method, params)``."""
127+
128+
DispatchMiddleware = Callable[[OnRequest], OnRequest]
129+
"""Wraps an ``OnRequest`` to produce another ``OnRequest``. Applied outermost-first."""
130+
131+
132+
class Dispatcher(Outbound, Protocol[TransportT_co]):
133+
"""A duplex request/notification channel with call-return semantics.
134+
135+
Implementations own correlation of outbound requests to inbound results, the
136+
receive loop, per-request concurrency, and cancellation/progress wiring.
137+
"""
138+
139+
async def run(self, on_request: OnRequest, on_notify: OnNotify) -> None:
140+
"""Drive the receive loop until the underlying channel closes.
141+
142+
Each inbound request is dispatched to ``on_request`` in its own task;
143+
the returned dict (or raised ``MCPError``) is sent back as the response.
144+
Inbound notifications go to ``on_notify``.
145+
"""
146+
...

src/mcp/shared/exceptions.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from typing import Any, cast
44

5-
from mcp.types import URL_ELICITATION_REQUIRED, ElicitRequestURLParams, ErrorData, JSONRPCError
5+
from mcp.types import INVALID_REQUEST, URL_ELICITATION_REQUIRED, ElicitRequestURLParams, ErrorData, JSONRPCError
66

77

88
class MCPError(Exception):
@@ -41,6 +41,25 @@ def __str__(self) -> str:
4141
return self.message
4242

4343

44+
class NoBackChannelError(MCPError):
45+
"""Raised when sending a server-initiated request over a transport that cannot deliver it.
46+
47+
Stateless HTTP and JSON-response-mode HTTP have no channel for the server to
48+
push requests (sampling, elicitation, roots/list) to the client. This is
49+
raised by `DispatchContext.send_raw_request` when `transport.can_send_request`
50+
is ``False``, and serializes to an ``INVALID_REQUEST`` error response.
51+
"""
52+
53+
def __init__(self, method: str):
54+
super().__init__(
55+
code=INVALID_REQUEST,
56+
message=(
57+
f"Cannot send {method!r}: this transport context has no back-channel for server-initiated requests."
58+
),
59+
)
60+
self.method = method
61+
62+
4463
class StatelessModeNotSupported(RuntimeError):
4564
"""Raised when attempting to use a method that is not supported in stateless mode.
4665
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
"""Transport-specific metadata attached to each inbound message.
2+
3+
`TransportContext` is the base; each transport defines its own subclass with
4+
whatever fields make sense (HTTP request id, ASGI scope, stdio process handle,
5+
etc.). The dispatcher passes it through opaquely; only the layers above the
6+
dispatcher (`ServerRunner`, `Context`, user handlers) read its concrete fields.
7+
"""
8+
9+
from dataclasses import dataclass
10+
11+
__all__ = ["TransportContext"]
12+
13+
14+
@dataclass(kw_only=True, frozen=True)
15+
class TransportContext:
16+
"""Base transport metadata for an inbound message.
17+
18+
Subclass per transport and add fields as needed. Instances are immutable.
19+
"""
20+
21+
kind: str
22+
"""Short identifier for the transport (e.g. ``"stdio"``, ``"streamable-http"``)."""
23+
24+
can_send_request: bool
25+
"""Whether the transport can deliver server-initiated requests to the peer.
26+
27+
``False`` for stateless HTTP and HTTP with JSON response mode; ``True`` for
28+
stdio, SSE, and stateful streamable HTTP. When ``False``,
29+
`DispatchContext.send_raw_request` raises `NoBackChannelError`.
30+
"""

0 commit comments

Comments
 (0)