Skip to content

Commit aa61306

Browse files
committed
Merge branch 'master' of https://github.com/rohe/pyjwkest
2 parents 640e9b3 + a5549af commit aa61306

File tree

5 files changed

+102
-14
lines changed

5 files changed

+102
-14
lines changed

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
requests
2-
pycrypto >= 2.2
2+
pycryptodomex>=3.4.2

src/jwkest/jwe.py

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -489,15 +489,20 @@ def encrypt(self, key, iv="", cek="", **kwargs):
489489
else:
490490
raise ParameterError("Zip has unknown value: %s" % self["zip"])
491491

492+
kwarg_cek = cek or None
493+
492494
_enc = self["enc"]
493495
cek, iv = self._generate_key_and_iv(_enc, cek, iv)
496+
self["cek"] = cek
494497

495498
logger.debug("cek: %s, iv: %s" % ([c for c in cek], [c for c in iv]))
496499

497500
_encrypt = RSAEncrypter(self.with_digest).encrypt
498501

499502
_alg = self["alg"]
500-
if _alg == "RSA-OAEP":
503+
if kwarg_cek:
504+
jwe_enc_key = ''
505+
elif _alg == "RSA-OAEP":
501506
jwe_enc_key = _encrypt(cek, key, 'pkcs1_oaep_padding')
502507
elif _alg == "RSA1_5":
503508
jwe_enc_key = _encrypt(cek, key)
@@ -511,7 +516,7 @@ def encrypt(self, key, iv="", cek="", **kwargs):
511516
ctxt, tag, key = self.enc_setup(_enc, _msg, enc_header, cek, iv)
512517
return jwe.pack(parts=[jwe_enc_key, iv, ctxt, tag])
513518

514-
def decrypt(self, token, key):
519+
def decrypt(self, token, key, cek=None):
515520
""" Decrypts a JWT
516521
517522
:param token: The JWT
@@ -529,13 +534,16 @@ def decrypt(self, token, key):
529534
_decrypt = RSAEncrypter(self.with_digest).decrypt
530535

531536
_alg = jwe.headers["alg"]
532-
if _alg == "RSA-OAEP":
537+
if cek:
538+
pass
539+
elif _alg == "RSA-OAEP":
533540
cek = _decrypt(jek, key, 'pkcs1_oaep_padding')
534541
elif _alg == "RSA1_5":
535542
cek = _decrypt(jek, key)
536543
else:
537544
raise NotSupportedAlgorithm(_alg)
538545

546+
self["cek"] = cek
539547
enc = jwe.headers["enc"]
540548
try:
541549
assert enc in SUPPORTED["enc"]
@@ -687,7 +695,7 @@ def encrypt(self, key, iv="", cek="", **kwargs):
687695
return jwe.pack(parts=[kwargs['encrypted_key'], iv, ctxt, tag])
688696
return jwe.pack(parts=[iv, ctxt, tag])
689697

690-
def decrypt(self, token=None, key=None):
698+
def decrypt(self, token=None, key=None, **kwargs):
691699

692700
if not self.cek:
693701
raise Exception("Content Encryption Key is Not Yet Set")
@@ -747,7 +755,7 @@ def encrypt(self, keys=None, cek="", iv="", **kwargs):
747755
:return: Encrypted message
748756
"""
749757

750-
encrypted_key = cek = iv = None
758+
# encrypted_key = cek = iv = None
751759
_alg = self["alg"]
752760

753761
# Find Usable Keys
@@ -801,6 +809,7 @@ def encrypt(self, keys=None, cek="", iv="", **kwargs):
801809

802810
try:
803811
token = encrypter.encrypt(_key, **kwargs)
812+
self["cek"] = encrypter.cek if 'cek' in encrypter else None
804813
except TypeError as err:
805814
raise err
806815
else:
@@ -811,7 +820,7 @@ def encrypt(self, keys=None, cek="", iv="", **kwargs):
811820
logger.error("Could not find any suitable encryption key")
812821
raise NoSuitableEncryptionKey()
813822

814-
def decrypt(self, token=None, keys=None, alg=None):
823+
def decrypt(self, token=None, keys=None, alg=None, cek=None):
815824
if token:
816825
jwe = JWEnc().unpack(token)
817826
# header, ek, eiv, ctxt, tag = token.split(b".")
@@ -829,7 +838,7 @@ def decrypt(self, token=None, keys=None, alg=None):
829838
else:
830839
keys = self._pick_keys(self._get_keys(), use="enc", alg=_alg)
831840

832-
if not keys:
841+
if not keys and not cek:
833842
raise NoSuitableDecryptionKey(_alg)
834843

835844
if _alg in ["RSA-OAEP", "RSA1_5"]:
@@ -847,10 +856,21 @@ def decrypt(self, token=None, keys=None, alg=None):
847856
else:
848857
raise NotSupportedAlgorithm
849858

859+
if cek:
860+
try:
861+
msg = decrypter.decrypt(as_bytes(token), None, cek=cek)
862+
self["cek"] = decrypter.cek if 'cek' in decrypter else None
863+
except (KeyError, DecryptionFailed):
864+
pass
865+
else:
866+
logger.debug("Decrypted message using exiting CEK")
867+
return msg
868+
850869
for key in keys:
851870
_key = key.encryption_key(alg=_alg, private=False)
852871
try:
853872
msg = decrypter.decrypt(as_bytes(token), _key)
873+
self["cek"] = decrypter.cek if 'cek' in decrypter else None
854874
except (KeyError, DecryptionFailed):
855875
pass
856876
else:

src/jwkest/jwk.py

Lines changed: 39 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from jwkest import JWKESTException
2323
from jwkest import b64d
2424
from jwkest import b64e
25+
from jwkest import UnknownAlgorithm
2526
from jwkest.ecc import NISTEllipticCurve
2627
from jwkest.jwt import b2s_conv
2728

@@ -87,18 +88,18 @@ def sha512_digest(msg):
8788
# =============================================================================
8889

8990

90-
def import_rsa_key_from_file(filename):
91-
return RSA.importKey(open(filename, 'r').read())
91+
def import_rsa_key_from_file(filename, passphrase=None):
92+
return RSA.importKey(open(filename, 'r').read(), passphrase=passphrase)
9293

9394

94-
def import_rsa_key(key):
95+
def import_rsa_key(key, passphrase=None):
9596
"""
9697
Extract an RSA key from a PEM-encoded certificate
97-
9898
:param key: RSA key encoded in standard form
99+
:param passphrase: Password to open the certificate (Optional)
99100
:return: RSA key instance
100101
"""
101-
return importKey(key)
102+
return importKey(key, passphrase=passphrase)
102103

103104

104105
def der2rsa(der):
@@ -189,6 +190,39 @@ def x509_rsa_load(txt):
189190
return [("rsa", import_rsa_key(txt))]
190191

191192

193+
def key_from_jwk_dict(jwk_dict, private=True):
194+
"""Load JWK from dictionary"""
195+
if jwk_dict['kty'] == 'EC':
196+
if private:
197+
return ECKey(kid=jwk_dict['kid'],
198+
crv=jwk_dict['crv'],
199+
x=jwk_dict['x'],
200+
y=jwk_dict['y'],
201+
d=jwk_dict['d'])
202+
else:
203+
return ECKey(kid=jwk_dict['kid'],
204+
crv=jwk_dict['crv'],
205+
x=jwk_dict['x'],
206+
y=jwk_dict['y'])
207+
elif jwk_dict['kty'] == 'RSA':
208+
if private:
209+
return RSAKey(kid=jwk_dict['kid'],
210+
n=jwk_dict['n'],
211+
e=jwk_dict['e'],
212+
d=jwk_dict['d'],
213+
p=jwk_dict['p'],
214+
q=jwk_dict['q'])
215+
else:
216+
return RSAKey(kid=jwk_dict['kid'],
217+
n=jwk_dict['n'],
218+
e=jwk_dict['e'])
219+
elif jwk_dict['kty'] == 'oct':
220+
return SYMKey(kid=jwk_dict['kid'],
221+
k=jwk_dict['k'])
222+
else:
223+
raise UnknownAlgorithm
224+
225+
192226
class Key(object):
193227
"""
194228
Basic JSON Web key class

src/jwkest/jws.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,10 @@ def sign(self, msg, key):
161161

162162
def verify(self, msg, sig, key):
163163
h = bytes_to_long(self.digest.new(msg).digest())
164-
return self._sign.verify(h, sig, key)
164+
if self._sign.verify(h, sig, key):
165+
return True
166+
else:
167+
raise BadSignature()
165168

166169

167170
class PSSSigner(Signer):

tests/test_4_jwe.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,37 @@ def full_path(local_file):
193193
rsa = RSA.importKey(open(KEY, 'r').read())
194194
plain = b'Now is the time for all good men to come to the aid of their country.'
195195

196+
def test_cek_reuse_encryption_rsaes_rsa15():
197+
198+
_rsa = JWE_RSA(plain, alg="RSA1_5", enc="A128CBC-HS256")
199+
jwt = _rsa.encrypt(rsa)
200+
dec = JWE_RSA()
201+
msg = dec.decrypt(jwt, rsa)
202+
203+
assert msg == plain
204+
205+
_rsa2 = JWE_RSA(plain, alg="RSA1_5", enc="A128CBC-HS256")
206+
jwt = _rsa2.encrypt(None, cek=dec["cek"])
207+
dec2 = JWE_RSA()
208+
msg = dec2.decrypt(jwt, None, cek=_rsa["cek"])
209+
210+
assert msg == plain
211+
212+
def test_cek_reuse_encryption_rsaes_rsa_oaep():
213+
214+
_rsa = JWE_RSA(plain, alg="RSA-OAEP", enc="A256GCM")
215+
jwt = _rsa.encrypt(rsa)
216+
dec = JWE_RSA()
217+
msg = dec.decrypt(jwt, rsa)
218+
219+
assert msg == plain
220+
221+
_rsa2 = JWE_RSA(plain, alg="RSA-OAEP", enc="A256GCM")
222+
jwt = _rsa2.encrypt(None, cek=dec["cek"])
223+
dec2 = JWE_RSA()
224+
msg = dec2.decrypt(jwt, None, cek=_rsa["cek"])
225+
226+
assert msg == plain
196227

197228
def test_rsa_encrypt_decrypt_rsa_cbc():
198229
_rsa = JWE_RSA(plain, alg="RSA1_5", enc="A128CBC-HS256")

0 commit comments

Comments
 (0)