@@ -631,6 +631,168 @@ async def test_refresh_token_request(self, oauth_provider: OAuthClientProvider,
631631 assert "client_id=test_client" in content
632632 assert "client_secret=test_secret" in content
633633
634+ @pytest .mark .anyio
635+ async def test_prepare_request_with_refresh_refreshes_expired_token (
636+ self , oauth_provider : OAuthClientProvider , mock_storage : MockTokenStorage , valid_tokens : OAuthToken
637+ ):
638+ """Test preflight refresh for streaming requests that cannot drive OAuth inline."""
639+
640+ class FailingAuth (httpx .Auth ):
641+ async def async_auth_flow (self , request : httpx .Request ): # pragma: no cover
642+ raise AssertionError ("preflight refresh should bypass client auth" )
643+ yield request
644+
645+ oauth_provider .context .current_tokens = valid_tokens
646+ oauth_provider .context .token_expiry_time = time .time () - 1
647+ oauth_provider .context .client_info = OAuthClientInformationFull (
648+ client_id = "test_client" ,
649+ client_secret = "test_secret" ,
650+ redirect_uris = [AnyUrl ("http://localhost:3030/callback" )],
651+ token_endpoint_auth_method = "client_secret_post" ,
652+ )
653+ oauth_provider ._initialized = True
654+
655+ requests : list [httpx .Request ] = []
656+
657+ async def handler (request : httpx .Request ) -> httpx .Response :
658+ requests .append (request )
659+ return httpx .Response (
660+ 200 ,
661+ json = {
662+ "access_token" : "refreshed_access_token" ,
663+ "token_type" : "Bearer" ,
664+ "expires_in" : 3600 ,
665+ "refresh_token" : "refreshed_refresh_token" ,
666+ },
667+ request = request ,
668+ )
669+
670+ request = httpx .Request (
671+ "GET" ,
672+ "https://api.example.com/v1/mcp/sse" ,
673+ headers = {"mcp-protocol-version" : "2025-06-18" },
674+ )
675+
676+ async with httpx .AsyncClient (transport = httpx .MockTransport (handler ), auth = FailingAuth ()) as client :
677+ await oauth_provider ._prepare_request_with_refresh (client , request ) # type: ignore[reportPrivateUsage]
678+
679+ assert len (requests ) == 1
680+ assert requests [0 ].method == "POST"
681+ assert str (requests [0 ].url ) == "https://api.example.com/token"
682+ assert "grant_type=refresh_token" in requests [0 ].content .decode ()
683+ assert "resource=" in requests [0 ].content .decode ()
684+ assert request .headers ["Authorization" ] == "Bearer refreshed_access_token"
685+ assert oauth_provider .context .current_tokens is not None
686+ assert oauth_provider .context .current_tokens .access_token == "refreshed_access_token"
687+ assert mock_storage ._tokens is not None
688+ assert mock_storage ._tokens .access_token == "refreshed_access_token"
689+
690+ @pytest .mark .anyio
691+ async def test_prepare_request_with_refresh_skips_valid_token (
692+ self , oauth_provider : OAuthClientProvider , valid_tokens : OAuthToken
693+ ):
694+ """Test preflight refresh is a no-op while the current token is still valid."""
695+ oauth_provider .context .current_tokens = valid_tokens
696+ oauth_provider .context .token_expiry_time = time .time () + 1800
697+ oauth_provider .context .client_info = OAuthClientInformationFull (
698+ client_id = "test_client" ,
699+ client_secret = "test_secret" ,
700+ redirect_uris = [AnyUrl ("http://localhost:3030/callback" )],
701+ )
702+ oauth_provider ._initialized = True
703+
704+ requests : list [httpx .Request ] = []
705+
706+ async def handler (request : httpx .Request ) -> httpx .Response : # pragma: no cover
707+ requests .append (request )
708+ return httpx .Response (500 , request = request )
709+
710+ request = httpx .Request ("GET" , "https://api.example.com/v1/mcp/sse" )
711+
712+ async with httpx .AsyncClient (transport = httpx .MockTransport (handler )) as client :
713+ await oauth_provider ._prepare_request_with_refresh (client , request ) # type: ignore[reportPrivateUsage]
714+
715+ assert requests == []
716+ assert request .headers ["Authorization" ] == "Bearer test_access_token"
717+ assert oauth_provider .context .current_tokens is valid_tokens
718+
719+ @pytest .mark .anyio
720+ async def test_prepare_request_with_refresh_initializes_storage (
721+ self , oauth_provider : OAuthClientProvider , mock_storage : MockTokenStorage , valid_tokens : OAuthToken
722+ ):
723+ """Test preflight refresh loads persisted OAuth state before preparing the request."""
724+ client_info = OAuthClientInformationFull (
725+ client_id = "test_client" ,
726+ client_secret = "test_secret" ,
727+ redirect_uris = [AnyUrl ("http://localhost:3030/callback" )],
728+ )
729+ await mock_storage .set_tokens (valid_tokens )
730+ await mock_storage .set_client_info (client_info )
731+
732+ request = httpx .Request ("GET" , "https://api.example.com/v1/mcp/sse" )
733+
734+ async with httpx .AsyncClient (transport = httpx .MockTransport (lambda request : httpx .Response (500 ))) as client :
735+ await oauth_provider ._prepare_request_with_refresh (client , request ) # type: ignore[reportPrivateUsage]
736+
737+ assert request .headers ["Authorization" ] == "Bearer test_access_token"
738+ assert oauth_provider .context .current_tokens is valid_tokens
739+ assert oauth_provider .context .client_info is client_info
740+
741+ @pytest .mark .anyio
742+ async def test_prepare_request_with_refresh_skips_without_refresh_token (self , oauth_provider : OAuthClientProvider ):
743+ """Test preflight refresh leaves the request alone when refresh is not possible."""
744+ oauth_provider .context .current_tokens = OAuthToken (
745+ access_token = "expired_access_token" ,
746+ refresh_token = None ,
747+ expires_in = 1 ,
748+ )
749+ oauth_provider .context .token_expiry_time = time .time () - 1
750+ oauth_provider .context .client_info = OAuthClientInformationFull (
751+ client_id = "test_client" ,
752+ redirect_uris = [AnyUrl ("http://localhost:3030/callback" )],
753+ )
754+ oauth_provider ._initialized = True
755+
756+ requests : list [httpx .Request ] = []
757+
758+ async def handler (request : httpx .Request ) -> httpx .Response : # pragma: no cover
759+ requests .append (request )
760+ return httpx .Response (500 , request = request )
761+
762+ request = httpx .Request ("GET" , "https://api.example.com/v1/mcp/sse" )
763+
764+ async with httpx .AsyncClient (transport = httpx .MockTransport (handler )) as client :
765+ await oauth_provider ._prepare_request_with_refresh (client , request ) # type: ignore[reportPrivateUsage]
766+
767+ assert requests == []
768+ assert "Authorization" not in request .headers
769+
770+ @pytest .mark .anyio
771+ async def test_prepare_request_with_refresh_keeps_request_unauthenticated_after_refresh_failure (
772+ self , oauth_provider : OAuthClientProvider , valid_tokens : OAuthToken
773+ ):
774+ """Test failed preflight refresh does not add a stale bearer header."""
775+ oauth_provider .context .current_tokens = valid_tokens
776+ oauth_provider .context .token_expiry_time = time .time () - 1
777+ oauth_provider .context .client_info = OAuthClientInformationFull (
778+ client_id = "test_client" ,
779+ client_secret = "test_secret" ,
780+ redirect_uris = [AnyUrl ("http://localhost:3030/callback" )],
781+ token_endpoint_auth_method = "client_secret_post" ,
782+ )
783+ oauth_provider ._initialized = True
784+
785+ async def handler (request : httpx .Request ) -> httpx .Response :
786+ return httpx .Response (400 , request = request )
787+
788+ request = httpx .Request ("GET" , "https://api.example.com/v1/mcp/sse" )
789+
790+ async with httpx .AsyncClient (transport = httpx .MockTransport (handler )) as client :
791+ await oauth_provider ._prepare_request_with_refresh (client , request ) # type: ignore[reportPrivateUsage]
792+
793+ assert "Authorization" not in request .headers
794+ assert oauth_provider .context .current_tokens is None
795+
634796 @pytest .mark .anyio
635797 async def test_basic_auth_token_exchange (self , oauth_provider : OAuthClientProvider ):
636798 """Test token exchange with client_secret_basic authentication."""
0 commit comments