Skip to content

Commit 72c1c17

Browse files
committed
Simplify auth server metadata discovery fallbacks
1 parent 830fc7e commit 72c1c17

File tree

2 files changed

+56
-392
lines changed

2 files changed

+56
-392
lines changed

src/mcp/client/auth.py

Lines changed: 42 additions & 132 deletions
Original file line numberDiff line numberDiff line change
@@ -176,9 +176,6 @@ def should_include_resource_param(self, protocol_version: str | None = None) ->
176176
return protocol_version >= "2025-06-18"
177177

178178

179-
OAuthDiscoveryStack = list[Callable[[], Awaitable[httpx.Request]]]
180-
181-
182179
class OAuthClientProvider(httpx.Auth):
183180
"""
184181
OAuth2 authentication for httpx.
@@ -254,117 +251,32 @@ async def _handle_protected_resource_response(self, response: httpx.Response) ->
254251
except ValidationError:
255252
pass
256253

257-
def _build_well_known_path(self, pathname: str, well_known_endpoint: str) -> str:
258-
"""Construct well-known path for OAuth metadata discovery."""
259-
well_known_path = f"/.well-known/{well_known_endpoint}{pathname}"
260-
if pathname.endswith("/"):
261-
# Strip trailing slash from pathname to avoid double slashes
262-
well_known_path = well_known_path[:-1]
263-
return well_known_path
264-
265-
def _build_well_known_fallback_url(self, well_known_endpoint: str) -> str:
266-
"""Construct fallback well-known URL for OAuth metadata discovery in legacy servers."""
267-
base_url = getattr(self.context, "discovery_base_url", "")
268-
if not base_url:
269-
raise OAuthFlowError("No base URL available for fallback discovery")
270-
271-
# Fallback to root discovery for legacy servers
272-
return urljoin(base_url, f"/.well-known/{well_known_endpoint}")
273-
274-
def _build_oidc_fallback_path(self, pathname: str, well_known_endpoint: str) -> str:
275-
"""Construct fallback well-known path for OIDC metadata discovery in legacy servers."""
276-
# Strip trailing slash from pathname to avoid double slashes
277-
clean_pathname = pathname[:-1] if pathname.endswith("/") else pathname
278-
# OIDC 1.0 appends the well-known path to the full AS URL
279-
return f"{clean_pathname}/.well-known/{well_known_endpoint}"
280-
281-
def _build_oidc_fallback_url(self, well_known_endpoint: str) -> str:
282-
"""Construct fallback well-known URL for OIDC metadata discovery in legacy servers."""
283-
if self.context.auth_server_url:
284-
auth_server_url = self.context.auth_server_url
285-
else:
286-
auth_server_url = self.context.server_url
287-
288-
parsed = urlparse(auth_server_url)
289-
well_known_path = self._build_oidc_fallback_path(parsed.path, well_known_endpoint)
290-
base_url = f"{parsed.scheme}://{parsed.netloc}"
291-
return urljoin(base_url, well_known_path)
292-
293-
def _should_attempt_fallback(self, response_status: int, discovery_stack: OAuthDiscoveryStack) -> bool:
294-
"""Determine if further fallback should be attempted."""
295-
return response_status == 404 and len(discovery_stack) > 0
296-
297-
async def _try_metadata_discovery(self, url: str) -> httpx.Request:
298-
"""Build metadata discovery request for a specific URL."""
299-
return httpx.Request("GET", url, headers={MCP_PROTOCOL_VERSION: LATEST_PROTOCOL_VERSION})
300-
301-
async def _discover_well_known_metadata(self, well_known_endpoint: str) -> httpx.Request:
302-
"""Build .well-known metadata discovery request with fallback support."""
303-
if self.context.auth_server_url:
304-
auth_server_url = self.context.auth_server_url
305-
else:
306-
auth_server_url = self.context.server_url
307-
308-
# Per RFC 8414, try path-aware discovery first
254+
def _get_discovery_urls(self) -> list[str]:
255+
"""Generate ordered list of (url, type) tuples for discovery attempts."""
256+
urls: list[str] = []
257+
auth_server_url = self.context.auth_server_url or self.context.server_url
309258
parsed = urlparse(auth_server_url)
310-
well_known_path = self._build_well_known_path(parsed.path, well_known_endpoint)
311259
base_url = f"{parsed.scheme}://{parsed.netloc}"
312-
url = urljoin(base_url, well_known_path)
313260

314-
# Store fallback info for use in response handler
315-
self.context.discovery_base_url = base_url
316-
self.context.discovery_pathname = parsed.path
261+
# RFC 8414: Path-aware OAuth discovery
262+
if parsed.path and parsed.path != "/":
263+
oauth_path = f"/.well-known/oauth-authorization-server{parsed.path.rstrip('/')}"
264+
urls.append(urljoin(base_url, oauth_path))
317265

318-
return await self._try_metadata_discovery(url)
266+
# OAuth root fallback
267+
urls.append(urljoin(base_url, "/.well-known/oauth-authorization-server"))
319268

320-
async def _discover_well_known_metadata_fallback(self, well_known_endpoint: str) -> httpx.Request:
321-
"""Build fallback OAuth metadata discovery request for legacy servers."""
322-
url = self._build_well_known_fallback_url(well_known_endpoint)
323-
return await self._try_metadata_discovery(url)
269+
# RFC 8414 section 5: Path-aware OIDC discovery
270+
# See https://www.rfc-editor.org/rfc/rfc8414.html#section-5
271+
if parsed.path and parsed.path != "/":
272+
oidc_path = f"/.well-known/openid-configuration{parsed.path.rstrip('/')}"
273+
urls.append(urljoin(base_url, oidc_path))
324274

325-
async def _discover_oauth_metadata(self) -> httpx.Request:
326-
"""Build OAuth metadata discovery request with fallback support."""
327-
return await self._discover_well_known_metadata("oauth-authorization-server")
275+
# OIDC 1.0 fallback (appends to full URL per OIDC spec)
276+
oidc_fallback = f"{auth_server_url.rstrip('/')}/.well-known/openid-configuration"
277+
urls.append(oidc_fallback)
328278

329-
async def _discover_oauth_metadata_fallback(self) -> httpx.Request:
330-
"""Build fallback OAuth metadata discovery request for legacy servers."""
331-
return await self._discover_well_known_metadata_fallback("oauth-authorization-server")
332-
333-
async def _discover_oidc_metadata(self) -> httpx.Request:
334-
"""
335-
Build fallback OIDC metadata discovery request.
336-
See https://www.rfc-editor.org/rfc/rfc8414.html#section-5
337-
"""
338-
return await self._discover_well_known_metadata("openid-configuration")
339-
340-
async def _discover_oidc_metadata_fallback(self) -> httpx.Request:
341-
"""
342-
Build fallback OIDC metadata discovery request for legacy servers.
343-
See https://www.rfc-editor.org/rfc/rfc8414.html#section-5
344-
"""
345-
url = self._build_oidc_fallback_url("openid-configuration")
346-
return await self._try_metadata_discovery(url)
347-
348-
async def _handle_oauth_metadata_response(
349-
self, response: httpx.Response, discovery_stack: OAuthDiscoveryStack
350-
) -> bool:
351-
"""Handle OAuth metadata response. Returns True if handled successfully."""
352-
if response.status_code == 200:
353-
try:
354-
content = await response.aread()
355-
metadata = OAuthMetadata.model_validate_json(content)
356-
self.context.oauth_metadata = metadata
357-
# Apply default scope if none specified
358-
if self.context.client_metadata.scope is None and metadata.scopes_supported is not None:
359-
self.context.client_metadata.scope = " ".join(metadata.scopes_supported)
360-
return True
361-
except ValidationError:
362-
pass
363-
364-
# Check if we should attempt fallback
365-
# True: No fallback needed (either success or non-404 error)
366-
# False: Signal that fallback should be attempted
367-
return not self._should_attempt_fallback(response.status_code, discovery_stack)
279+
return urls
368280

369281
async def _register_client(self) -> httpx.Request | None:
370282
"""Build registration request or skip if already registered."""
@@ -559,25 +471,16 @@ def _add_auth_header(self, request: httpx.Request) -> None:
559471
if self.context.current_tokens and self.context.current_tokens.access_token:
560472
request.headers["Authorization"] = f"Bearer {self.context.current_tokens.access_token}"
561473

562-
def _create_oauth_discovery_stack(self) -> OAuthDiscoveryStack:
563-
"""Create a stack of attempts to discover OAuth metadata."""
564-
discovery_attempts: OAuthDiscoveryStack = [
565-
# Start with path-aware OAuth discovery
566-
self._discover_oauth_metadata,
567-
# If path-aware discovery fails with 404, try fallback to root
568-
self._discover_oauth_metadata_fallback,
569-
# If root discovery fails with 404, fall back to OIDC 1.0 following
570-
# RFC 8414 path-aware semantics (see RFC 8414 section 5)
571-
self._discover_oidc_metadata,
572-
# If path-aware OIDC discovery failed with 404, fall back to OIDC 1.0
573-
# following OIDC 1.0 semantics (see RFC 8414 section 5)
574-
self._discover_oidc_metadata_fallback,
575-
]
576-
577-
# Reverse the list so we can call pop() without remembering we declared
578-
# this stack backwards for readability
579-
discovery_attempts.reverse()
580-
return discovery_attempts
474+
def _create_oauth_metadata_request(self, url: str) -> httpx.Request:
475+
return httpx.Request("GET", url, headers={MCP_PROTOCOL_VERSION: LATEST_PROTOCOL_VERSION})
476+
477+
async def _handle_oauth_metadata_response(self, response: httpx.Response) -> None:
478+
content = await response.aread()
479+
metadata = OAuthMetadata.model_validate_json(content)
480+
self.context.oauth_metadata = metadata
481+
# Apply default scope if needed
482+
if self.context.client_metadata.scope is None and metadata.scopes_supported is not None:
483+
self.context.client_metadata.scope = " ".join(metadata.scopes_supported)
581484

582485
async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.Request, httpx.Response]:
583486
"""HTTPX auth flow integration."""
@@ -612,12 +515,19 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.
612515
await self._handle_protected_resource_response(discovery_response)
613516

614517
# Step 2: Discover OAuth metadata (with fallback for legacy servers)
615-
oauth_discovery_stack = self._create_oauth_discovery_stack()
616-
while len(oauth_discovery_stack) > 0:
617-
oauth_discovery = oauth_discovery_stack.pop()
618-
oauth_request = await oauth_discovery()
619-
oauth_response = yield oauth_request
620-
await self._handle_oauth_metadata_response(oauth_response, oauth_discovery_stack)
518+
discovery_urls = self._get_discovery_urls()
519+
for url in discovery_urls:
520+
request = self._create_oauth_metadata_request(url)
521+
response = yield request
522+
523+
if response.status_code == 200:
524+
try:
525+
await self._handle_oauth_metadata_response(response)
526+
break
527+
except ValidationError:
528+
continue
529+
elif response.status_code != 404:
530+
break # Non-404 error, stop trying
621531

622532
# Step 3: Register client if needed
623533
registration_request = await self._register_client()

0 commit comments

Comments
 (0)