1313import time
1414from collections .abc import AsyncGenerator , Awaitable , Callable
1515from dataclasses import dataclass , field
16- from typing import Protocol
16+ from typing import Any , Protocol
1717from urllib .parse import urlencode , urljoin , urlparse
18+ from uuid import uuid4
1819
1920import anyio
2021import httpx
22+ import jwt
2123from pydantic import BaseModel , Field , ValidationError
2224
2325from mcp .client .streamable_http import MCP_PROTOCOL_VERSION
@@ -61,6 +63,23 @@ def generate(cls) -> "PKCEParameters":
6163 return cls (code_verifier = code_verifier , code_challenge = code_challenge )
6264
6365
66+ class JWTParameters (BaseModel ):
67+ """JWT parameters."""
68+
69+ assertion : str | None = Field (
70+ default = None ,
71+ description = "JWT assertion for JWT authentication. "
72+ "Will be used instead of generating a new assertion if provided." ,
73+ )
74+
75+ issuer : str | None = Field (default = None , description = "Issuer for JWT assertions." )
76+ subject : str | None = Field (default = None , description = "Subject identifier for JWT assertions." )
77+ claims : dict [str , Any ] | None = Field (default = None , description = "Additional claims for JWT assertions." )
78+ jwt_signing_algorithm : str | None = Field (default = "RS256" , description = "Algorithm for signing JWT assertions." )
79+ jwt_signing_key : str | None = Field (default = None , description = "Private key for JWT signing." )
80+ jwt_lifetime_seconds : int = Field (default = 300 , description = "Lifetime of generated JWT in seconds." )
81+
82+
6483class TokenStorage (Protocol ):
6584 """Protocol for token storage implementations."""
6685
@@ -91,6 +110,7 @@ class OAuthContext:
91110 redirect_handler : Callable [[str ], Awaitable [None ]] | None
92111 callback_handler : Callable [[], Awaitable [tuple [str , str | None ]]] | None
93112 timeout : float = 300.0
113+ jwt_parameters : JWTParameters | None = None
94114
95115 # Discovered metadata
96116 protected_resource_metadata : ProtectedResourceMetadata | None = None
@@ -192,6 +212,7 @@ def __init__(
192212 redirect_handler : Callable [[str ], Awaitable [None ]] | None = None ,
193213 callback_handler : Callable [[], Awaitable [tuple [str , str | None ]]] | None = None ,
194214 timeout : float = 300.0 ,
215+ jwt_parameters : JWTParameters | None = None ,
195216 ):
196217 """Initialize OAuth2 authentication."""
197218 self .context = OAuthContext (
@@ -201,6 +222,7 @@ def __init__(
201222 redirect_handler = redirect_handler ,
202223 callback_handler = callback_handler ,
203224 timeout = timeout ,
225+ jwt_parameters = jwt_parameters ,
204226 )
205227 self ._initialized = False
206228
@@ -314,6 +336,9 @@ async def _perform_authorization(self) -> httpx.Request:
314336 if "client_credentials" in self .context .client_metadata .grant_types :
315337 token_request = await self ._exchange_token_client_credentials ()
316338 return token_request
339+ elif "urn:ietf:params:oauth:grant-type:jwt-bearer" in self .context .client_metadata .grant_types :
340+ token_request = await self ._exchange_token_jwt_bearer ()
341+ return token_request
317342 else :
318343 auth_code , code_verifier = await self ._perform_authorization_code_grant ()
319344 token_request = await self ._exchange_token_authorization_code (auth_code , code_verifier )
@@ -372,19 +397,22 @@ async def _perform_authorization_code_grant(self) -> tuple[str, str]:
372397 # Return auth code and code verifier for token exchange
373398 return auth_code , pkce_params .code_verifier
374399
400+ def _get_token_endpoint (self ) -> str :
401+ if self .context .oauth_metadata and self .context .oauth_metadata .token_endpoint :
402+ token_url = str (self .context .oauth_metadata .token_endpoint )
403+ else :
404+ auth_base_url = self .context .get_authorization_base_url (self .context .server_url )
405+ token_url = urljoin (auth_base_url , "/token" )
406+ return token_url
407+
375408 async def _exchange_token_authorization_code (self , auth_code : str , code_verifier : str ) -> httpx .Request :
376409 """Build token exchange request for authorization_code flow."""
377410 if self .context .client_metadata .redirect_uris is None :
378411 raise OAuthFlowError ("No redirect URIs provided for authorization code grant" )
379412 if not self .context .client_info :
380413 raise OAuthFlowError ("Missing client info" )
381414
382- if self .context .oauth_metadata and self .context .oauth_metadata .token_endpoint :
383- token_url = str (self .context .oauth_metadata .token_endpoint )
384- else :
385- auth_base_url = self .context .get_authorization_base_url (self .context .server_url )
386- token_url = urljoin (auth_base_url , "/token" )
387-
415+ token_url = self ._get_token_endpoint ()
388416 token_data = {
389417 "grant_type" : "authorization_code" ,
390418 "code" : auth_code ,
@@ -409,19 +437,17 @@ async def _exchange_token_client_credentials(self) -> httpx.Request:
409437 if not self .context .client_info :
410438 raise OAuthFlowError ("Missing client info" )
411439
412- if self .context .oauth_metadata and self .context .oauth_metadata .token_endpoint :
413- token_url = str (self .context .oauth_metadata .token_endpoint )
414- else :
415- auth_base_url = self .context .get_authorization_base_url (self .context .server_url )
416- token_url = urljoin (auth_base_url , "/token" )
417-
440+ token_url = self ._get_token_endpoint ()
418441 token_data = {
419442 "grant_type" : "client_credentials" ,
420- "resource" : self .context .get_resource_url (), # RFC 8707
421443 }
422444
423445 headers = {"Content-Type" : "application/x-www-form-urlencoded" }
424446
447+ # Only include resource param if conditions are met
448+ if self .context .should_include_resource_param (self .context .protocol_version ):
449+ token_data ["resource" ] = self .context .get_resource_url () # RFC 8707
450+
425451 if self .context .client_metadata .scope :
426452 token_data ["scope" ] = self .context .client_metadata .scope
427453
@@ -442,6 +468,57 @@ async def _exchange_token_client_credentials(self) -> httpx.Request:
442468
443469 return httpx .Request ("POST" , token_url , data = token_data , headers = headers )
444470
471+ async def _exchange_token_jwt_bearer (self ) -> httpx .Request :
472+ """Build token exchange request for JWT bearer grant."""
473+ if not self .context .client_info :
474+ raise OAuthFlowError ("Missing client info" )
475+ if not self .context .jwt_parameters :
476+ raise OAuthFlowError ("Missing JWT parameters" )
477+
478+ token_url = self ._get_token_endpoint ()
479+
480+ if self .context .jwt_parameters .assertion is not None :
481+ assertion = self .context .jwt_parameters .assertion
482+ else :
483+ if not self .context .jwt_parameters .jwt_signing_key :
484+ raise OAuthFlowError ("Missing signing key for JWT bearer grant" )
485+ if not self .context .jwt_parameters .issuer :
486+ raise OAuthFlowError ("Missing issuer for JWT bearer grant" )
487+ if not self .context .jwt_parameters .subject :
488+ raise OAuthFlowError ("Missing subject for JWT bearer grant" )
489+
490+ now = int (time .time ())
491+ claims = {
492+ "iss" : self .context .jwt_parameters .issuer ,
493+ "sub" : self .context .jwt_parameters .subject ,
494+ "aud" : token_url ,
495+ "exp" : now + self .context .jwt_parameters .jwt_lifetime_seconds ,
496+ "iat" : now ,
497+ "jti" : str (uuid4 ()),
498+ }
499+ claims .update (self .context .jwt_parameters .claims or {})
500+
501+ assertion = jwt .encode (
502+ claims ,
503+ self .context .jwt_parameters .jwt_signing_key ,
504+ algorithm = self .context .jwt_parameters .jwt_signing_algorithm or "RS256" ,
505+ )
506+
507+ token_data = {
508+ "grant_type" : "urn:ietf:params:oauth:grant-type:jwt-bearer" ,
509+ "assertion" : assertion ,
510+ }
511+
512+ if self .context .should_include_resource_param (self .context .protocol_version ):
513+ token_data ["resource" ] = self .context .get_resource_url ()
514+
515+ if self .context .client_metadata .scope :
516+ token_data ["scope" ] = self .context .client_metadata .scope
517+
518+ return httpx .Request (
519+ "POST" , token_url , data = token_data , headers = {"Content-Type" : "application/x-www-form-urlencoded" }
520+ )
521+
445522 async def _handle_token_response (self , response : httpx .Response ) -> None :
446523 """Handle token exchange response."""
447524 if response .status_code != 200 :
0 commit comments