66import anyio
77import pytest
88
9- from mcp .proxy import _forward_message , mcp_proxy
9+ from mcp .proxy import _forward_message , _forward_messages , mcp_proxy
1010from mcp .shared ._context_streams import create_context_streams
1111from mcp .shared .message import SessionMessage
1212from mcp .types import JSONRPCRequest
@@ -117,6 +117,12 @@ async def __aexit__(
117117 return None
118118
119119
120+ class NestedException (Exception ):
121+ def __init__ (self , * exceptions : BaseException ) -> None :
122+ super ().__init__ ("nested" )
123+ self .exceptions = exceptions
124+
125+
120126def assert_contains_exception (exc : BaseException , expected_type : type [Exception ], expected_message : str ) -> None :
121127 nested_exceptions = getattr (exc , "exceptions" , None )
122128 if nested_exceptions is not None :
@@ -132,6 +138,42 @@ def assert_contains_exception(exc: BaseException, expected_type: type[Exception]
132138 assert expected_message in str (exc )
133139
134140
141+ @pytest .mark .anyio
142+ async def test_static_read_stream_receive_raises_end_of_stream_when_exhausted () -> None :
143+ stream = StaticReadStream ()
144+
145+ with pytest .raises (anyio .EndOfStream ):
146+ await stream .receive ()
147+
148+
149+ @pytest .mark .anyio
150+ async def test_tracking_write_stream_send_raises_configured_error () -> None :
151+ stream = TrackingWriteStream (RuntimeError ("write boom" ))
152+
153+ with pytest .raises (RuntimeError , match = "write boom" ):
154+ await stream .send (make_message ("client" , "client/method" ))
155+
156+
157+ @pytest .mark .anyio
158+ async def test_read_stream_with_context_support_methods () -> None :
159+ stream = ReadStreamWithContext (contextvars .copy_context ())
160+
161+ assert stream .__aiter__ () is stream
162+ assert await stream .__aenter__ () is stream
163+ assert await stream .aclose () is None
164+ assert await stream .__aexit__ (None , None , None ) is None
165+
166+ with pytest .raises (StopAsyncIteration ):
167+ await stream .__anext__ ()
168+
169+
170+ def test_assert_contains_exception_reports_missing_nested_exception () -> None :
171+ exc = NestedException (ValueError ("boom" ))
172+
173+ with pytest .raises (AssertionError , match = "Did not find RuntimeError containing 'missing'" ):
174+ assert_contains_exception (exc , RuntimeError , "missing" )
175+
176+
135177@pytest .mark .anyio
136178async def test_proxy_forwards_messages_bidirectionally () -> None :
137179 client_read_send , client_read = anyio .create_memory_object_stream [SessionMessage | Exception ](1 )
@@ -324,6 +366,15 @@ async def test_proxy_stops_forwarding_when_target_stream_is_closed() -> None:
324366 assert client_write .closed .is_set ()
325367
326368
369+ @pytest .mark .anyio
370+ async def test_forward_messages_stops_on_closed_target_stream () -> None :
371+ await _forward_messages (
372+ StaticReadStream (make_message ("client" , "client/method" )),
373+ TrackingWriteStream (anyio .ClosedResourceError ()),
374+ on_error = None ,
375+ )
376+
377+
327378@pytest .mark .anyio
328379async def test_proxy_closes_target_stream_when_source_stream_is_closed () -> None :
329380 server_write = TrackingWriteStream ()
0 commit comments