Skip to content

Commit 5462caa

Browse files
BSmick6claude
andcommitted
fix(security): warn on empty allowed_hosts and improve 421/403 response bodies
When TransportSecuritySettings is constructed with DNS rebinding protection enabled but allowed_hosts=[], every request is silently rejected with a bare HTTP 421 — hard to diagnose without reading the SDK source. Add a model_validator that emits a logger.warning at construction time pointing users at allowed_hosts configuration. Also include the received header value and a configuration hint in the 421/403 response bodies. Removes all pragma: no cover markers from transport_security.py by adding a direct unit test file (test_transport_security.py) that exercises all branches without subprocesses. Updates the existing integration tests to use substring matching now that the response bodies carry extra context. Closes #2688 Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent 3eb5799 commit 5462caa

4 files changed

Lines changed: 223 additions & 21 deletions

File tree

src/mcp/server/transport_security.py

Lines changed: 35 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,10 @@
22

33
import logging
44

5-
from pydantic import BaseModel, Field
5+
from pydantic import BaseModel, Field, model_validator
66
from starlette.requests import Request
77
from starlette.responses import Response
8+
from typing_extensions import Self
89

910
logger = logging.getLogger(__name__)
1011

@@ -31,6 +32,17 @@ class TransportSecuritySettings(BaseModel):
3132
Only applies when `enable_dns_rebinding_protection` is `True`.
3233
"""
3334

35+
@model_validator(mode="after")
36+
def _warn_if_protection_enabled_with_empty_allowlist(self) -> Self:
37+
if self.enable_dns_rebinding_protection and not self.allowed_hosts:
38+
logger.warning(
39+
"TransportSecuritySettings has DNS rebinding protection enabled but "
40+
"allowed_hosts is empty — all requests will be rejected with HTTP 421. "
41+
"Set allowed_hosts to your server's hostname(s), e.g. "
42+
'TransportSecuritySettings(allowed_hosts=["your-host.example.com:*"])'
43+
)
44+
return self
45+
3446

3547
# TODO(Marcelo): This should be a proper ASGI middleware. I'm sad to see this.
3648
class TransportSecurityMiddleware:
@@ -40,7 +52,7 @@ def __init__(self, settings: TransportSecuritySettings | None = None):
4052
# If not specified, disable DNS rebinding protection by default for backwards compatibility
4153
self.settings = settings or TransportSecuritySettings(enable_dns_rebinding_protection=False)
4254

43-
def _validate_host(self, host: str | None) -> bool: # pragma: no cover
55+
def _validate_host(self, host: str | None) -> bool:
4456
"""Validate the Host header against allowed values."""
4557
if not host:
4658
logger.warning("Missing Host header in request")
@@ -62,7 +74,7 @@ def _validate_host(self, host: str | None) -> bool: # pragma: no cover
6274
logger.warning(f"Invalid Host header: {host}")
6375
return False
6476

65-
def _validate_origin(self, origin: str | None) -> bool: # pragma: no cover
77+
def _validate_origin(self, origin: str | None) -> bool:
6678
"""Validate the Origin header against allowed values."""
6779
# Origin can be absent for same-origin requests
6880
if not origin:
@@ -94,7 +106,7 @@ async def validate_request(self, request: Request, is_post: bool = False) -> Res
94106
Returns None if validation passes, or an error Response if validation fails.
95107
"""
96108
# Always validate Content-Type for POST requests
97-
if is_post: # pragma: no branch
109+
if is_post:
98110
content_type = request.headers.get("content-type")
99111
if not self._validate_content_type(content_type):
100112
return Response("Invalid Content-Type header", status_code=400)
@@ -103,14 +115,22 @@ async def validate_request(self, request: Request, is_post: bool = False) -> Res
103115
if not self.settings.enable_dns_rebinding_protection:
104116
return None
105117

106-
# Validate Host header # pragma: no cover
107-
host = request.headers.get("host") # pragma: no cover
108-
if not self._validate_host(host): # pragma: no cover
109-
return Response("Invalid Host header", status_code=421) # pragma: no cover
110-
111-
# Validate Origin header # pragma: no cover
112-
origin = request.headers.get("origin") # pragma: no cover
113-
if not self._validate_origin(origin): # pragma: no cover
114-
return Response("Invalid Origin header", status_code=403) # pragma: no cover
115-
116-
return None # pragma: no cover
118+
# Validate Host header
119+
host = request.headers.get("host")
120+
if not self._validate_host(host):
121+
return Response(
122+
f"Invalid Host header: {host!r}. "
123+
"Configure TransportSecuritySettings(allowed_hosts=[...]) with your server's hostname.",
124+
status_code=421,
125+
)
126+
127+
# Validate Origin header
128+
origin = request.headers.get("origin")
129+
if not self._validate_origin(origin):
130+
return Response(
131+
f"Invalid Origin header: {origin!r}. "
132+
"Configure TransportSecuritySettings(allowed_origins=[...]) with your server's origin.",
133+
status_code=403,
134+
)
135+
136+
return None

tests/server/test_sse_security.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ async def test_sse_security_invalid_host_header(server_port: int):
105105
async with httpx.AsyncClient() as client:
106106
response = await client.get(f"http://127.0.0.1:{server_port}/sse", headers=headers)
107107
assert response.status_code == 421
108-
assert response.text == "Invalid Host header"
108+
assert "Invalid Host header" in response.text
109109

110110
finally:
111111
process.terminate()
@@ -128,7 +128,7 @@ async def test_sse_security_invalid_origin_header(server_port: int):
128128
async with httpx.AsyncClient() as client:
129129
response = await client.get(f"http://127.0.0.1:{server_port}/sse", headers=headers)
130130
assert response.status_code == 403
131-
assert response.text == "Invalid Origin header"
131+
assert "Invalid Origin header" in response.text
132132

133133
finally:
134134
process.terminate()
@@ -215,7 +215,7 @@ async def test_sse_security_custom_allowed_hosts(server_port: int):
215215
async with httpx.AsyncClient() as client:
216216
response = await client.get(f"http://127.0.0.1:{server_port}/sse", headers=headers)
217217
assert response.status_code == 421
218-
assert response.text == "Invalid Host header"
218+
assert "Invalid Host header" in response.text
219219

220220
finally:
221221
process.terminate()

tests/server/test_streamable_http_security.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ async def test_streamable_http_security_invalid_host_header(server_port: int):
126126
headers=headers,
127127
)
128128
assert response.status_code == 421
129-
assert response.text == "Invalid Host header"
129+
assert "Invalid Host header" in response.text
130130

131131
finally:
132132
process.terminate()
@@ -154,7 +154,7 @@ async def test_streamable_http_security_invalid_origin_header(server_port: int):
154154
headers=headers,
155155
)
156156
assert response.status_code == 403
157-
assert response.text == "Invalid Origin header"
157+
assert "Invalid Origin header" in response.text
158158

159159
finally:
160160
process.terminate()
@@ -269,7 +269,7 @@ async def test_streamable_http_security_get_request(server_port: int):
269269
async with httpx.AsyncClient(timeout=5.0) as client:
270270
response = await client.get(f"http://127.0.0.1:{server_port}/", headers=headers)
271271
assert response.status_code == 421
272-
assert response.text == "Invalid Host header"
272+
assert "Invalid Host header" in response.text
273273

274274
# Test GET request with valid host header
275275
headers = {
Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,182 @@
1+
"""Unit tests for TransportSecuritySettings and TransportSecurityMiddleware."""
2+
3+
import logging
4+
5+
import pytest
6+
from starlette.requests import Request
7+
8+
from mcp.server.transport_security import TransportSecurityMiddleware, TransportSecuritySettings
9+
10+
11+
def make_request(headers: dict[str, str], method: str = "GET") -> Request:
12+
scope = {
13+
"type": "http",
14+
"method": method,
15+
"headers": [(k.lower().encode(), v.encode()) for k, v in headers.items()],
16+
"path": "/",
17+
"query_string": b"",
18+
}
19+
return Request(scope)
20+
21+
22+
# ---------------------------------------------------------------------------
23+
# TransportSecuritySettings — construction-time warning
24+
# ---------------------------------------------------------------------------
25+
26+
27+
def test_no_warning_when_protection_disabled(caplog: pytest.LogCaptureFixture) -> None:
28+
with caplog.at_level(logging.WARNING, logger="mcp.server.transport_security"):
29+
TransportSecuritySettings(enable_dns_rebinding_protection=False)
30+
assert not caplog.records
31+
32+
33+
def test_no_warning_when_allowed_hosts_populated(caplog: pytest.LogCaptureFixture) -> None:
34+
with caplog.at_level(logging.WARNING, logger="mcp.server.transport_security"):
35+
TransportSecuritySettings(
36+
enable_dns_rebinding_protection=True,
37+
allowed_hosts=["example.com"],
38+
)
39+
assert not caplog.records
40+
41+
42+
def test_warning_when_protection_enabled_with_empty_allowed_hosts(caplog: pytest.LogCaptureFixture) -> None:
43+
with caplog.at_level(logging.WARNING, logger="mcp.server.transport_security"):
44+
TransportSecuritySettings(enable_dns_rebinding_protection=True)
45+
assert len(caplog.records) == 1
46+
assert "allowed_hosts is empty" in caplog.records[0].message
47+
assert "HTTP 421" in caplog.records[0].message
48+
assert "allowed_hosts=" in caplog.records[0].message
49+
50+
51+
# ---------------------------------------------------------------------------
52+
# TransportSecurityMiddleware._validate_host
53+
# ---------------------------------------------------------------------------
54+
55+
56+
def test_validate_host_missing_host() -> None:
57+
m = TransportSecurityMiddleware(TransportSecuritySettings(allowed_hosts=["example.com"]))
58+
assert m._validate_host(None) is False
59+
60+
61+
def test_validate_host_exact_match() -> None:
62+
m = TransportSecurityMiddleware(TransportSecuritySettings(allowed_hosts=["example.com"]))
63+
assert m._validate_host("example.com") is True
64+
65+
66+
def test_validate_host_exact_no_match() -> None:
67+
m = TransportSecurityMiddleware(TransportSecuritySettings(allowed_hosts=["example.com"]))
68+
assert m._validate_host("other.com") is False
69+
70+
71+
def test_validate_host_port_wildcard_match() -> None:
72+
m = TransportSecurityMiddleware(TransportSecuritySettings(allowed_hosts=["localhost:*"]))
73+
assert m._validate_host("localhost:8080") is True
74+
75+
76+
def test_validate_host_port_wildcard_different_base() -> None:
77+
m = TransportSecurityMiddleware(TransportSecuritySettings(allowed_hosts=["localhost:*"]))
78+
assert m._validate_host("other:8080") is False
79+
80+
81+
def test_validate_host_port_wildcard_no_port() -> None:
82+
m = TransportSecurityMiddleware(TransportSecuritySettings(allowed_hosts=["localhost:*"]))
83+
assert m._validate_host("localhost") is False
84+
85+
86+
# ---------------------------------------------------------------------------
87+
# TransportSecurityMiddleware._validate_origin
88+
# ---------------------------------------------------------------------------
89+
90+
91+
def test_validate_origin_absent_is_allowed() -> None:
92+
m = TransportSecurityMiddleware(TransportSecuritySettings(allowed_origins=["http://example.com"]))
93+
assert m._validate_origin(None) is True
94+
95+
96+
def test_validate_origin_exact_match() -> None:
97+
m = TransportSecurityMiddleware(TransportSecuritySettings(allowed_origins=["http://example.com"]))
98+
assert m._validate_origin("http://example.com") is True
99+
100+
101+
def test_validate_origin_exact_no_match() -> None:
102+
m = TransportSecurityMiddleware(TransportSecuritySettings(allowed_origins=["http://example.com"]))
103+
assert m._validate_origin("http://other.com") is False
104+
105+
106+
def test_validate_origin_port_wildcard_match() -> None:
107+
m = TransportSecurityMiddleware(TransportSecuritySettings(allowed_origins=["http://localhost:*"]))
108+
assert m._validate_origin("http://localhost:3000") is True
109+
110+
111+
def test_validate_origin_port_wildcard_different_base() -> None:
112+
m = TransportSecurityMiddleware(TransportSecuritySettings(allowed_origins=["http://localhost:*"]))
113+
assert m._validate_origin("http://other:3000") is False
114+
115+
116+
# ---------------------------------------------------------------------------
117+
# TransportSecurityMiddleware.validate_request
118+
# ---------------------------------------------------------------------------
119+
120+
121+
@pytest.mark.anyio
122+
async def test_validate_request_post_valid_content_type() -> None:
123+
m = TransportSecurityMiddleware(TransportSecuritySettings(enable_dns_rebinding_protection=False))
124+
request = make_request({"content-type": "application/json"}, method="POST")
125+
assert await m.validate_request(request, is_post=True) is None
126+
127+
128+
@pytest.mark.anyio
129+
async def test_validate_request_post_invalid_content_type() -> None:
130+
m = TransportSecurityMiddleware(TransportSecuritySettings(enable_dns_rebinding_protection=False))
131+
request = make_request({"content-type": "text/plain"}, method="POST")
132+
response = await m.validate_request(request, is_post=True)
133+
assert response is not None
134+
assert response.status_code == 400
135+
136+
137+
@pytest.mark.anyio
138+
async def test_validate_request_get_skips_content_type() -> None:
139+
m = TransportSecurityMiddleware(TransportSecuritySettings(enable_dns_rebinding_protection=False))
140+
request = make_request({})
141+
assert await m.validate_request(request, is_post=False) is None
142+
143+
144+
@pytest.mark.anyio
145+
async def test_validate_request_protection_disabled_allows_any_host() -> None:
146+
m = TransportSecurityMiddleware(TransportSecuritySettings(enable_dns_rebinding_protection=False))
147+
request = make_request({"host": "attacker.example.com"})
148+
assert await m.validate_request(request) is None
149+
150+
151+
@pytest.mark.anyio
152+
async def test_validate_request_valid_host_and_no_origin() -> None:
153+
m = TransportSecurityMiddleware(TransportSecuritySettings(allowed_hosts=["example.com"]))
154+
request = make_request({"host": "example.com"})
155+
assert await m.validate_request(request) is None
156+
157+
158+
@pytest.mark.anyio
159+
async def test_validate_request_invalid_host_returns_421_with_detail() -> None:
160+
m = TransportSecurityMiddleware(TransportSecuritySettings(allowed_hosts=["example.com"]))
161+
request = make_request({"host": "attacker.com"})
162+
response = await m.validate_request(request)
163+
assert response is not None
164+
assert response.status_code == 421
165+
assert b"attacker.com" in response.body
166+
assert b"allowed_hosts" in response.body
167+
168+
169+
@pytest.mark.anyio
170+
async def test_validate_request_invalid_origin_returns_403_with_detail() -> None:
171+
m = TransportSecurityMiddleware(
172+
TransportSecuritySettings(
173+
allowed_hosts=["example.com"],
174+
allowed_origins=["http://example.com"],
175+
)
176+
)
177+
request = make_request({"host": "example.com", "origin": "http://attacker.com"})
178+
response = await m.validate_request(request)
179+
assert response is not None
180+
assert response.status_code == 403
181+
assert b"attacker.com" in response.body
182+
assert b"allowed_origins" in response.body

0 commit comments

Comments
 (0)