@@ -862,6 +862,116 @@ async def test_auth_flow_no_unnecessary_retry_after_oauth(
862862 # Verify exactly one request was yielded (no double-sending)
863863 assert request_yields == 1 , f"Expected 1 request yield, got { request_yields } "
864864
865+ @pytest .mark .anyio
866+ async def test_token_exchange_accepts_201_status (
867+ self , oauth_provider : OAuthClientProvider , mock_storage : MockTokenStorage
868+ ):
869+ """Test that token exchange accepts both 200 and 201 status codes."""
870+ # Ensure no tokens are stored
871+ oauth_provider .context .current_tokens = None
872+ oauth_provider .context .token_expiry_time = None
873+ oauth_provider ._initialized = True
874+
875+ # Create a test request
876+ test_request = httpx .Request ("GET" , "https://api.example.com/mcp" )
877+
878+ # Mock the auth flow
879+ auth_flow = oauth_provider .async_auth_flow (test_request )
880+
881+ # First request should be the original request without auth header
882+ request = await auth_flow .__anext__ ()
883+ assert "Authorization" not in request .headers
884+
885+ # Send a 401 response to trigger the OAuth flow
886+ response = httpx .Response (
887+ 401 ,
888+ headers = {
889+ "WWW-Authenticate" : 'Bearer resource_metadata="https://api.example.com/.well-known/oauth-protected-resource"'
890+ },
891+ request = test_request ,
892+ )
893+
894+ # Next request should be to discover protected resource metadata
895+ discovery_request = await auth_flow .asend (response )
896+ assert discovery_request .method == "GET"
897+ assert str (discovery_request .url ) == "https://api.example.com/.well-known/oauth-protected-resource"
898+
899+ # Send a successful discovery response with minimal protected resource metadata
900+ discovery_response = httpx .Response (
901+ 200 ,
902+ content = b'{"resource": "https://api.example.com/mcp", "authorization_servers": ["https://auth.example.com"]}' ,
903+ request = discovery_request ,
904+ )
905+
906+ # Next request should be to discover OAuth metadata
907+ oauth_metadata_request = await auth_flow .asend (discovery_response )
908+ assert oauth_metadata_request .method == "GET"
909+ assert str (oauth_metadata_request .url ).startswith ("https://auth.example.com/" )
910+ assert "mcp-protocol-version" in oauth_metadata_request .headers
911+
912+ # Send a successful OAuth metadata response
913+ oauth_metadata_response = httpx .Response (
914+ 200 ,
915+ content = (
916+ b'{"issuer": "https://auth.example.com", '
917+ b'"authorization_endpoint": "https://auth.example.com/authorize", '
918+ b'"token_endpoint": "https://auth.example.com/token", '
919+ b'"registration_endpoint": "https://auth.example.com/register"}'
920+ ),
921+ request = oauth_metadata_request ,
922+ )
923+
924+ # Next request should be to register client
925+ registration_request = await auth_flow .asend (oauth_metadata_response )
926+ assert registration_request .method == "POST"
927+ assert str (registration_request .url ) == "https://auth.example.com/register"
928+
929+ # Send a successful registration response with 201 status
930+ registration_response = httpx .Response (
931+ 201 ,
932+ content = b'{"client_id": "test_client_id", "client_secret": "test_client_secret", "redirect_uris": ["http://localhost:3030/callback"]}' ,
933+ request = registration_request ,
934+ )
935+
936+ # Mock the authorization process
937+ oauth_provider ._perform_authorization_code_grant = mock .AsyncMock (
938+ return_value = ("test_auth_code" , "test_code_verifier" )
939+ )
940+
941+ # Next request should be to exchange token
942+ token_request = await auth_flow .asend (registration_response )
943+ assert token_request .method == "POST"
944+ assert str (token_request .url ) == "https://auth.example.com/token"
945+ assert "code=test_auth_code" in token_request .content .decode ()
946+
947+ # Send a successful token response with 201 status code (test both 200 and 201 are accepted)
948+ token_response = httpx .Response (
949+ 201 ,
950+ content = (
951+ b'{"access_token": "new_access_token", "token_type": "Bearer", "expires_in": 3600, '
952+ b'"refresh_token": "new_refresh_token"}'
953+ ),
954+ request = token_request ,
955+ )
956+
957+ # Final request should be the original request with auth header
958+ final_request = await auth_flow .asend (token_response )
959+ assert final_request .headers ["Authorization" ] == "Bearer new_access_token"
960+ assert final_request .method == "GET"
961+ assert str (final_request .url ) == "https://api.example.com/mcp"
962+
963+ # Send final success response to properly close the generator
964+ final_response = httpx .Response (200 , request = final_request )
965+ try :
966+ await auth_flow .asend (final_response )
967+ except StopAsyncIteration :
968+ pass # Expected - generator should complete
969+
970+ # Verify tokens were stored
971+ assert oauth_provider .context .current_tokens is not None
972+ assert oauth_provider .context .current_tokens .access_token == "new_access_token"
973+ assert oauth_provider .context .token_expiry_time is not None
974+
865975 @pytest .mark .anyio
866976 async def test_403_insufficient_scope_updates_scope_from_header (
867977 self ,
0 commit comments