Skip to content

Commit 3aef782

Browse files
authored
Merge pull request #38 from sacha-development-stuff/codex/fix-code-coverage-failure-mz8j4k
Add coverage for OAuth provider edge cases and token exchange handler
2 parents 6d7cc85 + 3996fc7 commit 3aef782

File tree

2 files changed

+293
-0
lines changed

2 files changed

+293
-0
lines changed

tests/unit/client/test_oauth2_providers.py

Lines changed: 223 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,24 @@ def test_apply_client_auth_prefers_post_when_supported() -> None:
294294
assert "Authorization" not in headers
295295

296296

297+
def test_apply_client_auth_defaults_when_metadata_omits_supported_methods() -> None:
298+
storage = InMemoryStorage()
299+
metadata = OAuthClientMetadata(redirect_uris=_redirect_uris())
300+
provider = ClientCredentialsProvider("https://api.example.com/service", metadata, storage)
301+
provider._metadata = OAuthMetadata.model_validate(
302+
{**_metadata_json(), "token_endpoint_auth_methods_supported": ["none"]}
303+
)
304+
305+
token_data: dict[str, str] = {}
306+
headers: dict[str, str] = {}
307+
client_info = OAuthClientInformationFull(client_id="client", client_secret="secret")
308+
309+
provider._apply_client_auth(token_data, headers, client_info)
310+
311+
assert token_data == {"client_id": "client", "client_secret": "secret"}
312+
assert headers == {}
313+
314+
297315
@pytest.mark.anyio
298316
async def test_client_credentials_request_token_with_metadata(monkeypatch: pytest.MonkeyPatch) -> None:
299317
storage = InMemoryStorage()
@@ -359,6 +377,26 @@ async def test_client_credentials_get_or_register_client(monkeypatch: pytest.Mon
359377
assert storage.client_info is client_info
360378

361379

380+
@pytest.mark.anyio
381+
async def test_client_credentials_get_or_register_client_skips_request_when_not_needed() -> None:
382+
storage = InMemoryStorage()
383+
metadata = OAuthClientMetadata(redirect_uris=_redirect_uris())
384+
provider = ClientCredentialsProvider("https://api.example.com/service", metadata, storage)
385+
386+
def fake_create_registration_request(
387+
self: ClientCredentialsProvider, metadata: OAuthMetadata | None
388+
) -> httpx.Request | None:
389+
self._client_info = OAuthClientInformationFull(client_id="existing-client")
390+
return None
391+
392+
provider._metadata = OAuthMetadata.model_validate(_metadata_json())
393+
provider._create_registration_request = MethodType(fake_create_registration_request, provider)
394+
395+
client_info = await provider._get_or_register_client()
396+
397+
assert client_info.client_id == "existing-client"
398+
399+
362400
@pytest.mark.anyio
363401
async def test_client_credentials_request_token_handles_invalid_metadata(monkeypatch: pytest.MonkeyPatch) -> None:
364402
storage = InMemoryStorage()
@@ -514,6 +552,32 @@ async def test_client_credentials_async_auth_flow_with_cached_token() -> None:
514552
await flow.asend(response)
515553

516554

555+
@pytest.mark.anyio
556+
async def test_client_credentials_async_auth_flow_without_access_token_header(monkeypatch: pytest.MonkeyPatch) -> None:
557+
storage = InMemoryStorage()
558+
client_metadata = OAuthClientMetadata(redirect_uris=_redirect_uris())
559+
provider = ClientCredentialsProvider("https://api.example.com/service", client_metadata, storage)
560+
561+
async def fake_initialize() -> None:
562+
provider._current_tokens = None
563+
564+
async def fake_ensure_token() -> None:
565+
provider._current_tokens = None
566+
567+
provider.initialize = fake_initialize # type: ignore[assignment]
568+
provider.ensure_token = fake_ensure_token # type: ignore[assignment]
569+
570+
request = httpx.Request("GET", "https://api.example.com/resource")
571+
flow = provider.async_auth_flow(request)
572+
573+
prepared_request = await anext(flow)
574+
assert "Authorization" not in prepared_request.headers
575+
576+
response = httpx.Response(200, request=prepared_request)
577+
with pytest.raises(StopAsyncIteration):
578+
await flow.asend(response)
579+
580+
517581
@pytest.mark.anyio
518582
async def test_token_exchange_request_token(monkeypatch: pytest.MonkeyPatch) -> None:
519583
storage = InMemoryStorage()
@@ -604,6 +668,43 @@ async def test_token_exchange_request_token_handles_invalid_metadata(monkeypatch
604668
actor_supplier.assert_awaited_once()
605669

606670

671+
@pytest.mark.anyio
672+
async def test_token_exchange_request_token_excludes_resource_when_unset(monkeypatch: pytest.MonkeyPatch) -> None:
673+
storage = InMemoryStorage()
674+
client_metadata = OAuthClientMetadata(redirect_uris=_redirect_uris())
675+
676+
subject_supplier = AsyncMock(return_value="subject-token")
677+
678+
provider = TokenExchangeProvider(
679+
"https://api.example.com/service",
680+
client_metadata,
681+
storage,
682+
subject_token_supplier=subject_supplier,
683+
)
684+
685+
provider._metadata = OAuthMetadata.model_validate(_metadata_json())
686+
provider._client_info = OAuthClientInformationFull(client_id="client", client_secret="secret")
687+
provider.resource = None
688+
689+
class RecordingAsyncClient(DummyAsyncClient):
690+
def __init__(self) -> None:
691+
super().__init__(post_responses=[_make_response(200, json_data=_token_json())])
692+
self.last_data: dict[str, str] | None = None
693+
694+
async def post(self, url: str, *, data: dict[str, str], headers: dict[str, str]) -> httpx.Response:
695+
self.last_data = data
696+
return await super().post(url, data=data, headers=headers)
697+
698+
clients = [RecordingAsyncClient()]
699+
monkeypatch.setattr("mcp.client.auth.oauth2.httpx.AsyncClient", AsyncClientFactory(clients))
700+
701+
await provider._request_token()
702+
703+
recorded_client = clients[0]
704+
assert recorded_client.last_data is not None
705+
assert "resource" not in recorded_client.last_data
706+
707+
607708
@pytest.mark.anyio
608709
async def test_token_exchange_request_token_raises_on_failure(monkeypatch: pytest.MonkeyPatch) -> None:
609710
storage = InMemoryStorage()
@@ -800,6 +901,64 @@ async def test_token_exchange_async_auth_flow_with_cached_token() -> None:
800901
await flow.asend(response)
801902

802903

904+
@pytest.mark.anyio
905+
async def test_token_exchange_async_auth_flow_without_access_token_header(monkeypatch: pytest.MonkeyPatch) -> None:
906+
storage = InMemoryStorage()
907+
client_metadata = OAuthClientMetadata(redirect_uris=_redirect_uris())
908+
909+
provider = TokenExchangeProvider(
910+
"https://api.example.com/service",
911+
client_metadata,
912+
storage,
913+
subject_token_supplier=AsyncMock(return_value="subject-token"),
914+
)
915+
916+
async def fake_initialize() -> None:
917+
provider._current_tokens = None
918+
919+
async def fake_ensure_token() -> None:
920+
provider._current_tokens = None
921+
922+
provider.initialize = fake_initialize # type: ignore[assignment]
923+
provider.ensure_token = fake_ensure_token # type: ignore[assignment]
924+
925+
request = httpx.Request("GET", "https://api.example.com/resource")
926+
flow = provider.async_auth_flow(request)
927+
928+
prepared_request = await anext(flow)
929+
assert "Authorization" not in prepared_request.headers
930+
931+
response = httpx.Response(200, request=prepared_request)
932+
with pytest.raises(StopAsyncIteration):
933+
await flow.asend(response)
934+
935+
936+
@pytest.mark.anyio
937+
async def test_token_exchange_get_or_register_client_skips_request_when_not_needed() -> None:
938+
storage = InMemoryStorage()
939+
metadata = OAuthClientMetadata(redirect_uris=_redirect_uris())
940+
941+
provider = TokenExchangeProvider(
942+
"https://api.example.com/service",
943+
metadata,
944+
storage,
945+
subject_token_supplier=AsyncMock(return_value="subject-token"),
946+
)
947+
948+
def fake_create_registration_request(
949+
self: TokenExchangeProvider, metadata: OAuthMetadata | None
950+
) -> httpx.Request | None:
951+
self._client_info = OAuthClientInformationFull(client_id="existing-client")
952+
return None
953+
954+
provider._metadata = OAuthMetadata.model_validate(_metadata_json())
955+
provider._create_registration_request = MethodType(fake_create_registration_request, provider)
956+
957+
client_info = await provider._get_or_register_client()
958+
959+
assert client_info.client_id == "existing-client"
960+
961+
803962
@pytest.mark.anyio
804963
async def test_token_exchange_ensure_token_returns_when_valid() -> None:
805964
storage = InMemoryStorage()
@@ -919,3 +1078,67 @@ async def fake_handle_token(self: OAuthClientProvider, response: httpx.Response)
9191078
final_response = httpx.Response(200, request=retry_request)
9201079
with pytest.raises(StopAsyncIteration):
9211080
await flow.asend(final_response)
1081+
1082+
1083+
@pytest.mark.anyio
1084+
async def test_oauth_client_provider_metadata_discovery_skips_when_no_urls(monkeypatch: pytest.MonkeyPatch) -> None:
1085+
storage = InMemoryStorage()
1086+
metadata = OAuthClientMetadata(redirect_uris=_redirect_uris())
1087+
provider = OAuthClientProvider("https://api.example.com/service", metadata, storage)
1088+
provider._initialized = True
1089+
1090+
client = OAuthClientInformationFull(client_id="client", client_secret="secret")
1091+
provider._metadata = OAuthMetadata.model_validate(_metadata_json())
1092+
provider._client_info = client
1093+
provider.context.client_info = client
1094+
1095+
def fake_build_resource_urls(self: OAuthClientProvider, response: httpx.Response) -> list[str]:
1096+
return ["https://resource.example.com/.well-known"]
1097+
1098+
async def fake_handle_resource(self: OAuthClientProvider, response: httpx.Response) -> bool:
1099+
self.context.auth_server_url = "https://auth.example.com"
1100+
return True
1101+
1102+
def fake_get_discovery_urls(self: OAuthClientProvider, url: str) -> list[str]:
1103+
assert url == "https://auth.example.com"
1104+
return []
1105+
1106+
async def fake_perform_authorization(self: OAuthClientProvider) -> httpx.Request:
1107+
return httpx.Request("POST", "https://auth.example.com/token")
1108+
1109+
async def fake_handle_token(self: OAuthClientProvider, response: httpx.Response) -> None:
1110+
token = OAuthToken(access_token="flow-token", scope="alpha")
1111+
self.context.current_tokens = token
1112+
await self.context.storage.set_tokens(token)
1113+
1114+
provider._select_scopes = MethodType(lambda self, response: None, provider)
1115+
monkeypatch.setattr(provider, "_build_protected_resource_discovery_urls", MethodType(fake_build_resource_urls, provider))
1116+
monkeypatch.setattr(provider, "_handle_protected_resource_response", MethodType(fake_handle_resource, provider))
1117+
monkeypatch.setattr(provider, "_get_discovery_urls", MethodType(fake_get_discovery_urls, provider))
1118+
monkeypatch.setattr(provider, "_perform_authorization", MethodType(fake_perform_authorization, provider))
1119+
monkeypatch.setattr(provider, "_handle_token_response", MethodType(fake_handle_token, provider))
1120+
1121+
request = httpx.Request("GET", "https://api.example.com/resource")
1122+
flow = provider.async_auth_flow(request)
1123+
1124+
prepared_request = await anext(flow)
1125+
assert "Authorization" not in prepared_request.headers
1126+
1127+
headers = {
1128+
"WWW-Authenticate": 'Bearer resource_metadata="https://resource.example.com/.well-known"'
1129+
}
1130+
first_response = httpx.Response(401, headers=headers, request=prepared_request)
1131+
1132+
discovery_request = await flow.asend(first_response)
1133+
discovery_response = httpx.Response(200, request=discovery_request)
1134+
1135+
token_request = await flow.asend(discovery_response)
1136+
assert isinstance(token_request, httpx.Request)
1137+
1138+
token_response = httpx.Response(200, request=token_request)
1139+
retry_request = await flow.asend(token_response)
1140+
assert retry_request.headers["Authorization"] == "Bearer flow-token"
1141+
1142+
final_response = httpx.Response(200, request=retry_request)
1143+
with pytest.raises(StopAsyncIteration):
1144+
await flow.asend(final_response)

tests/unit/server/auth/test_token_handler.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
AuthorizationCodeRequest,
1313
ClientCredentialsRequest,
1414
TokenErrorResponse,
15+
TokenExchangeRequest,
1516
TokenHandler,
1617
TokenSuccessResponse,
1718
)
@@ -52,6 +53,33 @@ async def exchange_client_credentials(self, client_info: object, scopes: list[st
5253
raise TokenError(error="invalid_client", error_description="bad credentials")
5354

5455

56+
class TokenExchangeProviderStub:
57+
def __init__(self) -> None:
58+
self.last_call: dict[str, Any] | None = None
59+
60+
async def exchange_token(
61+
self,
62+
client_info: object,
63+
subject_token: str,
64+
subject_token_type: str,
65+
actor_token: str | None,
66+
actor_token_type: str | None,
67+
scopes: list[str],
68+
audience: str | None,
69+
resource: str | None,
70+
) -> OAuthToken:
71+
self.last_call = {
72+
"subject_token": subject_token,
73+
"subject_token_type": subject_token_type,
74+
"actor_token": actor_token,
75+
"actor_token_type": actor_token_type,
76+
"scopes": scopes,
77+
"audience": audience,
78+
"resource": resource,
79+
}
80+
return OAuthToken(access_token="exchanged-token")
81+
82+
5583
class RefreshTokenProvider:
5684
def __init__(self) -> None:
5785
self.refresh_token = SimpleNamespace(
@@ -156,3 +184,45 @@ async def test_handle_route_refresh_token_branch() -> None:
156184
assert isinstance(body, bytes | bytearray | memoryview)
157185
payload = json.loads(bytes(body).decode())
158186
assert payload["access_token"] == "refreshed-token"
187+
188+
189+
@pytest.mark.anyio
190+
async def test_handle_route_token_exchange_branch() -> None:
191+
provider = TokenExchangeProviderStub()
192+
client_info = OAuthClientInformationFull(
193+
client_id="client",
194+
grant_types=["token_exchange"],
195+
scope="alpha beta",
196+
)
197+
handler = TokenHandler(
198+
provider=cast(OAuthAuthorizationServerProvider[Any, Any, Any], provider),
199+
client_authenticator=cast(ClientAuthenticator, DummyAuthenticator(client_info)),
200+
)
201+
202+
request_data = {
203+
"grant_type": "token_exchange",
204+
"subject_token": "subject-token",
205+
"subject_token_type": "access_token",
206+
"actor_token": "actor-token",
207+
"actor_token_type": "jwt",
208+
"scope": "alpha beta",
209+
"audience": "https://audience.example.com",
210+
"resource": "https://resource.example.com",
211+
"client_id": "client",
212+
"client_secret": "secret",
213+
}
214+
215+
response = await handler.handle(cast(Request, DummyRequest(request_data)))
216+
217+
assert response.status_code == 200
218+
payload = json.loads(bytes(response.body).decode())
219+
assert payload["access_token"] == "exchanged-token"
220+
assert provider.last_call == {
221+
"subject_token": "subject-token",
222+
"subject_token_type": "access_token",
223+
"actor_token": "actor-token",
224+
"actor_token_type": "jwt",
225+
"scopes": ["alpha", "beta"],
226+
"audience": "https://audience.example.com",
227+
"resource": "https://resource.example.com",
228+
}

0 commit comments

Comments
 (0)