Skip to content

Commit df56106

Browse files
committed
Fix proxy coverage regressions
1 parent 525440c commit df56106

File tree

1 file changed

+52
-1
lines changed

1 file changed

+52
-1
lines changed

tests/test_proxy.py

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import anyio
77
import pytest
88

9-
from mcp.proxy import _forward_message, mcp_proxy
9+
from mcp.proxy import _forward_message, _forward_messages, mcp_proxy
1010
from mcp.shared._context_streams import create_context_streams
1111
from mcp.shared.message import SessionMessage
1212
from mcp.types import JSONRPCRequest
@@ -117,6 +117,12 @@ async def __aexit__(
117117
return None
118118

119119

120+
class NestedException(Exception):
121+
def __init__(self, *exceptions: BaseException) -> None:
122+
super().__init__("nested")
123+
self.exceptions = exceptions
124+
125+
120126
def assert_contains_exception(exc: BaseException, expected_type: type[Exception], expected_message: str) -> None:
121127
nested_exceptions = getattr(exc, "exceptions", None)
122128
if nested_exceptions is not None:
@@ -132,6 +138,42 @@ def assert_contains_exception(exc: BaseException, expected_type: type[Exception]
132138
assert expected_message in str(exc)
133139

134140

141+
@pytest.mark.anyio
142+
async def test_static_read_stream_receive_raises_end_of_stream_when_exhausted() -> None:
143+
stream = StaticReadStream()
144+
145+
with pytest.raises(anyio.EndOfStream):
146+
await stream.receive()
147+
148+
149+
@pytest.mark.anyio
150+
async def test_tracking_write_stream_send_raises_configured_error() -> None:
151+
stream = TrackingWriteStream(RuntimeError("write boom"))
152+
153+
with pytest.raises(RuntimeError, match="write boom"):
154+
await stream.send(make_message("client", "client/method"))
155+
156+
157+
@pytest.mark.anyio
158+
async def test_read_stream_with_context_support_methods() -> None:
159+
stream = ReadStreamWithContext(contextvars.copy_context())
160+
161+
assert stream.__aiter__() is stream
162+
assert await stream.__aenter__() is stream
163+
assert await stream.aclose() is None
164+
assert await stream.__aexit__(None, None, None) is None
165+
166+
with pytest.raises(StopAsyncIteration):
167+
await stream.__anext__()
168+
169+
170+
def test_assert_contains_exception_reports_missing_nested_exception() -> None:
171+
exc = NestedException(ValueError("boom"))
172+
173+
with pytest.raises(AssertionError, match="Did not find RuntimeError containing 'missing'"):
174+
assert_contains_exception(exc, RuntimeError, "missing")
175+
176+
135177
@pytest.mark.anyio
136178
async def test_proxy_forwards_messages_bidirectionally() -> None:
137179
client_read_send, client_read = anyio.create_memory_object_stream[SessionMessage | Exception](1)
@@ -324,6 +366,15 @@ async def test_proxy_stops_forwarding_when_target_stream_is_closed() -> None:
324366
assert client_write.closed.is_set()
325367

326368

369+
@pytest.mark.anyio
370+
async def test_forward_messages_stops_on_closed_target_stream() -> None:
371+
await _forward_messages(
372+
StaticReadStream(make_message("client", "client/method")),
373+
TrackingWriteStream(anyio.ClosedResourceError()),
374+
on_error=None,
375+
)
376+
377+
327378
@pytest.mark.anyio
328379
async def test_proxy_closes_target_stream_when_source_stream_is_closed() -> None:
329380
server_write = TrackingWriteStream()

0 commit comments

Comments
 (0)