Skip to content

Commit 15dbfd5

Browse files
committed
fix(oauth): narrow async_auth_flow lock scope to avoid blocking long-poll requests
Previously the entire `OAuthClientProvider.async_auth_flow` body ran under `self.context.lock`, including the `yield request` that hands the request off to httpx. For requests that complete quickly this is fine, but a GET SSE long-poll holds the lock for the full SSE lifetime — which means any concurrent POST (e.g. `tools/call`) is blocked waiting for the lock, producing a ~16s perceived stall on lazy MCP connections that use OAuth. This commit splits the single coarse lock into purpose-specific scopes: Phase 1 (context.lock): initialize state, capture protocol_version, and decide whether a refresh is needed. Short-held; no HTTP I/O. Phase 2 (refresh_lock, new): single-flight token refresh. The refresh request `yield` happens outside any lock. A double-check inside `context.lock` ensures concurrent waiters do not redundantly refresh after another coroutine completed one. Phase 3 (no lock): add the auth header and yield the actual request. GET SSE long-polls and other long-running requests no longer block unrelated traffic. Phase 4 (context.lock): 401 / 403 full OAuth re-auth path. Conservatively kept under lock because this path is rare and its yielded sub-requests (metadata discovery, registration, token exchange) hit the AS, not the resource server. A future refactor can narrow this further. Lock additions: - `OAuthContext.refresh_lock: anyio.Lock` provides single-flight refresh so concurrent requests trigger at most one token refresh. Behavior changes: - Concurrent requests through the same `OAuthClientProvider` no longer serialize at the lock. GET SSE long-polls and POSTs now proceed in parallel. - Token refresh remains serialized (via `refresh_lock`), preserving the invariant that only one refresh request is in flight at a time. - Public API and behavior are otherwise unchanged. Related upstream issue: #1326
1 parent f475344 commit 15dbfd5

1 file changed

Lines changed: 61 additions & 16 deletions

File tree

src/mcp/client/auth/oauth2.py

Lines changed: 61 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,18 @@ class OAuthContext:
114114
token_expiry_time: float | None = None
115115

116116
# State
117+
#
118+
# `lock` guards short-lived reads/writes of provider state (initialization
119+
# flag, token cache mutation, protocol_version assignment). It is held only
120+
# while mutating state and is released before any HTTP request is yielded
121+
# so a long-running request (e.g. GET SSE long-poll) does not block
122+
# unrelated concurrent requests.
123+
#
124+
# `refresh_lock` provides single-flight semantics for token refresh: only
125+
# one concurrent refresh fires; other waiters block on this lock, then
126+
# re-check the token cache and proceed without re-refreshing.
117127
lock: anyio.Lock = field(default_factory=anyio.Lock)
128+
refresh_lock: anyio.Lock = field(default_factory=anyio.Lock)
118129

119130
def get_authorization_base_url(self, server_url: str) -> str:
120131
"""Extract base URL by removing path component."""
@@ -504,7 +515,17 @@ async def _validate_resource_match(self, prm: ProtectedResourceMetadata) -> None
504515
raise OAuthFlowError(f"Protected resource {prm_resource} does not match expected {default_resource}")
505516

506517
async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.Request, httpx.Response]:
507-
"""HTTPX auth flow integration."""
518+
"""HTTPX auth flow integration.
519+
520+
Lock scope:
521+
``self.context.lock`` is held only while reading/mutating provider
522+
state. The actual HTTP request yield (which may be a long-poll GET
523+
SSE stream) runs outside any lock so concurrent unrelated requests
524+
are not blocked. ``self.context.refresh_lock`` provides
525+
single-flight semantics for token refresh.
526+
"""
527+
# === Phase 1: state read + refresh decision (brief context.lock) ===
528+
needs_refresh = False
508529
async with self.context.lock:
509530
if not self._initialized:
510531
await self._initialize() # pragma: no cover
@@ -513,20 +534,43 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.
513534
self.context.protocol_version = request.headers.get(MCP_PROTOCOL_VERSION)
514535

515536
if not self.context.is_token_valid() and self.context.can_refresh_token():
516-
# Try to refresh token
517-
refresh_request = await self._refresh_token() # pragma: no cover
518-
refresh_response = yield refresh_request # pragma: no cover
519-
520-
if not await self._handle_refresh_response(refresh_response): # pragma: no cover
521-
# Refresh failed, need full re-authentication
522-
self._initialized = False
523-
524-
if self.context.is_token_valid():
525-
self._add_auth_header(request)
526-
527-
response = yield request
528-
529-
if response.status_code == 401:
537+
needs_refresh = True
538+
539+
# === Phase 2: single-flight token refresh (yield outside context.lock) ===
540+
if needs_refresh:
541+
async with self.context.refresh_lock:
542+
# Re-check under context.lock: another coroutine may already have
543+
# refreshed while we were waiting on refresh_lock.
544+
async with self.context.lock:
545+
still_invalid = (
546+
not self.context.is_token_valid()
547+
and self.context.can_refresh_token()
548+
)
549+
if still_invalid:
550+
refresh_request = await self._refresh_token() # pragma: no cover
551+
if still_invalid:
552+
# yield runs outside any lock so a long network round trip
553+
# does not block unrelated concurrent requests.
554+
refresh_response = yield refresh_request # pragma: no cover
555+
async with self.context.lock:
556+
if not await self._handle_refresh_response(refresh_response): # pragma: no cover
557+
# Refresh failed; fall through to 401 handling below.
558+
self._initialized = False
559+
560+
# === Phase 3: send request (no lock; safe for long-poll GET SSE) ===
561+
if self.context.is_token_valid():
562+
self._add_auth_header(request)
563+
564+
response = yield request
565+
566+
# === Phase 4: 401 / 403 full OAuth flow ===
567+
# NOTE: Phase 4 yields multiple sub-requests (discovery, registration,
568+
# token exchange) under context.lock. This is the existing behavior and
569+
# is acceptable because the 401 path is exceptional and not concurrent
570+
# with steady-state traffic. A future refactor could narrow the lock
571+
# here in the same pattern as Phase 1-2.
572+
if response.status_code == 401:
573+
async with self.context.lock:
530574
# Perform full OAuth flow
531575
try:
532576
# OAuth flow must be inline due to generator constraints
@@ -619,7 +663,8 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.
619663
# Retry with new tokens
620664
self._add_auth_header(request)
621665
yield request
622-
elif response.status_code == 403:
666+
elif response.status_code == 403:
667+
async with self.context.lock:
623668
# Step 1: Extract error field from WWW-Authenticate header
624669
error = extract_field_from_www_auth(response, "error")
625670

0 commit comments

Comments
 (0)