Skip to content

Commit d0ec057

Browse files
committed
fix: Add integration markers to all multiprocessing tests
- Mark tests/shared/test_sse.py as integration (spawns subprocesses) - Mark tests/shared/test_streamable_http.py as integration (spawns subprocesses) - Mark tests/shared/test_ws.py as integration (spawns subprocesses) - Mark tests/server/test_sse_security.py as integration (spawns subprocesses) - Mark tests/server/test_streamable_http_security.py as integration (spawns subprocesses) - Mark tests/client/test_stdio.py as integration (spawns subprocesses) This ensures all subprocess-spawning tests run sequentially on Windows to prevent parallelization conflicts that cause test hangs.
1 parent 565ea48 commit d0ec057

File tree

6 files changed

+275
-81
lines changed

6 files changed

+275
-81
lines changed

tests/client/test_stdio.py

Lines changed: 49 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@
1919
from mcp.types import CONNECTION_CLOSED, JSONRPCMessage, JSONRPCRequest, JSONRPCResponse
2020
from tests.shared.test_win32_utils import escape_path_for_python
2121

22+
# Mark all tests in this file as integration tests (spawn subprocesses)
23+
pytestmark = [pytest.mark.integration]
24+
2225
# Timeout for cleanup of processes that ignore SIGTERM
2326
# This timeout ensures the test fails quickly if the cleanup logic doesn't have
2427
# proper fallback mechanisms (SIGINT/SIGKILL) for processes that ignore SIGTERM
@@ -63,14 +66,20 @@ async def test_stdio_client():
6366
break
6467

6568
assert len(read_messages) == 2
66-
assert read_messages[0] == JSONRPCMessage(root=JSONRPCRequest(jsonrpc="2.0", id=1, method="ping"))
67-
assert read_messages[1] == JSONRPCMessage(root=JSONRPCResponse(jsonrpc="2.0", id=2, result={}))
69+
assert read_messages[0] == JSONRPCMessage(
70+
root=JSONRPCRequest(jsonrpc="2.0", id=1, method="ping")
71+
)
72+
assert read_messages[1] == JSONRPCMessage(
73+
root=JSONRPCResponse(jsonrpc="2.0", id=2, result={})
74+
)
6875

6976

7077
@pytest.mark.anyio
7178
async def test_stdio_client_bad_path():
7279
"""Check that the connection doesn't hang if process errors."""
73-
server_params = StdioServerParameters(command="python", args=["-c", "non-existent-file.py"])
80+
server_params = StdioServerParameters(
81+
command="python", args=["-c", "non-existent-file.py"]
82+
)
7483
async with stdio_client(server_params) as (read_stream, write_stream):
7584
async with ClientSession(read_stream, write_stream) as session:
7685
# The session should raise an error when the connection closes
@@ -158,7 +167,9 @@ async def test_stdio_client_universal_cleanup():
158167

159168

160169
@pytest.mark.anyio
161-
@pytest.mark.skipif(sys.platform == "win32", reason="Windows signal handling is different")
170+
@pytest.mark.skipif(
171+
sys.platform == "win32", reason="Windows signal handling is different"
172+
)
162173
async def test_stdio_client_sigint_only_process():
163174
"""
164175
Test cleanup with a process that ignores SIGTERM but responds to SIGINT.
@@ -251,7 +262,9 @@ class TestChildProcessCleanup:
251262
"""
252263

253264
@pytest.mark.anyio
254-
@pytest.mark.filterwarnings("ignore::ResourceWarning" if sys.platform == "win32" else "default")
265+
@pytest.mark.filterwarnings(
266+
"ignore::ResourceWarning" if sys.platform == "win32" else "default"
267+
)
255268
async def test_basic_child_process_cleanup(self):
256269
"""
257270
Test basic parent-child process cleanup.
@@ -300,7 +313,9 @@ async def test_basic_child_process_cleanup(self):
300313
print("\nStarting child process termination test...")
301314

302315
# Start the parent process
303-
proc = await _create_platform_compatible_process(sys.executable, ["-c", parent_script])
316+
proc = await _create_platform_compatible_process(
317+
sys.executable, ["-c", parent_script]
318+
)
304319

305320
# Wait for processes to start
306321
await anyio.sleep(0.5)
@@ -314,7 +329,9 @@ async def test_basic_child_process_cleanup(self):
314329
await anyio.sleep(0.3)
315330
size_after_wait = os.path.getsize(marker_file)
316331
assert size_after_wait > initial_size, "Child process should be writing"
317-
print(f"Child is writing (file grew from {initial_size} to {size_after_wait} bytes)")
332+
print(
333+
f"Child is writing (file grew from {initial_size} to {size_after_wait} bytes)"
334+
)
318335

319336
# Terminate using our function
320337
print("Terminating process and children...")
@@ -330,9 +347,9 @@ async def test_basic_child_process_cleanup(self):
330347
final_size = os.path.getsize(marker_file)
331348

332349
print(f"After cleanup: file size {size_after_cleanup} -> {final_size}")
333-
assert final_size == size_after_cleanup, (
334-
f"Child process still running! File grew by {final_size - size_after_cleanup} bytes"
335-
)
350+
assert (
351+
final_size == size_after_cleanup
352+
), f"Child process still running! File grew by {final_size - size_after_cleanup} bytes"
336353

337354
print("SUCCESS: Child process was properly terminated")
338355

@@ -345,7 +362,9 @@ async def test_basic_child_process_cleanup(self):
345362
pass
346363

347364
@pytest.mark.anyio
348-
@pytest.mark.filterwarnings("ignore::ResourceWarning" if sys.platform == "win32" else "default")
365+
@pytest.mark.filterwarnings(
366+
"ignore::ResourceWarning" if sys.platform == "win32" else "default"
367+
)
349368
async def test_nested_process_tree(self):
350369
"""
351370
Test nested process tree cleanup (parent → child → grandchild).
@@ -405,13 +424,19 @@ async def test_nested_process_tree(self):
405424
)
406425

407426
# Start the parent process
408-
proc = await _create_platform_compatible_process(sys.executable, ["-c", parent_script])
427+
proc = await _create_platform_compatible_process(
428+
sys.executable, ["-c", parent_script]
429+
)
409430

410431
# Let all processes start
411432
await anyio.sleep(1.0)
412433

413434
# Verify all are writing
414-
for file_path, name in [(parent_file, "parent"), (child_file, "child"), (grandchild_file, "grandchild")]:
435+
for file_path, name in [
436+
(parent_file, "parent"),
437+
(child_file, "child"),
438+
(grandchild_file, "grandchild"),
439+
]:
415440
if os.path.exists(file_path):
416441
initial_size = os.path.getsize(file_path)
417442
await anyio.sleep(0.3)
@@ -425,7 +450,11 @@ async def test_nested_process_tree(self):
425450

426451
# Verify all stopped
427452
await anyio.sleep(0.5)
428-
for file_path, name in [(parent_file, "parent"), (child_file, "child"), (grandchild_file, "grandchild")]:
453+
for file_path, name in [
454+
(parent_file, "parent"),
455+
(child_file, "child"),
456+
(grandchild_file, "grandchild"),
457+
]:
429458
if os.path.exists(file_path):
430459
size1 = os.path.getsize(file_path)
431460
await anyio.sleep(0.3)
@@ -443,7 +472,9 @@ async def test_nested_process_tree(self):
443472
pass
444473

445474
@pytest.mark.anyio
446-
@pytest.mark.filterwarnings("ignore::ResourceWarning" if sys.platform == "win32" else "default")
475+
@pytest.mark.filterwarnings(
476+
"ignore::ResourceWarning" if sys.platform == "win32" else "default"
477+
)
447478
async def test_early_parent_exit(self):
448479
"""
449480
Test cleanup when parent exits during termination sequence.
@@ -487,7 +518,9 @@ def handle_term(sig, frame):
487518
)
488519

489520
# Start the parent process
490-
proc = await _create_platform_compatible_process(sys.executable, ["-c", parent_script])
521+
proc = await _create_platform_compatible_process(
522+
sys.executable, ["-c", parent_script]
523+
)
491524

492525
# Let child start writing
493526
await anyio.sleep(0.5)

tests/server/test_sse_security.py

Lines changed: 57 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,10 @@
1818
from mcp.server.transport_security import TransportSecuritySettings
1919
from mcp.types import Tool
2020

21+
# Mark all tests in this file as integration tests (spawn subprocesses)
22+
pytestmark = [pytest.mark.integration]
23+
24+
2125
logger = logging.getLogger(__name__)
2226
SERVER_NAME = "test_sse_security_server"
2327

@@ -42,16 +46,22 @@ async def on_list_tools(self) -> list[Tool]:
4246
return []
4347

4448

45-
def run_server_with_settings(port: int, security_settings: TransportSecuritySettings | None = None):
49+
def run_server_with_settings(
50+
port: int, security_settings: TransportSecuritySettings | None = None
51+
):
4652
"""Run the SSE server with specified security settings."""
4753
app = SecurityTestServer()
4854
sse_transport = SseServerTransport("/messages/", security_settings)
4955

5056
async def handle_sse(request: Request):
5157
try:
52-
async with sse_transport.connect_sse(request.scope, request.receive, request._send) as streams:
58+
async with sse_transport.connect_sse(
59+
request.scope, request.receive, request._send
60+
) as streams:
5361
if streams:
54-
await app.run(streams[0], streams[1], app.create_initialization_options())
62+
await app.run(
63+
streams[0], streams[1], app.create_initialization_options()
64+
)
5565
except ValueError as e:
5666
# Validation error was already handled inside connect_sse
5767
logger.debug(f"SSE connection failed validation: {e}")
@@ -66,9 +76,13 @@ async def handle_sse(request: Request):
6676
uvicorn.run(starlette_app, host="127.0.0.1", port=port, log_level="error")
6777

6878

69-
def start_server_process(port: int, security_settings: TransportSecuritySettings | None = None):
79+
def start_server_process(
80+
port: int, security_settings: TransportSecuritySettings | None = None
81+
):
7082
"""Start server in a separate process."""
71-
process = multiprocessing.Process(target=run_server_with_settings, args=(port, security_settings))
83+
process = multiprocessing.Process(
84+
target=run_server_with_settings, args=(port, security_settings)
85+
)
7286
process.start()
7387
# Give server time to start
7488
time.sleep(1)
@@ -84,7 +98,9 @@ async def test_sse_security_default_settings(server_port: int):
8498
headers = {"Host": "evil.com", "Origin": "http://evil.com"}
8599

86100
async with httpx.AsyncClient(timeout=5.0) as client:
87-
async with client.stream("GET", f"http://127.0.0.1:{server_port}/sse", headers=headers) as response:
101+
async with client.stream(
102+
"GET", f"http://127.0.0.1:{server_port}/sse", headers=headers
103+
) as response:
88104
assert response.status_code == 200
89105
finally:
90106
process.terminate()
@@ -95,15 +111,19 @@ async def test_sse_security_default_settings(server_port: int):
95111
async def test_sse_security_invalid_host_header(server_port: int):
96112
"""Test SSE with invalid Host header."""
97113
# Enable security by providing settings with an empty allowed_hosts list
98-
security_settings = TransportSecuritySettings(enable_dns_rebinding_protection=True, allowed_hosts=["example.com"])
114+
security_settings = TransportSecuritySettings(
115+
enable_dns_rebinding_protection=True, allowed_hosts=["example.com"]
116+
)
99117
process = start_server_process(server_port, security_settings)
100118

101119
try:
102120
# Test with invalid host header
103121
headers = {"Host": "evil.com"}
104122

105123
async with httpx.AsyncClient() as client:
106-
response = await client.get(f"http://127.0.0.1:{server_port}/sse", headers=headers)
124+
response = await client.get(
125+
f"http://127.0.0.1:{server_port}/sse", headers=headers
126+
)
107127
assert response.status_code == 421
108128
assert response.text == "Invalid Host header"
109129

@@ -117,7 +137,9 @@ async def test_sse_security_invalid_origin_header(server_port: int):
117137
"""Test SSE with invalid Origin header."""
118138
# Configure security to allow the host but restrict origins
119139
security_settings = TransportSecuritySettings(
120-
enable_dns_rebinding_protection=True, allowed_hosts=["127.0.0.1:*"], allowed_origins=["http://localhost:*"]
140+
enable_dns_rebinding_protection=True,
141+
allowed_hosts=["127.0.0.1:*"],
142+
allowed_origins=["http://localhost:*"],
121143
)
122144
process = start_server_process(server_port, security_settings)
123145

@@ -126,7 +148,9 @@ async def test_sse_security_invalid_origin_header(server_port: int):
126148
headers = {"Origin": "http://evil.com"}
127149

128150
async with httpx.AsyncClient() as client:
129-
response = await client.get(f"http://127.0.0.1:{server_port}/sse", headers=headers)
151+
response = await client.get(
152+
f"http://127.0.0.1:{server_port}/sse", headers=headers
153+
)
130154
assert response.status_code == 400
131155
assert response.text == "Invalid Origin header"
132156

@@ -140,7 +164,9 @@ async def test_sse_security_post_invalid_content_type(server_port: int):
140164
"""Test POST endpoint with invalid Content-Type header."""
141165
# Configure security to allow the host
142166
security_settings = TransportSecuritySettings(
143-
enable_dns_rebinding_protection=True, allowed_hosts=["127.0.0.1:*"], allowed_origins=["http://127.0.0.1:*"]
167+
enable_dns_rebinding_protection=True,
168+
allowed_hosts=["127.0.0.1:*"],
169+
allowed_origins=["http://127.0.0.1:*"],
144170
)
145171
process = start_server_process(server_port, security_settings)
146172

@@ -158,7 +184,8 @@ async def test_sse_security_post_invalid_content_type(server_port: int):
158184

159185
# Test POST with missing content type
160186
response = await client.post(
161-
f"http://127.0.0.1:{server_port}/messages/?session_id={fake_session_id}", content="test"
187+
f"http://127.0.0.1:{server_port}/messages/?session_id={fake_session_id}",
188+
content="test",
162189
)
163190
assert response.status_code == 400
164191
assert response.text == "Invalid Content-Type header"
@@ -180,7 +207,9 @@ async def test_sse_security_disabled(server_port: int):
180207

181208
async with httpx.AsyncClient(timeout=5.0) as client:
182209
# For SSE endpoints, we need to use stream to avoid timeout
183-
async with client.stream("GET", f"http://127.0.0.1:{server_port}/sse", headers=headers) as response:
210+
async with client.stream(
211+
"GET", f"http://127.0.0.1:{server_port}/sse", headers=headers
212+
) as response:
184213
# Should connect successfully even with invalid host
185214
assert response.status_code == 200
186215

@@ -205,15 +234,19 @@ async def test_sse_security_custom_allowed_hosts(server_port: int):
205234

206235
async with httpx.AsyncClient(timeout=5.0) as client:
207236
# For SSE endpoints, we need to use stream to avoid timeout
208-
async with client.stream("GET", f"http://127.0.0.1:{server_port}/sse", headers=headers) as response:
237+
async with client.stream(
238+
"GET", f"http://127.0.0.1:{server_port}/sse", headers=headers
239+
) as response:
209240
# Should connect successfully with custom host
210241
assert response.status_code == 200
211242

212243
# Test with non-allowed host
213244
headers = {"Host": "evil.com"}
214245

215246
async with httpx.AsyncClient() as client:
216-
response = await client.get(f"http://127.0.0.1:{server_port}/sse", headers=headers)
247+
response = await client.get(
248+
f"http://127.0.0.1:{server_port}/sse", headers=headers
249+
)
217250
assert response.status_code == 421
218251
assert response.text == "Invalid Host header"
219252

@@ -239,15 +272,19 @@ async def test_sse_security_wildcard_ports(server_port: int):
239272

240273
async with httpx.AsyncClient(timeout=5.0) as client:
241274
# For SSE endpoints, we need to use stream to avoid timeout
242-
async with client.stream("GET", f"http://127.0.0.1:{server_port}/sse", headers=headers) as response:
275+
async with client.stream(
276+
"GET", f"http://127.0.0.1:{server_port}/sse", headers=headers
277+
) as response:
243278
# Should connect successfully with any port
244279
assert response.status_code == 200
245280

246281
headers = {"Origin": f"http://localhost:{test_port}"}
247282

248283
async with httpx.AsyncClient(timeout=5.0) as client:
249284
# For SSE endpoints, we need to use stream to avoid timeout
250-
async with client.stream("GET", f"http://127.0.0.1:{server_port}/sse", headers=headers) as response:
285+
async with client.stream(
286+
"GET", f"http://127.0.0.1:{server_port}/sse", headers=headers
287+
) as response:
251288
# Should connect successfully with any port
252289
assert response.status_code == 200
253290

@@ -261,7 +298,9 @@ async def test_sse_security_post_valid_content_type(server_port: int):
261298
"""Test POST endpoint with valid Content-Type headers."""
262299
# Configure security to allow the host
263300
security_settings = TransportSecuritySettings(
264-
enable_dns_rebinding_protection=True, allowed_hosts=["127.0.0.1:*"], allowed_origins=["http://127.0.0.1:*"]
301+
enable_dns_rebinding_protection=True,
302+
allowed_hosts=["127.0.0.1:*"],
303+
allowed_origins=["http://127.0.0.1:*"],
265304
)
266305
process = start_server_process(server_port, security_settings)
267306

tests/server/test_streamable_http_security.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@
1717
from mcp.server import Server
1818
from mcp.server.streamable_http_manager import StreamableHTTPSessionManager
1919
from mcp.server.transport_security import TransportSecuritySettings
20+
21+
# Mark all tests in this file as integration tests (spawn subprocesses)
22+
pytestmark = [pytest.mark.integration]
2023
from mcp.types import Tool
2124

2225
logger = logging.getLogger(__name__)

tests/shared/test_sse.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,9 @@
3232
Tool,
3333
)
3434

35+
# Mark all tests in this file as integration tests (spawn subprocesses)
36+
pytestmark = [pytest.mark.integration]
37+
3538
SERVER_NAME = "test_server_for_SSE"
3639

3740

0 commit comments

Comments
 (0)