|
6 | 6 | """ |
7 | 7 |
|
8 | 8 | import logging |
9 | | -from collections.abc import Awaitable, Callable |
| 9 | +from collections.abc import AsyncGenerator, Awaitable, Callable |
10 | 10 | from contextlib import asynccontextmanager |
11 | | -from typing import AsyncGenerator |
12 | 11 |
|
13 | 12 | import anyio |
14 | 13 | from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream |
|
23 | 22 | ] |
24 | 23 |
|
25 | 24 |
|
| 25 | +async def _handle_error( |
| 26 | + error: Exception, |
| 27 | + onerror: Callable[[Exception], None | Awaitable[None]] | None, |
| 28 | +) -> None: |
| 29 | + """Handle an error by calling the error callback if provided.""" |
| 30 | + if onerror: |
| 31 | + try: |
| 32 | + result = onerror(error) |
| 33 | + if isinstance(result, Awaitable): |
| 34 | + await result |
| 35 | + except Exception as callback_error: # pragma: no cover |
| 36 | + logger.exception("Error in onerror callback", exc_info=callback_error) |
| 37 | + |
| 38 | + |
| 39 | +async def _forward_message( |
| 40 | + message: SessionMessage | Exception, |
| 41 | + write_stream: MemoryObjectSendStream[SessionMessage], |
| 42 | + onerror: Callable[[Exception], None | Awaitable[None]] | None, |
| 43 | + source: str, |
| 44 | +) -> None: |
| 45 | + """Forward a single message, handling exceptions appropriately.""" |
| 46 | + if isinstance(message, SessionMessage): |
| 47 | + await write_stream.send(message) |
| 48 | + elif isinstance(message, Exception): |
| 49 | + logger.debug(f"Exception received from {source}: {message}") |
| 50 | + await _handle_error(message, onerror) |
| 51 | + # Exceptions are not forwarded as messages (write streams only accept SessionMessage) |
| 52 | + |
| 53 | + |
| 54 | +async def _forward_loop( |
| 55 | + read_stream: MemoryObjectReceiveStream[SessionMessage | Exception], |
| 56 | + write_stream: MemoryObjectSendStream[SessionMessage], |
| 57 | + onerror: Callable[[Exception], None | Awaitable[None]] | None, |
| 58 | + source: str, |
| 59 | +) -> None: |
| 60 | + """Forward messages from read_stream to write_stream.""" |
| 61 | + try: |
| 62 | + async with read_stream: |
| 63 | + async for message in read_stream: |
| 64 | + try: |
| 65 | + await _forward_message(message, write_stream, onerror, source) |
| 66 | + except anyio.ClosedResourceError: |
| 67 | + logger.debug(f"{source} write stream closed") |
| 68 | + break |
| 69 | + except Exception as exc: |
| 70 | + logger.exception(f"Error forwarding message from {source}", exc_info=exc) |
| 71 | + await _handle_error(exc, onerror) |
| 72 | + except anyio.ClosedResourceError: |
| 73 | + logger.debug(f"{source} read stream closed") |
| 74 | + except Exception as exc: |
| 75 | + logger.exception(f"Error in forward loop from {source}", exc_info=exc) |
| 76 | + await _handle_error(exc, onerror) |
| 77 | + finally: |
| 78 | + # Close write stream when read stream closes |
| 79 | + try: |
| 80 | + await write_stream.aclose() |
| 81 | + except Exception: # pragma: no cover |
| 82 | + # Stream might already be closed |
| 83 | + pass |
| 84 | + |
| 85 | + |
26 | 86 | @asynccontextmanager |
27 | 87 | async def mcp_proxy( |
28 | 88 | transport_to_client: MessageStream, |
@@ -60,111 +120,9 @@ async def mcp_proxy( |
60 | 120 | client_read, client_write = transport_to_client |
61 | 121 | server_read, server_write = transport_to_server |
62 | 122 |
|
63 | | - async def forward_to_server(): |
64 | | - """Forward messages from client to server.""" |
65 | | - try: |
66 | | - async with client_read: |
67 | | - async for message in client_read: |
68 | | - try: |
69 | | - # Forward SessionMessage objects directly |
70 | | - if isinstance(message, SessionMessage): |
71 | | - await server_write.send(message) |
72 | | - # Handle Exception objects via error callback |
73 | | - elif isinstance(message, Exception): |
74 | | - logger.debug(f"Exception received from client: {message}") |
75 | | - if onerror: |
76 | | - try: |
77 | | - result = onerror(message) |
78 | | - if isinstance(result, Awaitable): |
79 | | - await result |
80 | | - except Exception as callback_error: # pragma: no cover |
81 | | - logger.exception("Error in onerror callback", exc_info=callback_error) |
82 | | - # Exceptions are not forwarded as messages (write streams only accept SessionMessage) |
83 | | - except anyio.ClosedResourceError: |
84 | | - logger.debug("Server write stream closed while forwarding from client") |
85 | | - break |
86 | | - except Exception as exc: # pragma: no cover |
87 | | - logger.exception("Error forwarding message from client to server", exc_info=exc) |
88 | | - if onerror: |
89 | | - try: |
90 | | - result = onerror(exc) |
91 | | - if isinstance(result, Awaitable): |
92 | | - await result |
93 | | - except Exception as callback_error: # pragma: no cover |
94 | | - logger.exception("Error in onerror callback", exc_info=callback_error) |
95 | | - except anyio.ClosedResourceError: |
96 | | - logger.debug("Client read stream closed") |
97 | | - except Exception as exc: # pragma: no cover |
98 | | - logger.exception("Error in forward_to_server task", exc_info=exc) |
99 | | - if onerror: |
100 | | - try: |
101 | | - result = onerror(exc) |
102 | | - if isinstance(result, Awaitable): |
103 | | - await result |
104 | | - except Exception as callback_error: # pragma: no cover |
105 | | - logger.exception("Error in onerror callback", exc_info=callback_error) |
106 | | - finally: |
107 | | - # Close server write stream when client read closes |
108 | | - try: |
109 | | - await server_write.aclose() |
110 | | - except Exception: # pragma: no cover |
111 | | - # Stream might already be closed |
112 | | - pass |
113 | | - |
114 | | - async def forward_to_client(): |
115 | | - """Forward messages from server to client.""" |
116 | | - try: |
117 | | - async with server_read: |
118 | | - async for message in server_read: |
119 | | - try: |
120 | | - # Forward SessionMessage objects directly |
121 | | - if isinstance(message, SessionMessage): |
122 | | - await client_write.send(message) |
123 | | - # Handle Exception objects via error callback |
124 | | - elif isinstance(message, Exception): |
125 | | - logger.debug(f"Exception received from server: {message}") |
126 | | - if onerror: |
127 | | - try: |
128 | | - result = onerror(message) |
129 | | - if isinstance(result, Awaitable): |
130 | | - await result |
131 | | - except Exception as callback_error: # pragma: no cover |
132 | | - logger.exception("Error in onerror callback", exc_info=callback_error) |
133 | | - # Exceptions are not forwarded as messages (write streams only accept SessionMessage) |
134 | | - except anyio.ClosedResourceError: |
135 | | - logger.debug("Client write stream closed while forwarding from server") |
136 | | - break |
137 | | - except Exception as exc: # pragma: no cover |
138 | | - logger.exception("Error forwarding message from server to client", exc_info=exc) |
139 | | - if onerror: |
140 | | - try: |
141 | | - result = onerror(exc) |
142 | | - if isinstance(result, Awaitable): |
143 | | - await result |
144 | | - except Exception as callback_error: # pragma: no cover |
145 | | - logger.exception("Error in onerror callback", exc_info=callback_error) |
146 | | - except anyio.ClosedResourceError: |
147 | | - logger.debug("Server read stream closed") |
148 | | - except Exception as exc: # pragma: no cover |
149 | | - logger.exception("Error in forward_to_client task", exc_info=exc) |
150 | | - if onerror: |
151 | | - try: |
152 | | - result = onerror(exc) |
153 | | - if isinstance(result, Awaitable): |
154 | | - await result |
155 | | - except Exception as callback_error: # pragma: no cover |
156 | | - logger.exception("Error in onerror callback", exc_info=callback_error) |
157 | | - finally: |
158 | | - # Close client write stream when server read closes |
159 | | - try: |
160 | | - await client_write.aclose() |
161 | | - except Exception: # pragma: no cover |
162 | | - # Stream might already be closed |
163 | | - pass |
164 | | - |
165 | 123 | async with anyio.create_task_group() as tg: |
166 | | - tg.start_soon(forward_to_server) |
167 | | - tg.start_soon(forward_to_client) |
| 124 | + tg.start_soon(_forward_loop, client_read, server_write, onerror, "client") |
| 125 | + tg.start_soon(_forward_loop, server_read, client_write, onerror, "server") |
168 | 126 | try: |
169 | 127 | yield |
170 | 128 | finally: |
|
0 commit comments