Skip to content

Commit d2ed7e4

Browse files
committed
add test
1 parent f9282cf commit d2ed7e4

File tree

1 file changed

+110
-0
lines changed

1 file changed

+110
-0
lines changed

tests/client/test_auth.py

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -862,6 +862,116 @@ async def test_auth_flow_no_unnecessary_retry_after_oauth(
862862
# Verify exactly one request was yielded (no double-sending)
863863
assert request_yields == 1, f"Expected 1 request yield, got {request_yields}"
864864

865+
@pytest.mark.anyio
866+
async def test_token_exchange_accepts_201_status(
867+
self, oauth_provider: OAuthClientProvider, mock_storage: MockTokenStorage
868+
):
869+
"""Test that token exchange accepts both 200 and 201 status codes."""
870+
# Ensure no tokens are stored
871+
oauth_provider.context.current_tokens = None
872+
oauth_provider.context.token_expiry_time = None
873+
oauth_provider._initialized = True
874+
875+
# Create a test request
876+
test_request = httpx.Request("GET", "https://api.example.com/mcp")
877+
878+
# Mock the auth flow
879+
auth_flow = oauth_provider.async_auth_flow(test_request)
880+
881+
# First request should be the original request without auth header
882+
request = await auth_flow.__anext__()
883+
assert "Authorization" not in request.headers
884+
885+
# Send a 401 response to trigger the OAuth flow
886+
response = httpx.Response(
887+
401,
888+
headers={
889+
"WWW-Authenticate": 'Bearer resource_metadata="https://api.example.com/.well-known/oauth-protected-resource"'
890+
},
891+
request=test_request,
892+
)
893+
894+
# Next request should be to discover protected resource metadata
895+
discovery_request = await auth_flow.asend(response)
896+
assert discovery_request.method == "GET"
897+
assert str(discovery_request.url) == "https://api.example.com/.well-known/oauth-protected-resource"
898+
899+
# Send a successful discovery response with minimal protected resource metadata
900+
discovery_response = httpx.Response(
901+
200,
902+
content=b'{"resource": "https://api.example.com/mcp", "authorization_servers": ["https://auth.example.com"]}',
903+
request=discovery_request,
904+
)
905+
906+
# Next request should be to discover OAuth metadata
907+
oauth_metadata_request = await auth_flow.asend(discovery_response)
908+
assert oauth_metadata_request.method == "GET"
909+
assert str(oauth_metadata_request.url).startswith("https://auth.example.com/")
910+
assert "mcp-protocol-version" in oauth_metadata_request.headers
911+
912+
# Send a successful OAuth metadata response
913+
oauth_metadata_response = httpx.Response(
914+
200,
915+
content=(
916+
b'{"issuer": "https://auth.example.com", '
917+
b'"authorization_endpoint": "https://auth.example.com/authorize", '
918+
b'"token_endpoint": "https://auth.example.com/token", '
919+
b'"registration_endpoint": "https://auth.example.com/register"}'
920+
),
921+
request=oauth_metadata_request,
922+
)
923+
924+
# Next request should be to register client
925+
registration_request = await auth_flow.asend(oauth_metadata_response)
926+
assert registration_request.method == "POST"
927+
assert str(registration_request.url) == "https://auth.example.com/register"
928+
929+
# Send a successful registration response with 201 status
930+
registration_response = httpx.Response(
931+
201,
932+
content=b'{"client_id": "test_client_id", "client_secret": "test_client_secret", "redirect_uris": ["http://localhost:3030/callback"]}',
933+
request=registration_request,
934+
)
935+
936+
# Mock the authorization process
937+
oauth_provider._perform_authorization_code_grant = mock.AsyncMock(
938+
return_value=("test_auth_code", "test_code_verifier")
939+
)
940+
941+
# Next request should be to exchange token
942+
token_request = await auth_flow.asend(registration_response)
943+
assert token_request.method == "POST"
944+
assert str(token_request.url) == "https://auth.example.com/token"
945+
assert "code=test_auth_code" in token_request.content.decode()
946+
947+
# Send a successful token response with 201 status code (test both 200 and 201 are accepted)
948+
token_response = httpx.Response(
949+
201,
950+
content=(
951+
b'{"access_token": "new_access_token", "token_type": "Bearer", "expires_in": 3600, '
952+
b'"refresh_token": "new_refresh_token"}'
953+
),
954+
request=token_request,
955+
)
956+
957+
# Final request should be the original request with auth header
958+
final_request = await auth_flow.asend(token_response)
959+
assert final_request.headers["Authorization"] == "Bearer new_access_token"
960+
assert final_request.method == "GET"
961+
assert str(final_request.url) == "https://api.example.com/mcp"
962+
963+
# Send final success response to properly close the generator
964+
final_response = httpx.Response(200, request=final_request)
965+
try:
966+
await auth_flow.asend(final_response)
967+
except StopAsyncIteration:
968+
pass # Expected - generator should complete
969+
970+
# Verify tokens were stored
971+
assert oauth_provider.context.current_tokens is not None
972+
assert oauth_provider.context.current_tokens.access_token == "new_access_token"
973+
assert oauth_provider.context.token_expiry_time is not None
974+
865975
@pytest.mark.anyio
866976
async def test_403_insufficient_scope_updates_scope_from_header(
867977
self,

0 commit comments

Comments
 (0)