Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 48 additions & 29 deletions src/adcp/server/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -637,41 +637,60 @@ def __init__(self, app: Any, config: BearerTokenAuth) -> None:

async def __call__(self, scope: Any, receive: Any, send: Any) -> None:
# Lifespan + websocket pass through unchanged. Auth applies to
# HTTP requests only.
# HTTP requests only. No contextvar changes — these scopes are
# not dispatched to skill handlers that read auth vars.
if scope.get("type") != "http":
await self._app(scope, receive, send)
return

# CORS preflight is part of the public surface — browser-origin
# clients send ``OPTIONS`` before any auth'd POST. Returning 401
# here breaks the preflight and the buyer never gets a chance to
# retry with a token. Pass through; let the inner app's CORS
# handler (or operator-supplied ``asgi_middleware``) respond.
if scope.get("method") == "OPTIONS":
await self._app(scope, receive, send)
return
principal_token = None
tenant_token = None
metadata_token = None
try:
# CORS preflight and A2A discovery are part of the public surface.
# Set contextvars to None to prevent stale values from an enclosing
# task context from leaking into downstream code on these paths —
# mirrors BearerTokenAuthMiddleware's discovery branch.
if scope.get("method") == "OPTIONS" or scope.get("path", "") in _A2A_DISCOVERY_PATHS:
principal_token = current_principal.set(None)
tenant_token = current_tenant.set(None)
metadata_token = current_principal_metadata.set(None)
await self._app(scope, receive, send)
return

path = scope.get("path", "")
if path in _A2A_DISCOVERY_PATHS:
principal = self._authenticate_scope(scope)
if principal is None:
await self._send_unauthenticated(send)
return

# Stash both the duck-typed user (for DefaultServerCallContextBuilder)
# and the raw Principal (for downstream code reading scope['auth']).
# Mutating the scope dict before delegating propagates state to
# nested apps without copying.
scope["user"] = _A2AAuthenticatedUser(
display_name=principal.caller_identity,
tenant_id=principal.tenant_id,
principal_metadata=dict(principal.metadata) if principal.metadata else None,
)
scope["auth"] = principal
# Populate the same module-level ContextVars that BearerTokenAuthMiddleware
# sets on the MCP path. auth_context_factory and adopter code that reads
# current_principal.get() directly see the authenticated identity on A2A
# exactly as they do on MCP. Reset unconditionally in finally so a later
# task sharing this context can't read a stale principal.
principal_token = current_principal.set(principal.caller_identity)
tenant_token = current_tenant.set(principal.tenant_id)
metadata_token = current_principal_metadata.set(
dict(principal.metadata) if principal.metadata else None
)
await self._app(scope, receive, send)
return

principal = self._authenticate_scope(scope)
if principal is None:
await self._send_unauthenticated(send)
return

# Stash both the duck-typed user (for DefaultServerCallContextBuilder)
# and the raw Principal (for downstream code reading scope['auth']).
# Mutating the scope dict before delegating propagates state to
# nested apps without copying.
scope["user"] = _A2AAuthenticatedUser(
display_name=principal.caller_identity,
tenant_id=principal.tenant_id,
principal_metadata=dict(principal.metadata) if principal.metadata else None,
)
scope["auth"] = principal
await self._app(scope, receive, send)
finally:
if principal_token is not None:
current_principal.reset(principal_token)
if tenant_token is not None:
current_tenant.reset(tenant_token)
if metadata_token is not None:
current_principal_metadata.reset(metadata_token)

def _authenticate_scope(self, scope: Any) -> Principal | None:
"""Read + validate the bearer header off raw ASGI scope.
Expand Down
73 changes: 73 additions & 0 deletions tests/test_serve_auth_both.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,28 @@ async def inner(scope: Any, _receive: Any, _send: Any) -> None:
assert "auth" in passed_scope
assert passed_scope["auth"].caller_identity == "p-acme"

@pytest.mark.asyncio
async def test_valid_token_sets_current_principal_contextvar(self):
"""On auth success, current_principal must be populated inside the
inner app and reset to None after __call__ returns (#590 regression)."""
from adcp.server.auth import current_principal, current_tenant

captured: dict[str, str | None] = {}

async def inner(scope: Any, _receive: Any, _send: Any) -> None:
captured["principal"] = current_principal.get()
captured["tenant"] = current_tenant.get()

mw = A2ABearerAuthMiddleware(inner, _auth())
scope = self._scope(headers=[(b"authorization", b"Bearer good-token")])
await mw(scope, lambda: None, lambda _: None)

assert captured["principal"] == "p-acme"
assert captured["tenant"] == "acme"
# Verify reset-in-finally: contextvar must be cleared after __call__ returns.
assert current_principal.get() is None
assert current_tenant.get() is None

@pytest.mark.asyncio
async def test_missing_header_returns_401(self):
sent: list[dict] = []
Expand Down Expand Up @@ -302,6 +324,57 @@ async def test_a2a_jsonrpc_authenticated_passes_through() -> None:
assert response.status_code == 200


@pytest.mark.asyncio
async def test_a2a_auth_populates_current_principal_contextvar() -> None:
"""A2ABearerAuthMiddleware must set current_principal contextvar so
auth_context_factory and adopter code reading it directly see the
authenticated identity on A2A — same as MCP (regression for #590).

Verifies both that the var is populated inside the handler AND that it is
reset to None after the request completes (try/finally contract)."""
from adcp.server.a2a_server import create_a2a_server
from adcp.server.auth import current_principal, current_tenant

observed: dict[str, str | None] = {}

class _ContextCaptureHandler(ADCPHandler):
async def get_adcp_capabilities(self, params: Any, context: Any = None) -> dict[str, Any]:
return {"adcp": {"major_versions": [3]}, "supported_protocols": ["media_buy"]}

async def get_products(self, params: Any, context: Any = None) -> dict[str, Any]:
observed["principal"] = current_principal.get()
observed["tenant"] = current_tenant.get()
return {"products": []}

inner = create_a2a_server(_ContextCaptureHandler(), name="ctx-test", validation=None)
app = A2ABearerAuthMiddleware(inner, _auth())
body = {
"jsonrpc": "2.0",
"id": "1",
"method": "message/send",
"params": {
"message": {
"messageId": "m1",
"role": "user",
"parts": [{"kind": "data", "data": {"skill": "get_products", "parameters": {}}}],
}
},
}
async with LifespanManager(inner):
async with httpx.AsyncClient(
transport=httpx.ASGITransport(app=app), base_url="http://test"
) as client:
response = await client.post(
"/", json=body, headers={"Authorization": "Bearer good-token"}
)
assert response.status_code == 200
assert observed.get("principal") == "p-acme", "current_principal not set on A2A path"
assert observed.get("tenant") == "acme", "current_tenant not set on A2A path"
# Verify reset-in-finally: contextvar must be None after the request.
assert current_principal.get() is None
assert current_tenant.get() is None


# ===========================================================================
# transport="both": the regression case from #558
# ===========================================================================
Expand Down
Loading