Skip to content

Commit c24d7a1

Browse files
fix(auth): avoid SSE OAuth refresh deadlock
1 parent 9773a3f commit c24d7a1

4 files changed

Lines changed: 463 additions & 0 deletions

File tree

src/mcp/client/auth/oauth2.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -483,6 +483,29 @@ def _add_auth_header(self, request: httpx.Request) -> None:
483483
if self.context.current_tokens and self.context.current_tokens.access_token: # pragma: no branch
484484
request.headers["Authorization"] = f"Bearer {self.context.current_tokens.access_token}"
485485

486+
async def _prepare_request_with_refresh(self, client: httpx.AsyncClient, request: httpx.Request) -> None:
487+
"""Refresh stored tokens and add an auth header for requests sent outside the auth flow."""
488+
async with self.context.lock:
489+
if not self._initialized:
490+
await self._initialize()
491+
492+
self.context.protocol_version = request.headers.get(MCP_PROTOCOL_VERSION)
493+
494+
if self.context.is_token_valid():
495+
self._add_auth_header(request)
496+
return
497+
498+
if not self.context.can_refresh_token():
499+
return
500+
501+
refresh_request = await self._refresh_token()
502+
refresh_response = await client.send(refresh_request, auth=None)
503+
504+
if not await self._handle_refresh_response(refresh_response):
505+
return
506+
507+
self._add_auth_header(request)
508+
486509
async def _handle_oauth_metadata_response(self, response: httpx.Response) -> None:
487510
content = await response.aread()
488511
metadata = OAuthMetadata.model_validate_json(content)

src/mcp/client/sse.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from httpx_sse import SSEError, aconnect_sse
1212

1313
import mcp.types as types
14+
from mcp.client.auth import OAuthClientProvider
1415
from mcp.shared._httpx_utils import McpHttpClientFactory, create_mcp_http_client
1516
from mcp.shared.message import SessionMessage
1617

@@ -65,10 +66,19 @@ async def sse_client(
6566
async with httpx_client_factory(
6667
headers=headers, auth=auth, timeout=httpx.Timeout(timeout, read=sse_read_timeout)
6768
) as client:
69+
sse_request_kwargs: dict[str, Any] = {}
70+
if isinstance(auth, OAuthClientProvider):
71+
sse_request = httpx.Request("GET", url, headers=headers)
72+
await auth._prepare_request_with_refresh(client, sse_request) # pyright: ignore[reportPrivateUsage]
73+
if "Authorization" in sse_request.headers:
74+
sse_request_kwargs["headers"] = dict(sse_request.headers)
75+
sse_request_kwargs["auth"] = None
76+
6877
async with aconnect_sse(
6978
client,
7079
"GET",
7180
url,
81+
**sse_request_kwargs,
7282
) as event_source:
7383
event_source.response.raise_for_status()
7484
logger.debug("SSE connection established")

tests/client/test_auth.py

Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -631,6 +631,168 @@ async def test_refresh_token_request(self, oauth_provider: OAuthClientProvider,
631631
assert "client_id=test_client" in content
632632
assert "client_secret=test_secret" in content
633633

634+
@pytest.mark.anyio
635+
async def test_prepare_request_with_refresh_refreshes_expired_token(
636+
self, oauth_provider: OAuthClientProvider, mock_storage: MockTokenStorage, valid_tokens: OAuthToken
637+
):
638+
"""Test preflight refresh for streaming requests that cannot drive OAuth inline."""
639+
640+
class FailingAuth(httpx.Auth):
641+
async def async_auth_flow(self, request: httpx.Request): # pragma: no cover
642+
raise AssertionError("preflight refresh should bypass client auth")
643+
yield request
644+
645+
oauth_provider.context.current_tokens = valid_tokens
646+
oauth_provider.context.token_expiry_time = time.time() - 1
647+
oauth_provider.context.client_info = OAuthClientInformationFull(
648+
client_id="test_client",
649+
client_secret="test_secret",
650+
redirect_uris=[AnyUrl("http://localhost:3030/callback")],
651+
token_endpoint_auth_method="client_secret_post",
652+
)
653+
oauth_provider._initialized = True
654+
655+
requests: list[httpx.Request] = []
656+
657+
async def handler(request: httpx.Request) -> httpx.Response:
658+
requests.append(request)
659+
return httpx.Response(
660+
200,
661+
json={
662+
"access_token": "refreshed_access_token",
663+
"token_type": "Bearer",
664+
"expires_in": 3600,
665+
"refresh_token": "refreshed_refresh_token",
666+
},
667+
request=request,
668+
)
669+
670+
request = httpx.Request(
671+
"GET",
672+
"https://api.example.com/v1/mcp/sse",
673+
headers={"mcp-protocol-version": "2025-06-18"},
674+
)
675+
676+
async with httpx.AsyncClient(transport=httpx.MockTransport(handler), auth=FailingAuth()) as client:
677+
await oauth_provider._prepare_request_with_refresh(client, request) # type: ignore[reportPrivateUsage]
678+
679+
assert len(requests) == 1
680+
assert requests[0].method == "POST"
681+
assert str(requests[0].url) == "https://api.example.com/token"
682+
assert "grant_type=refresh_token" in requests[0].content.decode()
683+
assert "resource=" in requests[0].content.decode()
684+
assert request.headers["Authorization"] == "Bearer refreshed_access_token"
685+
assert oauth_provider.context.current_tokens is not None
686+
assert oauth_provider.context.current_tokens.access_token == "refreshed_access_token"
687+
assert mock_storage._tokens is not None
688+
assert mock_storage._tokens.access_token == "refreshed_access_token"
689+
690+
@pytest.mark.anyio
691+
async def test_prepare_request_with_refresh_skips_valid_token(
692+
self, oauth_provider: OAuthClientProvider, valid_tokens: OAuthToken
693+
):
694+
"""Test preflight refresh is a no-op while the current token is still valid."""
695+
oauth_provider.context.current_tokens = valid_tokens
696+
oauth_provider.context.token_expiry_time = time.time() + 1800
697+
oauth_provider.context.client_info = OAuthClientInformationFull(
698+
client_id="test_client",
699+
client_secret="test_secret",
700+
redirect_uris=[AnyUrl("http://localhost:3030/callback")],
701+
)
702+
oauth_provider._initialized = True
703+
704+
requests: list[httpx.Request] = []
705+
706+
async def handler(request: httpx.Request) -> httpx.Response: # pragma: no cover
707+
requests.append(request)
708+
return httpx.Response(500, request=request)
709+
710+
request = httpx.Request("GET", "https://api.example.com/v1/mcp/sse")
711+
712+
async with httpx.AsyncClient(transport=httpx.MockTransport(handler)) as client:
713+
await oauth_provider._prepare_request_with_refresh(client, request) # type: ignore[reportPrivateUsage]
714+
715+
assert requests == []
716+
assert request.headers["Authorization"] == "Bearer test_access_token"
717+
assert oauth_provider.context.current_tokens is valid_tokens
718+
719+
@pytest.mark.anyio
720+
async def test_prepare_request_with_refresh_initializes_storage(
721+
self, oauth_provider: OAuthClientProvider, mock_storage: MockTokenStorage, valid_tokens: OAuthToken
722+
):
723+
"""Test preflight refresh loads persisted OAuth state before preparing the request."""
724+
client_info = OAuthClientInformationFull(
725+
client_id="test_client",
726+
client_secret="test_secret",
727+
redirect_uris=[AnyUrl("http://localhost:3030/callback")],
728+
)
729+
await mock_storage.set_tokens(valid_tokens)
730+
await mock_storage.set_client_info(client_info)
731+
732+
request = httpx.Request("GET", "https://api.example.com/v1/mcp/sse")
733+
734+
async with httpx.AsyncClient(transport=httpx.MockTransport(lambda request: httpx.Response(500))) as client:
735+
await oauth_provider._prepare_request_with_refresh(client, request) # type: ignore[reportPrivateUsage]
736+
737+
assert request.headers["Authorization"] == "Bearer test_access_token"
738+
assert oauth_provider.context.current_tokens is valid_tokens
739+
assert oauth_provider.context.client_info is client_info
740+
741+
@pytest.mark.anyio
742+
async def test_prepare_request_with_refresh_skips_without_refresh_token(self, oauth_provider: OAuthClientProvider):
743+
"""Test preflight refresh leaves the request alone when refresh is not possible."""
744+
oauth_provider.context.current_tokens = OAuthToken(
745+
access_token="expired_access_token",
746+
refresh_token=None,
747+
expires_in=1,
748+
)
749+
oauth_provider.context.token_expiry_time = time.time() - 1
750+
oauth_provider.context.client_info = OAuthClientInformationFull(
751+
client_id="test_client",
752+
redirect_uris=[AnyUrl("http://localhost:3030/callback")],
753+
)
754+
oauth_provider._initialized = True
755+
756+
requests: list[httpx.Request] = []
757+
758+
async def handler(request: httpx.Request) -> httpx.Response: # pragma: no cover
759+
requests.append(request)
760+
return httpx.Response(500, request=request)
761+
762+
request = httpx.Request("GET", "https://api.example.com/v1/mcp/sse")
763+
764+
async with httpx.AsyncClient(transport=httpx.MockTransport(handler)) as client:
765+
await oauth_provider._prepare_request_with_refresh(client, request) # type: ignore[reportPrivateUsage]
766+
767+
assert requests == []
768+
assert "Authorization" not in request.headers
769+
770+
@pytest.mark.anyio
771+
async def test_prepare_request_with_refresh_keeps_request_unauthenticated_after_refresh_failure(
772+
self, oauth_provider: OAuthClientProvider, valid_tokens: OAuthToken
773+
):
774+
"""Test failed preflight refresh does not add a stale bearer header."""
775+
oauth_provider.context.current_tokens = valid_tokens
776+
oauth_provider.context.token_expiry_time = time.time() - 1
777+
oauth_provider.context.client_info = OAuthClientInformationFull(
778+
client_id="test_client",
779+
client_secret="test_secret",
780+
redirect_uris=[AnyUrl("http://localhost:3030/callback")],
781+
token_endpoint_auth_method="client_secret_post",
782+
)
783+
oauth_provider._initialized = True
784+
785+
async def handler(request: httpx.Request) -> httpx.Response:
786+
return httpx.Response(400, request=request)
787+
788+
request = httpx.Request("GET", "https://api.example.com/v1/mcp/sse")
789+
790+
async with httpx.AsyncClient(transport=httpx.MockTransport(handler)) as client:
791+
await oauth_provider._prepare_request_with_refresh(client, request) # type: ignore[reportPrivateUsage]
792+
793+
assert "Authorization" not in request.headers
794+
assert oauth_provider.context.current_tokens is None
795+
634796
@pytest.mark.anyio
635797
async def test_basic_auth_token_exchange(self, oauth_provider: OAuthClientProvider):
636798
"""Test token exchange with client_secret_basic authentication."""

0 commit comments

Comments
 (0)