Skip to content

Commit 3c31190

Browse files
committed
Add reviewed MCP proxy helper
1 parent d5b9155 commit 3c31190

File tree

3 files changed

+367
-0
lines changed

3 files changed

+367
-0
lines changed

src/mcp/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from .client.session import ClientSession
33
from .client.session_group import ClientSessionGroup
44
from .client.stdio import StdioServerParameters, stdio_client
5+
from .proxy import mcp_proxy
56
from .server.session import ServerSession
67
from .server.stdio import stdio_server
78
from .shared.exceptions import MCPError, UrlElicitationRequiredError
@@ -97,6 +98,7 @@
9798
"LoggingLevel",
9899
"LoggingMessageNotification",
99100
"MCPError",
101+
"mcp_proxy",
100102
"Notification",
101103
"PingRequest",
102104
"ProgressNotification",

src/mcp/proxy.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
"""Provide utilities for proxying messages between two MCP transports."""
2+
3+
from __future__ import annotations
4+
5+
import inspect
6+
from collections.abc import AsyncGenerator, Awaitable, Callable
7+
from contextlib import asynccontextmanager
8+
9+
import anyio
10+
11+
from mcp.shared._stream_protocols import ReadStream, WriteStream
12+
from mcp.shared.message import SessionMessage
13+
14+
MessageStream = tuple[ReadStream[SessionMessage | Exception], WriteStream[SessionMessage]]
15+
ErrorHandler = Callable[[Exception], None | Awaitable[None]]
16+
17+
18+
@asynccontextmanager
19+
async def mcp_proxy(
20+
transport_to_client: MessageStream,
21+
transport_to_server: MessageStream,
22+
on_error: ErrorHandler | None = None,
23+
) -> AsyncGenerator[None]:
24+
"""Proxy messages bidirectionally between two MCP transports."""
25+
client_read, client_write = transport_to_client
26+
server_read, server_write = transport_to_server
27+
28+
async with anyio.create_task_group() as task_group:
29+
task_group.start_soon(_forward_messages, client_read, server_write, on_error)
30+
task_group.start_soon(_forward_messages, server_read, client_write, on_error)
31+
try:
32+
yield
33+
finally:
34+
task_group.cancel_scope.cancel()
35+
36+
37+
async def _forward_messages(
38+
read_stream: ReadStream[SessionMessage | Exception],
39+
write_stream: WriteStream[SessionMessage],
40+
on_error: ErrorHandler | None,
41+
) -> None:
42+
try:
43+
async with write_stream:
44+
async with read_stream:
45+
async for item in read_stream:
46+
if isinstance(item, Exception):
47+
await _run_error_handler(item, on_error)
48+
continue
49+
50+
try:
51+
await write_stream.send(item)
52+
except anyio.ClosedResourceError:
53+
break
54+
except anyio.ClosedResourceError:
55+
return
56+
57+
58+
async def _run_error_handler(error: Exception, on_error: ErrorHandler | None) -> None:
59+
if on_error is None:
60+
return
61+
62+
try:
63+
result = on_error(error)
64+
if inspect.isawaitable(result):
65+
await result
66+
except Exception:
67+
return

tests/test_proxy.py

Lines changed: 298 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,298 @@
1+
from __future__ import annotations
2+
3+
from types import TracebackType
4+
5+
import anyio
6+
import pytest
7+
8+
from mcp.proxy import mcp_proxy
9+
from mcp.shared.message import SessionMessage
10+
from mcp.types import JSONRPCRequest
11+
12+
13+
def make_message(request_id: str, method: str) -> SessionMessage:
14+
return SessionMessage(JSONRPCRequest(jsonrpc="2.0", id=request_id, method=method, params={}))
15+
16+
17+
def assert_request(message: SessionMessage, request_id: str, method: str) -> None:
18+
assert isinstance(message.message, JSONRPCRequest)
19+
assert message.message.id == request_id
20+
assert message.message.method == method
21+
22+
23+
class StaticReadStream:
24+
def __init__(self, *items: SessionMessage | Exception, error: Exception | None = None) -> None:
25+
self._items = list(items)
26+
self._error = error
27+
self.closed = False
28+
29+
async def receive(self) -> SessionMessage | Exception:
30+
try:
31+
return await self.__anext__()
32+
except StopAsyncIteration as exc:
33+
raise anyio.EndOfStream from exc
34+
35+
async def aclose(self) -> None:
36+
self.closed = True
37+
38+
def __aiter__(self) -> StaticReadStream:
39+
return self
40+
41+
async def __anext__(self) -> SessionMessage | Exception:
42+
if self._items:
43+
return self._items.pop(0)
44+
if self._error is not None:
45+
error = self._error
46+
self._error = None
47+
raise error
48+
raise StopAsyncIteration
49+
50+
async def __aenter__(self) -> StaticReadStream:
51+
return self
52+
53+
async def __aexit__(
54+
self,
55+
exc_type: type[BaseException] | None,
56+
exc_val: BaseException | None,
57+
exc_tb: TracebackType | None,
58+
) -> bool | None:
59+
await self.aclose()
60+
return None
61+
62+
63+
class TrackingWriteStream:
64+
def __init__(self, error: Exception | None = None) -> None:
65+
self.items: list[SessionMessage] = []
66+
self.error = error
67+
self.closed = anyio.Event()
68+
69+
async def send(self, item: SessionMessage, /) -> None:
70+
if self.error is not None:
71+
raise self.error
72+
self.items.append(item)
73+
74+
async def aclose(self) -> None:
75+
self.closed.set()
76+
77+
async def __aenter__(self) -> TrackingWriteStream:
78+
return self
79+
80+
async def __aexit__(
81+
self,
82+
exc_type: type[BaseException] | None,
83+
exc_val: BaseException | None,
84+
exc_tb: TracebackType | None,
85+
) -> bool | None:
86+
await self.aclose()
87+
return None
88+
89+
90+
@pytest.mark.anyio
91+
async def test_proxy_forwards_messages_bidirectionally() -> None:
92+
client_read_send, client_read = anyio.create_memory_object_stream[SessionMessage | Exception](1)
93+
client_write, client_write_read = anyio.create_memory_object_stream[SessionMessage](1)
94+
server_read_send, server_read = anyio.create_memory_object_stream[SessionMessage | Exception](1)
95+
server_write, server_write_read = anyio.create_memory_object_stream[SessionMessage](1)
96+
97+
async with (
98+
client_read_send,
99+
client_read,
100+
client_write,
101+
client_write_read,
102+
server_read_send,
103+
server_read,
104+
server_write,
105+
server_write_read,
106+
):
107+
async with mcp_proxy((client_read, client_write), (server_read, server_write)):
108+
await client_read_send.send(make_message("client", "client/method"))
109+
await server_read_send.send(make_message("server", "server/method"))
110+
111+
assert_request(await server_write_read.receive(), "client", "client/method")
112+
assert_request(await client_write_read.receive(), "server", "server/method")
113+
114+
115+
@pytest.mark.anyio
116+
async def test_proxy_calls_sync_error_handler_and_continues() -> None:
117+
errors: list[Exception] = []
118+
handled = anyio.Event()
119+
120+
def on_error(error: Exception) -> None:
121+
errors.append(error)
122+
handled.set()
123+
124+
client_read_send, client_read = anyio.create_memory_object_stream[SessionMessage | Exception](1)
125+
client_write, _client_write_read = anyio.create_memory_object_stream[SessionMessage](1)
126+
server_read_send, server_read = anyio.create_memory_object_stream[SessionMessage | Exception](1)
127+
server_write, server_write_read = anyio.create_memory_object_stream[SessionMessage](1)
128+
129+
async with (
130+
client_read_send,
131+
client_read,
132+
client_write,
133+
_client_write_read,
134+
server_read_send,
135+
server_read,
136+
server_write,
137+
server_write_read,
138+
):
139+
async with mcp_proxy((client_read, client_write), (server_read, server_write), on_error=on_error):
140+
await client_read_send.send(ValueError("boom"))
141+
await handled.wait()
142+
await client_read_send.send(make_message("after-error", "client/method"))
143+
144+
assert len(errors) == 1
145+
assert isinstance(errors[0], ValueError)
146+
assert str(errors[0]) == "boom"
147+
assert_request(await server_write_read.receive(), "after-error", "client/method")
148+
149+
150+
@pytest.mark.anyio
151+
async def test_proxy_calls_async_error_handler() -> None:
152+
errors: list[Exception] = []
153+
handled = anyio.Event()
154+
155+
async def on_error(error: Exception) -> None:
156+
errors.append(error)
157+
handled.set()
158+
159+
client_read_send, client_read = anyio.create_memory_object_stream[SessionMessage | Exception](1)
160+
client_write, _client_write_read = anyio.create_memory_object_stream[SessionMessage](1)
161+
server_read_send, server_read = anyio.create_memory_object_stream[SessionMessage | Exception](1)
162+
server_write, _server_write_read = anyio.create_memory_object_stream[SessionMessage](1)
163+
164+
async with (
165+
client_read_send,
166+
client_read,
167+
client_write,
168+
_client_write_read,
169+
server_read_send,
170+
server_read,
171+
server_write,
172+
_server_write_read,
173+
):
174+
async with mcp_proxy((client_read, client_write), (server_read, server_write), on_error=on_error):
175+
await client_read_send.send(ValueError("async-boom"))
176+
await handled.wait()
177+
178+
assert len(errors) == 1
179+
assert isinstance(errors[0], ValueError)
180+
assert str(errors[0]) == "async-boom"
181+
182+
183+
@pytest.mark.anyio
184+
async def test_proxy_ignores_sync_error_handler_failures() -> None:
185+
def on_error(error: Exception) -> None:
186+
raise RuntimeError(f"handler failed for {error}")
187+
188+
client_read_send, client_read = anyio.create_memory_object_stream[SessionMessage | Exception](1)
189+
client_write, _client_write_read = anyio.create_memory_object_stream[SessionMessage](1)
190+
server_read_send, server_read = anyio.create_memory_object_stream[SessionMessage | Exception](1)
191+
server_write, server_write_read = anyio.create_memory_object_stream[SessionMessage](1)
192+
193+
async with (
194+
client_read_send,
195+
client_read,
196+
client_write,
197+
_client_write_read,
198+
server_read_send,
199+
server_read,
200+
server_write,
201+
server_write_read,
202+
):
203+
async with mcp_proxy((client_read, client_write), (server_read, server_write), on_error=on_error):
204+
await client_read_send.send(ValueError("boom"))
205+
await client_read_send.send(make_message("after-sync-handler-error", "client/method"))
206+
assert_request(await server_write_read.receive(), "after-sync-handler-error", "client/method")
207+
208+
209+
@pytest.mark.anyio
210+
async def test_proxy_ignores_async_error_handler_failures() -> None:
211+
async def on_error(error: Exception) -> None:
212+
raise RuntimeError(f"handler failed for {error}")
213+
214+
client_read_send, client_read = anyio.create_memory_object_stream[SessionMessage | Exception](1)
215+
client_write, _client_write_read = anyio.create_memory_object_stream[SessionMessage](1)
216+
server_read_send, server_read = anyio.create_memory_object_stream[SessionMessage | Exception](1)
217+
server_write, server_write_read = anyio.create_memory_object_stream[SessionMessage](1)
218+
219+
async with (
220+
client_read_send,
221+
client_read,
222+
client_write,
223+
_client_write_read,
224+
server_read_send,
225+
server_read,
226+
server_write,
227+
server_write_read,
228+
):
229+
async with mcp_proxy((client_read, client_write), (server_read, server_write), on_error=on_error):
230+
await client_read_send.send(ValueError("boom"))
231+
await client_read_send.send(make_message("after-async-handler-error", "client/method"))
232+
assert_request(await server_write_read.receive(), "after-async-handler-error", "client/method")
233+
234+
235+
@pytest.mark.anyio
236+
async def test_proxy_continues_without_error_handler() -> None:
237+
client_read_send, client_read = anyio.create_memory_object_stream[SessionMessage | Exception](1)
238+
client_write, _client_write_read = anyio.create_memory_object_stream[SessionMessage](1)
239+
server_read_send, server_read = anyio.create_memory_object_stream[SessionMessage | Exception](1)
240+
server_write, server_write_read = anyio.create_memory_object_stream[SessionMessage](1)
241+
242+
async with (
243+
client_read_send,
244+
client_read,
245+
client_write,
246+
_client_write_read,
247+
server_read_send,
248+
server_read,
249+
server_write,
250+
server_write_read,
251+
):
252+
async with mcp_proxy((client_read, client_write), (server_read, server_write)):
253+
await client_read_send.send(ValueError("boom"))
254+
await client_read_send.send(make_message("after-no-handler", "client/method"))
255+
assert_request(await server_write_read.receive(), "after-no-handler", "client/method")
256+
257+
258+
@pytest.mark.anyio
259+
async def test_proxy_stops_forwarding_when_target_stream_is_closed() -> None:
260+
server_write = TrackingWriteStream(anyio.ClosedResourceError())
261+
client_write = TrackingWriteStream()
262+
263+
async with mcp_proxy(
264+
(StaticReadStream(make_message("client", "client/method")), server_write),
265+
(StaticReadStream(), client_write),
266+
):
267+
await server_write.closed.wait()
268+
269+
assert server_write.items == []
270+
assert server_write.closed.is_set()
271+
assert client_write.closed.is_set()
272+
273+
274+
@pytest.mark.anyio
275+
async def test_proxy_closes_target_stream_when_source_stream_is_closed() -> None:
276+
server_write = TrackingWriteStream()
277+
client_write = TrackingWriteStream()
278+
279+
async with mcp_proxy((StaticReadStream(), server_write), (StaticReadStream(), client_write)):
280+
await server_write.closed.wait()
281+
await client_write.closed.wait()
282+
283+
assert server_write.items == []
284+
assert client_write.items == []
285+
286+
287+
@pytest.mark.anyio
288+
async def test_proxy_handles_closed_resource_error_from_source_stream() -> None:
289+
server_write = TrackingWriteStream()
290+
client_write = TrackingWriteStream()
291+
292+
async with mcp_proxy(
293+
(StaticReadStream(error=anyio.ClosedResourceError()), server_write),
294+
(StaticReadStream(), client_write),
295+
):
296+
await server_write.closed.wait()
297+
298+
assert server_write.items == []

0 commit comments

Comments
 (0)