Skip to content

Commit a2deae1

Browse files
committed
Update test_proxy.py
1 parent e5bdd4c commit a2deae1

File tree

1 file changed

+32
-33
lines changed

1 file changed

+32
-33
lines changed

tests/shared/test_proxy.py

Lines changed: 32 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
StreamPair = tuple[ReadStream, WriteStream]
1818
WriterReaderPair = tuple[MemoryObjectSendStream[SessionMessage | Exception], MemoryObjectReceiveStream[SessionMessage]]
1919
StreamsFixtureReturn = tuple[StreamPair, StreamPair, WriterReaderPair, WriterReaderPair]
20+
CreateStreamsFixture = Callable[[], StreamsFixtureReturn]
2021

2122

2223
@pytest.fixture
@@ -63,7 +64,7 @@ def _create() -> StreamsFixtureReturn:
6364

6465

6566
@pytest.mark.anyio
66-
async def test_proxy_forwards_client_to_server(create_streams):
67+
async def test_proxy_forwards_client_to_server(create_streams: CreateStreamsFixture) -> None:
6768
"""Test that messages from client are forwarded to server."""
6869
client_streams, server_streams, (client_read_writer, _), (_, server_write_reader) = create_streams()
6970

@@ -79,16 +80,16 @@ async def test_proxy_forwards_client_to_server(create_streams):
7980
# Verify it arrives at server
8081
with anyio.fail_after(1):
8182
received = await server_write_reader.receive()
82-
assert received.message.root.id == "1"
83-
assert received.message.root.method == "test_method"
83+
assert received.message.root.id == "1" # type: ignore[attr-defined]
84+
assert received.message.root.method == "test_method" # type: ignore[attr-defined]
8485
finally:
8586
# Clean up test streams
8687
await client_read_writer.aclose()
8788
await server_write_reader.aclose()
8889

8990

9091
@pytest.mark.anyio
91-
async def test_proxy_forwards_server_to_client(create_streams):
92+
async def test_proxy_forwards_server_to_client(create_streams: CreateStreamsFixture) -> None:
9293
"""Test that messages from server are forwarded to client."""
9394
client_streams, server_streams, (_, client_write_reader), (server_read_writer, _) = create_streams()
9495

@@ -104,16 +105,16 @@ async def test_proxy_forwards_server_to_client(create_streams):
104105
# Verify it arrives at client
105106
with anyio.fail_after(1):
106107
received = await client_write_reader.receive()
107-
assert received.message.root.id == "2"
108-
assert received.message.root.method == "server_method"
108+
assert received.message.root.id == "2" # type: ignore[attr-defined]
109+
assert received.message.root.method == "server_method" # type: ignore[attr-defined]
109110
finally:
110111
# Clean up test streams
111112
await server_read_writer.aclose()
112113
await client_write_reader.aclose()
113114

114115

115116
@pytest.mark.anyio
116-
async def test_proxy_bidirectional_forwarding(create_streams):
117+
async def test_proxy_bidirectional_forwarding(create_streams: CreateStreamsFixture) -> None:
117118
"""Test that proxy forwards messages in both directions simultaneously."""
118119
(
119120
client_streams,
@@ -146,11 +147,11 @@ async def test_proxy_bidirectional_forwarding(create_streams):
146147
with anyio.fail_after(1):
147148
# Client message should arrive at server
148149
received_at_server = await server_write_reader.receive()
149-
assert received_at_server.message.root.id == "client_1"
150+
assert received_at_server.message.root.id == "client_1" # type: ignore[attr-defined]
150151

151152
# Server message should arrive at client
152153
received_at_client = await client_write_reader.receive()
153-
assert received_at_client.message.root.id == "server_1"
154+
assert received_at_client.message.root.id == "server_1" # type: ignore[attr-defined]
154155
finally:
155156
# Clean up ALL 8 streams
156157
await client_read_writer.aclose()
@@ -164,12 +165,12 @@ async def test_proxy_bidirectional_forwarding(create_streams):
164165

165166

166167
@pytest.mark.anyio
167-
async def test_proxy_error_handling(create_streams):
168+
async def test_proxy_error_handling(create_streams: CreateStreamsFixture) -> None:
168169
"""Test that errors are caught and onerror callback is invoked."""
169170
client_streams, server_streams, (client_read_writer, _), (_, server_write_reader) = create_streams()
170171

171172
try:
172-
errors = []
173+
errors: list[Exception] = []
173174

174175
def error_handler(error: Exception) -> None:
175176
"""Collect errors."""
@@ -195,12 +196,12 @@ def error_handler(error: Exception) -> None:
195196

196197

197198
@pytest.mark.anyio
198-
async def test_proxy_async_error_handler(create_streams):
199+
async def test_proxy_async_error_handler(create_streams: CreateStreamsFixture) -> None:
199200
"""Test that async error handlers work."""
200201
client_streams, server_streams, (client_read_writer, _), (_, server_write_reader) = create_streams()
201202

202203
try:
203-
errors = []
204+
errors: list[Exception] = []
204205

205206
async def async_error_handler(error: Exception) -> None:
206207
"""Collect errors asynchronously."""
@@ -226,12 +227,12 @@ async def async_error_handler(error: Exception) -> None:
226227

227228

228229
@pytest.mark.anyio
229-
async def test_proxy_continues_after_error(create_streams):
230+
async def test_proxy_continues_after_error(create_streams: CreateStreamsFixture) -> None:
230231
"""Test that proxy continues forwarding after an error."""
231232
client_streams, server_streams, (client_read_writer, _), (_, server_write_reader) = create_streams()
232233

233234
try:
234-
errors = []
235+
errors: list[Exception] = []
235236

236237
def error_handler(error: Exception) -> None:
237238
errors.append(error)
@@ -248,7 +249,7 @@ def error_handler(error: Exception) -> None:
248249
# Valid message should still be forwarded
249250
with anyio.fail_after(1):
250251
received = await server_write_reader.receive()
251-
assert received.message.root.id == "after_error"
252+
assert received.message.root.id == "after_error" # type: ignore[attr-defined]
252253

253254
# Error should have been captured
254255
assert len(errors) == 1
@@ -259,7 +260,7 @@ def error_handler(error: Exception) -> None:
259260

260261

261262
@pytest.mark.anyio
262-
async def test_proxy_cleans_up_streams(create_streams):
263+
async def test_proxy_cleans_up_streams(create_streams: CreateStreamsFixture) -> None:
263264
"""Test that proxy exits cleanly and doesn't interfere with stream lifecycle."""
264265
(
265266
client_streams,
@@ -287,7 +288,7 @@ async def test_proxy_cleans_up_streams(create_streams):
287288

288289

289290
@pytest.mark.anyio
290-
async def test_proxy_multiple_messages(create_streams):
291+
async def test_proxy_multiple_messages(create_streams: CreateStreamsFixture) -> None:
291292
"""Test that proxy can forward multiple messages."""
292293
client_streams, server_streams, (client_read_writer, _), (_, server_write_reader) = create_streams()
293294

@@ -303,21 +304,21 @@ async def test_proxy_multiple_messages(create_streams):
303304
with anyio.fail_after(1):
304305
for i in range(5):
305306
received = await server_write_reader.receive()
306-
assert received.message.root.id == str(i)
307-
assert received.message.root.method == f"method_{i}"
307+
assert received.message.root.id == str(i) # type: ignore[attr-defined]
308+
assert received.message.root.method == f"method_{i}" # type: ignore[attr-defined]
308309
finally:
309310
# Clean up test streams
310311
await client_read_writer.aclose()
311312
await server_write_reader.aclose()
312313

313314

314315
@pytest.mark.anyio
315-
async def test_proxy_handles_closed_resource_error(create_streams):
316+
async def test_proxy_handles_closed_resource_error(create_streams: CreateStreamsFixture) -> None:
316317
"""Test that proxy handles ClosedResourceError gracefully."""
317318
client_streams, server_streams, (client_read_writer, _), (_, server_write_reader) = create_streams()
318319

319320
try:
320-
errors = []
321+
errors: list[Exception] = []
321322

322323
def error_handler(error: Exception) -> None:
323324
errors.append(error)
@@ -340,13 +341,13 @@ def error_handler(error: Exception) -> None:
340341

341342

342343
@pytest.mark.anyio
343-
async def test_proxy_closes_other_stream_on_close(create_streams):
344+
async def test_proxy_closes_other_stream_on_close(create_streams: CreateStreamsFixture) -> None:
344345
"""Test that when one stream closes, the other is also closed."""
345346
client_streams, server_streams, (client_read_writer, _), (_, server_write_reader) = create_streams()
346347

347348
try:
348-
client_read, client_write = client_streams
349-
server_read, server_write = server_streams
349+
client_read, _client_write = client_streams
350+
_server_read, server_write = server_streams
350351

351352
async with mcp_proxy(client_streams, server_streams):
352353
# Close the client read stream
@@ -369,7 +370,7 @@ async def test_proxy_closes_other_stream_on_close(create_streams):
369370

370371

371372
@pytest.mark.anyio
372-
async def test_proxy_error_in_callback(create_streams):
373+
async def test_proxy_error_in_callback(create_streams: CreateStreamsFixture) -> None:
373374
"""Test that errors in the error callback are handled gracefully."""
374375
client_streams, server_streams, (client_read_writer, _), (_, server_write_reader) = create_streams()
375376

@@ -396,15 +397,15 @@ def failing_error_handler(error: Exception) -> None:
396397
# Valid message should still be forwarded
397398
with anyio.fail_after(1):
398399
received = await server_write_reader.receive()
399-
assert received.message.root.id == "after_callback_error"
400+
assert received.message.root.id == "after_callback_error" # type: ignore[attr-defined]
400401
finally:
401402
# Clean up test streams
402403
await client_read_writer.aclose()
403404
await server_write_reader.aclose()
404405

405406

406407
@pytest.mark.anyio
407-
async def test_proxy_async_error_in_callback(create_streams):
408+
async def test_proxy_async_error_in_callback(create_streams: CreateStreamsFixture) -> None:
408409
"""Test that async errors in the error callback are handled gracefully."""
409410
client_streams, server_streams, (client_read_writer, _), (_, server_write_reader) = create_streams()
410411

@@ -432,15 +433,15 @@ async def failing_async_error_handler(error: Exception) -> None:
432433
# Valid message should still be forwarded
433434
with anyio.fail_after(1):
434435
received = await server_write_reader.receive()
435-
assert received.message.root.id == "after_async_callback_error"
436+
assert received.message.root.id == "after_async_callback_error" # type: ignore[attr-defined]
436437
finally:
437438
# Clean up test streams
438439
await client_read_writer.aclose()
439440
await server_write_reader.aclose()
440441

441442

442443
@pytest.mark.anyio
443-
async def test_proxy_without_error_handler(create_streams):
444+
async def test_proxy_without_error_handler(create_streams: CreateStreamsFixture) -> None:
444445
"""Test that proxy works without an error handler (covers onerror=None branch)."""
445446
client_streams, server_streams, (client_read_writer, _), (_, server_write_reader) = create_streams()
446447

@@ -462,10 +463,8 @@ async def test_proxy_without_error_handler(create_streams):
462463
# Valid message should still be forwarded
463464
with anyio.fail_after(1):
464465
received = await server_write_reader.receive()
465-
assert received.message.root.id == "after_exception_no_handler"
466+
assert received.message.root.id == "after_exception_no_handler" # type: ignore[attr-defined]
466467
finally:
467468
# Clean up test streams
468469
await client_read_writer.aclose()
469470
await server_write_reader.aclose()
470-
471-

0 commit comments

Comments
 (0)