Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@

import asyncio
import logging
import jwt
import threading
from typing import Any
from dataclasses import dataclass

from jwt import PyJWKClient, PyJWK, decode, get_unverified_header

Expand All @@ -13,15 +15,72 @@
logger = logging.getLogger(__name__)


@dataclass
class _JwkClientCacheEntry:

jwk_client: PyJWKClient
lock: threading.Lock


class _JwkClientManager:
"""Helper class to manage PyJWKClient instances for different JWKS URIs, with caching and async-safety"""

_cache: dict[str, _JwkClientCacheEntry]

def __init__(self):
self._cache = {}

def _get_jwk_client(self, jwks_uri: str) -> _JwkClientCacheEntry:
"""Retrieves a PyJWKClient for the given JWKS URI, using a cache to avoid creating multiple clients for the same URI."""
if jwks_uri not in self._cache:
self._cache[jwks_uri] = _JwkClientCacheEntry(
PyJWKClient(jwks_uri), threading.Lock()
)
return self._cache[jwks_uri]

async def get_signing_key(self, jwks_uri: str, header: dict[str, Any]) -> PyJWK:
"""Retrieves the signing key from the JWK client for the given token header."""

jwk_cache_entry = self._get_jwk_client(jwks_uri)

# locking and creating a new thread seems strange,
# but PyJWKClient.get_signing_key is synchronous, so we spawn another thread
# to make the call non-blocking, allowing other queued coroutines to run in the meantime.
# Meanwhile, the lock ensures safety for the PyJWKClient's underlying cache and
# prevents duplicate calls to the JWKS endpoint for the same URI when multiple
# coroutines are trying to get signing keys concurrently.

def _helper():
with jwk_cache_entry.lock:
return jwk_cache_entry.jwk_client.get_signing_key(header["kid"])

key = await asyncio.to_thread(_helper)
return key


class JwtTokenValidator:
"""Utility class for validating JWT tokens using the PyJWT library and JWKs from a specified URI."""

_jwk_client_manager = _JwkClientManager()

def __init__(self, configuration: AgentAuthConfiguration):
"""Initializes the JwtTokenValidator with the given configuration.

:param configuration: An instance of AgentAuthConfiguration containing the necessary settings for token validation.
"""
self.configuration = configuration

async def validate_token(self, token: str) -> ClaimsIdentity:
"""Validates a JWT token.

:param token: The JWT token to validate.
:return: A ClaimsIdentity object containing the token's claims if validation is successful.
:raises ValueError: If the token is invalid or if the audience claim is not valid
"""

logger.debug("Validating JWT token.")
key = await self._get_public_key_or_secret(token)
decoded_token = jwt.decode(
decoded_token = decode(
token,
key=key,
algorithms=["RS256"],
Expand All @@ -37,20 +96,21 @@ async def validate_token(self, token: str) -> ClaimsIdentity:
return ClaimsIdentity(decoded_token, True, security_token=token)

def get_anonymous_claims(self) -> ClaimsIdentity:
"""Returns a ClaimsIdentity for an anonymous user."""
logger.debug("Returning anonymous claims identity.")
return ClaimsIdentity({}, False, authentication_type="Anonymous")

async def _get_public_key_or_secret(self, token: str) -> PyJWK:
"""Retrieves the public key or secret for validating the JWT token."""
header = get_unverified_header(token)
unverified_payload: dict = decode(token, options={"verify_signature": False})

jwksUri = (
jwks_uri = (
"https://login.botframework.com/v1/.well-known/keys"
if unverified_payload.get("iss") == "https://api.botframework.com"
else f"https://login.microsoftonline.com/{self.configuration.TENANT_ID}/discovery/v2.0/keys"
)
jwks_client = PyJWKClient(jwksUri)

key = await asyncio.to_thread(jwks_client.get_signing_key, header["kid"])
key = await self._jwk_client_manager.get_signing_key(jwks_uri, header)

return key
1 change: 1 addition & 0 deletions tests/copilotstudio_client/test_copilot_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from aiohttp import ClientSession, ClientError
from urllib.parse import urlparse


@pytest.mark.asyncio
async def test_copilot_client_error(mocker):
# Define the connection settings
Expand Down
Empty file.
128 changes: 128 additions & 0 deletions tests/hosting_core/authorization/test_jwk_client_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
import asyncio
import threading
import time

import pytest
from jwt import PyJWKClient

from microsoft_agents.hosting.core.authorization.jwt_token_validator import (
_JwkClientManager,
)


async def _wait_until_set(event: threading.Event, timeout: float = 1.0) -> None:
start = time.monotonic()
while not event.is_set():
if time.monotonic() - start > timeout:
raise AssertionError("Timed out waiting for threading event.")
await asyncio.sleep(0.01)


class TestJwkClientManager:
def test_get_jwk_client_reuses_cache_for_same_uri(self):
manager = _JwkClientManager()
jwks_uri = "https://issuer.example.com/keys"

first = manager._get_jwk_client(jwks_uri)
second = manager._get_jwk_client(jwks_uri)

assert first is second
assert len(manager._cache) == 1

def test_get_jwk_client_creates_distinct_entries_for_distinct_uris(self):
manager = _JwkClientManager()

first = manager._get_jwk_client("https://issuer-a.example.com/keys")
second = manager._get_jwk_client("https://issuer-b.example.com/keys")

assert first is not second
assert first.lock is not second.lock
assert len(manager._cache) == 2

@pytest.mark.asyncio
async def test_get_signing_key_calls_pyjwkclient_with_header_kid(self, monkeypatch):
manager = _JwkClientManager()
jwks_uri = "https://issuer.example.com/keys"
seen_kids = []
expected_key = object()

def fake_get_signing_key(self, kid):
seen_kids.append(kid)
return expected_key

# Only mocked member: PyJWKClient.get_signing_key
monkeypatch.setattr(PyJWKClient, "get_signing_key", fake_get_signing_key)

key = await manager.get_signing_key(jwks_uri, {"kid": "kid-123"})

assert key is expected_key
assert seen_kids == ["kid-123"]

@pytest.mark.asyncio
async def test_get_signing_key_reuses_same_client_for_same_uri(self, monkeypatch):
manager = _JwkClientManager()
jwks_uri = "https://issuer.example.com/keys"
client_ids = []

def fake_get_signing_key(self, kid):
client_ids.append(id(self))
return {"kid": kid}

# Only mocked member: PyJWKClient.get_signing_key
monkeypatch.setattr(PyJWKClient, "get_signing_key", fake_get_signing_key)

await manager.get_signing_key(jwks_uri, {"kid": "kid-a"})
await manager.get_signing_key(jwks_uri, {"kid": "kid-b"})

assert client_ids[0] == client_ids[1]
assert len(manager._cache) == 1

@pytest.mark.asyncio
async def test_get_signing_key_serializes_concurrent_calls_per_uri(
self, monkeypatch
):
manager = _JwkClientManager()
jwks_uri = "https://issuer.example.com/keys"

first_entered = threading.Event()
second_entered = threading.Event()
release_first = threading.Event()

def fake_get_signing_key(self, kid):
if kid == "kid-1":
first_entered.set()
if not release_first.wait(timeout=2):
raise TimeoutError("First call was not released in time.")
elif kid == "kid-2":
second_entered.set()
return {"kid": kid}

# Only mocked member: PyJWKClient.get_signing_key
monkeypatch.setattr(PyJWKClient, "get_signing_key", fake_get_signing_key)

first_task = asyncio.create_task(
manager.get_signing_key(jwks_uri, {"kid": "kid-1"})
)
await _wait_until_set(first_entered)

second_task = asyncio.create_task(
manager.get_signing_key(jwks_uri, {"kid": "kid-2"})
)

# If per-URI lock works, second call must not enter get_signing_key yet.
await asyncio.sleep(0.05)
assert not second_entered.is_set()

release_first.set()
results = await asyncio.gather(first_task, second_task)

assert results[0]["kid"] == "kid-1"
assert results[1]["kid"] == "kid-2"
assert second_entered.is_set()

@pytest.mark.asyncio
async def test_get_signing_key_raises_key_error_when_header_has_no_kid(self):
manager = _JwkClientManager()

with pytest.raises(KeyError):
await manager.get_signing_key("https://issuer.example.com/keys", {})