Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
137 changes: 95 additions & 42 deletions lib/auth/athenz/ZTSClient.cc
Original file line number Diff line number Diff line change
Expand Up @@ -69,50 +69,64 @@ namespace pulsar {
const static std::string DEFAULT_PRINCIPAL_HEADER = "Athenz-Principal-Auth";
const static std::string DEFAULT_ROLE_HEADER = "Athenz-Role-Auth";
const static int REQUEST_TIMEOUT = 30000;
const static int DEFAULT_TOKEN_EXPIRATION_TIME_SEC = 3600;
const static int MIN_TOKEN_EXPIRATION_TIME_SEC = 900;
const static int PRINCIPAL_TOKEN_EXPIRATION_TIME_SEC = 3600;
const static int ROLE_TOKEN_EXPIRATION_MIN_TIME_SEC = 7200;
const static int ROLE_TOKEN_EXPIRATION_MAX_TIME_SEC = 86400;
const static int MAX_HTTP_REDIRECTS = 20;
const static long long FETCH_EPSILON = 60; // if cache expires in 60 seconds, get it from ZTS
const static std::string requiredParams[] = {"tenantDomain", "tenantService", "providerDomain", "privateKey",
"ztsUrl"};

std::map<std::string, RoleToken> ZTSClient::roleTokenCache_;
const static std::string TENANT_DOMAIN = "tenantDomain";
const static std::string TENANT_SERVICE = "tenantService";
const static std::string PROVIDER_DOMAIN = "providerDomain";
const static std::string PRIVATE_KEY = "privateKey";
const static std::string ZTS_URL = "ztsUrl";
const static std::string KEY_ID = "keyId";
const static std::string PRINCIPAL_HEADER = "principalHeader";
const static std::string ROLE_HEADER = "roleHeader";
const static std::string X509_CERT_CHAIN = "x509CertChain";
const static std::string CA_CERT = "caCert";

ZTSClient::ZTSClient(std::map<std::string, std::string> &params) {
// required parameter check
bool valid = true;
for (int i = 0; i < sizeof(requiredParams) / sizeof(std::string); i++) {
if (params.find(requiredParams[i]) == params.end()) {
valid = false;
LOG_ERROR(requiredParams[i] << " parameter is required");
}
std::vector<std::string> requiredParams;
requiredParams.push_back(PROVIDER_DOMAIN);
requiredParams.push_back(PRIVATE_KEY);
requiredParams.push_back(ZTS_URL);
if (params.find(X509_CERT_CHAIN) != params.end()) {
// use Copper Argos
enableX509CertChain_ = true;
} else {
requiredParams.push_back(TENANT_DOMAIN);
requiredParams.push_back(TENANT_SERVICE);
}

if (!valid) {
if (!checkRequiredParams(params, requiredParams)) {
LOG_ERROR("Some parameters are missing")
return;
}

// set required value
tenantDomain_ = params[requiredParams[0]];
tenantService_ = params[requiredParams[1]];
providerDomain_ = params[requiredParams[2]];
privateKeyUri_ = parseUri(params[requiredParams[3]].c_str());
ztsUrl_ = params[requiredParams[4]];
providerDomain_ = params[PROVIDER_DOMAIN];
privateKeyUri_ = parseUri(params[PRIVATE_KEY].c_str());
ztsUrl_ = params[ZTS_URL];

// set optional value
keyId_ = params.find("keyId") == params.end() ? "0" : params["keyId"];
principalHeader_ =
params.find("principalHeader") == params.end() ? DEFAULT_PRINCIPAL_HEADER : params["principalHeader"];
roleHeader_ = params.find("roleHeader") == params.end() ? DEFAULT_ROLE_HEADER : params["roleHeader"];
tokenExpirationTime_ = DEFAULT_TOKEN_EXPIRATION_TIME_SEC;
if (params.find("tokenExpirationTime") != params.end()) {
tokenExpirationTime_ = std::stoi(params["tokenExpirationTime"]);
if (tokenExpirationTime_ < MIN_TOKEN_EXPIRATION_TIME_SEC) {
LOG_WARN(tokenExpirationTime_ << " is too small as a token expiration time. "
<< MIN_TOKEN_EXPIRATION_TIME_SEC << " is set instead of it.");
tokenExpirationTime_ = MIN_TOKEN_EXPIRATION_TIME_SEC;
}
roleHeader_ = params.find(ROLE_HEADER) == params.end() ? DEFAULT_ROLE_HEADER : params[ROLE_HEADER];
if (params.find(CA_CERT) != params.end()) {
caCert_ = parseUri(params[CA_CERT].c_str());
}

if (enableX509CertChain_) {
// set required value
x509CertChain_ = parseUri(params[X509_CERT_CHAIN].c_str());
} else {
// set required value
tenantDomain_ = params[TENANT_DOMAIN];
tenantService_ = params[TENANT_SERVICE];

// set optional value
keyId_ = params.find(KEY_ID) == params.end() ? "0" : params[KEY_ID];
principalHeader_ = params.find(PRINCIPAL_HEADER) == params.end() ? DEFAULT_PRINCIPAL_HEADER
: params[PRINCIPAL_HEADER];
}

if (*(--ztsUrl_.end()) == '/') {
Expand Down Expand Up @@ -205,7 +219,7 @@ const std::string ZTSClient::getPrincipalToken() const {
unsignedTokenString += ";h=" + std::string(host);
unsignedTokenString += ";a=" + getSalt();
unsignedTokenString += ";t=" + std::to_string(t);
unsignedTokenString += ";e=" + std::to_string(t + tokenExpirationTime_);
unsignedTokenString += ";e=" + std::to_string(t + PRINCIPAL_TOKEN_EXPIRATION_TIME_SEC);
unsignedTokenString += ";k=" + keyId_;

LOG_DEBUG("Created unsigned principal token: " << unsignedTokenString);
Expand Down Expand Up @@ -258,7 +272,7 @@ const std::string ZTSClient::getPrincipalToken() const {
return "";
}
} else {
LOG_ERROR("Unsupported URI Scheme: " << privateKeyUri_.scheme);
LOG_ERROR("URI scheme not supported in privateKey: " << privateKeyUri_.scheme);
return "";
}

Expand All @@ -278,15 +292,14 @@ static size_t curlWriteCallback(void *contents, size_t size, size_t nmemb, void
return size * nmemb;
}

static std::mutex cacheMtx_;
const std::string ZTSClient::getRoleToken() const {
std::mutex cacheMtx_;
const std::string ZTSClient::getRoleToken() {
RoleToken roleToken;
std::string cacheKey = "p=" + tenantDomain_ + "." + tenantService_ + ";d=" + providerDomain_;

// locked block
{
std::lock_guard<std::mutex> lock(cacheMtx_);
roleToken = roleTokenCache_[cacheKey];
roleToken = roleTokenCache_;
}

if (!roleToken.token.empty() && roleToken.expiryTime > (long long)time(NULL) + FETCH_EPSILON) {
Expand All @@ -295,6 +308,8 @@ const std::string ZTSClient::getRoleToken() const {
}

std::string completeUrl = ztsUrl_ + "/zts/v1/domain/" + providerDomain_ + "/token";
completeUrl += "?minExpiryTime=" + std::to_string(ROLE_TOKEN_EXPIRATION_MIN_TIME_SEC);
completeUrl += "&maxExpiryTime=" + std::to_string(ROLE_TOKEN_EXPIRATION_MAX_TIME_SEC);

CURL *handle;
CURLcode res;
Expand Down Expand Up @@ -326,10 +341,31 @@ const std::string ZTSClient::getRoleToken() const {
// Fail if HTTP return code >= 400
curl_easy_setopt(handle, CURLOPT_FAILONERROR, 1L);

if (!caCert_.scheme.empty()) {
if (caCert_.scheme == "file") {
curl_easy_setopt(handle, CURLOPT_CAINFO, caCert_.path.c_str());
} else {
LOG_ERROR("URI scheme not supported in caCert: " << caCert_.scheme);
}
}

struct curl_slist *list = NULL;
std::string httpHeader = principalHeader_ + ": " + getPrincipalToken();
list = curl_slist_append(list, httpHeader.c_str());
curl_easy_setopt(handle, CURLOPT_HTTPHEADER, list);
if (enableX509CertChain_) {
if (x509CertChain_.scheme == "file") {
curl_easy_setopt(handle, CURLOPT_SSLCERT, x509CertChain_.path.c_str());
} else {
LOG_ERROR("URI scheme not supported in x509CertChain: " << x509CertChain_.scheme);
}
if (privateKeyUri_.scheme == "file") {
curl_easy_setopt(handle, CURLOPT_SSLKEY, privateKeyUri_.path.c_str());
} else {
LOG_ERROR("URI scheme not supported in privateKey: " << privateKeyUri_.scheme);
}
} else {
std::string httpHeader = principalHeader_ + ": " + getPrincipalToken();
list = curl_slist_append(list, httpHeader.c_str());
curl_easy_setopt(handle, CURLOPT_HTTPHEADER, list);
}

// Make get call to server
res = curl_easy_perform(handle);
Expand Down Expand Up @@ -357,7 +393,7 @@ const std::string ZTSClient::getRoleToken() const {
roleToken.token = root.get<std::string>("token");
roleToken.expiryTime = root.get<uint32_t>("expiryTime");
std::lock_guard<std::mutex> lock(cacheMtx_);
roleTokenCache_[cacheKey] = roleToken;
roleTokenCache_ = roleToken;
LOG_DEBUG("Got role token " << roleToken.token)
} else {
LOG_ERROR("Response failed for url " << completeUrl << ". response Code " << response_code)
Expand All @@ -374,8 +410,8 @@ const std::string ZTSClient::getRoleToken() const {

const std::string ZTSClient::getHeader() const { return roleHeader_; }

PrivateKeyUri ZTSClient::parseUri(const char *uri) {
PrivateKeyUri uriSt;
UriSt ZTSClient::parseUri(const char *uri) {
UriSt uriSt;
// scheme mediatype[;base64] path file
static const PULSAR_REGEX_NAMESPACE::regex expression(
R"(^(?:([A-Za-z]+):)(?:([/\w\-]+;\w+),([=\w]+))?(?:\/\/)?([^?#]+)?)");
Expand All @@ -385,7 +421,24 @@ PrivateKeyUri ZTSClient::parseUri(const char *uri) {
uriSt.mediaTypeAndEncodingType = groups.str(2);
uriSt.data = groups.str(3);
uriSt.path = groups.str(4);
} else {
// consider a file path specified instead of a URI
uriSt.scheme = "file";
uriSt.path = std::string(uri);
}
return uriSt;
}

bool ZTSClient::checkRequiredParams(std::map<std::string, std::string> &params,
const std::vector<std::string> &requiredParams) {
bool valid = true;
for (int i = 0; i < requiredParams.size(); i++) {
if (params.find(requiredParams[i]) == params.end()) {
valid = false;
LOG_ERROR(requiredParams[i] << " parameter is required");
}
}

return valid;
}
} // namespace pulsar
17 changes: 11 additions & 6 deletions lib/auth/athenz/ZTSClient.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

#include <map>
#include <string>
#include <vector>

namespace pulsar {

Expand All @@ -28,7 +29,7 @@ struct RoleToken {
long long expiryTime;
};

struct PrivateKeyUri {
struct UriSt {
std::string scheme;
std::string mediaTypeAndEncodingType;
std::string data;
Expand All @@ -38,26 +39,30 @@ struct PrivateKeyUri {
class PULSAR_PUBLIC ZTSClient {
public:
ZTSClient(std::map<std::string, std::string>& params);
const std::string getRoleToken() const;
const std::string getRoleToken();
const std::string getHeader() const;
~ZTSClient();

private:
std::string tenantDomain_;
std::string tenantService_;
std::string providerDomain_;
PrivateKeyUri privateKeyUri_;
UriSt privateKeyUri_;
std::string ztsUrl_;
std::string keyId_;
UriSt x509CertChain_;
UriSt caCert_;
std::string principalHeader_;
std::string roleHeader_;
int tokenExpirationTime_;
static std::map<std::string, RoleToken> roleTokenCache_;
RoleToken roleTokenCache_;
bool enableX509CertChain_ = false;
static std::string getSalt();
static std::string ybase64Encode(const unsigned char* input, int length);
static char* base64Decode(const char* input);
const std::string getPrincipalToken() const;
static PrivateKeyUri parseUri(const char* uri);
static UriSt parseUri(const char* uri);
static bool checkRequiredParams(std::map<std::string, std::string>& params,
const std::vector<std::string>& requiredParams);

friend class ZTSClientWrapper;
};
Expand Down
31 changes: 20 additions & 11 deletions tests/ZTSClientTest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,54 +26,63 @@ namespace pulsar {

class ZTSClientWrapper {
public:
static PrivateKeyUri parseUri(const char* uri) { return ZTSClient::parseUri(uri); }
static UriSt parseUri(const char* uri) { return ZTSClient::parseUri(uri); }
};
} // namespace pulsar

TEST(ZTSClientTest, testZTSClient) {
{
PrivateKeyUri uri = ZTSClientWrapper::parseUri("file:/path/to/private.key");
UriSt uri = ZTSClientWrapper::parseUri("file:/path/to/private.key");
ASSERT_EQ("file", uri.scheme);
ASSERT_EQ("/path/to/private.key", uri.path);
ASSERT_EQ("", uri.mediaTypeAndEncodingType);
ASSERT_EQ("", uri.data);
}

{
PrivateKeyUri uri = ZTSClientWrapper::parseUri("file:///path/to/private.key");
UriSt uri = ZTSClientWrapper::parseUri("file:///path/to/private.key");
ASSERT_EQ("file", uri.scheme);
ASSERT_EQ("/path/to/private.key", uri.path);
ASSERT_EQ("", uri.mediaTypeAndEncodingType);
ASSERT_EQ("", uri.data);
}

{
PrivateKeyUri uri = ZTSClientWrapper::parseUri("file:./path/to/private.key");
UriSt uri = ZTSClientWrapper::parseUri("file:./path/to/private.key");
ASSERT_EQ("file", uri.scheme);
ASSERT_EQ("./path/to/private.key", uri.path);
ASSERT_EQ("", uri.mediaTypeAndEncodingType);
ASSERT_EQ("", uri.data);
}

{
PrivateKeyUri uri = ZTSClientWrapper::parseUri("file://./path/to/private.key");
UriSt uri = ZTSClientWrapper::parseUri("file://./path/to/private.key");
ASSERT_EQ("file", uri.scheme);
ASSERT_EQ("./path/to/private.key", uri.path);
ASSERT_EQ("", uri.mediaTypeAndEncodingType);
ASSERT_EQ("", uri.data);
}

{
PrivateKeyUri uri = ZTSClientWrapper::parseUri("data:application/x-pem-file;base64,SGVsbG8gV29ybGQK");
UriSt uri = ZTSClientWrapper::parseUri("data:application/x-pem-file;base64,SGVsbG8gV29ybGQK");
ASSERT_EQ("data", uri.scheme);
ASSERT_EQ("", uri.path);
ASSERT_EQ("application/x-pem-file;base64", uri.mediaTypeAndEncodingType);
ASSERT_EQ("SGVsbG8gV29ybGQK", uri.data);
}

{
PrivateKeyUri uri = ZTSClientWrapper::parseUri("");
ASSERT_EQ("", uri.scheme);
UriSt uri = ZTSClientWrapper::parseUri("");
ASSERT_EQ("file", uri.scheme);
ASSERT_EQ("", uri.path);
ASSERT_EQ("", uri.mediaTypeAndEncodingType);
ASSERT_EQ("", uri.data);
}

{
PrivateKeyUri uri = ZTSClientWrapper::parseUri("/path/to/private.key");
ASSERT_EQ("", uri.scheme);
ASSERT_EQ("", uri.path);
UriSt uri = ZTSClientWrapper::parseUri("/path/to/private.key");
ASSERT_EQ("file", uri.scheme);
ASSERT_EQ("/path/to/private.key", uri.path);
ASSERT_EQ("", uri.mediaTypeAndEncodingType);
ASSERT_EQ("", uri.data);
}
Expand Down