Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 141 additions & 0 deletions docs/proxying.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
# Proxying MCP Transports

The `mcp_proxy()` helper bridges two MCP transports and forwards messages in both directions.

It is useful when you want to put a transport boundary between an MCP client and an upstream MCP server without
rewriting the forwarding loop yourself.

## What It Does

`mcp_proxy()` takes two transport pairs:

- a transport facing the downstream client
- a transport facing the upstream server

While the context manager is active, it:

- forwards `SessionMessage` objects from client to server
- forwards `SessionMessage` objects from server to client
- sends transport exceptions to an optional `on_error` callback
- closes the paired write side when the corresponding read side stops

## What It Does Not Do

`mcp_proxy()` is a transport relay, not a full proxy server.

It does not add:

- authentication
- authorization
- request or response rewriting
- routing across multiple upstream servers
- retries or buffering policies
- metrics or tracing by default

If you need those behaviors, build them around the helper.

## Weather Service Example

This example proxies a small weather service. The upstream service is defined with `MCPServer` and exposed over
streamable HTTP. The proxy bridges a downstream transport to that upstream transport.

- `get_weather(city)` for a structured weather snapshot
- `get_weather_alerts(region)` for active alerts

The client talks only to the downstream side of the proxy.

```python
import anyio
import uvicorn

from mcp.client.session import ClientSession
from mcp.client.streamable_http import streamable_http_client
from mcp.proxy import mcp_proxy
from mcp.server.mcpserver import MCPServer
from mcp.shared.memory import create_client_server_memory_streams


app = MCPServer("Weather Service")


@app.tool()
def get_weather(city: str) -> dict[str, str | float]:
return {
"city": city,
"temperature_c": 22.5,
"condition": "partly cloudy",
"wind_speed_kmh": 12.3,
}


@app.tool()
def get_weather_alerts(region: str) -> dict[str, object]:
return {
"region": region,
"alerts": [{"severity": "medium", "title": "Heat advisory"}],
}


async def main() -> None:
starlette_app = app.streamable_http_app(streamable_http_path="/mcp")
config = uvicorn.Config(starlette_app, host="127.0.0.1", port=8765, log_level="warning")
upstream_server = uvicorn.Server(config)

async with (
create_client_server_memory_streams() as (client_streams, proxy_client_streams),
streamable_http_client("http://127.0.0.1:8765/mcp") as proxy_server_streams,
anyio.create_task_group() as tg,
):
tg.start_soon(upstream_server.serve)

async with mcp_proxy(
proxy_client_streams,
proxy_server_streams,
):
async with ClientSession(client_streams[0], client_streams[1]) as session:
await session.initialize()
weather = await session.call_tool("get_weather", {"city": "London"})
alerts = await session.call_tool("get_weather_alerts", {"region": "California"})

print(weather.content[0].text)
print(alerts.content[0].text)

upstream_server.should_exit = True
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why did you add this?

tg.cancel_scope.cancel()


anyio.run(main)
```

## Error Handling

Use `on_error` to observe transport-level exceptions:

```python
async with mcp_proxy(
downstream_transport,
upstream_transport,
on_error=handle_transport_error,
):
...
```

`on_error` is keyword-only. It may be either:

- an async callable
- a sync callable, which will run in a worker thread

Exceptions raised by `on_error` are swallowed. Transport exceptions still terminate the proxy instead of being silently
consumed.

## When To Use It

`mcp_proxy()` is a good fit when you are:

- exposing an upstream MCP server through a different transport boundary
- inserting middleware-like behavior between two MCP transports
- building a local relay for testing or development
- experimenting with transport adapters

If all you need is to test a server directly, prefer [`Client`](testing.md), which already provides an in-memory
transport for that use case.
1 change: 1 addition & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ nav:
- Documentation:
- Concepts: concepts.md
- Low-Level Server: low-level-server.md
- Proxying Transports: proxying.md
- Authorization: authorization.md
- Testing: testing.md
- Experimental:
Expand Down
2 changes: 2 additions & 0 deletions src/mcp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from .client.session import ClientSession
from .client.session_group import ClientSessionGroup
from .client.stdio import StdioServerParameters, stdio_client
from .proxy import mcp_proxy
from .server.session import ServerSession
from .server.stdio import stdio_server
from .shared.exceptions import MCPError, UrlElicitationRequiredError
Expand Down Expand Up @@ -97,6 +98,7 @@
"LoggingLevel",
"LoggingMessageNotification",
"MCPError",
"mcp_proxy",
"Notification",
"PingRequest",
"ProgressNotification",
Expand Down
99 changes: 99 additions & 0 deletions src/mcp/proxy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
"""Provide utilities for proxying messages between two MCP transports."""

from __future__ import annotations

import contextvars
from collections.abc import AsyncGenerator, Awaitable, Callable
from contextlib import asynccontextmanager
from functools import partial
from typing import Any, Protocol, cast

import anyio
from anyio import to_thread

from mcp.shared._callable_inspection import is_async_callable
from mcp.shared._stream_protocols import ReadStream, WriteStream
from mcp.shared.message import SessionMessage

MessageStream = tuple[ReadStream[SessionMessage | Exception], WriteStream[SessionMessage]]
ErrorHandler = Callable[[Exception], None | Awaitable[None]]


class ContextualWriteStream(Protocol):
async def send_with_context(self, context: contextvars.Context, item: SessionMessage | Exception) -> None: ...


@asynccontextmanager
async def mcp_proxy(
transport_to_client: MessageStream,
transport_to_server: MessageStream,
*,
on_error: ErrorHandler | None = None,
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
on_error: ErrorHandler | None = None,
*,
on_error: ErrorHandler | None = None,

) -> AsyncGenerator[None]:
"""Proxy messages bidirectionally between two MCP transports."""
client_read, client_write = transport_to_client
server_read, server_write = transport_to_server

async with anyio.create_task_group() as task_group:
task_group.start_soon(_forward_messages, client_read, server_write, on_error)
task_group.start_soon(_forward_messages, server_read, client_write, on_error)
try:
yield
finally:
task_group.cancel_scope.cancel()


async def _forward_messages(
read_stream: ReadStream[SessionMessage | Exception],
write_stream: WriteStream[SessionMessage],
on_error: ErrorHandler | None,
) -> None:
try:
async with write_stream:
async with read_stream:
async for item in read_stream:
if isinstance(item, Exception):
await _run_error_handler(item, on_error)
raise item

try:
await _forward_message(item, write_stream, read_stream)
except anyio.ClosedResourceError:
break
except anyio.ClosedResourceError:
return


async def _forward_message(
item: SessionMessage,
write_stream: WriteStream[SessionMessage],
read_stream: ReadStream[SessionMessage | Exception],
) -> None:
sender_context: contextvars.Context | None = getattr(read_stream, "last_context", None)
context_write_stream = cast(ContextualWriteStream | None, _get_contextual_write_stream(write_stream))

if sender_context is not None and context_write_stream is not None:
await context_write_stream.send_with_context(sender_context, item)
return

await write_stream.send(item)


def _get_contextual_write_stream(write_stream: WriteStream[SessionMessage]) -> Any:
send_with_context = getattr(write_stream, "send_with_context", None)
if callable(send_with_context):
return write_stream
return None


async def _run_error_handler(error: Exception, on_error: ErrorHandler | None) -> None:
if on_error is None:
return

try:
if is_async_callable(on_error):
await cast(Awaitable[None], on_error(error))
else:
await to_thread.run_sync(partial(on_error, error))
except Exception:
return
14 changes: 2 additions & 12 deletions src/mcp/server/mcpserver/tools/base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from __future__ import annotations

import functools
import inspect
from collections.abc import Callable
from functools import cached_property
from typing import TYPE_CHECKING, Any
Expand All @@ -11,6 +9,7 @@
from mcp.server.mcpserver.exceptions import ToolError
from mcp.server.mcpserver.utilities.context_injection import find_context_parameter
from mcp.server.mcpserver.utilities.func_metadata import FuncMetadata, func_metadata
from mcp.shared._callable_inspection import is_async_callable
from mcp.shared.exceptions import UrlElicitationRequiredError
from mcp.shared.tool_name_validation import validate_and_warn_tool_name
from mcp.types import Icon, ToolAnnotations
Expand Down Expand Up @@ -63,7 +62,7 @@ def from_function(
raise ValueError("You must provide a name for lambda functions")

func_doc = description or fn.__doc__ or ""
is_async = _is_async_callable(fn)
is_async = is_async_callable(fn)

if context_kwarg is None: # pragma: no branch
context_kwarg = find_context_parameter(fn)
Expand Down Expand Up @@ -118,12 +117,3 @@ async def run(
raise
except Exception as e:
raise ToolError(f"Error executing tool {self.name}: {e}") from e


def _is_async_callable(obj: Any) -> bool:
while isinstance(obj, functools.partial): # pragma: lax no cover
obj = obj.func

return inspect.iscoroutinefunction(obj) or (
callable(obj) and inspect.iscoroutinefunction(getattr(obj, "__call__", None))
)
14 changes: 14 additions & 0 deletions src/mcp/shared/_callable_inspection.py
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't want a new file! Isn't this implemented somewhere in this repo already?

Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from __future__ import annotations

import functools
import inspect
from typing import Any


def is_async_callable(obj: Any) -> bool:
while isinstance(obj, functools.partial): # pragma: lax no cover
obj = obj.func

return inspect.iscoroutinefunction(obj) or (
callable(obj) and inspect.iscoroutinefunction(getattr(obj, "__call__", None))
)
3 changes: 3 additions & 0 deletions src/mcp/shared/_context_streams.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ def __init__(self, inner: MemoryObjectSendStream[_Envelope[T]]) -> None:
async def send(self, item: T) -> None:
await self._inner.send((contextvars.copy_context(), item))

async def send_with_context(self, context: contextvars.Context, item: T) -> None:
await self._inner.send((context, item))

def close(self) -> None:
self._inner.close()

Expand Down
Loading
Loading