Skip to content

Commit e1cff6c

Browse files
committed
fix: refactor proxy to reduce complexity and improve coverage
- Extract error handling into _handle_error helper function - Extract message forwarding into _forward_message helper function - Extract forwarding loop into _forward_loop helper function - Add tests for error callback exceptions (sync and async) - Reduces cyclomatic complexity from 39 to below 24 - Reduces statement count from 113 to below 102 - Improves test coverage to meet 100% requirement
1 parent 0357258 commit e1cff6c

File tree

2 files changed

+133
-106
lines changed

2 files changed

+133
-106
lines changed

src/mcp/shared/proxy.py

Lines changed: 64 additions & 106 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,8 @@
66
"""
77

88
import logging
9-
from collections.abc import Awaitable, Callable
9+
from collections.abc import AsyncGenerator, Awaitable, Callable
1010
from contextlib import asynccontextmanager
11-
from typing import AsyncGenerator
1211

1312
import anyio
1413
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
@@ -23,6 +22,67 @@
2322
]
2423

2524

25+
async def _handle_error(
26+
error: Exception,
27+
onerror: Callable[[Exception], None | Awaitable[None]] | None,
28+
) -> None:
29+
"""Handle an error by calling the error callback if provided."""
30+
if onerror:
31+
try:
32+
result = onerror(error)
33+
if isinstance(result, Awaitable):
34+
await result
35+
except Exception as callback_error: # pragma: no cover
36+
logger.exception("Error in onerror callback", exc_info=callback_error)
37+
38+
39+
async def _forward_message(
40+
message: SessionMessage | Exception,
41+
write_stream: MemoryObjectSendStream[SessionMessage],
42+
onerror: Callable[[Exception], None | Awaitable[None]] | None,
43+
source: str,
44+
) -> None:
45+
"""Forward a single message, handling exceptions appropriately."""
46+
if isinstance(message, SessionMessage):
47+
await write_stream.send(message)
48+
elif isinstance(message, Exception):
49+
logger.debug(f"Exception received from {source}: {message}")
50+
await _handle_error(message, onerror)
51+
# Exceptions are not forwarded as messages (write streams only accept SessionMessage)
52+
53+
54+
async def _forward_loop(
55+
read_stream: MemoryObjectReceiveStream[SessionMessage | Exception],
56+
write_stream: MemoryObjectSendStream[SessionMessage],
57+
onerror: Callable[[Exception], None | Awaitable[None]] | None,
58+
source: str,
59+
) -> None:
60+
"""Forward messages from read_stream to write_stream."""
61+
try:
62+
async with read_stream:
63+
async for message in read_stream:
64+
try:
65+
await _forward_message(message, write_stream, onerror, source)
66+
except anyio.ClosedResourceError:
67+
logger.debug(f"{source} write stream closed")
68+
break
69+
except Exception as exc:
70+
logger.exception(f"Error forwarding message from {source}", exc_info=exc)
71+
await _handle_error(exc, onerror)
72+
except anyio.ClosedResourceError:
73+
logger.debug(f"{source} read stream closed")
74+
except Exception as exc:
75+
logger.exception(f"Error in forward loop from {source}", exc_info=exc)
76+
await _handle_error(exc, onerror)
77+
finally:
78+
# Close write stream when read stream closes
79+
try:
80+
await write_stream.aclose()
81+
except Exception: # pragma: no cover
82+
# Stream might already be closed
83+
pass
84+
85+
2686
@asynccontextmanager
2787
async def mcp_proxy(
2888
transport_to_client: MessageStream,
@@ -60,111 +120,9 @@ async def mcp_proxy(
60120
client_read, client_write = transport_to_client
61121
server_read, server_write = transport_to_server
62122

63-
async def forward_to_server():
64-
"""Forward messages from client to server."""
65-
try:
66-
async with client_read:
67-
async for message in client_read:
68-
try:
69-
# Forward SessionMessage objects directly
70-
if isinstance(message, SessionMessage):
71-
await server_write.send(message)
72-
# Handle Exception objects via error callback
73-
elif isinstance(message, Exception):
74-
logger.debug(f"Exception received from client: {message}")
75-
if onerror:
76-
try:
77-
result = onerror(message)
78-
if isinstance(result, Awaitable):
79-
await result
80-
except Exception as callback_error: # pragma: no cover
81-
logger.exception("Error in onerror callback", exc_info=callback_error)
82-
# Exceptions are not forwarded as messages (write streams only accept SessionMessage)
83-
except anyio.ClosedResourceError:
84-
logger.debug("Server write stream closed while forwarding from client")
85-
break
86-
except Exception as exc: # pragma: no cover
87-
logger.exception("Error forwarding message from client to server", exc_info=exc)
88-
if onerror:
89-
try:
90-
result = onerror(exc)
91-
if isinstance(result, Awaitable):
92-
await result
93-
except Exception as callback_error: # pragma: no cover
94-
logger.exception("Error in onerror callback", exc_info=callback_error)
95-
except anyio.ClosedResourceError:
96-
logger.debug("Client read stream closed")
97-
except Exception as exc: # pragma: no cover
98-
logger.exception("Error in forward_to_server task", exc_info=exc)
99-
if onerror:
100-
try:
101-
result = onerror(exc)
102-
if isinstance(result, Awaitable):
103-
await result
104-
except Exception as callback_error: # pragma: no cover
105-
logger.exception("Error in onerror callback", exc_info=callback_error)
106-
finally:
107-
# Close server write stream when client read closes
108-
try:
109-
await server_write.aclose()
110-
except Exception: # pragma: no cover
111-
# Stream might already be closed
112-
pass
113-
114-
async def forward_to_client():
115-
"""Forward messages from server to client."""
116-
try:
117-
async with server_read:
118-
async for message in server_read:
119-
try:
120-
# Forward SessionMessage objects directly
121-
if isinstance(message, SessionMessage):
122-
await client_write.send(message)
123-
# Handle Exception objects via error callback
124-
elif isinstance(message, Exception):
125-
logger.debug(f"Exception received from server: {message}")
126-
if onerror:
127-
try:
128-
result = onerror(message)
129-
if isinstance(result, Awaitable):
130-
await result
131-
except Exception as callback_error: # pragma: no cover
132-
logger.exception("Error in onerror callback", exc_info=callback_error)
133-
# Exceptions are not forwarded as messages (write streams only accept SessionMessage)
134-
except anyio.ClosedResourceError:
135-
logger.debug("Client write stream closed while forwarding from server")
136-
break
137-
except Exception as exc: # pragma: no cover
138-
logger.exception("Error forwarding message from server to client", exc_info=exc)
139-
if onerror:
140-
try:
141-
result = onerror(exc)
142-
if isinstance(result, Awaitable):
143-
await result
144-
except Exception as callback_error: # pragma: no cover
145-
logger.exception("Error in onerror callback", exc_info=callback_error)
146-
except anyio.ClosedResourceError:
147-
logger.debug("Server read stream closed")
148-
except Exception as exc: # pragma: no cover
149-
logger.exception("Error in forward_to_client task", exc_info=exc)
150-
if onerror:
151-
try:
152-
result = onerror(exc)
153-
if isinstance(result, Awaitable):
154-
await result
155-
except Exception as callback_error: # pragma: no cover
156-
logger.exception("Error in onerror callback", exc_info=callback_error)
157-
finally:
158-
# Close client write stream when server read closes
159-
try:
160-
await client_write.aclose()
161-
except Exception: # pragma: no cover
162-
# Stream might already be closed
163-
pass
164-
165123
async with anyio.create_task_group() as tg:
166-
tg.start_soon(forward_to_server)
167-
tg.start_soon(forward_to_client)
124+
tg.start_soon(_forward_loop, client_read, server_write, onerror, "client")
125+
tg.start_soon(_forward_loop, server_read, client_write, onerror, "server")
168126
try:
169127
yield
170128
finally:

tests/shared/test_proxy.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -366,3 +366,72 @@ async def test_proxy_closes_other_stream_on_close(create_streams):
366366
# Clean up test streams
367367
await client_read_writer.aclose()
368368
await server_write_reader.aclose()
369+
370+
371+
@pytest.mark.anyio
372+
async def test_proxy_error_in_callback(create_streams):
373+
"""Test that errors in the error callback are handled gracefully."""
374+
client_streams, server_streams, (client_read_writer, _), (_, server_write_reader) = create_streams()
375+
376+
try:
377+
def failing_error_handler(error: Exception) -> None:
378+
"""Error handler that raises an exception."""
379+
raise RuntimeError("Callback error")
380+
381+
# Send an exception through the stream
382+
test_exception = ValueError("Test error")
383+
384+
async with mcp_proxy(client_streams, server_streams, onerror=failing_error_handler):
385+
await client_read_writer.send(test_exception)
386+
387+
# Give it time to process
388+
await anyio.sleep(0.1)
389+
390+
# Proxy should continue working despite callback error
391+
request = JSONRPCRequest(jsonrpc="2.0", id="after_callback_error", method="test", params={})
392+
message = SessionMessage(JSONRPCMessage(request))
393+
await client_read_writer.send(message)
394+
395+
# Valid message should still be forwarded
396+
with anyio.fail_after(1):
397+
received = await server_write_reader.receive()
398+
assert received.message.root.id == "after_callback_error"
399+
finally:
400+
# Clean up test streams
401+
await client_read_writer.aclose()
402+
await server_write_reader.aclose()
403+
404+
405+
@pytest.mark.anyio
406+
async def test_proxy_async_error_in_callback(create_streams):
407+
"""Test that async errors in the error callback are handled gracefully."""
408+
client_streams, server_streams, (client_read_writer, _), (_, server_write_reader) = create_streams()
409+
410+
try:
411+
async def failing_async_error_handler(error: Exception) -> None:
412+
"""Async error handler that raises an exception."""
413+
await anyio.sleep(0.01)
414+
raise RuntimeError("Async callback error")
415+
416+
# Send an exception through the stream
417+
test_exception = ValueError("Test error")
418+
419+
async with mcp_proxy(client_streams, server_streams, onerror=failing_async_error_handler):
420+
await client_read_writer.send(test_exception)
421+
422+
# Give it time to process
423+
await anyio.sleep(0.1)
424+
425+
# Proxy should continue working despite callback error
426+
request = JSONRPCRequest(jsonrpc="2.0", id="after_async_callback_error", method="test", params={})
427+
message = SessionMessage(JSONRPCMessage(request))
428+
await client_read_writer.send(message)
429+
430+
# Valid message should still be forwarded
431+
with anyio.fail_after(1):
432+
received = await server_write_reader.receive()
433+
assert received.message.root.id == "after_async_callback_error"
434+
finally:
435+
# Clean up test streams
436+
await client_read_writer.aclose()
437+
await server_write_reader.aclose()

0 commit comments

Comments
 (0)