Skip to content

Commit 6dd33ed

Browse files
committed
refactor: pull more oauth helpers out
1 parent 6901553 commit 6dd33ed

File tree

3 files changed

+124
-35
lines changed

3 files changed

+124
-35
lines changed

src/mcp/client/auth/oauth2.py

Lines changed: 24 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,8 @@
44
Implements authorization code flow with PKCE and automatic token refresh.
55
"""
66

7-
import base64
8-
import hashlib
97
import logging
108
import secrets
11-
import string
129
import time
1310
from collections.abc import AsyncGenerator, Awaitable, Callable
1411
from dataclasses import dataclass, field
@@ -32,6 +29,7 @@
3229
handle_auth_metadata_response,
3330
handle_protected_resource_response,
3431
handle_registration_response,
32+
handle_token_response_scopes,
3533
)
3634
from mcp.client.streamable_http import MCP_PROTOCOL_VERSION
3735
from mcp.shared.auth import (
@@ -41,7 +39,12 @@
4139
OAuthToken,
4240
ProtectedResourceMetadata,
4341
)
44-
from mcp.shared.auth_utils import check_resource_allowed, resource_url_from_server_url
42+
from mcp.shared.auth_utils import (
43+
calculate_token_expiry,
44+
check_resource_allowed,
45+
generate_pkce_parameters,
46+
resource_url_from_server_url,
47+
)
4548

4649
logger = logging.getLogger(__name__)
4750

@@ -54,10 +57,8 @@ class PKCEParameters(BaseModel):
5457

5558
@classmethod
5659
def generate(cls) -> "PKCEParameters":
57-
"""Generate new PKCE parameters."""
58-
code_verifier = "".join(secrets.choice(string.ascii_letters + string.digits + "-._~") for _ in range(128))
59-
digest = hashlib.sha256(code_verifier.encode()).digest()
60-
code_challenge = base64.urlsafe_b64encode(digest).decode().rstrip("=")
60+
"""Generate new PKCE parameters using shared util function."""
61+
code_verifier, code_challenge = generate_pkce_parameters(verifier_length=128)
6162
return cls(code_verifier=code_verifier, code_challenge=code_challenge)
6263

6364

@@ -114,11 +115,8 @@ def get_authorization_base_url(self, server_url: str) -> str:
114115
return f"{parsed.scheme}://{parsed.netloc}"
115116

116117
def update_token_expiry(self, token: OAuthToken) -> None:
117-
"""Update token expiry time."""
118-
if token.expires_in:
119-
self.token_expiry_time = time.time() + token.expires_in
120-
else:
121-
self.token_expiry_time = None
118+
"""Update token expiry time using shared util function."""
119+
self.token_expiry_time = calculate_token_expiry(token.expires_in)
122120

123121
def is_token_valid(self) -> bool:
124122
"""Check if current token is valid."""
@@ -364,26 +362,20 @@ async def _handle_token_response(self, response: httpx.Response) -> None:
364362
"""Handle token exchange response."""
365363
if response.status_code != 200:
366364
body = await response.aread()
367-
body = body.decode("utf-8")
368-
raise OAuthTokenError(f"Token exchange failed ({response.status_code}): {body}")
369-
370-
try:
371-
content = await response.aread()
372-
token_response = OAuthToken.model_validate_json(content)
373-
374-
# Validate scopes
375-
if token_response.scope and self.context.client_metadata.scope:
376-
requested_scopes = set(self.context.client_metadata.scope.split())
377-
returned_scopes = set(token_response.scope.split())
378-
unauthorized_scopes = returned_scopes - requested_scopes
379-
if unauthorized_scopes:
380-
raise OAuthTokenError(f"Server granted unauthorized scopes: {unauthorized_scopes}")
365+
body_text = body.decode("utf-8")
366+
raise OAuthTokenError(f"Token exchange failed ({response.status_code}): {body_text}")
367+
368+
# Parse and validate response with scope validation
369+
token_response = await handle_token_response_scopes(
370+
response,
371+
self.context.client_metadata,
372+
validate_scope=True,
373+
)
381374

382-
self.context.current_tokens = token_response
383-
self.context.update_token_expiry(token_response)
384-
await self.context.storage.set_tokens(token_response)
385-
except ValidationError as e:
386-
raise OAuthTokenError(f"Invalid token response: {e}")
375+
# Store tokens in context
376+
self.context.current_tokens = token_response
377+
self.context.update_token_expiry(token_response)
378+
await self.context.storage.set_tokens(token_response)
387379

388380
async def _refresh_token(self) -> httpx.Request:
389381
"""Build token refresh request."""

src/mcp/client/auth/utils.py

Lines changed: 47 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,15 @@
55
from httpx import Request, Response
66
from pydantic import ValidationError
77

8-
from mcp.client.auth import OAuthRegistrationError
8+
from mcp.client.auth import OAuthRegistrationError, OAuthTokenError
99
from mcp.client.streamable_http import MCP_PROTOCOL_VERSION
10-
from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata, OAuthMetadata, ProtectedResourceMetadata
10+
from mcp.shared.auth import (
11+
OAuthClientInformationFull,
12+
OAuthClientMetadata,
13+
OAuthMetadata,
14+
OAuthToken,
15+
ProtectedResourceMetadata,
16+
)
1117
from mcp.types import LATEST_PROTOCOL_VERSION
1218

1319
logger = logging.getLogger(__name__)
@@ -220,6 +226,45 @@ async def handle_registration_response(response: Response) -> OAuthClientInforma
220226
raise OAuthRegistrationError(f"Invalid registration response: {e}")
221227

222228

229+
async def handle_token_response_scopes(
230+
response: Response,
231+
client_metadata: OAuthClientMetadata,
232+
validate_scope: bool = True,
233+
) -> OAuthToken:
234+
"""Parse and validate token response with optional scope validation.
235+
236+
Parses token response JSON and validates scopes to prevent scope escalation
237+
if requested. Callers should check response.status_code before calling.
238+
239+
Args:
240+
response: HTTP response from token endpoint (status already checked by caller)
241+
client_metadata: Client metadata containing requested scopes (if any)
242+
validate_scope: Whether to validate scopes (default True). Set False for refresh.
243+
244+
Returns:
245+
Validated OAuthToken model
246+
247+
Raises:
248+
OAuthTokenError: If response JSON is invalid or contains unauthorized scopes
249+
"""
250+
try:
251+
content = await response.aread()
252+
token_response = OAuthToken.model_validate_json(content)
253+
254+
# Validate scopes to prevent scope escalation
255+
# Only validate during initial token exchange, not during refresh
256+
if validate_scope and token_response.scope and client_metadata.scope:
257+
requested_scopes = set(client_metadata.scope.split())
258+
returned_scopes = set(token_response.scope.split())
259+
unauthorized_scopes = returned_scopes - requested_scopes
260+
if unauthorized_scopes:
261+
raise OAuthTokenError(f"Server granted unauthorized scopes: {unauthorized_scopes}")
262+
263+
return token_response
264+
except ValidationError as e:
265+
raise OAuthTokenError(f"Invalid token response: {e}")
266+
267+
223268
# async def prm_discovery(
224269
# server_url: str,
225270
# www_auth_resource_metadata_url: str | None,

src/mcp/shared/auth_utils.py

Lines changed: 53 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
1-
"""Utilities for OAuth 2.0 Resource Indicators (RFC 8707)."""
1+
"""Utilities for OAuth 2.0 Resource Indicators (RFC 8707) and PKCE (RFC 7636)."""
22

3+
import base64
4+
import hashlib
5+
import secrets
6+
import string
7+
import time
38
from urllib.parse import urlparse, urlsplit, urlunsplit
49

510
from pydantic import AnyUrl, HttpUrl
@@ -67,3 +72,50 @@ def check_resource_allowed(requested_resource: str, configured_resource: str) ->
6772
configured_path += "/"
6873

6974
return requested_path.startswith(configured_path)
75+
76+
77+
def generate_pkce_parameters(verifier_length: int = 128) -> tuple[str, str]:
78+
"""Generate PKCE verifier and challenge per RFC 7636.
79+
80+
Generates cryptographically secure code_verifier and code_challenge
81+
for OAuth 2.0 PKCE (Proof Key for Code Exchange).
82+
83+
Args:
84+
verifier_length: Length of code_verifier (43-128 chars per RFC 7636, default 128)
85+
86+
Returns:
87+
Tuple of (code_verifier, code_challenge)
88+
89+
Raises:
90+
ValueError: If verifier_length is not between 43 and 128
91+
"""
92+
if not 43 <= verifier_length <= 128:
93+
raise ValueError("verifier_length must be between 43 and 128 per RFC 7636")
94+
95+
# Generate code_verifier using unreserved characters per RFC 7636 Section 4.1
96+
# unreserved = ALPHA / DIGIT / "-" / "." / "_" / "~"
97+
code_verifier = "".join(
98+
secrets.choice(string.ascii_letters + string.digits + "-._~") for _ in range(verifier_length)
99+
)
100+
101+
# Generate code_challenge using S256 method per RFC 7636 Section 4.2
102+
# code_challenge = BASE64URL(SHA256(ASCII(code_verifier)))
103+
digest = hashlib.sha256(code_verifier.encode("ascii")).digest()
104+
code_challenge = base64.urlsafe_b64encode(digest).decode("ascii").rstrip("=")
105+
106+
return code_verifier, code_challenge
107+
108+
109+
def calculate_token_expiry(expires_in: int | str | None) -> float | None:
110+
"""Calculate token expiry timestamp from expires_in seconds.
111+
112+
Args:
113+
expires_in: Seconds until token expiration (may be string from some servers)
114+
115+
Returns:
116+
Unix timestamp when token expires, or None if no expiry specified
117+
"""
118+
if expires_in is None:
119+
return None
120+
# Defensive: handle servers that return expires_in as string
121+
return time.time() + int(expires_in)

0 commit comments

Comments
 (0)