Skip to content

Commit 2ab65c0

Browse files
committed
fix: Improve process cleanup in integration tests
- Increase join timeout from 2s to 5s - Add fallback terminate() call for stubborn processes - Add exception handling for cleanup edge cases
1 parent 5641827 commit 2ab65c0

File tree

1 file changed

+41
-15
lines changed

1 file changed

+41
-15
lines changed

tests/server/fastmcp/test_integration.py

Lines changed: 41 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,9 @@ def run_server_with_transport(module_name: str, port: int, transport: str) -> No
8080
import sys
8181

8282
# Add examples/snippets to Python path for multiprocessing context
83-
snippets_path = os.path.join(os.path.dirname(__file__), "..", "..", "..", "examples", "snippets")
83+
snippets_path = os.path.join(
84+
os.path.dirname(__file__), "..", "..", "..", "examples", "snippets"
85+
)
8486
sys.path.insert(0, os.path.abspath(snippets_path))
8587

8688
# Import the servers module in the multiprocessing context
@@ -129,7 +131,9 @@ def run_server_with_transport(module_name: str, port: int, transport: str) -> No
129131
else:
130132
raise ValueError(f"Invalid transport for test server: {transport}")
131133

132-
server = uvicorn.Server(config=uvicorn.Config(app=app, host="127.0.0.1", port=port, log_level="error"))
134+
server = uvicorn.Server(
135+
config=uvicorn.Config(app=app, host="127.0.0.1", port=port, log_level="error")
136+
)
133137
print(f"Starting {transport} server on port {port}")
134138
server.run()
135139

@@ -169,14 +173,22 @@ def server_transport(request, server_port: int) -> Generator[str, None, None]:
169173
time.sleep(delay)
170174
attempt += 1
171175
else:
172-
raise RuntimeError(f"Server failed to start after {max_attempts} attempts (port {server_port})")
176+
raise RuntimeError(
177+
f"Server failed to start after {max_attempts} attempts (port {server_port})"
178+
)
173179

174180
yield transport
175181

182+
# Aggressive cleanup - kill and force terminate
176183
proc.kill()
177-
proc.join(timeout=2)
184+
proc.join(timeout=5)
178185
if proc.is_alive():
179-
print("Server process failed to terminate")
186+
print("Server process failed to terminate, force killing")
187+
try:
188+
proc.terminate()
189+
proc.join(timeout=2)
190+
except Exception:
191+
pass
180192

181193

182194
# Helper function to create client based on transport
@@ -340,10 +352,14 @@ async def test_basic_prompts(server_transport: str, server_url: str) -> None:
340352

341353
# Test review_code prompt
342354
prompts = await session.list_prompts()
343-
review_prompt = next((p for p in prompts.prompts if p.name == "review_code"), None)
355+
review_prompt = next(
356+
(p for p in prompts.prompts if p.name == "review_code"), None
357+
)
344358
assert review_prompt is not None
345359

346-
prompt_result = await session.get_prompt("review_code", {"code": "def hello():\n print('Hello')"})
360+
prompt_result = await session.get_prompt(
361+
"review_code", {"code": "def hello():\n print('Hello')"}
362+
)
347363
assert isinstance(prompt_result, GetPromptResult)
348364
assert len(prompt_result.messages) == 1
349365
assert isinstance(prompt_result.messages[0].content, TextContent)
@@ -399,16 +415,18 @@ async def test_tool_progress(server_transport: str, server_url: str) -> None:
399415
assert result.capabilities.tools is not None
400416

401417
# Test long_running_task tool that reports progress
402-
tool_result = await session.call_tool("long_running_task", {"task_name": "test", "steps": 3})
418+
tool_result = await session.call_tool(
419+
"long_running_task", {"task_name": "test", "steps": 3}
420+
)
403421
assert len(tool_result.content) == 1
404422
assert isinstance(tool_result.content[0], TextContent)
405423
assert "Task 'test' completed" in tool_result.content[0].text
406424

407425
# Verify that progress notifications or log messages were sent
408426
# Progress can come through either progress notifications or log messages
409-
total_notifications = len(notification_collector.progress_notifications) + len(
410-
notification_collector.log_messages
411-
)
427+
total_notifications = len(
428+
notification_collector.progress_notifications
429+
) + len(notification_collector.log_messages)
412430
assert total_notifications > 0
413431

414432

@@ -429,7 +447,9 @@ async def test_sampling(server_transport: str, server_url: str) -> None:
429447

430448
async with client_cm as client_streams:
431449
read_stream, write_stream = unpack_streams(client_streams)
432-
async with ClientSession(read_stream, write_stream, sampling_callback=sampling_callback) as session:
450+
async with ClientSession(
451+
read_stream, write_stream, sampling_callback=sampling_callback
452+
) as session:
433453
# Test initialization
434454
result = await session.initialize()
435455
assert isinstance(result, InitializeResult)
@@ -460,7 +480,9 @@ async def test_elicitation(server_transport: str, server_url: str) -> None:
460480

461481
async with client_cm as client_streams:
462482
read_stream, write_stream = unpack_streams(client_streams)
463-
async with ClientSession(read_stream, write_stream, elicitation_callback=elicitation_callback) as session:
483+
async with ClientSession(
484+
read_stream, write_stream, elicitation_callback=elicitation_callback
485+
) as session:
464486
# Test initialization
465487
result = await session.initialize()
466488
assert isinstance(result, InitializeResult)
@@ -506,7 +528,9 @@ async def test_completion(server_transport: str, server_url: str) -> None:
506528
assert len(prompts.prompts) > 0
507529

508530
# Test getting a prompt
509-
prompt_result = await session.get_prompt("review_code", {"language": "python", "code": "def test(): pass"})
531+
prompt_result = await session.get_prompt(
532+
"review_code", {"language": "python", "code": "def test(): pass"}
533+
)
510534
assert len(prompt_result.messages) > 0
511535

512536

@@ -618,7 +642,9 @@ async def test_structured_output(server_transport: str, server_url: str) -> None
618642
assert result.serverInfo.name == "Structured Output Example"
619643

620644
# Test get_weather tool
621-
weather_result = await session.call_tool("get_weather", {"city": "New York"})
645+
weather_result = await session.call_tool(
646+
"get_weather", {"city": "New York"}
647+
)
622648
assert len(weather_result.content) == 1
623649
assert isinstance(weather_result.content[0], TextContent)
624650

0 commit comments

Comments
 (0)