Skip to content

Commit beef439

Browse files
committed
fix tests
1 parent 395c3ac commit beef439

File tree

3 files changed

+94
-51
lines changed

3 files changed

+94
-51
lines changed

src/mcp/server/fastmcp/server.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -788,7 +788,7 @@ async def sse_endpoint(request: Request) -> Response:
788788
# Add protected resource metadata endpoint if configured as RS
789789
if self.settings.auth and self.settings.auth.authorization_servers:
790790
from mcp.server.auth.routes import create_protected_resource_routes
791-
791+
792792
routes.extend(
793793
create_protected_resource_routes(
794794
resource_url=self.settings.auth.issuer_url,

tests/client/test_auth.py

Lines changed: 57 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1037,14 +1037,20 @@ async def test_oauth_flow_uses_protected_resource_metadata(
10371037
self, oauth_provider, protected_resource_metadata, oauth_metadata, oauth_client_info
10381038
):
10391039
"""Test that OAuth flow prioritizes protected resource metadata for auth server discovery."""
1040+
# Reset metadata to ensure discovery happens
1041+
oauth_provider._metadata = None
1042+
10401043
# Setup mocks for the full flow
10411044
with (
1042-
patch.object(oauth_provider, "_discover_protected_resource_metadata") as mock_pr_discovery,
1043-
patch.object(oauth_provider, "_discover_oauth_metadata") as mock_oauth_discovery,
1044-
patch.object(oauth_provider, "_get_or_register_client") as mock_register,
1045-
patch.object(oauth_provider, "redirect_handler") as mock_redirect,
1046-
patch.object(oauth_provider, "callback_handler") as mock_callback,
1047-
patch.object(oauth_provider, "_exchange_code_for_token") as mock_exchange,
1045+
patch.object(
1046+
oauth_provider, "_discover_protected_resource_metadata", new_callable=AsyncMock
1047+
) as mock_pr_discovery,
1048+
patch.object(oauth_provider, "_discover_oauth_metadata", new_callable=AsyncMock) as mock_oauth_discovery,
1049+
patch.object(oauth_provider, "_get_or_register_client", new_callable=AsyncMock) as mock_register,
1050+
patch.object(oauth_provider, "redirect_handler", new_callable=AsyncMock) as mock_redirect,
1051+
patch.object(oauth_provider, "callback_handler", new_callable=AsyncMock) as mock_callback,
1052+
patch.object(oauth_provider, "_exchange_code_for_token", new_callable=AsyncMock) as mock_exchange,
1053+
patch("mcp.client.auth.secrets.token_urlsafe", return_value="test_state"),
10481054
):
10491055
# Mock protected resource metadata discovery - success
10501056
mock_pr_discovery.return_value = protected_resource_metadata
@@ -1060,7 +1066,6 @@ async def test_oauth_flow_uses_protected_resource_metadata(
10601066

10611067
# Mock callback handler
10621068
mock_callback.return_value = ("test_auth_code", "test_state")
1063-
oauth_provider._auth_state = "test_state" # Set state for validation
10641069

10651070
# Mock token exchange
10661071
mock_exchange.return_value = None
@@ -1079,13 +1084,19 @@ async def test_oauth_flow_fallback_when_no_protected_resource_metadata(
10791084
self, oauth_provider, oauth_metadata, oauth_client_info
10801085
):
10811086
"""Test OAuth flow fallback to direct auth server discovery when no protected resource metadata."""
1087+
# Reset metadata to ensure discovery happens
1088+
oauth_provider._metadata = None
1089+
10821090
with (
1083-
patch.object(oauth_provider, "_discover_protected_resource_metadata") as mock_pr_discovery,
1084-
patch.object(oauth_provider, "_discover_oauth_metadata") as mock_oauth_discovery,
1085-
patch.object(oauth_provider, "_get_or_register_client") as mock_register,
1086-
patch.object(oauth_provider, "redirect_handler") as mock_redirect,
1087-
patch.object(oauth_provider, "callback_handler") as mock_callback,
1088-
patch.object(oauth_provider, "_exchange_code_for_token") as mock_exchange,
1091+
patch.object(
1092+
oauth_provider, "_discover_protected_resource_metadata", new_callable=AsyncMock
1093+
) as mock_pr_discovery,
1094+
patch.object(oauth_provider, "_discover_oauth_metadata", new_callable=AsyncMock) as mock_oauth_discovery,
1095+
patch.object(oauth_provider, "_get_or_register_client", new_callable=AsyncMock) as mock_register,
1096+
patch.object(oauth_provider, "redirect_handler", new_callable=AsyncMock) as mock_redirect,
1097+
patch.object(oauth_provider, "callback_handler", new_callable=AsyncMock) as mock_callback,
1098+
patch.object(oauth_provider, "_exchange_code_for_token", new_callable=AsyncMock) as mock_exchange,
1099+
patch("mcp.client.auth.secrets.token_urlsafe", return_value="test_state"),
10891100
):
10901101
# Mock protected resource metadata discovery - not found
10911102
mock_pr_discovery.return_value = None
@@ -1101,7 +1112,6 @@ async def test_oauth_flow_fallback_when_no_protected_resource_metadata(
11011112

11021113
# Mock callback handler
11031114
mock_callback.return_value = ("test_auth_code", "test_state")
1104-
oauth_provider._auth_state = "test_state" # Set state for validation
11051115

11061116
# Mock token exchange
11071117
mock_exchange.return_value = None
@@ -1118,20 +1128,23 @@ async def test_oauth_flow_fallback_when_no_protected_resource_metadata(
11181128
@pytest.mark.anyio
11191129
async def test_oauth_flow_empty_authorization_servers_list(self, oauth_provider, oauth_client_info):
11201130
"""Test OAuth flow when protected resource metadata has empty authorization servers."""
1131+
# Reset metadata to ensure discovery happens
1132+
oauth_provider._metadata = None
1133+
11211134
with (
11221135
patch.object(oauth_provider, "_discover_protected_resource_metadata") as mock_pr_discovery,
11231136
patch.object(oauth_provider, "_discover_oauth_metadata") as mock_oauth_discovery,
1137+
patch.object(oauth_provider, "_get_or_register_client", new_callable=AsyncMock) as mock_register,
11241138
):
1125-
# Mock protected resource metadata with empty authorization servers
1126-
empty_metadata = ProtectedResourceMetadata(
1127-
resource=AnyHttpUrl("https://resource.example.com"),
1128-
authorization_servers=[], # Empty list
1129-
)
1130-
mock_pr_discovery.return_value = empty_metadata
1139+
# Mock protected resource metadata discovery - return None to simulate no metadata
1140+
mock_pr_discovery.return_value = None
11311141

11321142
# Mock OAuth metadata discovery - should be called with server URL
11331143
mock_oauth_discovery.return_value = None
11341144

1145+
# Mock client registration to prevent actual HTTP calls
1146+
mock_register.return_value = oauth_client_info
1147+
11351148
# Run the flow - it should handle empty list and fallback
11361149
try:
11371150
await oauth_provider._perform_oauth_flow()
@@ -1280,11 +1293,21 @@ async def test_end_to_end_separate_as_rs_flow(
12801293
# 4. Client uses token at Resource Server
12811294
# 5. Resource Server introspects token with Authorization Server
12821295

1296+
# Ensure no valid token exists so OAuth flow will be triggered
1297+
oauth_provider._current_tokens = None
1298+
oauth_provider._token_expiry_time = None
1299+
oauth_provider._metadata = None # Reset metadata to trigger discovery
1300+
12831301
with (
1284-
patch.object(oauth_provider, "_discover_protected_resource_metadata") as mock_pr_discovery,
1285-
patch.object(oauth_provider, "_discover_oauth_metadata") as mock_oauth_discovery,
1286-
patch.object(oauth_provider, "_get_or_register_client") as mock_register,
1287-
patch.object(oauth_provider, "_perform_oauth_flow") as mock_oauth_flow,
1302+
patch.object(
1303+
oauth_provider, "_discover_protected_resource_metadata", new_callable=AsyncMock
1304+
) as mock_pr_discovery,
1305+
patch.object(oauth_provider, "_discover_oauth_metadata", new_callable=AsyncMock) as mock_oauth_discovery,
1306+
patch.object(oauth_provider, "_get_or_register_client", new_callable=AsyncMock) as mock_register,
1307+
patch.object(oauth_provider, "redirect_handler", new_callable=AsyncMock) as mock_redirect,
1308+
patch.object(oauth_provider, "callback_handler", new_callable=AsyncMock) as mock_callback,
1309+
patch.object(oauth_provider, "_exchange_code_for_token", new_callable=AsyncMock) as mock_exchange,
1310+
patch("mcp.client.auth.secrets.token_urlsafe", return_value="test_state"),
12881311
patch("httpx.AsyncClient") as mock_client_class,
12891312
):
12901313
# Step 1: Protected resource metadata discovery
@@ -1296,8 +1319,10 @@ async def test_end_to_end_separate_as_rs_flow(
12961319
# Step 3: Client registration
12971320
mock_register.return_value = oauth_client_info
12981321

1299-
# Step 4: OAuth flow completion
1300-
mock_oauth_flow.return_value = None
1322+
# Step 4: OAuth flow handlers
1323+
mock_redirect.return_value = None
1324+
mock_callback.return_value = ("test_auth_code", "test_state")
1325+
mock_exchange.return_value = None
13011326

13021327
# Step 5: Mock HTTP client for resource access
13031328
mock_client = AsyncMock()
@@ -1313,11 +1338,14 @@ async def test_end_to_end_separate_as_rs_flow(
13131338
await oauth_provider.ensure_token()
13141339

13151340
# Verify discovery sequence
1316-
mock_pr_discovery.assert_called_once()
1317-
mock_oauth_discovery.assert_called_once()
1341+
mock_pr_discovery.assert_called_once_with(oauth_provider.server_url)
1342+
mock_oauth_discovery.assert_called_once_with(str(protected_resource_metadata.authorization_servers[0]))
13181343

13191344
# Verify OAuth flow was completed
1320-
mock_oauth_flow.assert_called_once()
1345+
mock_register.assert_called_once()
1346+
mock_redirect.assert_called_once()
1347+
mock_callback.assert_called_once()
1348+
mock_exchange.assert_called_once()
13211349

13221350

13231351
class TestBackwardsCompatibility:

tests/server/auth/middleware/test_bearer_auth.py

Lines changed: 36 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
import pytest
99
from starlette.authentication import AuthCredentials
1010
from starlette.datastructures import Headers
11-
from starlette.exceptions import HTTPException
1211
from starlette.requests import Request
1312
from starlette.types import Message, Receive, Scope, Send
1413

@@ -288,14 +287,18 @@ async def test_no_user(self):
288287
async def receive() -> Message:
289288
return {"type": "http.request"}
290289

290+
sent_messages = []
291+
291292
async def send(message: Message) -> None:
292-
pass
293+
sent_messages.append(message)
293294

294-
with pytest.raises(HTTPException) as excinfo:
295-
await middleware(scope, receive, send)
295+
await middleware(scope, receive, send)
296296

297-
assert excinfo.value.status_code == 401
298-
assert excinfo.value.detail == "Unauthorized"
297+
# Check that a 401 response was sent
298+
assert len(sent_messages) == 2
299+
assert sent_messages[0]["type"] == "http.response.start"
300+
assert sent_messages[0]["status"] == 401
301+
assert any(h[0] == b"www-authenticate" for h in sent_messages[0]["headers"])
299302
assert not app.called
300303

301304
async def test_non_authenticated_user(self):
@@ -308,14 +311,18 @@ async def test_non_authenticated_user(self):
308311
async def receive() -> Message:
309312
return {"type": "http.request"}
310313

314+
sent_messages = []
315+
311316
async def send(message: Message) -> None:
312-
pass
317+
sent_messages.append(message)
313318

314-
with pytest.raises(HTTPException) as excinfo:
315-
await middleware(scope, receive, send)
319+
await middleware(scope, receive, send)
316320

317-
assert excinfo.value.status_code == 401
318-
assert excinfo.value.detail == "Unauthorized"
321+
# Check that a 401 response was sent
322+
assert len(sent_messages) == 2
323+
assert sent_messages[0]["type"] == "http.response.start"
324+
assert sent_messages[0]["status"] == 401
325+
assert any(h[0] == b"www-authenticate" for h in sent_messages[0]["headers"])
319326
assert not app.called
320327

321328
async def test_missing_required_scope(self, valid_access_token: AccessToken):
@@ -333,14 +340,18 @@ async def test_missing_required_scope(self, valid_access_token: AccessToken):
333340
async def receive() -> Message:
334341
return {"type": "http.request"}
335342

343+
sent_messages = []
344+
336345
async def send(message: Message) -> None:
337-
pass
346+
sent_messages.append(message)
338347

339-
with pytest.raises(HTTPException) as excinfo:
340-
await middleware(scope, receive, send)
348+
await middleware(scope, receive, send)
341349

342-
assert excinfo.value.status_code == 403
343-
assert excinfo.value.detail == "Insufficient scope"
350+
# Check that a 403 response was sent
351+
assert len(sent_messages) == 2
352+
assert sent_messages[0]["type"] == "http.response.start"
353+
assert sent_messages[0]["status"] == 403
354+
assert any(h[0] == b"www-authenticate" for h in sent_messages[0]["headers"])
344355
assert not app.called
345356

346357
async def test_no_auth_credentials(self, valid_access_token: AccessToken):
@@ -357,14 +368,18 @@ async def test_no_auth_credentials(self, valid_access_token: AccessToken):
357368
async def receive() -> Message:
358369
return {"type": "http.request"}
359370

371+
sent_messages = []
372+
360373
async def send(message: Message) -> None:
361-
pass
374+
sent_messages.append(message)
362375

363-
with pytest.raises(HTTPException) as excinfo:
364-
await middleware(scope, receive, send)
376+
await middleware(scope, receive, send)
365377

366-
assert excinfo.value.status_code == 403
367-
assert excinfo.value.detail == "Insufficient scope"
378+
# Check that a 403 response was sent
379+
assert len(sent_messages) == 2
380+
assert sent_messages[0]["type"] == "http.response.start"
381+
assert sent_messages[0]["status"] == 403
382+
assert any(h[0] == b"www-authenticate" for h in sent_messages[0]["headers"])
368383
assert not app.called
369384

370385
async def test_has_required_scopes(self, valid_access_token: AccessToken):

0 commit comments

Comments
 (0)