Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/auto-publish.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ jobs:
- name: Set up Python
uses: actions/setup-python@v2
with:
python-version: "3.11"
python-version: "3.14"

- name: Install dependencies
run: |
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/pull-request.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ jobs:

strategy:
matrix:
python-version: ["3.8", "3.9", "3.10", "3.11"]
python-version: ["3.10", "3.11", "3.12", "3.13", "3.14"]

name: "Run tests for python ${{ matrix.python-version }}"
steps:
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ It pairs well with the Stytch [Web SDK](https://www.npmjs.com/package/@stytch/va

## Requirements

The Stytch Python library supports Python 3.8+
The Stytch Python library supports Python 3.10+

## Installation

Expand Down
4 changes: 2 additions & 2 deletions requirements_dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,6 @@
build==1.2.2.post1

# For type checking and testing
black==24.3.0
mypy==0.991
black==26.3.1
mypy==1.19.1
types-requests==2.28.11.5
5 changes: 2 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,9 @@
"Intended Audience :: Developers",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3 :: Only",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
],
python_requires=">=3.8",
packages=find_packages(
Expand All @@ -55,7 +54,7 @@
install_requires=[
"aiohttp>=3.8.3",
"requests>=2.7.0",
"pydantic>=1.10.2",
"pydantic>=2.0",
"pyjwt[crypto]>=2.9.0",
],
)
28 changes: 21 additions & 7 deletions stytch/b2b/api/rbac_organizations.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,31 +179,45 @@ def validate_org_policy(project_policy: B2BPolicy, org_policy: OrgPolicy) -> Non
for role in org_policy.roles:
org_role_id = role.role_id
if org_role_id in org_roles:
raise Exception(f"Duplicate role {org_role_id} in Organization RBAC policy")
raise Exception(
f"Duplicate role {org_role_id} in Organization RBAC policy"
)
org_roles.add(org_role_id)

if org_role_id in project_roles:
raise Exception(f"Role {org_role_id} already defined in Project RBAC policy")
raise Exception(
f"Role {org_role_id} already defined in Project RBAC policy"
)

for permission in role.permissions:
resource_id = permission.resource_id
if not resource_id in project_resources:
raise Exception(f"Resource {resource_id} not defined in Project RBAC policy")
raise Exception(
f"Resource {resource_id} not defined in Project RBAC policy"
)

if len(permission.actions) == 0:
raise Exception(f"No actions defined for role {org_role_id}, resource {resource_id}")
raise Exception(
f"No actions defined for role {org_role_id}, resource {resource_id}"
)
if len(permission.actions) == 1 and "*" == permission.actions[0]:
continue
if len(permission.actions) > 1 and "*" in permission.actions:
raise Exception("Wildcard actions must be the only action defined for a role and resource")
raise Exception(
"Wildcard actions must be the only action defined for a role and resource"
)

project_resource = project_resources[resource_id]
for action in permission.actions:
if action.strip() == "":
raise Exception(f"Empty action on resource {resource_id} is not permitted")
raise Exception(
f"Empty action on resource {resource_id} is not permitted"
)

if not action in project_resource.actions:
raise Exception(f"Unknown action {action} defined on resource {resource_id}")
raise Exception(
f"Unknown action {action} defined on resource {resource_id}"
)

return

Expand Down
58 changes: 57 additions & 1 deletion stytch/b2b/api/sessions.py
Original file line number Diff line number Diff line change
Expand Up @@ -970,6 +970,62 @@ def _authenticate_jwt_local_common(
roles_claim=roles_claim,
)

async def _authenticate_jwt_local_common_async(
self,
session_jwt: str,
max_token_age_seconds: Optional[int] = None,
leeway: int = 0,
) -> Optional[LocalJWTResponse]:
_session_claim = "https://stytch.com/session"
_organization_claim = "https://stytch.com/organization"
generic_claims = await jwt_helpers.authenticate_jwt_local_async(
project_id=self.project_id,
jwks_client=self.jwks_client,
jwt=session_jwt,
max_token_age_seconds=max_token_age_seconds,
leeway=leeway,
base_url=self.api_base.base_url,
)
if generic_claims is None:
return None

claim = generic_claims.untyped_claims[_session_claim]
custom_claims = {
k: v
for k, v in generic_claims.untyped_claims.items()
if k not in [_session_claim, _organization_claim]
}

# For JWTs that include it, prefer the inner expires_at claim.
expires_at = claim.get("expires_at", generic_claims.reserved_claims["exp"])

# Claim related to unpacking organization-specific fields
org_claim = generic_claims.untyped_claims[_organization_claim]

# Claim related to RBAC roles
roles_claim = claim.get("roles")
if roles_claim is not None:
if not isinstance(roles_claim, list) or not all(
isinstance(x, str) for x in roles_claim
):
raise ValueError("Invalid roles claim. Expected a list of strings.")

return LocalJWTResponse(
member_session=MemberSession(
authentication_factors=claim["authentication_factors"],
expires_at=expires_at,
last_accessed_at=claim["last_accessed_at"],
member_session_id=claim["id"],
started_at=claim["started_at"],
organization_id=org_claim["organization_id"],
member_id=generic_claims.reserved_claims["sub"],
custom_claims=custom_claims,
roles=roles_claim or [],
organization_slug=org_claim["slug"],
),
roles_claim=roles_claim,
)

def authenticate_jwt_local(
self,
session_jwt: str,
Expand Down Expand Up @@ -1008,7 +1064,7 @@ async def authenticate_jwt_local_async(
leeway: int = 0,
authorization_check: Optional[AuthorizationCheck] = None,
) -> Optional[MemberSession]:
local_resp = self._authenticate_jwt_local_common(
local_resp = await self._authenticate_jwt_local_common_async(
session_jwt=session_jwt,
max_token_age_seconds=max_token_age_seconds,
leeway=leeway,
Expand Down
Empty file.
136 changes: 136 additions & 0 deletions stytch/b2b/api/test/test_sessions_jwt_async.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
#!/usr/bin/env python3

import asyncio
import time
import unittest
from typing import cast
from unittest.mock import AsyncMock, MagicMock, patch

from stytch.b2b.api.sessions import Sessions
from stytch.b2b.models.sessions import AuthorizationCheck
from stytch.shared.jwt_helpers import GenericClaims

FAKE_JWT = "fake.jwt.token"
FAKE_PROJECT_ID = "project-test-abc123"
_SESSION_CLAIM = "https://stytch.com/session"
_ORG_CLAIM = "https://stytch.com/organization"

FAKE_GENERIC_CLAIMS = GenericClaims(
reserved_claims={"sub": "member-test-123", "exp": 9999999999},
untyped_claims={
_SESSION_CLAIM: {
"id": "session-test-123",
"authentication_factors": [],
"last_accessed_at": "2026-01-01T00:00:00Z",
"started_at": "2026-01-01T00:00:00Z",
"expires_at": "2026-01-01T01:00:00Z",
"roles": ["stytch_member"],
},
_ORG_CLAIM: {
"organization_id": "org-test-123",
"slug": "test-org",
},
},
)


def _make_sessions() -> Sessions:
mock_api_base = MagicMock()
mock_api_base.base_url = "https://test.stytch.com/"
return Sessions(
api_base=mock_api_base,
sync_client=MagicMock(),
async_client=MagicMock(),
jwks_client=MagicMock(),
project_id=FAKE_PROJECT_ID,
policy_cache=MagicMock(),
)


class TestB2BAuthenticateJWTLocalAsync(unittest.IsolatedAsyncioTestCase):
def setUp(self) -> None:
self.sessions = _make_sessions()
self.auth_check = AuthorizationCheck(
organization_id="org-test-123",
resource_id="documents",
action="read",
)

def test_is_coroutine_function(self) -> None:
self.assertTrue(
asyncio.iscoroutinefunction(self.sessions.authenticate_jwt_local_async)
)

@patch(
"stytch.b2b.api.sessions.jwt_helpers.authenticate_jwt_local_async",
new_callable=AsyncMock,
)
async def test_returns_none_for_invalid_jwt(self, mock_jwt) -> None:
mock_jwt.return_value = None
result = await self.sessions.authenticate_jwt_local_async(session_jwt=FAKE_JWT)
self.assertIsNone(result)

@patch(
"stytch.b2b.api.sessions.jwt_helpers.authenticate_jwt_local_async",
new_callable=AsyncMock,
)
async def test_returns_member_session_for_valid_jwt_without_auth_check(
self, mock_jwt
) -> None:
mock_jwt.return_value = FAKE_GENERIC_CLAIMS
result = await self.sessions.authenticate_jwt_local_async(session_jwt=FAKE_JWT)
self.assertIsNotNone(result)
if result is not None:
self.assertEqual(result.member_session_id, "session-test-123")
self.assertEqual(result.member_id, "member-test-123")
self.assertEqual(result.organization_id, "org-test-123")

@patch("stytch.b2b.api.sessions.rbac_local.perform_authorization_check")
@patch(
"stytch.b2b.api.sessions.jwt_helpers.authenticate_jwt_local_async",
new_callable=AsyncMock,
)
async def test_uses_get_with_org_async_not_sync_for_authorization_check(
self, mock_jwt, _mock_rbac
) -> None:
mock_jwt.return_value = FAKE_GENERIC_CLAIMS
mock_policy = MagicMock()
policy_cache = cast(MagicMock, self.sessions.policy_cache)
policy_cache.get_with_org_async = AsyncMock(return_value=mock_policy)

await self.sessions.authenticate_jwt_local_async(
session_jwt=FAKE_JWT, authorization_check=self.auth_check
)

policy_cache.get_with_org_async.assert_awaited_once_with("org-test-123")
policy_cache.get_with_org.assert_not_called()

async def test_is_non_blocking_jwt_verification(self) -> None:
DELAY = 0.1
N = 5

async def slow_authenticate_jwt_local_async(**kwargs) -> GenericClaims:
await asyncio.sleep(DELAY)
return FAKE_GENERIC_CLAIMS

with patch(
"stytch.b2b.api.sessions.jwt_helpers.authenticate_jwt_local_async",
side_effect=slow_authenticate_jwt_local_async,
):
start = time.monotonic()
results = await asyncio.gather(
*[
self.sessions.authenticate_jwt_local_async(session_jwt=FAKE_JWT)
for _ in range(N)
]
)
elapsed = time.monotonic() - start

# All N calls should interleave at the await point, completing in ~DELAY total
# (not N * DELAY as would happen if get_signing_key_from_jwt blocked the event loop)
self.assertLess(elapsed, DELAY * 2)
self.assertEqual(len(results), N)


if __name__ == "__main__":
unittest.main()
55 changes: 52 additions & 3 deletions stytch/consumer/api/sessions.py
Original file line number Diff line number Diff line change
Expand Up @@ -622,16 +622,16 @@ async def authenticate_jwt_async(
zero or use the authenticate method instead.
"""
# Return the local_result if available, otherwise call the Stytch API
local_token = self.authenticate_jwt_local(
local_session = await self.authenticate_jwt_local_async(
session_jwt=session_jwt,
max_token_age_seconds=max_token_age_seconds,
authorization_check=authorization_check,
)
if local_token is not None:
if local_session is not None:
return AuthenticateJWTLocalResponse.from_json(
status_code=200,
json={
"session": local_token,
"session": local_session,
"session_jwt": session_jwt,
"status_code": 200,
"request_id": "",
Expand Down Expand Up @@ -707,4 +707,53 @@ def authenticate_jwt_local(
roles=claim["roles"],
)

async def authenticate_jwt_local_async(
self,
session_jwt: str,
max_token_age_seconds: Optional[int] = None,
leeway: int = 0,
authorization_check: Optional[AuthorizationCheck] = None,
) -> Optional[Session]:
_session_claim = "https://stytch.com/session"
generic_claims = await jwt_helpers.authenticate_jwt_local_async(
project_id=self.project_id,
jwks_client=self.jwks_client,
jwt=session_jwt,
max_token_age_seconds=max_token_age_seconds,
leeway=leeway,
base_url=self.api_base.base_url,
)
if generic_claims is None:
return None

claim = generic_claims.untyped_claims[_session_claim]
custom_claims = {
k: v
for k, v in generic_claims.untyped_claims.items()
if k != _session_claim
}

# For JWTs that include it, prefer the inner expires_at claim.
expires_at = claim.get("expires_at", generic_claims.reserved_claims["exp"])

if authorization_check is not None:
_session_claim = "https://stytch.com/session"
rbac_local.perform_consumer_authorization_check(
policy=await self.policy_cache.get_async(),
subject_roles=claim["roles"],
authorization_check=authorization_check,
)

return Session(
attributes=claim["attributes"],
authentication_factors=claim["authentication_factors"],
expires_at=expires_at,
last_accessed_at=claim["last_accessed_at"],
session_id=claim["id"],
started_at=claim["started_at"],
user_id=generic_claims.reserved_claims["sub"],
custom_claims=custom_claims,
roles=claim["roles"],
)

# ENDMANUAL(authenticate_jwt_local)
Loading
Loading