Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 65 additions & 22 deletions pgdog/src/backend/auth/azure_workload_identity.rs
Original file line number Diff line number Diff line change
@@ -1,33 +1,43 @@
use crate::backend::{pool::Address, Error};
use std::time::SystemTime;

use azure_core::credentials::TokenCredential;
use azure_identity::WorkloadIdentityCredential;

use super::token_cache;
use crate::backend::{pool::Address, Error};

pub async fn token(addr: &Address) -> Result<String, Error> {
#[cfg(test)]
if let Some(token) = test_token_override() {
return Ok(token);
}

token_cache::get_or_fetch(addr, fetch_token).await
}

async fn fetch_token(addr: Address) -> Result<(String, SystemTime), Error> {
let credential = WorkloadIdentityCredential::new(None).map_err(|error| {
Error::AzureWorkloadIdentityToken(format!(
"failed to build workload identity credential for {}@{}:{}: {}",
addr.user, addr.host, addr.port, error
))
})?;

credential
let access_token = credential
.get_token(
&["https://ossrdbms-aad.database.windows.net/.default"],
None,
)
.await
.map(|token| token.token.secret().to_string())
.map_err(|error| {
Error::AzureWorkloadIdentityToken(format!(
"failed to get Azure AD token for {}@{}:{}: {}",
addr.user, addr.host, addr.port, error
))
})
})?;

let expires_at = SystemTime::from(access_token.expires_on);
Ok((access_token.token.secret().to_string(), expires_at))
}

#[cfg(test)]
Expand All @@ -46,13 +56,15 @@ static TEST_TOKEN_OVERRIDE: once_cell::sync::Lazy<parking_lot::Mutex<Option<Stri

#[cfg(test)]
mod tests {
use crate::backend::pool::Address;
use crate::config::ServerAuth;
use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine as _};
use pgdog_config::Role;
use std::env;
use std::time::{Duration, SystemTime};

use super::*;
use crate::backend::pool::Address;
use crate::config::ServerAuth;
use token_cache::{CacheKey, CachedToken};

struct EnvVarGuard {
key: &'static str,
Expand All @@ -69,22 +81,15 @@ mod tests {

impl Drop for EnvVarGuard {
fn drop(&mut self) {
if let Some(previous) = self.previous.take() {
env::set_var(self.key, previous);
} else {
env::remove_var(self.key);
match self.previous.take() {
Some(v) => env::set_var(self.key, v),
None => env::remove_var(self.key),
}
}
}

#[tokio::test]
#[ignore = "requires AKS environment with Workload Identity injection"]
async fn test_token_contains_expected_query_fields() {
let _azure_client_id = EnvVarGuard::set("AZURE_CLIENT_ID", "EXAMPLE");
let _azure_tenant_id = EnvVarGuard::set("AZURE_TENANT_ID", "EXAMPLE");
let _azure_token_file_path = EnvVarGuard::set("AZURE_FEDERATED_TOKEN_FILE", "/tmp/example");

let addr = Address {
fn make_addr() -> Address {
Address {
host: "my-awesome-db.postgres.database.azure.com".into(),
port: 5432,
database_name: "postgres".into(),
Expand All @@ -94,17 +99,55 @@ mod tests {
server_auth: ServerAuth::AzureWorkloadIdentity,
server_iam_region: None,
configured_role: Role::Auto,
};
}
}

#[test]
fn token_override_bypasses_cache() {
set_test_token_override(Some("override-token".into()));
let result = tokio::runtime::Runtime::new()
.unwrap()
.block_on(token(&make_addr()))
.unwrap();
assert_eq!(result, "override-token");
set_test_token_override(None);
}

let b64_token = token(&addr).await.unwrap();
#[test]
fn cache_returns_same_token_on_second_call() {
let addr = make_addr();
let key = CacheKey::from(&addr);
let sentinel = "cached-sentinel-token".to_string();
token_cache::insert_test_token(
key,
CachedToken::new(
sentinel.clone(),
SystemTime::now() + Duration::from_secs(3600),
),
);

let result = tokio::runtime::Runtime::new()
.unwrap()
.block_on(token(&addr))
.unwrap();

assert_eq!(result, sentinel);
}

#[tokio::test]
#[ignore = "requires AKS environment with Workload Identity injection"]
async fn test_token_contains_expected_query_fields() {
let _azure_client_id = EnvVarGuard::set("AZURE_CLIENT_ID", "EXAMPLE");
let _azure_tenant_id = EnvVarGuard::set("AZURE_TENANT_ID", "EXAMPLE");
let _azure_token_file_path = EnvVarGuard::set("AZURE_FEDERATED_TOKEN_FILE", "/tmp/example");

// Use functional chaining to extract and decode
let b64_token = token(&make_addr()).await.unwrap();
let token = b64_token
.split('.')
.nth(1)
.map(|payload| URL_SAFE_NO_PAD.decode(payload))
.transpose()
.expect("Invalid JWT format") // Converts Option<Result<T, E>> to Result<Option<T>, E>
.expect("Invalid JWT format")
.and_then(|bytes| String::from_utf8(bytes).ok())
.expect("Failed to parse JWT payload as valid UTF-8 JSON");

Expand Down
1 change: 1 addition & 0 deletions pgdog/src/backend/auth/mod.rs
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
pub mod azure_workload_identity;
pub mod rds_iam;
pub mod token_cache;
90 changes: 73 additions & 17 deletions pgdog/src/backend/auth/rds_iam.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,20 @@
use std::time::{Duration, SystemTime};

use aws_config::{BehaviorVersion, Region};
use aws_sdk_rds::auth_token::{AuthTokenGenerator, Config as AuthTokenConfig};

use super::token_cache;
use crate::backend::{pool::Address, Error};

pub async fn token(addr: &Address) -> Result<String, Error> {
#[cfg(test)]
if let Some(token) = test_token_override() {
return Ok(token);
}

token_cache::get_or_fetch(addr, fetch_token).await
}

fn infer_region_from_rds_host(host: &str) -> Option<String> {
let host = host.to_ascii_lowercase();
let labels = host.split('.').collect::<Vec<_>>();
Expand Down Expand Up @@ -43,13 +55,8 @@ fn resolve_region(addr: &Address) -> Result<String, Error> {
})
}

pub async fn token(addr: &Address) -> Result<String, Error> {
#[cfg(test)]
if let Some(token) = test_token_override() {
return Ok(token);
}

let region = resolve_region(addr)?;
async fn fetch_token(addr: Address) -> Result<(String, SystemTime), Error> {
let region = resolve_region(&addr)?;
let sdk_config = aws_config::load_defaults(BehaviorVersion::latest()).await;

let config = AuthTokenConfig::builder()
Expand All @@ -65,7 +72,7 @@ pub async fn token(addr: &Address) -> Result<String, Error> {
))
})?;

AuthTokenGenerator::new(config)
let token = AuthTokenGenerator::new(config)
.auth_token(&sdk_config)
.await
.map(|token| token.to_string())
Expand All @@ -74,7 +81,11 @@ pub async fn token(addr: &Address) -> Result<String, Error> {
"failed to generate RDS IAM token for {}@{}:{} in region {}: {}",
addr.user, addr.host, addr.port, region, error
))
})
})?;

// RDS IAM tokens are valid for 15 minutes
let expires_at = SystemTime::now() + Duration::from_secs(900);
Ok((token, expires_at))
}

#[cfg(test)]
Expand All @@ -93,14 +104,14 @@ static TEST_TOKEN_OVERRIDE: once_cell::sync::Lazy<parking_lot::Mutex<Option<Stri

#[cfg(test)]
mod tests {
use std::env;

use pgdog_config::Role;
use std::env;
use std::time::{Duration, SystemTime};

use super::*;
use crate::backend::pool::Address;
use crate::config::ServerAuth;

use super::*;
use token_cache::{CacheKey, CachedToken};

struct EnvVarGuard {
key: &'static str,
Expand All @@ -117,14 +128,27 @@ mod tests {

impl Drop for EnvVarGuard {
fn drop(&mut self) {
if let Some(previous) = self.previous.take() {
env::set_var(self.key, previous);
} else {
env::remove_var(self.key);
match self.previous.take() {
Some(v) => env::set_var(self.key, v),
None => env::remove_var(self.key),
}
}
}

fn make_addr() -> Address {
Address {
host: "db.cluster-abc123.us-east-1.rds.amazonaws.com".into(),
port: 5432,
database_name: "postgres".into(),
user: "db_user".into(),
passwords: vec![String::new()],
database_number: 0,
server_auth: ServerAuth::RdsIam,
server_iam_region: Some("us-east-1".into()),
configured_role: Role::Auto,
}
}

#[test]
fn test_infer_region_commercial_endpoint() {
let region = infer_region_from_rds_host("db.cluster-abc123.us-east-1.rds.amazonaws.com");
Expand All @@ -144,6 +168,38 @@ mod tests {
assert!(region.is_none());
}

#[test]
fn token_override_bypasses_cache() {
set_test_token_override(Some("override-token".into()));
let result = tokio::runtime::Runtime::new()
.unwrap()
.block_on(token(&make_addr()))
.unwrap();
assert_eq!(result, "override-token");
set_test_token_override(None);
}

#[test]
fn cache_returns_same_token_on_second_call() {
let addr = make_addr();
let key = CacheKey::from(&addr);
let sentinel = "cached-sentinel-token".to_string();
token_cache::insert_test_token(
key,
CachedToken::new(
sentinel.clone(),
SystemTime::now() + Duration::from_secs(3600),
),
);

let result = tokio::runtime::Runtime::new()
.unwrap()
.block_on(token(&addr))
.unwrap();

assert_eq!(result, sentinel);
}

#[tokio::test]
async fn test_token_contains_expected_query_fields() {
let _access_key = EnvVarGuard::set("AWS_ACCESS_KEY_ID", "AKIDEXAMPLE");
Expand Down
Loading
Loading