@@ -1037,14 +1037,20 @@ async def test_oauth_flow_uses_protected_resource_metadata(
10371037 self , oauth_provider , protected_resource_metadata , oauth_metadata , oauth_client_info
10381038 ):
10391039 """Test that OAuth flow prioritizes protected resource metadata for auth server discovery."""
1040+ # Reset metadata to ensure discovery happens
1041+ oauth_provider ._metadata = None
1042+
10401043 # Setup mocks for the full flow
10411044 with (
1042- patch .object (oauth_provider , "_discover_protected_resource_metadata" ) as mock_pr_discovery ,
1043- patch .object (oauth_provider , "_discover_oauth_metadata" ) as mock_oauth_discovery ,
1044- patch .object (oauth_provider , "_get_or_register_client" ) as mock_register ,
1045- patch .object (oauth_provider , "redirect_handler" ) as mock_redirect ,
1046- patch .object (oauth_provider , "callback_handler" ) as mock_callback ,
1047- patch .object (oauth_provider , "_exchange_code_for_token" ) as mock_exchange ,
1045+ patch .object (
1046+ oauth_provider , "_discover_protected_resource_metadata" , new_callable = AsyncMock
1047+ ) as mock_pr_discovery ,
1048+ patch .object (oauth_provider , "_discover_oauth_metadata" , new_callable = AsyncMock ) as mock_oauth_discovery ,
1049+ patch .object (oauth_provider , "_get_or_register_client" , new_callable = AsyncMock ) as mock_register ,
1050+ patch .object (oauth_provider , "redirect_handler" , new_callable = AsyncMock ) as mock_redirect ,
1051+ patch .object (oauth_provider , "callback_handler" , new_callable = AsyncMock ) as mock_callback ,
1052+ patch .object (oauth_provider , "_exchange_code_for_token" , new_callable = AsyncMock ) as mock_exchange ,
1053+ patch ("mcp.client.auth.secrets.token_urlsafe" , return_value = "test_state" ),
10481054 ):
10491055 # Mock protected resource metadata discovery - success
10501056 mock_pr_discovery .return_value = protected_resource_metadata
@@ -1060,7 +1066,6 @@ async def test_oauth_flow_uses_protected_resource_metadata(
10601066
10611067 # Mock callback handler
10621068 mock_callback .return_value = ("test_auth_code" , "test_state" )
1063- oauth_provider ._auth_state = "test_state" # Set state for validation
10641069
10651070 # Mock token exchange
10661071 mock_exchange .return_value = None
@@ -1079,13 +1084,19 @@ async def test_oauth_flow_fallback_when_no_protected_resource_metadata(
10791084 self , oauth_provider , oauth_metadata , oauth_client_info
10801085 ):
10811086 """Test OAuth flow fallback to direct auth server discovery when no protected resource metadata."""
1087+ # Reset metadata to ensure discovery happens
1088+ oauth_provider ._metadata = None
1089+
10821090 with (
1083- patch .object (oauth_provider , "_discover_protected_resource_metadata" ) as mock_pr_discovery ,
1084- patch .object (oauth_provider , "_discover_oauth_metadata" ) as mock_oauth_discovery ,
1085- patch .object (oauth_provider , "_get_or_register_client" ) as mock_register ,
1086- patch .object (oauth_provider , "redirect_handler" ) as mock_redirect ,
1087- patch .object (oauth_provider , "callback_handler" ) as mock_callback ,
1088- patch .object (oauth_provider , "_exchange_code_for_token" ) as mock_exchange ,
1091+ patch .object (
1092+ oauth_provider , "_discover_protected_resource_metadata" , new_callable = AsyncMock
1093+ ) as mock_pr_discovery ,
1094+ patch .object (oauth_provider , "_discover_oauth_metadata" , new_callable = AsyncMock ) as mock_oauth_discovery ,
1095+ patch .object (oauth_provider , "_get_or_register_client" , new_callable = AsyncMock ) as mock_register ,
1096+ patch .object (oauth_provider , "redirect_handler" , new_callable = AsyncMock ) as mock_redirect ,
1097+ patch .object (oauth_provider , "callback_handler" , new_callable = AsyncMock ) as mock_callback ,
1098+ patch .object (oauth_provider , "_exchange_code_for_token" , new_callable = AsyncMock ) as mock_exchange ,
1099+ patch ("mcp.client.auth.secrets.token_urlsafe" , return_value = "test_state" ),
10891100 ):
10901101 # Mock protected resource metadata discovery - not found
10911102 mock_pr_discovery .return_value = None
@@ -1101,7 +1112,6 @@ async def test_oauth_flow_fallback_when_no_protected_resource_metadata(
11011112
11021113 # Mock callback handler
11031114 mock_callback .return_value = ("test_auth_code" , "test_state" )
1104- oauth_provider ._auth_state = "test_state" # Set state for validation
11051115
11061116 # Mock token exchange
11071117 mock_exchange .return_value = None
@@ -1118,20 +1128,23 @@ async def test_oauth_flow_fallback_when_no_protected_resource_metadata(
11181128 @pytest .mark .anyio
11191129 async def test_oauth_flow_empty_authorization_servers_list (self , oauth_provider , oauth_client_info ):
11201130 """Test OAuth flow when protected resource metadata has empty authorization servers."""
1131+ # Reset metadata to ensure discovery happens
1132+ oauth_provider ._metadata = None
1133+
11211134 with (
11221135 patch .object (oauth_provider , "_discover_protected_resource_metadata" ) as mock_pr_discovery ,
11231136 patch .object (oauth_provider , "_discover_oauth_metadata" ) as mock_oauth_discovery ,
1137+ patch .object (oauth_provider , "_get_or_register_client" , new_callable = AsyncMock ) as mock_register ,
11241138 ):
1125- # Mock protected resource metadata with empty authorization servers
1126- empty_metadata = ProtectedResourceMetadata (
1127- resource = AnyHttpUrl ("https://resource.example.com" ),
1128- authorization_servers = [], # Empty list
1129- )
1130- mock_pr_discovery .return_value = empty_metadata
1139+ # Mock protected resource metadata discovery - return None to simulate no metadata
1140+ mock_pr_discovery .return_value = None
11311141
11321142 # Mock OAuth metadata discovery - should be called with server URL
11331143 mock_oauth_discovery .return_value = None
11341144
1145+ # Mock client registration to prevent actual HTTP calls
1146+ mock_register .return_value = oauth_client_info
1147+
11351148 # Run the flow - it should handle empty list and fallback
11361149 try :
11371150 await oauth_provider ._perform_oauth_flow ()
@@ -1280,11 +1293,21 @@ async def test_end_to_end_separate_as_rs_flow(
12801293 # 4. Client uses token at Resource Server
12811294 # 5. Resource Server introspects token with Authorization Server
12821295
1296+ # Ensure no valid token exists so OAuth flow will be triggered
1297+ oauth_provider ._current_tokens = None
1298+ oauth_provider ._token_expiry_time = None
1299+ oauth_provider ._metadata = None # Reset metadata to trigger discovery
1300+
12831301 with (
1284- patch .object (oauth_provider , "_discover_protected_resource_metadata" ) as mock_pr_discovery ,
1285- patch .object (oauth_provider , "_discover_oauth_metadata" ) as mock_oauth_discovery ,
1286- patch .object (oauth_provider , "_get_or_register_client" ) as mock_register ,
1287- patch .object (oauth_provider , "_perform_oauth_flow" ) as mock_oauth_flow ,
1302+ patch .object (
1303+ oauth_provider , "_discover_protected_resource_metadata" , new_callable = AsyncMock
1304+ ) as mock_pr_discovery ,
1305+ patch .object (oauth_provider , "_discover_oauth_metadata" , new_callable = AsyncMock ) as mock_oauth_discovery ,
1306+ patch .object (oauth_provider , "_get_or_register_client" , new_callable = AsyncMock ) as mock_register ,
1307+ patch .object (oauth_provider , "redirect_handler" , new_callable = AsyncMock ) as mock_redirect ,
1308+ patch .object (oauth_provider , "callback_handler" , new_callable = AsyncMock ) as mock_callback ,
1309+ patch .object (oauth_provider , "_exchange_code_for_token" , new_callable = AsyncMock ) as mock_exchange ,
1310+ patch ("mcp.client.auth.secrets.token_urlsafe" , return_value = "test_state" ),
12881311 patch ("httpx.AsyncClient" ) as mock_client_class ,
12891312 ):
12901313 # Step 1: Protected resource metadata discovery
@@ -1296,8 +1319,10 @@ async def test_end_to_end_separate_as_rs_flow(
12961319 # Step 3: Client registration
12971320 mock_register .return_value = oauth_client_info
12981321
1299- # Step 4: OAuth flow completion
1300- mock_oauth_flow .return_value = None
1322+ # Step 4: OAuth flow handlers
1323+ mock_redirect .return_value = None
1324+ mock_callback .return_value = ("test_auth_code" , "test_state" )
1325+ mock_exchange .return_value = None
13011326
13021327 # Step 5: Mock HTTP client for resource access
13031328 mock_client = AsyncMock ()
@@ -1313,11 +1338,14 @@ async def test_end_to_end_separate_as_rs_flow(
13131338 await oauth_provider .ensure_token ()
13141339
13151340 # Verify discovery sequence
1316- mock_pr_discovery .assert_called_once ( )
1317- mock_oauth_discovery .assert_called_once ( )
1341+ mock_pr_discovery .assert_called_once_with ( oauth_provider . server_url )
1342+ mock_oauth_discovery .assert_called_once_with ( str ( protected_resource_metadata . authorization_servers [ 0 ]) )
13181343
13191344 # Verify OAuth flow was completed
1320- mock_oauth_flow .assert_called_once ()
1345+ mock_register .assert_called_once ()
1346+ mock_redirect .assert_called_once ()
1347+ mock_callback .assert_called_once ()
1348+ mock_exchange .assert_called_once ()
13211349
13221350
13231351class TestBackwardsCompatibility :
0 commit comments