Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 65 additions & 18 deletions src/mcp/client/auth/oauth2.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,18 @@ class OAuthContext:
token_expiry_time: float | None = None

# State
#
# `lock` guards short-lived reads/writes of provider state (initialization
# flag, token cache mutation, protocol_version assignment). It is held only
# while mutating state and is released before any HTTP request is yielded
# so a long-running request (e.g. GET SSE long-poll) does not block
# unrelated concurrent requests.
#
# `refresh_lock` provides single-flight semantics for token refresh: only
# one concurrent refresh fires; other waiters block on this lock, then
# re-check the token cache and proceed without re-refreshing.
lock: anyio.Lock = field(default_factory=anyio.Lock)
refresh_lock: anyio.Lock = field(default_factory=anyio.Lock)

def get_authorization_base_url(self, server_url: str) -> str:
"""Extract base URL by removing path component."""
Expand Down Expand Up @@ -452,7 +463,7 @@ async def _refresh_token(self) -> httpx.Request:

return httpx.Request("POST", token_url, data=refresh_data, headers=headers)

async def _handle_refresh_response(self, response: httpx.Response) -> bool: # pragma: no cover
async def _handle_refresh_response(self, response: httpx.Response) -> bool:
"""Handle token refresh response. Returns True if successful."""
if response.status_code != 200:
logger.warning(f"Token refresh failed: {response.status_code}")
Expand Down Expand Up @@ -504,29 +515,64 @@ async def _validate_resource_match(self, prm: ProtectedResourceMetadata) -> None
raise OAuthFlowError(f"Protected resource {prm_resource} does not match expected {default_resource}")

async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.Request, httpx.Response]:
"""HTTPX auth flow integration."""
"""HTTPX auth flow integration.

Lock scope:
``self.context.lock`` is held only while reading/mutating provider
state. The actual HTTP request yield (which may be a long-poll GET
SSE stream) runs outside any lock so concurrent unrelated requests
are not blocked. ``self.context.refresh_lock`` provides
single-flight semantics for token refresh.
"""
# === Phase 1: state read + refresh decision (brief context.lock) ===
needs_refresh = False
async with self.context.lock:
if not self._initialized:
await self._initialize() # pragma: no cover

# Capture protocol version from request headers
self.context.protocol_version = request.headers.get(MCP_PROTOCOL_VERSION)

if not self.context.is_token_valid() and self.context.can_refresh_token():
# Try to refresh token
refresh_request = await self._refresh_token() # pragma: no cover
refresh_response = yield refresh_request # pragma: no cover

if not await self._handle_refresh_response(refresh_response): # pragma: no cover
# Refresh failed, need full re-authentication
self._initialized = False

if self.context.is_token_valid():
self._add_auth_header(request)

response = yield request

if response.status_code == 401:
# pragma: no branch — coverage.py on Python 3.10/3.11 (sys.settrace
# backend) cannot reliably track both arms of compound boolean
# predicates inside an ``async with`` block in an async generator.
# Python 3.12+ (sys.monitoring) handles this correctly; the pragmas
# below are workarounds for the legacy backend only.
if not self.context.is_token_valid() and self.context.can_refresh_token(): # pragma: no branch
needs_refresh = True

# === Phase 2: single-flight token refresh (yield outside context.lock) ===
if needs_refresh:
async with self.context.refresh_lock:
# Re-check under context.lock: another coroutine may already have
# refreshed while we were waiting on refresh_lock.
refresh_request: httpx.Request | None = None
async with self.context.lock:
if not self.context.is_token_valid() and self.context.can_refresh_token(): # pragma: no branch
refresh_request = await self._refresh_token()
if refresh_request is not None: # pragma: no branch
# yield runs outside any lock so a long network round trip
# does not block unrelated concurrent requests.
refresh_response = yield refresh_request
async with self.context.lock:
if not await self._handle_refresh_response(refresh_response): # pragma: no branch
# Refresh failed; fall through to 401 handling below.
self._initialized = False

# === Phase 3: send request (no lock; safe for long-poll GET SSE) ===
if self.context.is_token_valid():
self._add_auth_header(request)

response = yield request

# === Phase 4: 401 / 403 full OAuth flow ===
# NOTE: Phase 4 yields multiple sub-requests (discovery, registration,
# token exchange) under context.lock. This is the existing behavior and
# is acceptable because the 401 path is exceptional and not concurrent
# with steady-state traffic. A future refactor could narrow the lock
# here in the same pattern as Phase 1-2.
if response.status_code == 401:
async with self.context.lock:
# Perform full OAuth flow
try:
# OAuth flow must be inline due to generator constraints
Expand Down Expand Up @@ -619,7 +665,8 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.
# Retry with new tokens
self._add_auth_header(request)
yield request
elif response.status_code == 403:
elif response.status_code == 403:
async with self.context.lock:
# Step 1: Extract error field from WWW-Authenticate header
error = extract_field_from_www_auth(response, "error")

Expand Down
Loading
Loading