Skip to content

Commit 4a59ba4

Browse files
committed
Fix OAuth refresh resource handling
1 parent f475344 commit 4a59ba4

2 files changed

Lines changed: 34 additions & 9 deletions

File tree

src/mcp/client/auth/oauth2.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ def get_resource_url(self) -> str:
151151

152152
# If PRM provides a resource that's a valid parent, use it
153153
if self.protected_resource_metadata and self.protected_resource_metadata.resource:
154-
prm_resource = str(self.protected_resource_metadata.resource)
154+
prm_resource = str(self.protected_resource_metadata.resource).rstrip("/")
155155
if check_resource_allowed(requested_resource=resource, configured_resource=prm_resource):
156156
resource = prm_resource
157157

@@ -442,10 +442,6 @@ async def _refresh_token(self) -> httpx.Request:
442442
"client_id": self.context.client_info.client_id,
443443
}
444444

445-
# Only include resource param if conditions are met
446-
if self.context.should_include_resource_param(self.context.protocol_version):
447-
refresh_data["resource"] = self.context.get_resource_url() # RFC 8707
448-
449445
# Prepare authentication based on preferred method
450446
headers = {"Content-Type": "application/x-www-form-urlencoded"}
451447
refresh_data, headers = self.context.prepare_token_auth(refresh_data, headers)

tests/client/test_auth.py

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -745,7 +745,7 @@ class TestProtectedResourceMetadata:
745745

746746
@pytest.mark.anyio
747747
async def test_resource_param_included_with_recent_protocol_version(self, oauth_provider: OAuthClientProvider):
748-
"""Test resource parameter is included for protocol version >= 2025-06-18."""
748+
"""Test resource parameter is included on authorization-code token requests."""
749749
# Set protocol version to 2025-06-18
750750
oauth_provider.context.protocol_version = "2025-06-18"
751751
oauth_provider.context.client_info = OAuthClientInformationFull(
@@ -762,15 +762,15 @@ async def test_resource_param_included_with_recent_protocol_version(self, oauth_
762762
expected_resource = quote(oauth_provider.context.get_resource_url(), safe="")
763763
assert f"resource={expected_resource}" in content
764764

765-
# Test in refresh token
765+
# Refresh grants should not include resource; some providers reject it.
766766
oauth_provider.context.current_tokens = OAuthToken(
767767
access_token="test_access",
768768
token_type="Bearer",
769769
refresh_token="test_refresh",
770770
)
771771
refresh_request = await oauth_provider._refresh_token()
772772
refresh_content = refresh_request.content.decode()
773-
assert "resource=" in refresh_content
773+
assert "resource=" not in refresh_content
774774

775775
@pytest.mark.anyio
776776
async def test_resource_param_excluded_with_old_protocol_version(self, oauth_provider: OAuthClientProvider):
@@ -800,7 +800,7 @@ async def test_resource_param_excluded_with_old_protocol_version(self, oauth_pro
800800

801801
@pytest.mark.anyio
802802
async def test_resource_param_included_with_protected_resource_metadata(self, oauth_provider: OAuthClientProvider):
803-
"""Test resource parameter is always included when protected resource metadata exists."""
803+
"""Test PRM makes resource available for authorization-code token requests."""
804804
# Set old protocol version but with protected resource metadata
805805
oauth_provider.context.protocol_version = "2025-03-26"
806806
oauth_provider.context.protected_resource_metadata = ProtectedResourceMetadata(
@@ -818,6 +818,16 @@ async def test_resource_param_included_with_protected_resource_metadata(self, oa
818818
content = request.content.decode()
819819
assert "resource=" in content
820820

821+
# Refresh grants should not include resource even when PRM is present.
822+
oauth_provider.context.current_tokens = OAuthToken(
823+
access_token="test_access",
824+
token_type="Bearer",
825+
refresh_token="test_refresh",
826+
)
827+
refresh_request = await oauth_provider._refresh_token()
828+
refresh_content = refresh_request.content.decode()
829+
assert "resource=" not in refresh_content
830+
821831

822832
@pytest.mark.anyio
823833
async def test_validate_resource_rejects_mismatched_resource(
@@ -949,6 +959,25 @@ async def test_get_resource_url_uses_canonical_when_prm_mismatches(
949959
assert provider.context.get_resource_url() == snapshot("https://api.example.com/v1/mcp")
950960

951961

962+
@pytest.mark.anyio
963+
async def test_get_resource_url_strips_bare_domain_prm_trailing_slash(
964+
client_metadata: OAuthClientMetadata, mock_storage: MockTokenStorage
965+
) -> None:
966+
"""Bare-domain PRM resources should not inherit AnyHttpUrl's trailing slash."""
967+
provider = OAuthClientProvider(
968+
server_url="https://api.example.com",
969+
client_metadata=client_metadata,
970+
storage=mock_storage,
971+
)
972+
provider._initialized = True
973+
provider.context.protected_resource_metadata = ProtectedResourceMetadata(
974+
resource=AnyHttpUrl("https://api.example.com"),
975+
authorization_servers=[AnyHttpUrl("https://auth.example.com")],
976+
)
977+
978+
assert provider.context.get_resource_url() == snapshot("https://api.example.com")
979+
980+
952981
class TestRegistrationResponse:
953982
"""Test client registration response handling."""
954983

0 commit comments

Comments
 (0)