Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 85 additions & 25 deletions lib/MessageCrypto.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,6 @@ MessageCrypto::MessageCrypto(const std::string& logCtx, bool keyGenNeeded)
ivLen_(12),
iv_(new unsigned char[ivLen_]),
logCtx_(logCtx) {
SSL_library_init();
SSL_load_error_strings();

if (!keyGenNeeded) {
mdCtx_ = EVP_MD_CTX_create();
EVP_MD_CTX_init(mdCtx_);
Expand All @@ -50,17 +47,17 @@ MessageCrypto::MessageCrypto(const std::string& logCtx, bool keyGenNeeded)

MessageCrypto::~MessageCrypto() {}

RSA* MessageCrypto::loadPublicKey(std::string& pubKeyStr) {
EVP_PKEY* MessageCrypto::loadPublicKey(std::string& pubKeyStr) {
BIO* pubBio = NULL;
RSA* rsaPub = NULL;
EVP_PKEY* rsaPub = NULL;

pubBio = BIO_new_mem_buf((char*)pubKeyStr.c_str(), -1);
if (pubBio == NULL) {
LOG_ERROR(logCtx_ << " Failed to get memory for public key");
return rsaPub;
}

rsaPub = PEM_read_bio_RSA_PUBKEY(pubBio, NULL, NULL, NULL);
rsaPub = PEM_read_bio_PUBKEY(pubBio, NULL, NULL, NULL);
if (rsaPub == NULL) {
LOG_ERROR(logCtx_ << " Failed to load public key");
}
Expand All @@ -69,17 +66,17 @@ RSA* MessageCrypto::loadPublicKey(std::string& pubKeyStr) {
return rsaPub;
}

RSA* MessageCrypto::loadPrivateKey(std::string& privateKeyStr) {
EVP_PKEY* MessageCrypto::loadPrivateKey(std::string& privateKeyStr) {
BIO* privBio = NULL;
RSA* rsaPriv = NULL;
EVP_PKEY* rsaPriv = NULL;

privBio = BIO_new_mem_buf((char*)privateKeyStr.c_str(), -1);
if (privBio == NULL) {
LOG_ERROR(logCtx_ << " Failed to get memory for private key");
return rsaPriv;
}

rsaPriv = PEM_read_bio_RSAPrivateKey(privBio, NULL, NULL, NULL);
rsaPriv = PEM_read_bio_PrivateKey(privBio, NULL, NULL, NULL);
if (rsaPriv == NULL) {
LOG_ERROR(logCtx_ << " Failed to load private key");
}
Expand All @@ -88,6 +85,59 @@ RSA* MessageCrypto::loadPrivateKey(std::string& privateKeyStr) {
return rsaPriv;
}

bool MessageCrypto::rsaDecrypt(EVP_PKEY_CTX* ctx, const std::string& in,
boost::scoped_array<unsigned char>& out, size_t& outLen) {
if (EVP_PKEY_decrypt_init(ctx) <= 0) {
LOG_ERROR(logCtx_ << "Failed to initialize decryption");
return false;
}
if (EVP_PKEY_CTX_set_rsa_padding(ctx, RSA_PKCS1_OAEP_PADDING) <= 0) {
LOG_ERROR(logCtx_ << "Failed to set RSA padding");
return false;
}
auto inStr_ = reinterpret_cast<unsigned const char*>(in.c_str());
size_t rsaSize;
if (EVP_PKEY_decrypt(ctx, NULL, &rsaSize, inStr_, in.size()) <= 0) {
LOG_ERROR(logCtx_ << "Failed to determine decrypt buffer size");
return false;
}
if (rsaSize != outLen) {
outLen = rsaSize;
out.reset(new unsigned char[outLen]);
}
if (EVP_PKEY_decrypt(ctx, out.get(), &outLen, inStr_, in.size()) <= 0) {
LOG_ERROR(logCtx_ << "Failed to decrypt.");
return false;
}
return true;
}

bool MessageCrypto::rsaEncrypt(EVP_PKEY_CTX* ctx, boost::scoped_array<unsigned char>& in, size_t inLen,
boost::scoped_array<unsigned char>& out, size_t& outLen) {
if (EVP_PKEY_encrypt_init(ctx) <= 0) {
LOG_ERROR(logCtx_ << "Failed to initialize encryption");
return false;
}
if (EVP_PKEY_CTX_set_rsa_padding(ctx, RSA_PKCS1_OAEP_PADDING) <= 0) {
LOG_ERROR(logCtx_ << "Failed to set RSA padding");
return false;
}
size_t rsaSize;
if (EVP_PKEY_encrypt(ctx, NULL, &rsaSize, in.get(), inLen) <= 0) {
LOG_ERROR(logCtx_ << "Failed to determine encrypt buffer size");
return false;
}
if (rsaSize != outLen) {
outLen = rsaSize;
out.reset(new unsigned char[rsaSize]);
}
if (EVP_PKEY_encrypt(ctx, out.get(), &outLen, in.get(), inLen) <= 0) {
LOG_ERROR(logCtx_ << "Failed to encrypt.");
return false;
}
return true;
}

bool MessageCrypto::getDigest(const std::string& keyName, const void* input, unsigned int inputLen,
unsigned char keyDigest[], unsigned int& digestLen) {
if (EVP_DigestInit_ex(mdCtx_, EVP_md5(), NULL) != 1) {
Expand Down Expand Up @@ -181,24 +231,29 @@ Result MessageCrypto::addPublicKeyCipher(const std::string& keyName, const Crypt
return result;
}

RSA* pubKey = loadPublicKey(keyInfo.getKey());
auto* pubKey = loadPublicKey(keyInfo.getKey());
if (pubKey == NULL) {
LOG_ERROR(logCtx_ << "Failed to load public key " << keyName);
return ResultCryptoError;
}
LOG_DEBUG(logCtx_ << " Public key " << keyName << " loaded successfully.");

int inSize = RSA_size(pubKey);
boost::scoped_array<unsigned char> encryptedKey(new unsigned char[inSize]);

int outSize =
RSA_public_encrypt(dataKeyLen_, dataKey_.get(), encryptedKey.get(), pubKey, RSA_PKCS1_OAEP_PADDING);

if (inSize != outSize) {
LOG_ERROR(logCtx_ << "Ciphertext is length not matching input key length for key " << keyName);
boost::scoped_array<unsigned char> encryptedKey{nullptr};
size_t encryptedKeyLen{0};
auto* ctx = EVP_PKEY_CTX_new(pubKey, NULL);
if (!ctx) {
LOG_ERROR(logCtx_ << "Failed to create EVP_PKEY_CTX for " << keyName);
EVP_PKEY_free(pubKey);
return ResultCryptoError;
}
bool encrypted = rsaEncrypt(ctx, dataKey_, dataKeyLen_, encryptedKey, encryptedKeyLen);
EVP_PKEY_CTX_free(ctx);
EVP_PKEY_free(pubKey);
if (!encrypted) {
LOG_ERROR(logCtx_ << "Failed to encrypt with " << keyName);
return ResultCryptoError;
}
std::string encryptedKeyStr(reinterpret_cast<char*>(encryptedKey.get()), inSize);
std::string encryptedKeyStr(reinterpret_cast<char*>(encryptedKey.get()), encryptedKeyLen);
std::shared_ptr<EncryptionKeyInfo> eki(new EncryptionKeyInfo());
eki->setKey(encryptedKeyStr);
eki->setMetadata(keyInfo.getMetadata());
Expand Down Expand Up @@ -353,19 +408,24 @@ bool MessageCrypto::decryptDataKey(const proto::EncryptionKeys& encKeys, const C
keyReader.getPrivateKey(keyName, keyMeta, keyInfo);

// Convert key from string to RSA key
RSA* privKey = loadPrivateKey(keyInfo.getKey());
auto* privKey = loadPrivateKey(keyInfo.getKey());
if (privKey == NULL) {
LOG_ERROR(logCtx_ << " Failed to load private key " << keyName);
return false;
}
LOG_DEBUG(logCtx_ << " Private key " << keyName << " loaded successfully.");

// Decrypt data key
int outSize = RSA_private_decrypt(encryptedDataKey.size(),
reinterpret_cast<unsigned const char*>(encryptedDataKey.c_str()),
dataKey_.get(), privKey, RSA_PKCS1_OAEP_PADDING);

if (outSize == -1) {
auto* ctx = EVP_PKEY_CTX_new(privKey, NULL);
if (!ctx) {
LOG_ERROR(logCtx_ << "Failed to create EVP_PKEY_CTX for " << keyName);
EVP_PKEY_free(privKey);
return false;
}
bool decrypted = rsaDecrypt(ctx, encryptedDataKey, dataKey_, dataKeyLen_);
EVP_PKEY_CTX_free(ctx);
EVP_PKEY_free(privKey);
if (!decrypted) {
LOG_ERROR(logCtx_ << "Failed to decrypt AES key for " << keyName);
return false;
}
Expand Down
10 changes: 7 additions & 3 deletions lib/MessageCrypto.h
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ class MessageCrypto {
typedef std::unique_lock<std::mutex> Lock;
std::mutex mutex_;

int dataKeyLen_;
size_t dataKeyLen_;
boost::scoped_array<unsigned char> dataKey_;

int tagLen_;
Expand All @@ -125,8 +125,12 @@ class MessageCrypto {

EVP_MD_CTX* mdCtx_;

RSA* loadPublicKey(std::string& pubKeyStr);
RSA* loadPrivateKey(std::string& privateKeyStr);
EVP_PKEY* loadPublicKey(std::string& pubKeyStr);
EVP_PKEY* loadPrivateKey(std::string& privateKeyStr);
bool rsaDecrypt(EVP_PKEY_CTX* ctx, const std::string& in, boost::scoped_array<unsigned char>& out,
size_t& outLen);
bool rsaEncrypt(EVP_PKEY_CTX* ctx, boost::scoped_array<unsigned char>& in, size_t inLen,
boost::scoped_array<unsigned char>& out, size_t& outLen);
bool getDigest(const std::string& keyName, const void* input, unsigned int inputLen,
unsigned char keyDigest[], unsigned int& digestLen);
void removeExpiredDataKey();
Expand Down
34 changes: 28 additions & 6 deletions lib/auth/athenz/ZTSClient.cc
Original file line number Diff line number Diff line change
Expand Up @@ -226,9 +226,9 @@ const std::string ZTSClient::getPrincipalToken() const {
const char *unsignedToken = unsignedTokenString.c_str();
unsigned char signature[BUFSIZ] = {};
unsigned char hash[SHA256_DIGEST_LENGTH] = {};
unsigned int siglen;
size_t siglen;
FILE *fp;
RSA *privateKey;
EVP_PKEY *privateKey;

if (privateKeyUri_.scheme == "data") {
if (privateKeyUri_.mediaTypeAndEncodingType != "application/x-pem-file;base64") {
Expand All @@ -249,7 +249,7 @@ const std::string ZTSClient::getPrincipalToken() const {
free(decodeStr);
return "";
}
privateKey = PEM_read_bio_RSAPrivateKey(bio, NULL, NULL, NULL);
privateKey = PEM_read_bio_PrivateKey(bio, NULL, NULL, NULL);
BIO_free(bio);
free(decodeStr);
if (privateKey == NULL) {
Expand All @@ -263,7 +263,7 @@ const std::string ZTSClient::getPrincipalToken() const {
return "";
}

privateKey = PEM_read_RSAPrivateKey(fp, NULL, NULL, NULL);
privateKey = PEM_read_PrivateKey(fp, NULL, NULL, NULL);
fclose(fp);
if (privateKey == NULL) {
LOG_ERROR("Failed to read private key: " << privateKeyUri_.path);
Expand All @@ -275,16 +275,38 @@ const std::string ZTSClient::getPrincipalToken() const {
}

SHA256((unsigned char *)unsignedToken, unsignedTokenString.length(), hash);
RSA_sign(NID_sha256, hash, SHA256_DIGEST_LENGTH, signature, &siglen, privateKey);
auto *ctx = EVP_MD_CTX_new();
if (ctx == NULL) {
LOG_ERROR("Failed to create EVP_MD_CTX.");
return "";
}

bool sign = rsaSign(ctx, privateKey, signature, &siglen, hash);
EVP_MD_CTX_free(ctx);
if (!sign) {
LOG_ERROR("Failed to sign with " << privateKeyUri_.path);
return "";
}

std::string principalToken = unsignedTokenString + ";s=" + ybase64Encode(signature, siglen);
LOG_DEBUG("Created signed principal token: " << principalToken);

RSA_free(privateKey);
EVP_PKEY_free(privateKey);

return principalToken;
}

bool ZTSClient::rsaSign(EVP_MD_CTX *ctx, EVP_PKEY *privateKey, unsigned char *signature, size_t *siglen,
unsigned char *hash) const {
if (EVP_DigestSignInit(ctx, NULL, EVP_sha256(), NULL, privateKey) != 1) {
return false;
}
if (EVP_DigestSign(ctx, signature, siglen, hash, SHA256_DIGEST_LENGTH) != 1) {
return false;
}
return true;
}

std::mutex cacheMtx_;
const std::string ZTSClient::getRoleToken() {
RoleToken roleToken;
Expand Down
3 changes: 3 additions & 0 deletions lib/auth/athenz/ZTSClient.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
* specific language governing permissions and limitations
* under the License.
*/
#include <openssl/evp.h>
#include <pulsar/defines.h>

#include <map>
Expand Down Expand Up @@ -63,6 +64,8 @@ class PULSAR_PUBLIC ZTSClient {
static UriSt parseUri(const char* uri);
static bool checkRequiredParams(std::map<std::string, std::string>& params,
const std::vector<std::string>& requiredParams);
bool rsaSign(EVP_MD_CTX* ctx, EVP_PKEY* privateKey, unsigned char* signature, size_t* siglen,
unsigned char* hash) const;

friend class ZTSClientWrapper;
};
Expand Down