1818from mcp .server .transport_security import TransportSecuritySettings
1919from mcp .types import Tool
2020
21+ # Mark all tests in this file as integration tests (spawn subprocesses)
22+ pytestmark = [pytest .mark .integration ]
23+
24+
2125logger = logging .getLogger (__name__ )
2226SERVER_NAME = "test_sse_security_server"
2327
@@ -42,16 +46,22 @@ async def on_list_tools(self) -> list[Tool]:
4246 return []
4347
4448
45- def run_server_with_settings (port : int , security_settings : TransportSecuritySettings | None = None ):
49+ def run_server_with_settings (
50+ port : int , security_settings : TransportSecuritySettings | None = None
51+ ):
4652 """Run the SSE server with specified security settings."""
4753 app = SecurityTestServer ()
4854 sse_transport = SseServerTransport ("/messages/" , security_settings )
4955
5056 async def handle_sse (request : Request ):
5157 try :
52- async with sse_transport .connect_sse (request .scope , request .receive , request ._send ) as streams :
58+ async with sse_transport .connect_sse (
59+ request .scope , request .receive , request ._send
60+ ) as streams :
5361 if streams :
54- await app .run (streams [0 ], streams [1 ], app .create_initialization_options ())
62+ await app .run (
63+ streams [0 ], streams [1 ], app .create_initialization_options ()
64+ )
5565 except ValueError as e :
5666 # Validation error was already handled inside connect_sse
5767 logger .debug (f"SSE connection failed validation: { e } " )
@@ -66,9 +76,13 @@ async def handle_sse(request: Request):
6676 uvicorn .run (starlette_app , host = "127.0.0.1" , port = port , log_level = "error" )
6777
6878
69- def start_server_process (port : int , security_settings : TransportSecuritySettings | None = None ):
79+ def start_server_process (
80+ port : int , security_settings : TransportSecuritySettings | None = None
81+ ):
7082 """Start server in a separate process."""
71- process = multiprocessing .Process (target = run_server_with_settings , args = (port , security_settings ))
83+ process = multiprocessing .Process (
84+ target = run_server_with_settings , args = (port , security_settings )
85+ )
7286 process .start ()
7387 # Give server time to start
7488 time .sleep (1 )
@@ -84,7 +98,9 @@ async def test_sse_security_default_settings(server_port: int):
8498 headers = {"Host" : "evil.com" , "Origin" : "http://evil.com" }
8599
86100 async with httpx .AsyncClient (timeout = 5.0 ) as client :
87- async with client .stream ("GET" , f"http://127.0.0.1:{ server_port } /sse" , headers = headers ) as response :
101+ async with client .stream (
102+ "GET" , f"http://127.0.0.1:{ server_port } /sse" , headers = headers
103+ ) as response :
88104 assert response .status_code == 200
89105 finally :
90106 process .terminate ()
@@ -95,15 +111,19 @@ async def test_sse_security_default_settings(server_port: int):
95111async def test_sse_security_invalid_host_header (server_port : int ):
96112 """Test SSE with invalid Host header."""
97113 # Enable security by providing settings with an empty allowed_hosts list
98- security_settings = TransportSecuritySettings (enable_dns_rebinding_protection = True , allowed_hosts = ["example.com" ])
114+ security_settings = TransportSecuritySettings (
115+ enable_dns_rebinding_protection = True , allowed_hosts = ["example.com" ]
116+ )
99117 process = start_server_process (server_port , security_settings )
100118
101119 try :
102120 # Test with invalid host header
103121 headers = {"Host" : "evil.com" }
104122
105123 async with httpx .AsyncClient () as client :
106- response = await client .get (f"http://127.0.0.1:{ server_port } /sse" , headers = headers )
124+ response = await client .get (
125+ f"http://127.0.0.1:{ server_port } /sse" , headers = headers
126+ )
107127 assert response .status_code == 421
108128 assert response .text == "Invalid Host header"
109129
@@ -117,7 +137,9 @@ async def test_sse_security_invalid_origin_header(server_port: int):
117137 """Test SSE with invalid Origin header."""
118138 # Configure security to allow the host but restrict origins
119139 security_settings = TransportSecuritySettings (
120- enable_dns_rebinding_protection = True , allowed_hosts = ["127.0.0.1:*" ], allowed_origins = ["http://localhost:*" ]
140+ enable_dns_rebinding_protection = True ,
141+ allowed_hosts = ["127.0.0.1:*" ],
142+ allowed_origins = ["http://localhost:*" ],
121143 )
122144 process = start_server_process (server_port , security_settings )
123145
@@ -126,7 +148,9 @@ async def test_sse_security_invalid_origin_header(server_port: int):
126148 headers = {"Origin" : "http://evil.com" }
127149
128150 async with httpx .AsyncClient () as client :
129- response = await client .get (f"http://127.0.0.1:{ server_port } /sse" , headers = headers )
151+ response = await client .get (
152+ f"http://127.0.0.1:{ server_port } /sse" , headers = headers
153+ )
130154 assert response .status_code == 400
131155 assert response .text == "Invalid Origin header"
132156
@@ -140,7 +164,9 @@ async def test_sse_security_post_invalid_content_type(server_port: int):
140164 """Test POST endpoint with invalid Content-Type header."""
141165 # Configure security to allow the host
142166 security_settings = TransportSecuritySettings (
143- enable_dns_rebinding_protection = True , allowed_hosts = ["127.0.0.1:*" ], allowed_origins = ["http://127.0.0.1:*" ]
167+ enable_dns_rebinding_protection = True ,
168+ allowed_hosts = ["127.0.0.1:*" ],
169+ allowed_origins = ["http://127.0.0.1:*" ],
144170 )
145171 process = start_server_process (server_port , security_settings )
146172
@@ -158,7 +184,8 @@ async def test_sse_security_post_invalid_content_type(server_port: int):
158184
159185 # Test POST with missing content type
160186 response = await client .post (
161- f"http://127.0.0.1:{ server_port } /messages/?session_id={ fake_session_id } " , content = "test"
187+ f"http://127.0.0.1:{ server_port } /messages/?session_id={ fake_session_id } " ,
188+ content = "test" ,
162189 )
163190 assert response .status_code == 400
164191 assert response .text == "Invalid Content-Type header"
@@ -180,7 +207,9 @@ async def test_sse_security_disabled(server_port: int):
180207
181208 async with httpx .AsyncClient (timeout = 5.0 ) as client :
182209 # For SSE endpoints, we need to use stream to avoid timeout
183- async with client .stream ("GET" , f"http://127.0.0.1:{ server_port } /sse" , headers = headers ) as response :
210+ async with client .stream (
211+ "GET" , f"http://127.0.0.1:{ server_port } /sse" , headers = headers
212+ ) as response :
184213 # Should connect successfully even with invalid host
185214 assert response .status_code == 200
186215
@@ -205,15 +234,19 @@ async def test_sse_security_custom_allowed_hosts(server_port: int):
205234
206235 async with httpx .AsyncClient (timeout = 5.0 ) as client :
207236 # For SSE endpoints, we need to use stream to avoid timeout
208- async with client .stream ("GET" , f"http://127.0.0.1:{ server_port } /sse" , headers = headers ) as response :
237+ async with client .stream (
238+ "GET" , f"http://127.0.0.1:{ server_port } /sse" , headers = headers
239+ ) as response :
209240 # Should connect successfully with custom host
210241 assert response .status_code == 200
211242
212243 # Test with non-allowed host
213244 headers = {"Host" : "evil.com" }
214245
215246 async with httpx .AsyncClient () as client :
216- response = await client .get (f"http://127.0.0.1:{ server_port } /sse" , headers = headers )
247+ response = await client .get (
248+ f"http://127.0.0.1:{ server_port } /sse" , headers = headers
249+ )
217250 assert response .status_code == 421
218251 assert response .text == "Invalid Host header"
219252
@@ -239,15 +272,19 @@ async def test_sse_security_wildcard_ports(server_port: int):
239272
240273 async with httpx .AsyncClient (timeout = 5.0 ) as client :
241274 # For SSE endpoints, we need to use stream to avoid timeout
242- async with client .stream ("GET" , f"http://127.0.0.1:{ server_port } /sse" , headers = headers ) as response :
275+ async with client .stream (
276+ "GET" , f"http://127.0.0.1:{ server_port } /sse" , headers = headers
277+ ) as response :
243278 # Should connect successfully with any port
244279 assert response .status_code == 200
245280
246281 headers = {"Origin" : f"http://localhost:{ test_port } " }
247282
248283 async with httpx .AsyncClient (timeout = 5.0 ) as client :
249284 # For SSE endpoints, we need to use stream to avoid timeout
250- async with client .stream ("GET" , f"http://127.0.0.1:{ server_port } /sse" , headers = headers ) as response :
285+ async with client .stream (
286+ "GET" , f"http://127.0.0.1:{ server_port } /sse" , headers = headers
287+ ) as response :
251288 # Should connect successfully with any port
252289 assert response .status_code == 200
253290
@@ -261,7 +298,9 @@ async def test_sse_security_post_valid_content_type(server_port: int):
261298 """Test POST endpoint with valid Content-Type headers."""
262299 # Configure security to allow the host
263300 security_settings = TransportSecuritySettings (
264- enable_dns_rebinding_protection = True , allowed_hosts = ["127.0.0.1:*" ], allowed_origins = ["http://127.0.0.1:*" ]
301+ enable_dns_rebinding_protection = True ,
302+ allowed_hosts = ["127.0.0.1:*" ],
303+ allowed_origins = ["http://127.0.0.1:*" ],
265304 )
266305 process = start_server_process (server_port , security_settings )
267306
0 commit comments