Skip to content

Commit 54ea6a5

Browse files
committed
test: add regression coverage for OAuthClientProvider concurrent requests
Two tests in a new TestConcurrentRequestsDoNotDeadlock class exercise the behavior the previous commit fixes: 1. ``test_concurrent_request_not_blocked_by_pending_long_running_request`` drives one async_auth_flow generator to its yield (= simulating a GET SSE long-poll suspended waiting for the next event) and then opens a second concurrent flow on the same provider. The second flow must reach its own yield within a short timeout — i.e., the lock release between Phase 1 and Phase 3 lets it through. Pre-fix, the second generator would block on context.lock indefinitely. 2. ``test_concurrent_token_refresh_is_single_flight`` exercises the refresh_lock single-flight path. A first flow performs the refresh yield; a second flow started after the refresh completes observes the freshly-updated token in Phase 1 and proceeds directly to its own request yield without issuing a second refresh. Also: tighten the refresh_request unbound-after-conditional-write pattern in async_auth_flow so pyright recognizes it as definitely assigned at the yield site (was: derived from a boolean predicate; now: typed as ``httpx.Request | None`` and checked explicitly).
1 parent 15dbfd5 commit 54ea6a5

2 files changed

Lines changed: 118 additions & 6 deletions

File tree

src/mcp/client/auth/oauth2.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -541,14 +541,11 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.
541541
async with self.context.refresh_lock:
542542
# Re-check under context.lock: another coroutine may already have
543543
# refreshed while we were waiting on refresh_lock.
544+
refresh_request: httpx.Request | None = None
544545
async with self.context.lock:
545-
still_invalid = (
546-
not self.context.is_token_valid()
547-
and self.context.can_refresh_token()
548-
)
549-
if still_invalid:
546+
if not self.context.is_token_valid() and self.context.can_refresh_token():
550547
refresh_request = await self._refresh_token() # pragma: no cover
551-
if still_invalid:
548+
if refresh_request is not None:
552549
# yield runs outside any lock so a long network round trip
553550
# does not block unrelated concurrent requests.
554551
refresh_response = yield refresh_request # pragma: no cover

tests/client/test_auth.py

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
"""Tests for refactored OAuth client authentication implementation."""
22

33
import base64
4+
import contextlib
45
import time
56
from unittest import mock
67
from urllib.parse import parse_qs, quote, unquote, urlparse
78

9+
import anyio
810
import httpx
911
import pytest
1012
from inline_snapshot import Is, snapshot
@@ -2618,3 +2620,116 @@ async def callback_handler() -> tuple[str, str | None]:
26182620
await auth_flow.asend(final_response)
26192621
except StopAsyncIteration:
26202622
pass
2623+
2624+
2625+
class TestConcurrentRequestsDoNotDeadlock:
2626+
"""Regression tests for #1326.
2627+
2628+
Ensures that ``OAuthClientProvider.async_auth_flow`` does not serialize
2629+
concurrent unrelated requests behind a long-running one (e.g. GET SSE
2630+
long-poll). The fix narrows ``context.lock`` to state mutation only; the
2631+
actual ``yield request`` runs outside any lock.
2632+
"""
2633+
2634+
@pytest.mark.anyio
2635+
async def test_concurrent_request_not_blocked_by_pending_long_running_request(
2636+
self, oauth_provider: OAuthClientProvider, valid_tokens: OAuthToken
2637+
):
2638+
"""A second request must reach its yield while the first is still
2639+
suspended at its yield (= simulating a server-side long-poll).
2640+
2641+
Before this fix, ``async_auth_flow`` held ``context.lock`` across
2642+
``yield request``. A GET SSE long-poll would therefore hold the lock
2643+
for the entire SSE lifetime, blocking any concurrent request waiting
2644+
on the same provider's lock and producing a multi-second stall.
2645+
"""
2646+
# Set up valid tokens so neither refresh (Phase 2) nor full OAuth
2647+
# flow (Phase 4) is triggered — we want to exercise the steady-state
2648+
# Phase 3 yield path that previously held the lock.
2649+
oauth_provider.context.current_tokens = valid_tokens
2650+
oauth_provider.context.token_expiry_time = time.time() + 1800
2651+
oauth_provider.context.client_info = OAuthClientInformationFull(
2652+
client_id="test_client_id",
2653+
client_secret="test_client_secret",
2654+
redirect_uris=[AnyUrl("http://localhost:3030/callback")],
2655+
)
2656+
oauth_provider._initialized = True
2657+
2658+
# Flow 1: simulate a slow request. Drive it to its yield, then
2659+
# deliberately do not send a response — it stays suspended at the
2660+
# yield, just like a GET SSE long-poll waiting for the next event.
2661+
slow_request = httpx.Request("GET", "https://api.example.com/v1/mcp")
2662+
slow_flow = oauth_provider.async_auth_flow(slow_request)
2663+
yielded_slow = await slow_flow.__anext__()
2664+
assert yielded_slow.headers.get("Authorization") == "Bearer test_access_token"
2665+
2666+
# Flow 2: a concurrent request on the same provider. With the fix,
2667+
# context.lock is not held during Flow 1's yield, so Flow 2 reaches
2668+
# its yield almost immediately. Without the fix, this would block
2669+
# until Flow 1 receives a response — i.e., it would hit the timeout.
2670+
fast_request = httpx.Request("POST", "https://api.example.com/v1/mcp")
2671+
fast_flow = oauth_provider.async_auth_flow(fast_request)
2672+
with anyio.fail_after(1.0):
2673+
yielded_fast = await fast_flow.__anext__()
2674+
assert yielded_fast.headers.get("Authorization") == "Bearer test_access_token"
2675+
2676+
# Clean up both generators in deterministic order.
2677+
with contextlib.suppress(StopAsyncIteration):
2678+
await fast_flow.asend(httpx.Response(200, request=yielded_fast))
2679+
with contextlib.suppress(StopAsyncIteration):
2680+
await slow_flow.asend(httpx.Response(200, request=yielded_slow))
2681+
2682+
@pytest.mark.anyio
2683+
async def test_concurrent_token_refresh_is_single_flight(
2684+
self, oauth_provider: OAuthClientProvider, valid_tokens: OAuthToken
2685+
):
2686+
"""When concurrent requests both observe an expired token, only one
2687+
refresh request is sent: ``refresh_lock`` provides single-flight
2688+
semantics so the second waiter re-checks state and proceeds without
2689+
re-triggering refresh.
2690+
"""
2691+
# Mark the token as expired so the next auth_flow run enters Phase 2.
2692+
oauth_provider.context.current_tokens = valid_tokens
2693+
oauth_provider.context.token_expiry_time = time.time() - 100 # expired
2694+
oauth_provider.context.client_info = OAuthClientInformationFull(
2695+
client_id="test_client_id",
2696+
client_secret="test_client_secret",
2697+
redirect_uris=[AnyUrl("http://localhost:3030/callback")],
2698+
)
2699+
oauth_provider._initialized = True
2700+
2701+
# Flow A: drive it to the refresh yield and suspend there.
2702+
request_a = httpx.Request("GET", "https://api.example.com/v1/mcp")
2703+
flow_a = oauth_provider.async_auth_flow(request_a)
2704+
refresh_a = await flow_a.__anext__()
2705+
assert "grant_type=refresh_token" in refresh_a.read().decode()
2706+
2707+
# Complete Flow A's refresh with a fresh token.
2708+
refresh_response = httpx.Response(
2709+
200,
2710+
content=(
2711+
b'{"access_token": "new_access_token", "token_type": "Bearer", '
2712+
b'"expires_in": 3600, "refresh_token": "new_refresh_token"}'
2713+
),
2714+
request=refresh_a,
2715+
)
2716+
request_a_post = await flow_a.asend(refresh_response)
2717+
assert request_a_post.headers.get("Authorization") == "Bearer new_access_token"
2718+
2719+
# Flow B starts after Flow A's refresh has completed. Because token
2720+
# state was updated under context.lock, Flow B observes the fresh
2721+
# token in Phase 1, skips Phase 2 entirely, and reaches its yield
2722+
# directly. No second refresh request is sent.
2723+
request_b = httpx.Request("POST", "https://api.example.com/v1/mcp")
2724+
flow_b = oauth_provider.async_auth_flow(request_b)
2725+
with anyio.fail_after(1.0):
2726+
request_b_yielded = await flow_b.__anext__()
2727+
assert request_b_yielded.headers.get("Authorization") == "Bearer new_access_token"
2728+
# Confirm Flow B yielded the original POST, not a refresh request.
2729+
assert request_b_yielded.method == "POST"
2730+
2731+
# Clean up.
2732+
with contextlib.suppress(StopAsyncIteration):
2733+
await flow_b.asend(httpx.Response(200, request=request_b_yielded))
2734+
with contextlib.suppress(StopAsyncIteration):
2735+
await flow_a.asend(httpx.Response(200, request=request_a_post))

0 commit comments

Comments
 (0)