Skip to content

Commit 69211a3

Browse files
bokelleyclaude
andauthored
feat(server): add asgi_middleware param to serve() (#441)
* feat(server): asgi_middleware accepts Callable factories alongside tuple form Closes #415 https://claude.ai/code/session_01GttQNCHuVSVUyRW5knbitw * fix(server): use Callable[..., Any] in tuple annotation, add ASGIMiddlewareEntry alias Fixes the tuple form annotation so functools.partial is accepted as the first element without a mypy error. Adds ASGIMiddlewareEntry type alias to reduce repetition across five annotation sites. Adds a logger.warning when asgi_middleware is passed with transport='stdio' (where it is silently ignored). https://claude.ai/code/session_01GttQNCHuVSVUyRW5knbitw * feat(server): export ASGIMiddlewareEntry from adcp.server Mirrors the SkillMiddleware export pattern so users can annotate their own asgi_middleware lists with the canonical type alias. https://claude.ai/code/session_01GttQNCHuVSVUyRW5knbitw --------- Co-authored-by: Claude <noreply@anthropic.com>
1 parent abf0d65 commit 69211a3

3 files changed

Lines changed: 124 additions & 28 deletions

File tree

src/adcp/server/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,7 @@ async def get_products(params, context=None):
130130
update_media_buy_response,
131131
)
132132
from adcp.server.serve import (
133+
ASGIMiddlewareEntry,
133134
ContextFactory,
134135
RequestMetadata,
135136
SkillMiddleware,
@@ -186,6 +187,7 @@ async def get_products(params, context=None):
186187
# A2A integration
187188
"ADCPAgentExecutor",
188189
"MessageParser",
190+
"ASGIMiddlewareEntry",
189191
"SkillMiddleware",
190192
"create_a2a_server",
191193
# Bearer-token auth middleware (seller-facing recipe)

src/adcp/server/serve.py

Lines changed: 66 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -406,6 +406,19 @@ def build_context(meta: RequestMetadata) -> ToolContext:
406406
mcp = create_mcp_server(MyAgent(), context_factory=build_context)
407407
"""
408408

409+
ASGIMiddlewareEntry = tuple[Callable[..., Any], dict[str, Any]] | Callable[..., Any]
410+
"""A single ASGI middleware entry for :func:`serve`'s ``asgi_middleware`` param.
411+
412+
Each entry is either:
413+
414+
- A ``(callable, kwargs)`` tuple — invoked as ``callable(app, **kwargs)``.
415+
Both plain class constructors and :func:`functools.partial` instances work
416+
as the first element.
417+
- A bare callable factory ``f(app) -> app`` — invoked as ``factory(app)``.
418+
419+
Both forms can be mixed in the same list.
420+
"""
421+
409422

410423
def serve(
411424
handler: ADCPHandler[Any] | Any,
@@ -420,7 +433,7 @@ def serve(
420433
task_store: TaskStore | None = None,
421434
push_config_store: PushNotificationConfigStore | None = None,
422435
middleware: Sequence[SkillMiddleware] | None = None,
423-
asgi_middleware: Sequence[tuple[type, dict[str, Any]]] | None = None,
436+
asgi_middleware: Sequence[ASGIMiddlewareEntry] | None = None,
424437
message_parser: MessageParser | None = None,
425438
advertise_all: bool = False,
426439
max_request_size: int | None = None,
@@ -472,23 +485,40 @@ def serve(
472485
rate limiting, tracing. Composes outermost-first. See
473486
:data:`SkillMiddleware` for the signature and composition
474487
semantics.
475-
asgi_middleware: Optional sequence of ``(MiddlewareClass, kwargs)``
476-
tuples — Starlette-shape ASGI middleware applied to the
477-
outer HTTP app before uvicorn binds. Use for cross-cutting
478-
HTTP concerns the SDK does not own: tenant resolution
479-
(:class:`adcp.server.SubdomainTenantMiddleware`), CORS,
480-
request-id propagation, IP allowlists, custom auth.
481-
Composes outermost-first — the first entry sees every
482-
request before later entries. Each class is invoked as
483-
``cls(app, **kwargs)``. Applied on every HTTP transport
484-
(``streamable-http``, ``a2a``, ``both``); ignored on
485-
``stdio``.
488+
asgi_middleware: Optional sequence of ASGI middleware entries
489+
applied to the outer HTTP app before uvicorn binds. Use for
490+
cross-cutting HTTP concerns the SDK does not own: tenant
491+
resolution (:class:`adcp.server.SubdomainTenantMiddleware`),
492+
CORS, request-id propagation, IP allowlists, custom auth.
493+
Composes outermost-first — the first entry sees every request
494+
before later entries. Applied on every HTTP transport
495+
(``streamable-http``, ``sse``, ``a2a``, ``both``); ignored
496+
on ``stdio``.
497+
498+
Each entry is either a ``(MiddlewareClass, kwargs)`` tuple
499+
invoked as ``cls(app, **kwargs)``, or a callable factory
500+
``f(app) -> app``. Both forms can appear in the same list.
486501
487502
Middleware sees ``lifespan`` and ``websocket`` scopes in
488503
addition to ``http`` — guard non-HTTP scopes by passing
489504
them through unchanged (``if scope['type'] != 'http':
490505
await self.app(scope, receive, send); return``) so the
491506
framework's lifespan composition still runs.
507+
508+
Example (tuple form)::
509+
510+
from starlette.middleware.cors import CORSMiddleware
511+
serve(handler, asgi_middleware=[
512+
(CORSMiddleware, {"allow_origins": ["*"]}),
513+
])
514+
515+
Example (callable factory form, e.g. with ``functools.partial``)::
516+
517+
import functools
518+
from starlette.middleware.cors import CORSMiddleware
519+
serve(handler, asgi_middleware=[
520+
functools.partial(CORSMiddleware, allow_origins=["*"]),
521+
])
492522
message_parser: Optional
493523
:data:`~adcp.server.a2a_server.MessageParser` callable for
494524
alternative A2A wire shapes (A2A transport only). The
@@ -690,11 +720,11 @@ async def force_account_status(self, account_id, status):
690720

691721

692722
def _prepend_debug_endpoint(
693-
asgi_middleware: Sequence[tuple[type, dict[str, Any]]] | None,
723+
asgi_middleware: Sequence[ASGIMiddlewareEntry] | None,
694724
*,
695725
enable_debug_endpoints: bool,
696726
debug_traffic_source: Callable[[], dict[str, int]] | None,
697-
) -> Sequence[tuple[type, dict[str, Any]]] | None:
727+
) -> Sequence[ASGIMiddlewareEntry] | None:
698728
"""Prepend :class:`DebugTrafficMiddleware` to the asgi_middleware
699729
sequence when debug endpoints are enabled.
700730
@@ -728,21 +758,27 @@ def _prepend_debug_endpoint(
728758

729759
def _apply_asgi_middleware(
730760
app: Any,
731-
asgi_middleware: Sequence[tuple[type, dict[str, Any]]] | None,
761+
asgi_middleware: Sequence[ASGIMiddlewareEntry] | None,
732762
) -> Any:
733763
"""Wrap ``app`` with operator-supplied Starlette-style ASGI middleware.
734764
735-
Each entry is ``(MiddlewareClass, kwargs)`` and is invoked as
736-
``cls(app, **kwargs)``. Composition is outermost-first — the first
737-
entry sees every request before later entries — so we wrap in
738-
reverse, matching :meth:`Starlette.add_middleware` semantics.
765+
Each entry is either ``(MiddlewareClass, kwargs)`` invoked as
766+
``cls(app, **kwargs)``, or a callable factory ``f(app) -> app`` invoked
767+
as ``factory(app)``. Both forms can appear in the same list. Composition
768+
is outermost-first — the first entry sees every request before later
769+
entries — so we wrap in reverse, matching :meth:`Starlette.add_middleware`
770+
semantics.
739771
740772
No-op when the sequence is empty or ``None``.
741773
"""
742774
if not asgi_middleware:
743775
return app
744-
for cls, kwargs in reversed(list(asgi_middleware)):
745-
app = cls(app, **kwargs)
776+
for entry in reversed(list(asgi_middleware)):
777+
if isinstance(entry, tuple):
778+
cls, kwargs = entry
779+
app = cls(app, **kwargs)
780+
else:
781+
app = entry(app)
746782
return app
747783

748784

@@ -952,7 +988,7 @@ def _serve_mcp(
952988
test_controller: TestControllerStore | None,
953989
context_factory: ContextFactory | None = None,
954990
middleware: Sequence[SkillMiddleware] | None = None,
955-
asgi_middleware: Sequence[tuple[type, dict[str, Any]]] | None = None,
991+
asgi_middleware: Sequence[ASGIMiddlewareEntry] | None = None,
956992
advertise_all: bool = False,
957993
max_request_size: int | None = None,
958994
streaming_responses: bool = False,
@@ -985,24 +1021,28 @@ def _serve_mcp(
9851021
_run_mcp_http(
9861022
mcp,
9871023
transport=transport,
988-
max_request_size=max_request_size,
9891024
asgi_middleware=asgi_middleware,
1025+
max_request_size=max_request_size,
9901026
discovery_name=name,
9911027
discovery_base_url=base_url,
9921028
discovery_specialisms=specialisms,
9931029
discovery_description=description,
9941030
)
9951031
else:
9961032
# stdio — no listening socket, nothing to configure.
1033+
if asgi_middleware:
1034+
logger.warning(
1035+
"asgi_middleware is ignored on transport='stdio'; " "ASGI middleware will not run"
1036+
)
9971037
mcp.run(transport=transport)
9981038

9991039

10001040
def _run_mcp_http(
10011041
mcp: Any,
10021042
*,
10031043
transport: str,
1044+
asgi_middleware: Sequence[ASGIMiddlewareEntry] | None = None,
10041045
max_request_size: int | None = None,
1005-
asgi_middleware: Sequence[tuple[type, dict[str, Any]]] | None = None,
10061046
discovery_name: str = "adcp-agent",
10071047
discovery_base_url: str | None = None,
10081048
discovery_specialisms: list[str] | None = None,
@@ -1080,7 +1120,7 @@ def _serve_a2a(
10801120
task_store: TaskStore | None = None,
10811121
push_config_store: PushNotificationConfigStore | None = None,
10821122
middleware: Sequence[SkillMiddleware] | None = None,
1083-
asgi_middleware: Sequence[tuple[type, dict[str, Any]]] | None = None,
1123+
asgi_middleware: Sequence[ASGIMiddlewareEntry] | None = None,
10841124
message_parser: MessageParser | None = None,
10851125
advertise_all: bool = False,
10861126
max_request_size: int | None = None,
@@ -1287,7 +1327,7 @@ def _serve_mcp_and_a2a(
12871327
task_store: TaskStore | None = None,
12881328
push_config_store: PushNotificationConfigStore | None = None,
12891329
middleware: Sequence[SkillMiddleware] | None = None,
1290-
asgi_middleware: Sequence[tuple[type, dict[str, Any]]] | None = None,
1330+
asgi_middleware: Sequence[ASGIMiddlewareEntry] | None = None,
12911331
message_parser: MessageParser | None = None,
12921332
advertise_all: bool = False,
12931333
max_request_size: int | None = None,

tests/test_serve_asgi_middleware.py

Lines changed: 56 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,14 @@
33
Operators wiring tenant routing, CORS, request-id propagation, and
44
custom auth use this kwarg to layer Starlette-style ASGI middleware
55
on the outer HTTP app before uvicorn binds. The kwarg accepts a
6-
sequence of ``(MiddlewareClass, kwargs)`` tuples and composes
7-
outermost-first.
6+
sequence of ``(MiddlewareClass, kwargs)`` tuples, callable factories,
7+
or a mix of both, and composes outermost-first.
88
"""
99

1010
from __future__ import annotations
1111

12+
import functools
13+
1214
from adcp.server.serve import _apply_asgi_middleware
1315

1416

@@ -65,3 +67,55 @@ def test_apply_asgi_middleware_passes_kwargs_through():
6567
assert isinstance(wrapped, _TaggingMiddleware)
6668
assert wrapped.name == "audit"
6769
assert wrapped.app is app
70+
71+
72+
def test_apply_asgi_middleware_callable_factory():
73+
"""Callable factory form ``f(app) -> app`` is accepted."""
74+
app = _NoOpAsgi()
75+
76+
def cors_factory(inner):
77+
return _TaggingMiddleware(inner, name="cors")
78+
79+
wrapped = _apply_asgi_middleware(app, [cors_factory])
80+
assert isinstance(wrapped, _TaggingMiddleware)
81+
assert wrapped.name == "cors"
82+
assert wrapped.app is app
83+
84+
85+
def test_apply_asgi_middleware_callable_factory_with_partial():
86+
"""``functools.partial`` is a valid callable factory."""
87+
app = _NoOpAsgi()
88+
factory = functools.partial(_TaggingMiddleware, name="partial-cors")
89+
wrapped = _apply_asgi_middleware(app, [factory])
90+
assert isinstance(wrapped, _TaggingMiddleware)
91+
assert wrapped.name == "partial-cors"
92+
assert wrapped.app is app
93+
94+
95+
def test_apply_asgi_middleware_mixed_tuple_and_callable_preserves_order():
96+
"""Mixed list composes outermost-first regardless of entry type.
97+
98+
Given ``[tuple_entry("outer"), callable("middle"), tuple_entry("inner")]``,
99+
the result must be outer → middle → inner → app, verified by walking
100+
the ``.app`` chain.
101+
"""
102+
app = _NoOpAsgi()
103+
104+
def middle_factory(inner):
105+
return _TaggingMiddleware(inner, name="middle")
106+
107+
wrapped = _apply_asgi_middleware(
108+
app,
109+
[
110+
(_TaggingMiddleware, {"name": "outer"}),
111+
middle_factory,
112+
(_TaggingMiddleware, {"name": "inner"}),
113+
],
114+
)
115+
assert isinstance(wrapped, _TaggingMiddleware)
116+
assert wrapped.name == "outer"
117+
assert isinstance(wrapped.app, _TaggingMiddleware)
118+
assert wrapped.app.name == "middle"
119+
assert isinstance(wrapped.app.app, _TaggingMiddleware)
120+
assert wrapped.app.app.name == "inner"
121+
assert wrapped.app.app.app is app

0 commit comments

Comments
 (0)