Skip to content

Commit ac8e973

Browse files
committed
fix: Improve server startup timeouts in shared tests
- Increase max_attempts from 20 to 30 for server startup - Add socket timeout and better error handling - Use progressive delays (0.05s -> 0.1s) for faster startup - Handle OSError in addition to ConnectionRefusedError This should reduce 'Server failed to start' errors in CI tests.
1 parent add1adc commit ac8e973

File tree

2 files changed

+246
-80
lines changed

2 files changed

+246
-80
lines changed

tests/shared/test_sse.py

Lines changed: 112 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,11 @@ async def handle_read_resource(uri: AnyUrl) -> str | bytes:
6464
await anyio.sleep(2.0)
6565
return f"Slow response from {uri.host}"
6666

67-
raise McpError(error=ErrorData(code=404, message="OOPS! no resource with that URI was found"))
67+
raise McpError(
68+
error=ErrorData(
69+
code=404, message="OOPS! no resource with that URI was found"
70+
)
71+
)
6872

6973
@self.list_tools()
7074
async def handle_list_tools() -> list[Tool]:
@@ -86,14 +90,19 @@ def make_server_app() -> Starlette:
8690
"""Create test Starlette app with SSE transport"""
8791
# Configure security with allowed hosts/origins for testing
8892
security_settings = TransportSecuritySettings(
89-
allowed_hosts=["127.0.0.1:*", "localhost:*"], allowed_origins=["http://127.0.0.1:*", "http://localhost:*"]
93+
allowed_hosts=["127.0.0.1:*", "localhost:*"],
94+
allowed_origins=["http://127.0.0.1:*", "http://localhost:*"],
9095
)
9196
sse = SseServerTransport("/messages/", security_settings=security_settings)
9297
server = ServerTest()
9398

9499
async def handle_sse(request: Request) -> Response:
95-
async with sse.connect_sse(request.scope, request.receive, request._send) as streams:
96-
await server.run(streams[0], streams[1], server.create_initialization_options())
100+
async with sse.connect_sse(
101+
request.scope, request.receive, request._send
102+
) as streams:
103+
await server.run(
104+
streams[0], streams[1], server.create_initialization_options()
105+
)
97106
return Response()
98107

99108
app = Starlette(
@@ -108,7 +117,11 @@ async def handle_sse(request: Request) -> Response:
108117

109118
def run_server(server_port: int) -> None:
110119
app = make_server_app()
111-
server = uvicorn.Server(config=uvicorn.Config(app=app, host="127.0.0.1", port=server_port, log_level="error"))
120+
server = uvicorn.Server(
121+
config=uvicorn.Config(
122+
app=app, host="127.0.0.1", port=server_port, log_level="error"
123+
)
124+
)
112125
print(f"starting server on {server_port}")
113126
server.run()
114127

@@ -120,21 +133,26 @@ def run_server(server_port: int) -> None:
120133

121134
@pytest.fixture()
122135
def server(server_port: int) -> Generator[None, None, None]:
123-
proc = multiprocessing.Process(target=run_server, kwargs={"server_port": server_port}, daemon=True)
136+
proc = multiprocessing.Process(
137+
target=run_server, kwargs={"server_port": server_port}, daemon=True
138+
)
124139
print("starting process")
125140
proc.start()
126141

127-
# Wait for server to be running
128-
max_attempts = 20
142+
# Wait for server to be running - optimized for faster startup
143+
max_attempts = 30
129144
attempt = 0
130145
print("waiting for server to start")
131146
while attempt < max_attempts:
132147
try:
133148
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
149+
s.settimeout(1.0)
134150
s.connect(("127.0.0.1", server_port))
135151
break
136-
except ConnectionRefusedError:
137-
time.sleep(0.1)
152+
except (ConnectionRefusedError, OSError):
153+
# Use shorter initial delays, then increase
154+
delay = 0.05 if attempt < 10 else 0.1
155+
time.sleep(delay)
138156
attempt += 1
139157
else:
140158
raise RuntimeError(f"Server failed to start after {max_attempts} attempts")
@@ -165,7 +183,10 @@ async def test_raw_sse_connection(http_client: httpx.AsyncClient) -> None:
165183
async def connection_test() -> None:
166184
async with http_client.stream("GET", "/sse") as response:
167185
assert response.status_code == 200
168-
assert response.headers["content-type"] == "text/event-stream; charset=utf-8"
186+
assert (
187+
response.headers["content-type"]
188+
== "text/event-stream; charset=utf-8"
189+
)
169190

170191
line_number = 0
171192
async for line in response.aiter_lines():
@@ -197,7 +218,9 @@ async def test_sse_client_basic_connection(server: None, server_url: str) -> Non
197218

198219

199220
@pytest.fixture
200-
async def initialized_sse_client_session(server, server_url: str) -> AsyncGenerator[ClientSession, None]:
221+
async def initialized_sse_client_session(
222+
server, server_url: str
223+
) -> AsyncGenerator[ClientSession, None]:
201224
async with sse_client(server_url + "/sse", sse_read_timeout=0.5) as streams:
202225
async with ClientSession(*streams) as session:
203226
await session.initialize()
@@ -225,7 +248,9 @@ async def test_sse_client_exception_handling(
225248

226249

227250
@pytest.mark.anyio
228-
@pytest.mark.skip("this test highlights a possible bug in SSE read timeout exception handling")
251+
@pytest.mark.skip(
252+
"this test highlights a possible bug in SSE read timeout exception handling"
253+
)
229254
async def test_sse_client_timeout(
230255
initialized_sse_client_session: ClientSession,
231256
) -> None:
@@ -247,7 +272,11 @@ async def test_sse_client_timeout(
247272
def run_mounted_server(server_port: int) -> None:
248273
app = make_server_app()
249274
main_app = Starlette(routes=[Mount("/mounted_app", app=app)])
250-
server = uvicorn.Server(config=uvicorn.Config(app=main_app, host="127.0.0.1", port=server_port, log_level="error"))
275+
server = uvicorn.Server(
276+
config=uvicorn.Config(
277+
app=main_app, host="127.0.0.1", port=server_port, log_level="error"
278+
)
279+
)
251280
print(f"starting server on {server_port}")
252281
server.run()
253282

@@ -259,21 +288,26 @@ def run_mounted_server(server_port: int) -> None:
259288

260289
@pytest.fixture()
261290
def mounted_server(server_port: int) -> Generator[None, None, None]:
262-
proc = multiprocessing.Process(target=run_mounted_server, kwargs={"server_port": server_port}, daemon=True)
291+
proc = multiprocessing.Process(
292+
target=run_mounted_server, kwargs={"server_port": server_port}, daemon=True
293+
)
263294
print("starting process")
264295
proc.start()
265296

266-
# Wait for server to be running
267-
max_attempts = 20
297+
# Wait for server to be running - optimized for faster startup
298+
max_attempts = 30
268299
attempt = 0
269300
print("waiting for server to start")
270301
while attempt < max_attempts:
271302
try:
272303
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
304+
s.settimeout(1.0)
273305
s.connect(("127.0.0.1", server_port))
274306
break
275-
except ConnectionRefusedError:
276-
time.sleep(0.1)
307+
except (ConnectionRefusedError, OSError):
308+
# Use shorter initial delays, then increase
309+
delay = 0.05 if attempt < 10 else 0.1
310+
time.sleep(delay)
277311
attempt += 1
278312
else:
279313
raise RuntimeError(f"Server failed to start after {max_attempts} attempts")
@@ -289,7 +323,9 @@ def mounted_server(server_port: int) -> Generator[None, None, None]:
289323

290324

291325
@pytest.mark.anyio
292-
async def test_sse_client_basic_connection_mounted_app(mounted_server: None, server_url: str) -> None:
326+
async def test_sse_client_basic_connection_mounted_app(
327+
mounted_server: None, server_url: str
328+
) -> None:
293329
async with sse_client(server_url + "/mounted_app/sse") as streams:
294330
async with ClientSession(*streams) as session:
295331
# Test initialization
@@ -349,14 +385,19 @@ def run_context_server(server_port: int) -> None:
349385
"""Run a server that captures request context"""
350386
# Configure security with allowed hosts/origins for testing
351387
security_settings = TransportSecuritySettings(
352-
allowed_hosts=["127.0.0.1:*", "localhost:*"], allowed_origins=["http://127.0.0.1:*", "http://localhost:*"]
388+
allowed_hosts=["127.0.0.1:*", "localhost:*"],
389+
allowed_origins=["http://127.0.0.1:*", "http://localhost:*"],
353390
)
354391
sse = SseServerTransport("/messages/", security_settings=security_settings)
355392
context_server = RequestContextServer()
356393

357394
async def handle_sse(request: Request) -> Response:
358-
async with sse.connect_sse(request.scope, request.receive, request._send) as streams:
359-
await context_server.run(streams[0], streams[1], context_server.create_initialization_options())
395+
async with sse.connect_sse(
396+
request.scope, request.receive, request._send
397+
) as streams:
398+
await context_server.run(
399+
streams[0], streams[1], context_server.create_initialization_options()
400+
)
360401
return Response()
361402

362403
app = Starlette(
@@ -366,32 +407,43 @@ async def handle_sse(request: Request) -> Response:
366407
]
367408
)
368409

369-
server = uvicorn.Server(config=uvicorn.Config(app=app, host="127.0.0.1", port=server_port, log_level="error"))
410+
server = uvicorn.Server(
411+
config=uvicorn.Config(
412+
app=app, host="127.0.0.1", port=server_port, log_level="error"
413+
)
414+
)
370415
print(f"starting context server on {server_port}")
371416
server.run()
372417

373418

374419
@pytest.fixture()
375420
def context_server(server_port: int) -> Generator[None, None, None]:
376421
"""Fixture that provides a server with request context capture"""
377-
proc = multiprocessing.Process(target=run_context_server, kwargs={"server_port": server_port}, daemon=True)
422+
proc = multiprocessing.Process(
423+
target=run_context_server, kwargs={"server_port": server_port}, daemon=True
424+
)
378425
print("starting context server process")
379426
proc.start()
380427

381-
# Wait for server to be running
382-
max_attempts = 20
428+
# Wait for server to be running - optimized for faster startup
429+
max_attempts = 30
383430
attempt = 0
384431
print("waiting for context server to start")
385432
while attempt < max_attempts:
386433
try:
387434
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
435+
s.settimeout(1.0)
388436
s.connect(("127.0.0.1", server_port))
389437
break
390-
except ConnectionRefusedError:
391-
time.sleep(0.1)
438+
except (ConnectionRefusedError, OSError):
439+
# Use shorter initial delays, then increase
440+
delay = 0.05 if attempt < 10 else 0.1
441+
time.sleep(delay)
392442
attempt += 1
393443
else:
394-
raise RuntimeError(f"Context server failed to start after {max_attempts} attempts")
444+
raise RuntimeError(
445+
f"Context server failed to start after {max_attempts} attempts"
446+
)
395447

396448
yield
397449

@@ -403,7 +455,9 @@ def context_server(server_port: int) -> Generator[None, None, None]:
403455

404456

405457
@pytest.mark.anyio
406-
async def test_request_context_propagation(context_server: None, server_url: str) -> None:
458+
async def test_request_context_propagation(
459+
context_server: None, server_url: str
460+
) -> None:
407461
"""Test that request context is properly propagated through SSE transport."""
408462
# Test with custom headers
409463
custom_headers = {
@@ -427,7 +481,11 @@ async def test_request_context_propagation(context_server: None, server_url: str
427481
# Parse the JSON response
428482

429483
assert len(tool_result.content) == 1
430-
headers_data = json.loads(tool_result.content[0].text if tool_result.content[0].type == "text" else "{}")
484+
headers_data = json.loads(
485+
tool_result.content[0].text
486+
if tool_result.content[0].type == "text"
487+
else "{}"
488+
)
431489

432490
# Verify headers were propagated
433491
assert headers_data.get("authorization") == "Bearer test-token"
@@ -452,11 +510,15 @@ async def test_request_context_isolation(context_server: None, server_url: str)
452510
await session.initialize()
453511

454512
# Call the tool that echoes context
455-
tool_result = await session.call_tool("echo_context", {"request_id": f"request-{i}"})
513+
tool_result = await session.call_tool(
514+
"echo_context", {"request_id": f"request-{i}"}
515+
)
456516

457517
assert len(tool_result.content) == 1
458518
context_data = json.loads(
459-
tool_result.content[0].text if tool_result.content[0].type == "text" else "{}"
519+
tool_result.content[0].text
520+
if tool_result.content[0].type == "text"
521+
else "{}"
460522
)
461523
contexts.append(context_data)
462524

@@ -480,11 +542,19 @@ def test_sse_message_id_coercion():
480542
"""
481543
json_message = '{"jsonrpc": "2.0", "id": "123", "method": "ping", "params": null}'
482544
msg = types.JSONRPCMessage.model_validate_json(json_message)
483-
assert msg == snapshot(types.JSONRPCMessage(root=types.JSONRPCRequest(method="ping", jsonrpc="2.0", id="123")))
545+
assert msg == snapshot(
546+
types.JSONRPCMessage(
547+
root=types.JSONRPCRequest(method="ping", jsonrpc="2.0", id="123")
548+
)
549+
)
484550

485551
json_message = '{"jsonrpc": "2.0", "id": 123, "method": "ping", "params": null}'
486552
msg = types.JSONRPCMessage.model_validate_json(json_message)
487-
assert msg == snapshot(types.JSONRPCMessage(root=types.JSONRPCRequest(method="ping", jsonrpc="2.0", id=123)))
553+
assert msg == snapshot(
554+
types.JSONRPCMessage(
555+
root=types.JSONRPCRequest(method="ping", jsonrpc="2.0", id=123)
556+
)
557+
)
488558

489559

490560
@pytest.mark.parametrize(
@@ -502,11 +572,15 @@ def test_sse_message_id_coercion():
502572
("/messages/#fragment", ValueError),
503573
],
504574
)
505-
def test_sse_server_transport_endpoint_validation(endpoint: str, expected_result: str | type[Exception]):
575+
def test_sse_server_transport_endpoint_validation(
576+
endpoint: str, expected_result: str | type[Exception]
577+
):
506578
"""Test that SseServerTransport properly validates and normalizes endpoints."""
507579
if isinstance(expected_result, type) and issubclass(expected_result, Exception):
508580
# Test invalid endpoints that should raise an exception
509-
with pytest.raises(expected_result, match="is not a relative path.*expecting a relative path"):
581+
with pytest.raises(
582+
expected_result, match="is not a relative path.*expecting a relative path"
583+
):
510584
SseServerTransport(endpoint)
511585
else:
512586
# Test valid endpoints that should normalize correctly

0 commit comments

Comments
 (0)