1+ import logging
12import os
23import ssl
3- import logging
4+
45import jwt
5- import grpc
6+ import requests
67from aiohttp import hdrs , web
7-
8- from temporalio .api .common .v1 import Payload , Payloads
9- from temporalio .api .cloud .cloudservice .v1 import request_response_pb2 , service_pb2_grpc
108from google .protobuf import json_format
9+ from jwt .algorithms import RSAAlgorithm
10+ from temporalio .api .cloud .cloudservice .v1 import GetUsersRequest
11+ from temporalio .api .common .v1 import Payloads
12+ from temporalio .client import CloudOperationsClient
13+
1114from encryption_jwt .codec import EncryptionCodec
1215
13- AUTHORIZED_ACCOUNT_ACCESS_ROLES = ["admin" ]
16+ AUTHORIZED_ACCOUNT_ACCESS_ROLES = ["owner" , " admin" ]
1417AUTHORIZED_NAMESPACE_ACCESS_ROLES = ["read" , "write" , "admin" ]
1518
19+ TEMPORAL_CLIENT_CLOUD_API_VERSION = "2024-05-13-00"
20+
1621temporal_ops_address = "saas-api.tmprl.cloud:443"
1722if os .environ .get ("TEMPORAL_OPS_ADDRESS" ):
1823 temporal_ops_address = os .environ .get ("TEMPORAL_OPS_ADDRESS" )
@@ -42,52 +47,90 @@ async def cors_options(req: web.Request) -> web.Response:
4247
4348 return resp
4449
45- def decryption_authorized (email : str , namespace : str ) -> bool :
46- credentials = grpc .composite_channel_credentials (grpc .ssl_channel_credentials (
47- ), grpc .access_token_call_credentials (os .environ .get ("TEMPORAL_API_KEY" )))
48-
49- with grpc .secure_channel (temporal_ops_address , credentials ) as channel :
50- client = service_pb2_grpc .CloudServiceStub (channel )
51- request = request_response_pb2 .GetUsersRequest ()
52-
53- response = client .GetUsers (request , metadata = (
54- ("temporal-cloud-api-version" , os .environ .get ("TEMPORAL_OPS_API_VERSION" )),))
55-
56- authorized = False
57- for user in response .users :
58- if user .spec .email .lower () == email .lower ():
59- if user .spec .access .account_access .role in AUTHORIZED_ACCOUNT_ACCESS_ROLES :
60- authorized = True
61- else :
62- if namespace in user .spec .access .namespace_accesses :
63- if user .spec .access .namespace_accesses [namespace ].permission in AUTHORIZED_NAMESPACE_ACCESS_ROLES :
64- authorized = True
65-
66- return authorized
50+ async def decryption_authorized (email : str , namespace : str ) -> bool :
51+ client = await CloudOperationsClient .connect (
52+ api_key = os .environ .get ("TEMPORAL_API_KEY" ),
53+ version = TEMPORAL_CLIENT_CLOUD_API_VERSION ,
54+ )
55+
56+ response = await client .cloud_service .get_users (
57+ GetUsersRequest (namespace = namespace )
58+ )
59+
60+ for user in response .users :
61+ if user .spec .email .lower () == email .lower ():
62+ if (
63+ user .spec .access .account_access .role
64+ in AUTHORIZED_ACCOUNT_ACCESS_ROLES
65+ ):
66+ return True
67+ else :
68+ if namespace in user .spec .access .namespace_accesses :
69+ if (
70+ user .spec .access .namespace_accesses [namespace ].permission
71+ in AUTHORIZED_NAMESPACE_ACCESS_ROLES
72+ ):
73+ return True
74+
75+ return False
6776
6877 def make_handler (fn : str ):
6978 async def handler (req : web .Request ):
70- # Read payloads as JSON
71- assert req .content_type == "application/json"
72- payloads = json_format .Parse (await req .read (), Payloads ())
73-
74- # Extract the email from the JWT.
75- auth_header = req .headers .get ("Authorization" )
7679 namespace = req .headers .get ("x-namespace" )
80+ auth_header = req .headers .get ("Authorization" )
7781 _bearer , encoded = auth_header .split (" " )
78- decoded = jwt .decode (encoded , options = {"verify_signature" : False })
7982
80- # Use the email to determine if the payload should be decrypted.
81- authorized = decryption_authorized (decoded ["https://saas-api.tmprl.cloud/user/email" ], namespace )
83+ # Extract the kid from the Auth header
84+ jwt_dict = jwt .get_unverified_header (encoded )
85+ kid = jwt_dict ["kid" ]
86+ algorithm = jwt_dict ["alg" ]
87+
88+ # Fetch Temporal Cloud JWKS
89+ jwks_url = "https://login.tmprl.cloud/.well-known/jwks.json"
90+ jwks = requests .get (jwks_url ).json ()
91+
92+ # Extract Temporal Cloud's public key
93+ public_key = None
94+ for key in jwks ["keys" ]:
95+ if key ["kid" ] == kid :
96+ # Convert JWKS key to PEM format
97+ public_key = RSAAlgorithm .from_jwk (key )
98+ break
99+
100+ if public_key is None :
101+ raise ValueError ("Public key not found in JWKS" )
102+
103+ # Decode the jwt, verifying against Temporal Cloud's public key
104+ decoded = jwt .decode (
105+ encoded ,
106+ public_key ,
107+ algorithms = [algorithm ],
108+ audience = [
109+ "https://saas-api.tmprl.cloud" ,
110+ "https://prod-tmprl.us.auth0.com/userinfo" ,
111+ ],
112+ )
113+
114+ # Use the email to determine if the user is authorized to decrypt the payload
115+ authorized = await decryption_authorized (
116+ decoded ["https://saas-api.tmprl.cloud/user/email" ], namespace
117+ )
118+
82119 if authorized :
120+ # Read payloads as JSON
121+ assert req .content_type == "application/json"
122+ payloads = json_format .Parse (await req .read (), Payloads ())
83123 encryptionCodec = EncryptionCodec (namespace )
84- payloads = Payloads (payloads = await getattr (encryptionCodec , fn )(payloads .payloads ))
124+ payloads = Payloads (
125+ payloads = await getattr (encryptionCodec , fn )(payloads .payloads )
126+ )
85127
86128 # Apply CORS and return JSON
87129 resp = await cors_options (req )
88130 resp .content_type = "application/json"
89131 resp .text = json_format .MessageToJson (payloads )
90132 return resp
133+
91134 return handler
92135
93136 # Build app
@@ -97,8 +140,8 @@ async def handler(req: web.Request):
97140 logger = logging .getLogger (__name__ )
98141 app .add_routes (
99142 [
100- web .post ("/encode" , make_handler (' encode' )),
101- web .post ("/decode" , make_handler (' decode' )),
143+ web .post ("/encode" , make_handler (" encode" )),
144+ web .post ("/decode" , make_handler (" decode" )),
102145 web .options ("/decode" , cors_options ),
103146 ]
104147 )
@@ -112,8 +155,10 @@ async def handler(req: web.Request):
112155 if os .environ .get ("SSL_PEM" ) and os .environ .get ("SSL_KEY" ):
113156 ssl_context = ssl .create_default_context (ssl .Purpose .CLIENT_AUTH )
114157 ssl_context .check_hostname = False
115- ssl_context .load_cert_chain (os .environ .get (
116- "SSL_PEM" ), os .environ .get ("SSL_KEY" ))
158+ ssl_context .load_cert_chain (
159+ os .environ .get ("SSL_PEM" ), os .environ .get ("SSL_KEY" )
160+ )
117161
118- web .run_app (build_codec_server (), host = "0.0.0.0" ,
119- port = 8081 , ssl_context = ssl_context )
162+ web .run_app (
163+ build_codec_server (), host = "0.0.0.0" , port = 8081 , ssl_context = ssl_context
164+ )
0 commit comments