Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 28 additions & 4 deletions services/azure-storage/src/provide_credential/client_secret.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ use std::time::Duration;
pub struct ClientSecretCredentialProvider {
tenant_id: Option<String>,
client_id: Option<String>,
client_secret: Option<String>,
authority_host: Option<String>,
}

impl ClientSecretCredentialProvider {
Expand All @@ -50,6 +52,18 @@ impl ClientSecretCredentialProvider {
self.client_id = Some(client_id.into());
self
}

/// Set the client secret.
pub fn with_client_secret(mut self, client_secret: impl Into<String>) -> Self {
self.client_secret = Some(client_secret.into());
self
}

/// Set the authority host.
pub fn with_authority_host(mut self, authority_host: impl Into<String>) -> Self {
self.authority_host = Some(authority_host.into());
self
}
}
impl ProvideCredential for ClientSecretCredentialProvider {
type Credential = Credential;
Expand All @@ -67,18 +81,28 @@ impl ProvideCredential for ClientSecretCredentialProvider {
_ => return Ok(None),
};

let client_id = match envs.get("AZURE_CLIENT_ID") {
let client_id = match self
.client_id
.as_ref()
.or_else(|| envs.get("AZURE_CLIENT_ID"))
{
Some(id) if !id.is_empty() => id,
_ => return Ok(None),
};

let client_secret = match envs.get("AZURE_CLIENT_SECRET") {
let client_secret = match self
.client_secret
.as_ref()
.or_else(|| envs.get("AZURE_CLIENT_SECRET"))
{
Some(secret) if !secret.is_empty() => secret,
_ => return Ok(None),
};

let authority_host = envs
.get("AZURE_AUTHORITY_HOST")
let authority_host = self
.authority_host
.as_ref()
.or_else(|| envs.get("AZURE_AUTHORITY_HOST"))
.filter(|h| !h.is_empty())
.map(|s| s.as_str())
.unwrap_or("https://login.microsoftonline.com");
Expand Down
45 changes: 40 additions & 5 deletions services/azure-storage/src/provide_credential/workload_identity.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ use std::time::Duration;
#[derive(Debug, Default, Clone)]
pub struct WorkloadIdentityCredentialProvider {
tenant_id: Option<String>,
client_id: Option<String>,
federated_token_file: Option<String>,
authority_host: Option<String>,
}

impl WorkloadIdentityCredentialProvider {
Expand All @@ -43,6 +46,24 @@ impl WorkloadIdentityCredentialProvider {
self.tenant_id = Some(tenant_id.into());
self
}

/// Set the client ID.
pub fn with_client_id(mut self, client_id: impl Into<String>) -> Self {
self.client_id = Some(client_id.into());
self
}

/// Set the federated token file path.
pub fn with_federated_token_file(mut self, path: impl Into<String>) -> Self {
self.federated_token_file = Some(path.into());
self
}

/// Set the authority host.
pub fn with_authority_host(mut self, authority_host: impl Into<String>) -> Self {
self.authority_host = Some(authority_host.into());
self
}
}
impl ProvideCredential for WorkloadIdentityCredentialProvider {
type Credential = Credential;
Expand All @@ -51,23 +72,37 @@ impl ProvideCredential for WorkloadIdentityCredentialProvider {
let envs = ctx.env_vars();

// Check if all required parameters are available from environment
let tenant_id = match envs.get("AZURE_TENANT_ID") {
let tenant_id = match self
.tenant_id
.as_ref()
.or_else(|| envs.get("AZURE_TENANT_ID"))
{
Some(id) if !id.is_empty() => id,
_ => return Ok(None),
};

let client_id = match envs.get("AZURE_CLIENT_ID") {
let client_id = match self
.client_id
.as_ref()
.or_else(|| envs.get("AZURE_CLIENT_ID"))
{
Some(id) if !id.is_empty() => id,
_ => return Ok(None),
};

let federated_token_file = match envs.get("AZURE_FEDERATED_TOKEN_FILE") {
let federated_token_file = match self
.federated_token_file
.as_ref()
.or_else(|| envs.get("AZURE_FEDERATED_TOKEN_FILE"))
{
Some(file) if !file.is_empty() => file,
_ => return Ok(None),
};

let authority_host = envs
.get("AZURE_AUTHORITY_HOST")
let authority_host = self
.authority_host
.as_ref()
.or_else(|| envs.get("AZURE_AUTHORITY_HOST"))
.filter(|h| !h.is_empty())
.map(|s| s.as_str())
.unwrap_or("https://login.microsoftonline.com");
Expand Down
Loading