Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 94 additions & 23 deletions src/attestation/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -182,52 +182,64 @@ impl AttestationGenerator {
&self,
input_data: [u8; 64],
) -> Result<AttestationExchangeMessage, AttestationError> {
Ok(AttestationExchangeMessage {
attestation_type: self.attestation_type,
attestation: self.generate_attestation_bytes(input_data).await?,
})
if let Some(url) = &self.attestation_provider_url {
Self::use_attestation_provider(url, self.attestation_type, input_data).await
} else {
Ok(AttestationExchangeMessage {
attestation_type: self.attestation_type,
attestation: self.generate_attestation_bytes(input_data).await?,
})
}
}

/// Generate attestation evidence bytes based on attestation type, with given input data
async fn generate_attestation_bytes(
&self,
input_data: [u8; 64],
) -> Result<Vec<u8>, AttestationError> {
if let Some(url) = &self.attestation_provider_url {
Self::use_attestation_provider(url, input_data).await
} else {
match self.attestation_type {
AttestationType::None => Ok(Vec::new()),
AttestationType::AzureTdx => {
#[cfg(feature = "azure")]
{
Ok(azure::create_azure_attestation(input_data).await?)
}
#[cfg(not(feature = "azure"))]
{
tracing::error!("Attempted to generate an azure attestation but the `azure` feature not enabled");
Err(AttestationError::AttestationTypeNotSupported)
}
match self.attestation_type {
AttestationType::None => Ok(Vec::new()),
AttestationType::AzureTdx => {
#[cfg(feature = "azure")]
{
Ok(azure::create_azure_attestation(input_data).await?)
}
#[cfg(not(feature = "azure"))]
{
tracing::error!("Attempted to generate an azure attestation but the `azure` feature not enabled");
Err(AttestationError::AttestationTypeNotSupported)
}
_ => dcap::create_dcap_attestation(input_data).await,
}
_ => dcap::create_dcap_attestation(input_data).await,
}
}

/// Generate an attestation by using an external service for the attestation generation
async fn use_attestation_provider(
url: &str,
attestation_type: AttestationType,
input_data: [u8; 64],
) -> Result<Vec<u8>, AttestationError> {
) -> Result<AttestationExchangeMessage, AttestationError> {
let url = format!("{}/attest/{}", url, hex::encode(input_data));

Ok(reqwest::get(url)
let response = reqwest::get(url)
.await
.map_err(|err| AttestationError::AttestationProvider(err.to_string()))?
.bytes()
.await
.map_err(|err| AttestationError::AttestationProvider(err.to_string()))?
.to_vec())
.to_vec();

// If the response is not already wrapped in an attestation exchange message, wrap it in
// one
if let Ok(message) = AttestationExchangeMessage::decode(&mut &response[..]) {
Ok(message)
} else {
Ok(AttestationExchangeMessage {
attestation_type,
attestation: response,
})
}
}
}

Expand Down Expand Up @@ -454,6 +466,30 @@ pub enum AttestationError {
#[cfg(test)]
mod tests {
use super::*;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpListener;

async fn spawn_test_attestation_provider_server(body: Vec<u8>) -> std::net::SocketAddr {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();

tokio::spawn(async move {
if let Ok((mut socket, _)) = listener.accept().await {
let mut buf = [0u8; 1024];
let _ = socket.read(&mut buf).await;

let response = format!(
"HTTP/1.1 200 OK\r\nContent-Length: {}\r\nConnection: close\r\n\r\n",
body.len()
);
let _ = socket.write_all(response.as_bytes()).await;
let _ = socket.write_all(&body).await;
let _ = socket.shutdown().await;
}
});

addr
}

#[tokio::test]
async fn attestation_detection_does_not_panic() {
Expand All @@ -465,4 +501,39 @@ mod tests {
async fn running_on_gcp_check_does_not_panic() {
let _ = running_on_gcp().await;
}

#[tokio::test]
async fn attestation_provider_response_is_wrapped_if_needed() {
let input_data = [0u8; 64];

let encoded_message = AttestationExchangeMessage {
attestation_type: AttestationType::None,
attestation: vec![1, 2, 3],
}
.encode();

let encoded_addr = spawn_test_attestation_provider_server(encoded_message).await;
let encoded_url = format!("http://{encoded_addr}");
let decoded = AttestationGenerator::use_attestation_provider(
&encoded_url,
AttestationType::GcpTdx,
input_data,
)
.await
.unwrap();
assert_eq!(decoded.attestation_type, AttestationType::None);
assert_eq!(decoded.attestation, vec![1, 2, 3]);

let raw_addr = spawn_test_attestation_provider_server(vec![9, 8]).await;
let raw_url = format!("http://{raw_addr}");
let wrapped = AttestationGenerator::use_attestation_provider(
&raw_url,
AttestationType::DcapTdx,
input_data,
)
.await
.unwrap();
assert_eq!(wrapped.attestation_type, AttestationType::DcapTdx);
assert_eq!(wrapped.attestation, vec![9, 8]);
}
}