Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 49 additions & 1 deletion src/adcp/server/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -1446,6 +1446,45 @@ async def _serve() -> None:
sock.close()


def _expand_allowed_hosts(hosts: Sequence[str]) -> list[str]:
"""Synthesize ``host:*`` siblings for bare hosts.

FastMCP's :class:`TransportSecurityMiddleware` matches the request's
``Host`` header literally against the configured ``allowed_hosts``
list. A bare host like ``acme.localhost`` matches a request without
a port suffix; the same request from a browser hitting
``http://acme.localhost:3001`` carries ``Host: acme.localhost:3001``
and is rejected with ``421 Misdirected Request``.

Adopters had to register both ``acme.localhost`` and
``acme.localhost:*`` explicitly. This helper synthesizes the second
form when the input has no ``:`` separator, mirroring the
port-stripping done in ``InMemorySubdomainTenantRouter`` so the two
surfaces stay symmetric. Hosts that already include ``:`` (already
have an explicit port or wildcard) pass through unchanged.

Idempotent: if the adopter passed both ``acme.localhost`` and
``acme.localhost:*``, the result still contains each only once.

IPv6 literals (bracketed ``[::1]`` or raw ``::1``) contain ``:`` and
pass through without synthesis — no malformed ``::1:*`` siblings.
Adopters running on custom IPv6 hosts pass the explicit
``[::1]:*`` form themselves.
"""
seen: set[str] = set()
result: list[str] = []
for host in hosts:
if host not in seen:
seen.add(host)
result.append(host)
if ":" not in host:
wildcard = f"{host}:*"
if wildcard not in seen:
seen.add(wildcard)
result.append(wildcard)
return result


def create_mcp_server(
handler: ADCPHandler[Any],
*,
Expand Down Expand Up @@ -1587,6 +1626,12 @@ def create_mcp_server(
# tenant table (e.g. :class:`SubdomainTenantMiddleware`) can set
# ``enable_dns_rebinding_protection=False`` so the MCP-layer check
# doesn't duplicate the upstream validation.
#
# ``_expand_allowed_hosts`` synthesizes the ``host:*`` sibling for
# any bare host (no ``:``) so adopters who pass ``acme.localhost``
# also cover requests on ``acme.localhost:3001``. Mirrors the port
# stripping :class:`InMemorySubdomainTenantRouter` does at lookup
# time so the two surfaces stay symmetric.
if (
enable_dns_rebinding_protection is not None
or allowed_hosts is not None
Expand All @@ -1600,7 +1645,10 @@ def create_mcp_server(
if enable_dns_rebinding_protection is not None:
ts.enable_dns_rebinding_protection = enable_dns_rebinding_protection
if allowed_hosts:
ts.allowed_hosts = [*ts.allowed_hosts, *allowed_hosts]
ts.allowed_hosts = [
*ts.allowed_hosts,
*_expand_allowed_hosts(allowed_hosts),
]
if allowed_origins:
ts.allowed_origins = [*ts.allowed_origins, *allowed_origins]
_register_handler_tools(
Expand Down
88 changes: 88 additions & 0 deletions tests/test_serve_transport_security.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,3 +82,91 @@ def test_enable_dns_rebinding_protection_false_disables_check() -> None:
ts = mcp.settings.transport_security
assert ts is not None
assert ts.enable_dns_rebinding_protection is False


# ---- Bare-host → ``host:*`` synthesis (issue #518) ----


def test_bare_host_synthesizes_port_wildcard_sibling() -> None:
"""``allowed_hosts=['acme.localhost']`` registers both
``acme.localhost`` and ``acme.localhost:*`` so requests on either
bare or port-suffixed Host headers pass the transport check.

Mirrors the port-stripping :class:`InMemorySubdomainTenantRouter`
does at lookup time — registering once should cover both surfaces.
"""
mcp = create_mcp_server(
_StubHandler(),
name="t",
allowed_hosts=["acme.localhost"],
)
ts = mcp.settings.transport_security
assert ts is not None
assert "acme.localhost" in ts.allowed_hosts
assert "acme.localhost:*" in ts.allowed_hosts


def test_explicit_port_wildcard_passes_through_unchanged() -> None:
"""When the adopter already passes ``acme.localhost:*``, no
further synthesis happens — the bare ``acme.localhost`` is NOT
auto-added because the input has an explicit ``:`` separator
(the adopter signaled they're managing the form themselves)."""
mcp = create_mcp_server(
_StubHandler(),
name="t",
allowed_hosts=["acme.localhost:*"],
)
ts = mcp.settings.transport_security
assert ts is not None
assert "acme.localhost:*" in ts.allowed_hosts
# Adopter passed the port-wildcard form; we don't second-guess.
assert "acme.localhost" not in ts.allowed_hosts


def test_explicit_port_passes_through_unchanged() -> None:
"""A specific port (``acme.localhost:3001``) is left alone — the
``:`` discriminator covers any port-bearing form."""
mcp = create_mcp_server(
_StubHandler(),
name="t",
allowed_hosts=["acme.localhost:3001"],
)
ts = mcp.settings.transport_security
assert ts is not None
assert "acme.localhost:3001" in ts.allowed_hosts
assert "acme.localhost" not in ts.allowed_hosts


def test_mixed_bare_and_explicit_hosts_each_get_correct_treatment() -> None:
"""Bare hosts get the wildcard sibling; explicit-port hosts pass
through. A list with both forms produces the right combined set."""
mcp = create_mcp_server(
_StubHandler(),
name="t",
allowed_hosts=["acme.localhost", "beta.example.com:8443"],
)
ts = mcp.settings.transport_security
assert ts is not None
# Bare host gets both forms.
assert "acme.localhost" in ts.allowed_hosts
assert "acme.localhost:*" in ts.allowed_hosts
# Explicit-port host stays as-is, no bare-host synthesis.
assert "beta.example.com:8443" in ts.allowed_hosts
assert "beta.example.com" not in ts.allowed_hosts


def test_idempotent_when_both_forms_passed() -> None:
"""Adopter who already passes both ``acme.localhost`` and
``acme.localhost:*`` doesn't end up with duplicates after
expansion."""
mcp = create_mcp_server(
_StubHandler(),
name="t",
allowed_hosts=["acme.localhost", "acme.localhost:*"],
)
ts = mcp.settings.transport_security
assert ts is not None
bare_count = sum(1 for h in ts.allowed_hosts if h == "acme.localhost")
wildcard_count = sum(1 for h in ts.allowed_hosts if h == "acme.localhost:*")
assert bare_count == 1
assert wildcard_count == 1
Loading