@@ -111,6 +111,8 @@ class OAuthContext:
111111 # Discovery state for fallback support
112112 discovery_base_url : str | None = None
113113 discovery_pathname : str | None = None
114+ # Optional expected issuer for access tokens (JWT iss claim)
115+ expected_issuer : str | None = None
114116
115117 def get_authorization_base_url (self , server_url : str ) -> str :
116118 """Extract base URL by removing path component."""
@@ -126,12 +128,64 @@ def update_token_expiry(self, token: OAuthToken) -> None:
126128
127129 def is_token_valid (self ) -> bool :
128130 """Check if current token is valid."""
129- return bool (
131+ # Basic existence and expiry checks
132+ basic_valid = bool (
130133 self .current_tokens
131134 and self .current_tokens .access_token
132135 and (not self .token_expiry_time or time .time () <= self .token_expiry_time )
133136 )
134137
138+ if not basic_valid :
139+ return False
140+
141+ # If no expected issuer is configured, behave as before
142+ if not getattr (self , "expected_issuer" , None ):
143+ return True
144+
145+ # If expected_issuer is set, ensure token issuer matches
146+ try :
147+ return self ._token_issuer_matches (self .current_tokens .access_token )
148+ except Exception :
149+ # On any parsing issue, treat token as invalid
150+ logger .exception ("Failed to validate token issuer" )
151+ return False
152+
153+ def _token_issuer_matches (self , token : str ) -> bool :
154+ """Decode a JWT access token (no signature verification) and compare its 'iss' claim.
155+
156+ This performs a safe, minimal check: split the token, base64-decode the payload,
157+ parse JSON, and compare the 'iss' field to self.expected_issuer. Returns False
158+ if the token is malformed or the claim is missing/mismatched.
159+ """
160+ # JWTs are in the form header.payload.signature
161+ parts = token .split ("." )
162+ if len (parts ) < 2 :
163+ return False
164+
165+ payload_b64 = parts [1 ]
166+
167+ # Add padding for base64 if necessary
168+ padding = "=" * (- len (payload_b64 ) % 4 )
169+ payload_b64 += padding
170+
171+ try :
172+ payload_bytes = base64 .urlsafe_b64decode (payload_b64 .encode ())
173+ except Exception :
174+ return False
175+
176+ try :
177+ import json
178+
179+ payload = json .loads (payload_bytes )
180+ except Exception :
181+ return False
182+
183+ iss = payload .get ("iss" )
184+ if not iss :
185+ return False
186+
187+ return iss == self .expected_issuer
188+
135189 def can_refresh_token (self ) -> bool :
136190 """Check if token can be refreshed."""
137191 return bool (self .current_tokens and self .current_tokens .refresh_token and self .client_info )
0 commit comments