Skip to content

Commit 53c03b3

Browse files
Merge pull request #10 from bitovi/pr-feedback
PR feedback
2 parents 55f18b2 + daa374b commit 53c03b3

11 files changed

Lines changed: 1565 additions & 1164 deletions

File tree

.gitignore

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,2 @@
11
.venv
22
__pycache__
3-
_certs

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ Some examples require extra dependencies. See each sample's directory for specif
6060
* [custom_decorator](custom_decorator) - Custom decorator to auto-heartbeat a long-running activity.
6161
* [dsl](dsl) - DSL workflow that executes steps defined in a YAML file.
6262
* [encryption](encryption) - Apply end-to-end encryption for all input/output.
63+
* [encryption_jwt](encryption_jwt) - Apply end-to-end encryption for all input/output using a KMS and per-namespace JWT-based auth.
6364
* [gevent_async](gevent_async) - Combine gevent and Temporal.
6465
* [langchain](langchain) - Orchestrate workflows for LangChain.
6566
* [open_telemetry](open_telemetry) - Trace workflows with OpenTelemetry.

encryption_jwt/.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
_certs

encryption_jwt/README.md

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,10 @@ The Codec Server uses the [Operations API](https://docs.temporal.io/ops) to get
1111

1212
## Install
1313

14-
For this sample, the optional `encryption` and `bedrock` dependency groups must be included. To include, run:
14+
For this sample, the optional `encryption_jwt` and `bedrock` dependency groups must be included. To include, run:
1515

1616
```sh
17-
poetry install --with encryption,bedrock
17+
poetry install --with encryption_jwt,bedrock
1818
```
1919

2020
## Setup
@@ -31,17 +31,17 @@ Alternately replace the key management portion with your own implementation.
3131
### Self-signed certificates
3232

3333
The codec server will need to use HTTPS, self-signed certificates will work in the development
34-
environment. Run the following command in a `_certs` directory that's a subdirectory of the
35-
repository root, it will create certificate files that are good for 10 years.
34+
environment. Run the following command in a `_certs` directory that's a subdirectory of this one.
35+
It will create certificate files that are good for 10 years.
3636

3737
```sh
3838
openssl req -x509 -newkey rsa:4096 -sha256 -days 3650 -nodes -keyout localhost.key -out localhost.pem -subj "/CN=localhost"
3939
```
4040

4141
In the projects you can access the files using the following relative paths.
4242

43-
- `../_certs/localhost.pem`
44-
- `../_certs/localhost.key`
43+
- `./_certs/localhost.pem`
44+
- `./_certs/localhost.key`
4545

4646
## Run
4747

encryption_jwt/codec.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,21 @@
11
from typing import Iterable, List
2+
23
from temporalio.api.common.v1 import Payload
34
from temporalio.converter import PayloadCodec
5+
46
from encryption_jwt.encryptor import KMSEncryptor
57

68

79
class EncryptionCodec(PayloadCodec):
8-
910
def __init__(self, namespace: str):
1011
self._encryptor = KMSEncryptor(namespace)
1112

1213
async def encode(self, payloads: Iterable[Payload]) -> List[Payload]:
1314
# We blindly encode all payloads with the key and set the metadata with the key that was
1415
# used (base64 encoded).
1516

16-
def encrypt_payload(p: Payload):
17-
data, key = self._encryptor.encrypt(p.SerializeToString())
17+
async def encrypt_payload(p: Payload):
18+
data, key = await self._encryptor.encrypt(p.SerializeToString())
1819
return Payload(
1920
metadata={
2021
"encoding": b"binary/encrypted",
@@ -23,12 +24,14 @@ def encrypt_payload(p: Payload):
2324
data=data,
2425
)
2526

26-
return list(map(encrypt_payload, payloads))
27+
# return list(map(encrypt_payload, payloads))
28+
return [await encrypt_payload(payload) for payload in payloads]
2729

2830
async def decode(self, payloads: Iterable[Payload]) -> List[Payload]:
29-
def decrypt_payload(p: Payload):
31+
async def decrypt_payload(p: Payload):
3032
data_key_encrypted_base64 = p.metadata.get("data_key_encrypted", b"")
31-
data = self._encryptor.decrypt(data_key_encrypted_base64, p.data)
33+
data = await self._encryptor.decrypt(data_key_encrypted_base64, p.data)
3234
return Payload.FromString(data)
3335

34-
return list(map(decrypt_payload, payloads))
36+
# return list(map(decrypt_payload, payloads))
37+
return [await decrypt_payload(payload) for payload in payloads]

encryption_jwt/codec_server.py

Lines changed: 89 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,23 @@
1+
import logging
12
import os
23
import ssl
3-
import logging
4+
45
import jwt
5-
import grpc
6+
import requests
67
from aiohttp import hdrs, web
7-
8-
from temporalio.api.common.v1 import Payload, Payloads
9-
from temporalio.api.cloud.cloudservice.v1 import request_response_pb2, service_pb2_grpc
108
from google.protobuf import json_format
9+
from jwt.algorithms import RSAAlgorithm
10+
from temporalio.api.cloud.cloudservice.v1 import GetUsersRequest
11+
from temporalio.api.common.v1 import Payloads
12+
from temporalio.client import CloudOperationsClient
13+
1114
from encryption_jwt.codec import EncryptionCodec
1215

13-
AUTHORIZED_ACCOUNT_ACCESS_ROLES = ["admin"]
16+
AUTHORIZED_ACCOUNT_ACCESS_ROLES = ["owner", "admin"]
1417
AUTHORIZED_NAMESPACE_ACCESS_ROLES = ["read", "write", "admin"]
1518

19+
TEMPORAL_CLIENT_CLOUD_API_VERSION = "2024-05-13-00"
20+
1621
temporal_ops_address = "saas-api.tmprl.cloud:443"
1722
if os.environ.get("TEMPORAL_OPS_ADDRESS"):
1823
temporal_ops_address = os.environ.get("TEMPORAL_OPS_ADDRESS")
@@ -42,52 +47,90 @@ async def cors_options(req: web.Request) -> web.Response:
4247

4348
return resp
4449

45-
def decryption_authorized(email: str, namespace: str) -> bool:
46-
credentials = grpc.composite_channel_credentials(grpc.ssl_channel_credentials(
47-
), grpc.access_token_call_credentials(os.environ.get("TEMPORAL_API_KEY")))
48-
49-
with grpc.secure_channel(temporal_ops_address, credentials) as channel:
50-
client = service_pb2_grpc.CloudServiceStub(channel)
51-
request = request_response_pb2.GetUsersRequest()
52-
53-
response = client.GetUsers(request, metadata=(
54-
("temporal-cloud-api-version", os.environ.get("TEMPORAL_OPS_API_VERSION")),))
55-
56-
authorized = False
57-
for user in response.users:
58-
if user.spec.email.lower() == email.lower():
59-
if user.spec.access.account_access.role in AUTHORIZED_ACCOUNT_ACCESS_ROLES:
60-
authorized = True
61-
else:
62-
if namespace in user.spec.access.namespace_accesses:
63-
if user.spec.access.namespace_accesses[namespace].permission in AUTHORIZED_NAMESPACE_ACCESS_ROLES:
64-
authorized = True
65-
66-
return authorized
50+
async def decryption_authorized(email: str, namespace: str) -> bool:
51+
client = await CloudOperationsClient.connect(
52+
api_key=os.environ.get("TEMPORAL_API_KEY"),
53+
version=TEMPORAL_CLIENT_CLOUD_API_VERSION,
54+
)
55+
56+
response = await client.cloud_service.get_users(
57+
GetUsersRequest(namespace=namespace)
58+
)
59+
60+
for user in response.users:
61+
if user.spec.email.lower() == email.lower():
62+
if (
63+
user.spec.access.account_access.role
64+
in AUTHORIZED_ACCOUNT_ACCESS_ROLES
65+
):
66+
return True
67+
else:
68+
if namespace in user.spec.access.namespace_accesses:
69+
if (
70+
user.spec.access.namespace_accesses[namespace].permission
71+
in AUTHORIZED_NAMESPACE_ACCESS_ROLES
72+
):
73+
return True
74+
75+
return False
6776

6877
def make_handler(fn: str):
6978
async def handler(req: web.Request):
70-
# Read payloads as JSON
71-
assert req.content_type == "application/json"
72-
payloads = json_format.Parse(await req.read(), Payloads())
73-
74-
# Extract the email from the JWT.
75-
auth_header = req.headers.get("Authorization")
7679
namespace = req.headers.get("x-namespace")
80+
auth_header = req.headers.get("Authorization")
7781
_bearer, encoded = auth_header.split(" ")
78-
decoded = jwt.decode(encoded, options={"verify_signature": False})
7982

80-
# Use the email to determine if the payload should be decrypted.
81-
authorized = decryption_authorized(decoded["https://saas-api.tmprl.cloud/user/email"], namespace)
83+
# Extract the kid from the Auth header
84+
jwt_dict = jwt.get_unverified_header(encoded)
85+
kid = jwt_dict["kid"]
86+
algorithm = jwt_dict["alg"]
87+
88+
# Fetch Temporal Cloud JWKS
89+
jwks_url = "https://login.tmprl.cloud/.well-known/jwks.json"
90+
jwks = requests.get(jwks_url).json()
91+
92+
# Extract Temporal Cloud's public key
93+
public_key = None
94+
for key in jwks["keys"]:
95+
if key["kid"] == kid:
96+
# Convert JWKS key to PEM format
97+
public_key = RSAAlgorithm.from_jwk(key)
98+
break
99+
100+
if public_key is None:
101+
raise ValueError("Public key not found in JWKS")
102+
103+
# Decode the jwt, verifying against Temporal Cloud's public key
104+
decoded = jwt.decode(
105+
encoded,
106+
public_key,
107+
algorithms=[algorithm],
108+
audience=[
109+
"https://saas-api.tmprl.cloud",
110+
"https://prod-tmprl.us.auth0.com/userinfo",
111+
],
112+
)
113+
114+
# Use the email to determine if the user is authorized to decrypt the payload
115+
authorized = await decryption_authorized(
116+
decoded["https://saas-api.tmprl.cloud/user/email"], namespace
117+
)
118+
82119
if authorized:
120+
# Read payloads as JSON
121+
assert req.content_type == "application/json"
122+
payloads = json_format.Parse(await req.read(), Payloads())
83123
encryptionCodec = EncryptionCodec(namespace)
84-
payloads = Payloads(payloads=await getattr(encryptionCodec, fn)(payloads.payloads))
124+
payloads = Payloads(
125+
payloads=await getattr(encryptionCodec, fn)(payloads.payloads)
126+
)
85127

86128
# Apply CORS and return JSON
87129
resp = await cors_options(req)
88130
resp.content_type = "application/json"
89131
resp.text = json_format.MessageToJson(payloads)
90132
return resp
133+
91134
return handler
92135

93136
# Build app
@@ -97,8 +140,8 @@ async def handler(req: web.Request):
97140
logger = logging.getLogger(__name__)
98141
app.add_routes(
99142
[
100-
web.post("/encode", make_handler('encode')),
101-
web.post("/decode", make_handler('decode')),
143+
web.post("/encode", make_handler("encode")),
144+
web.post("/decode", make_handler("decode")),
102145
web.options("/decode", cors_options),
103146
]
104147
)
@@ -112,8 +155,10 @@ async def handler(req: web.Request):
112155
if os.environ.get("SSL_PEM") and os.environ.get("SSL_KEY"):
113156
ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
114157
ssl_context.check_hostname = False
115-
ssl_context.load_cert_chain(os.environ.get(
116-
"SSL_PEM"), os.environ.get("SSL_KEY"))
158+
ssl_context.load_cert_chain(
159+
os.environ.get("SSL_PEM"), os.environ.get("SSL_KEY")
160+
)
117161

118-
web.run_app(build_codec_server(), host="0.0.0.0",
119-
port=8081, ssl_context=ssl_context)
162+
web.run_app(
163+
build_codec_server(), host="0.0.0.0", port=8081, ssl_context=ssl_context
164+
)

encryption_jwt/encryptor.py

Lines changed: 43 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,33 +1,37 @@
1-
import os
21
import base64
32
import logging
4-
from temporalio import workflow
5-
from cryptography.hazmat.primitives.ciphers.aead import AESGCM
3+
import os
4+
65
from botocore.exceptions import ClientError
6+
from cryptography.hazmat.primitives.ciphers.aead import AESGCM
7+
from temporalio import workflow
78

89
with workflow.unsafe.imports_passed_through():
9-
import boto3
10+
import aioboto3
1011

1112

1213
class KMSEncryptor:
1314
"""Encrypts and decrypts using keys from AWS KMS."""
1415

1516
def __init__(self, namespace: str):
1617
self._namespace = namespace
17-
self._kms_client = None
18+
self._boto_session = None
1819

1920
@property
20-
def kms_client(self):
21+
def boto_session(self):
2122
"""Get a KMS client from boto3."""
22-
if not self._kms_client:
23-
self._kms_client = boto3.client("kms")
23+
if not self._boto_session:
24+
session = aioboto3.Session()
25+
self._boto_session = session
2426

25-
return self._kms_client
27+
return self._boto_session
2628

27-
def encrypt(self, data: bytes) -> tuple[bytes, bytes]:
29+
async def encrypt(self, data: bytes) -> tuple[bytes, bytes]:
2830
"""Encrypt data using a key from KMS."""
2931
# The keys are rotated automatically by KMS, so fetch a new key to encrypt the data.
30-
data_key_encrypted, data_key_plaintext = self.__create_data_key(self._namespace)
32+
data_key_encrypted, data_key_plaintext = await self.__create_data_key(
33+
self._namespace
34+
)
3135

3236
if data_key_encrypted is None:
3337
raise ValueError("No data key!")
@@ -38,38 +42,42 @@ def encrypt(self, data: bytes) -> tuple[bytes, bytes]:
3842
data_key_encrypted
3943
)
4044

41-
def decrypt(self, data_key_encrypted_base64, data: bytes) -> bytes:
45+
async def decrypt(self, data_key_encrypted_base64, data: bytes) -> bytes:
4246
"""Encrypt data using a key from KMS."""
4347
data_key_encrypted = base64.b64decode(data_key_encrypted_base64)
44-
data_key_plaintext = self.__decrypt_data_key(data_key_encrypted)
48+
data_key_plaintext = await self.__decrypt_data_key(data_key_encrypted)
4549
encryptor = AESGCM(data_key_plaintext)
4650
return encryptor.decrypt(data[:12], data[12:], None)
4751

48-
def __create_data_key(self, namespace: str):
52+
async def __create_data_key(self, namespace: str):
4953
"""Get a set of keys from AWS KMS that can be used to encrypt data."""
5054

5155
# Create data key
52-
alias_name = 'alias/' + namespace.replace('.', '_')
53-
response = self.kms_client.describe_key(KeyId=alias_name)
54-
cmk_id = response['KeyMetadata']['Arn']
55-
key_spec = "AES_256"
56-
try:
57-
response = self.kms_client.generate_data_key(KeyId=cmk_id, KeySpec=key_spec)
58-
except ClientError as e:
59-
logging.error(e)
60-
return None, None
61-
62-
# Return the encrypted and plaintext data key
63-
return response["CiphertextBlob"], response["Plaintext"]
64-
65-
def __decrypt_data_key(self, data_key_encrypted):
56+
alias_name = "alias/" + namespace.replace(".", "_")
57+
async with self.boto_session.client("kms") as kms_client:
58+
response = await kms_client.describe_key(KeyId=alias_name)
59+
cmk_id = response["KeyMetadata"]["Arn"]
60+
key_spec = "AES_256"
61+
try:
62+
response = await kms_client.generate_data_key(
63+
KeyId=cmk_id, KeySpec=key_spec
64+
)
65+
except ClientError as e:
66+
logging.error(e)
67+
return None, None
68+
69+
# Return the encrypted and plaintext data key
70+
return response["CiphertextBlob"], response["Plaintext"]
71+
72+
async def __decrypt_data_key(self, data_key_encrypted):
6673
"""Use AWS KMS to exchange an encrypted key for its plaintext value."""
6774

68-
# Decrypt the data key
69-
try:
70-
response = self.kms_client.decrypt(CiphertextBlob=data_key_encrypted)
71-
except ClientError as e:
72-
logging.error(e)
73-
return None
75+
async with self.boto_session.client("kms") as kms_client:
76+
# Decrypt the data key
77+
try:
78+
response = await kms_client.decrypt(CiphertextBlob=data_key_encrypted)
79+
except ClientError as e:
80+
logging.error(e)
81+
return None
7482

75-
return response["Plaintext"]
83+
return response["Plaintext"]

0 commit comments

Comments
 (0)