Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 5 additions & 6 deletions shell_wrapper/kahe.cc
Original file line number Diff line number Diff line change
Expand Up @@ -322,12 +322,6 @@ FfiStatus PackMessagesRaw(rust::Slice<const uint64_t> messages,
return MakeFfiStatus(absl::InvalidArgumentError(
secure_aggregation::kNullPointerErrorMessage));
}

// Allocate the vector for output packed values if needed.
if (packed_values->ptr == nullptr) {
packed_values->ptr =
std::make_unique<std::vector<secure_aggregation::BigInteger>>();
}
auto curr_packed_values =
rlwe::PackMessagesFlat<secure_aggregation::Integer,
secure_aggregation::BigInteger>(
Expand All @@ -339,6 +333,11 @@ FfiStatus PackMessagesRaw(rust::Slice<const uint64_t> messages,
}
// Pad with zeros if needed.
curr_packed_values.resize(num_packed_values, 0);
// Allocate the vector for output packed values if needed.
if (packed_values->ptr == nullptr) {
packed_values->ptr =
std::make_unique<std::vector<secure_aggregation::BigInteger>>();
}
// Append the packed values to the end of the output vector.
packed_values->ptr->insert(packed_values->ptr->end(),
curr_packed_values.begin(),
Expand Down
18 changes: 13 additions & 5 deletions shell_wrapper/kahe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ use single_thread_hkdf::{SeedWrapper, SingleThreadHkdfWrapper};
use status::rust_status_from_cpp;
use std::collections::HashMap;
use std::marker::PhantomData;
use std::mem;
use std::mem::MaybeUninit;

#[derive(Debug, PartialEq, Clone)]
Expand All @@ -37,7 +38,7 @@ mod ffi {
}

pub struct BigIntVectorWrapper {
pub ptr: UniquePtr<CxxVector<BigInteger>>,
ptr: UniquePtr<CxxVector<BigInteger>>,
}

unsafe extern "C++" {
Expand Down Expand Up @@ -190,7 +191,10 @@ pub fn encrypt(
params: &KahePublicParametersWrapper,
prng: &mut SingleThreadHkdfWrapper,
) -> Result<RnsPolynomialVec, status::StatusError> {
let mut packed_values = MaybeUninit::<BigIntVectorWrapper>::zeroed();
// SAFETY: this initializes `packed_values` with packed_values.ptr == nullptr. The following
// loop ensures that we either return an error (and drop `packed_values`), or make
// packed_values.ptr point to a valid C++ vector.
let mut packed_values: BigIntVectorWrapper = unsafe { mem::zeroed() };
// SAFETY: No lifetime constraints (`PackMessagesRaw` may create a new vector of BigIntegers
// wrapped by `packed_values` which does not keep any reference to the inputs).
// `PackMessagesRaw` only appends to the C++ vector wrapped by `packed_values`,
Expand All @@ -206,18 +210,19 @@ pub fn encrypt(
packed_vector_config.base,
packed_vector_config.dimension,
packed_vector_config.num_packed_coeffs,
packed_values.as_mut_ptr(),
&mut packed_values,
)
})?;
}
// SAFETY: `packed_values` is safely initialized if we get to this point.

let mut out = MaybeUninit::<RnsPolynomialVec>::zeroed();
// SAFETY: No lifetime constraints (`Encrypt` creates a new vector of polynomials wrapped by
// `out` which does not keep any reference to the inputs). `Encrypt` reads the C++ vector
// wrapped by `packed_values`, updates the states wrapped by `prng`, and writes into the C++
// vector wrapped by `out`.
rust_status_from_cpp(unsafe {
ffi::Encrypt(&packed_values.assume_init(), secret_key, params, prng, out.as_mut_ptr())
ffi::Encrypt(&packed_values, secret_key, params, prng, out.as_mut_ptr())
})?;
// SAFETY: `out` is safely initialized if we get to this point.
Ok(unsafe { out.assume_init() })
Expand All @@ -239,6 +244,9 @@ pub fn decrypt(
ffi::Decrypt(ciphertext, secret_key, params, packed_values.as_mut_ptr())
})?;

// SAFETY: `packed_values` is safely initialized if we get to this point.
let mut packed_values = unsafe { packed_values.assume_init() };

let mut output_vectors = HashMap::<String, Vec<u64>>::new();
// Assume the packed values are stored in the same order as the configs.
for (id, packed_vector_config) in packed_vector_configs.iter() {
Expand All @@ -253,7 +261,7 @@ pub fn decrypt(
packed_vector_config.base,
packed_vector_config.dimension,
packed_vector_config.num_packed_coeffs,
packed_values.assume_init_mut(),
&mut packed_values,
&mut unpacked_values,
)
})?;
Expand Down
22 changes: 22 additions & 0 deletions willow/src/testing_utils/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -89,3 +89,25 @@ rust_test(
"@crate_index//:googletest",
],
)

rust_library(
name = "shell_testing_decryptor",
testonly = 1,
srcs = [
"shell_testing_decryptor.rs",
],
deps = [
":shell_testing_parameters",
"//shell_wrapper:status",
"//willow/src/api:aggregation_config",
"//willow/src/shell:kahe_shell",
"//willow/src/shell:parameters_shell",
"//willow/src/shell:single_thread_hkdf",
"//willow/src/shell:vahe_shell",
"//willow/src/traits:ahe_traits",
"//willow/src/traits:kahe_traits",
"//willow/src/traits:messages",
"//willow/src/traits:prng_traits",
"//willow/src/traits:vahe_traits",
],
)
91 changes: 91 additions & 0 deletions willow/src/testing_utils/shell_testing_decryptor.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

use aggregation_config::AggregationConfig;
use ahe_traits::{AheBase, AheKeygen, PartialDec};
use kahe_shell::ShellKahe;
use kahe_traits::{KaheBase, KaheDecrypt, TrySecretKeyFrom};
use messages::ClientMessage;
use parameters_shell::create_shell_configs;
use prng_traits::SecurePrng;
use single_thread_hkdf::SingleThreadHkdfPrng;
use status::{StatusError, StatusErrorCode};
use vahe_shell::ShellVahe;
use vahe_traits::Recover;

/// Basic implementation of a single decryptor that uses Shell operations directly. Useful for
/// testing Shell clients, by checking that encrypted messages can be decrypted properly.
pub struct ShellTestingDecryptor {
kahe: ShellKahe,
vahe: ShellVahe,
prng: SingleThreadHkdfPrng,
secret_key: Option<<ShellVahe as AheBase>::SecretKeyShare>,
}

impl ShellTestingDecryptor {
/// Creates a new ShellTestingDecryptor, using the given context string to seed KAHE and AHE
/// public parameters.
pub fn new(
aggregation_config: &AggregationConfig,
context_string: &[u8],
) -> Result<ShellTestingDecryptor, StatusError> {
let (kahe_config, ahe_config) = create_shell_configs(aggregation_config)?;
let kahe = ShellKahe::new(kahe_config, context_string)?;
let vahe = ShellVahe::new(ahe_config, context_string)?;
let seed = SingleThreadHkdfPrng::generate_seed()?;
let prng = SingleThreadHkdfPrng::create(&seed)?;
Ok(ShellTestingDecryptor { kahe, vahe, prng, secret_key: None })
}

/// Generates a new AHE public key, and stores the corresponding secret key.
pub fn generate_public_key(
&mut self,
) -> Result<<ShellVahe as AheBase>::PublicKey, StatusError> {
let (sk_share, pk_share, _) = self.vahe.key_gen(&mut self.prng)?;
self.secret_key = Some(sk_share);
let public_key = self.vahe.aggregate_public_key_shares(&[pk_share])?;
Ok(public_key)
}

/// Decrypts a client message using the stored AHE secret key, by recovering the KAHE key from
/// the AHE ciphertext and then decrypting the KAHE ciphertext. Does not verify the client proof
/// contained in the message.
pub fn decrypt(
&mut self,
client_message: &ClientMessage<ShellKahe, ShellVahe>,
) -> Result<<ShellKahe as KaheBase>::Plaintext, StatusError> {
let decryption_request =
self.vahe.get_partial_dec_ciphertext(&client_message.ahe_ciphertext)?;
let rest_of_ciphertext =
self.vahe.get_recover_ciphertext(&client_message.ahe_ciphertext)?;
match &self.secret_key {
None => Err(StatusError::new_with_current_location(
StatusErrorCode::InvalidArgument,
"No secret key available",
)),
Some(sk_share) => {
let partial_decryption =
self.vahe.partial_decrypt(&decryption_request, sk_share, &mut self.prng)?;
let decrypted_kahe_key =
self.vahe.recover(&partial_decryption, &rest_of_ciphertext, None)?;
let decrypted_kahe_key = self.kahe.try_secret_key_from(decrypted_kahe_key)?;
let decrypted_plaintext =
self.kahe.decrypt(&client_message.kahe_ciphertext, &decrypted_kahe_key)?;
Ok(decrypted_plaintext)
}
}
}
}
2 changes: 2 additions & 0 deletions willow/src/willow_v1/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,11 @@ rust_test(
"@crate_index//:googletest",
"//willow/src/api:aggregation_config",
"//willow/src/shell:kahe_shell",
"//willow/src/shell:parameters_shell",
"//willow/src/shell:single_thread_hkdf",
"//willow/src/shell:vahe_shell",
"//willow/src/testing_utils",
"//willow/src/testing_utils:shell_testing_decryptor",
"//willow/src/testing_utils:shell_testing_parameters",
"//willow/src/traits:prng_traits",
],
Expand Down
69 changes: 25 additions & 44 deletions willow/src/willow_v1/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
use client_traits::SecureAggregationClient;
use kahe_traits::{HasKahe, KaheBase, KaheEncrypt, KaheKeygen, TrySecretKeyInto};
use messages::{ClientMessage, DecryptorPublicKey};
use prng_traits::SecurePrng;
use vahe_traits::{HasVahe, VaheBase, VerifiableEncrypt};

/// Lightweight client directly exposing KAHE/VAHE types.
Expand Down Expand Up @@ -84,18 +83,17 @@ mod test {
use super::*;

use aggregation_config::AggregationConfig;
use ahe_traits::{AheBase, AheKeygen, PartialDec};
use ahe_traits::AheBase;
use googletest::prelude::container_eq;
use googletest::{gtest, verify_eq, verify_that};
use kahe_shell::ShellKahe;
use kahe_traits::{KaheDecrypt, TrySecretKeyFrom};
use parameters_shell::create_shell_configs;
use prng_traits::SecurePrng;
use shell_testing_parameters::{make_ahe_config, make_kahe_config};
use shell_testing_decryptor::ShellTestingDecryptor;
use single_thread_hkdf::SingleThreadHkdfPrng;
use std::collections::HashMap;
use testing_utils::generate_random_nonce;
use vahe_shell::ShellVahe;
use vahe_traits::Recover;

const CONTEXT_STRING: &[u8] = b"test_context_string";

Expand All @@ -111,36 +109,27 @@ mod test {
};

// Create a client.
let kahe = ShellKahe::new(make_kahe_config(&aggregation_config), CONTEXT_STRING).unwrap();
let vahe = ShellVahe::new(make_ahe_config(), CONTEXT_STRING).unwrap();
let (kahe_config, ahe_config) = create_shell_configs(&aggregation_config)?;
let kahe = ShellKahe::new(kahe_config, CONTEXT_STRING)?;
let vahe = ShellVahe::new(ahe_config, CONTEXT_STRING)?;
let client_seed = SingleThreadHkdfPrng::generate_seed()?;
let prng = SingleThreadHkdfPrng::create(&client_seed)?;
let mut client = WillowV1Client { kahe, vahe, prng };

// Generate AHE keys.
let kahe = ShellKahe::new(make_kahe_config(&aggregation_config), CONTEXT_STRING).unwrap();
let vahe = ShellVahe::new(make_ahe_config(), CONTEXT_STRING).unwrap();
let seed = SingleThreadHkdfPrng::generate_seed()?;
let mut prng = SingleThreadHkdfPrng::create(&seed)?;
let (sk_share, pk_share, _) = vahe.key_gen(&mut prng)?;
let public_key = vahe.aggregate_public_key_shares(&[pk_share])?;
let mut testing_decryptor =
ShellTestingDecryptor::new(&aggregation_config, CONTEXT_STRING)?;
let public_key = testing_decryptor.generate_public_key()?;

// Create client message.
let input_values = vec![1, 2, 3, 4, 5, 6, 7, 8, 7, 6, 5, 4, 3, 2, 1];
let client_plaintext = HashMap::from([(default_id.as_str(), input_values.as_slice())]);
let nonce = generate_random_nonce();
let client_message =
client.create_client_message(&client_plaintext, &public_key, &nonce).unwrap();
client.create_client_message(&client_plaintext, &public_key, &nonce)?;

// Decrypt client message.
let decryption_request = vahe.get_partial_dec_ciphertext(&client_message.ahe_ciphertext)?;
let rest_of_ciphertext = vahe.get_recover_ciphertext(&client_message.ahe_ciphertext)?;
let partial_decryption = vahe.partial_decrypt(&decryption_request, &sk_share, &mut prng)?;
let decrypted_kahe_key = vahe.recover(&partial_decryption, &rest_of_ciphertext, None)?;
let decrypted_kahe_key = kahe.try_secret_key_from(decrypted_kahe_key)?;
let decrypted_plaintext =
kahe.decrypt(&client_message.kahe_ciphertext, &decrypted_kahe_key)?;

let decrypted_plaintext = testing_decryptor.decrypt(&client_message)?;
verify_that!(decrypted_plaintext.keys().collect::<Vec<_>>(), container_eq([&default_id]))?;
let client_plaintext_length = client_plaintext.get(default_id.as_str()).unwrap().len();
verify_eq!(
Expand All @@ -161,26 +150,25 @@ mod test {
};

// Create a client.
let kahe = ShellKahe::new(make_kahe_config(&aggregation_config), CONTEXT_STRING).unwrap();
let vahe = ShellVahe::new(make_ahe_config(), CONTEXT_STRING).unwrap();
let (kahe_config, ahe_config) = create_shell_configs(&aggregation_config)?;
let kahe = ShellKahe::new(kahe_config, CONTEXT_STRING)?;
let vahe = ShellVahe::new(ahe_config, CONTEXT_STRING)?;
let client1_seed = SingleThreadHkdfPrng::generate_seed()?;
let prng = SingleThreadHkdfPrng::create(&client1_seed)?;
let mut client1 = WillowV1Client { kahe, vahe, prng };

// Create a second client.
let kahe = ShellKahe::new(make_kahe_config(&aggregation_config), CONTEXT_STRING).unwrap();
let vahe = ShellVahe::new(make_ahe_config(), CONTEXT_STRING).unwrap();
let (kahe_config, ahe_config) = create_shell_configs(&aggregation_config)?;
let kahe = ShellKahe::new(kahe_config, CONTEXT_STRING)?;
let vahe = ShellVahe::new(ahe_config, CONTEXT_STRING)?;
let client2_seed = SingleThreadHkdfPrng::generate_seed()?;
let prng = SingleThreadHkdfPrng::create(&client2_seed)?;
let mut client2 = WillowV1Client { kahe, vahe, prng };

// Generate AHE keys.
let kahe = ShellKahe::new(make_kahe_config(&aggregation_config), CONTEXT_STRING).unwrap();
let vahe = ShellVahe::new(make_ahe_config(), CONTEXT_STRING).unwrap();
let seed = SingleThreadHkdfPrng::generate_seed()?;
let mut prng = SingleThreadHkdfPrng::create(&seed)?;
let (sk_share, pk_share, _) = vahe.key_gen(&mut prng)?;
let public_key = vahe.aggregate_public_key_shares(&[pk_share])?;
let mut testing_decryptor =
ShellTestingDecryptor::new(&aggregation_config, CONTEXT_STRING)?;
let public_key = testing_decryptor.generate_public_key()?;

// Create client messages.
let input_values1 = vec![1, 2, 3, 4, 5, 6, 7, 8, 7, 6, 5, 4, 3, 2, 1];
Expand All @@ -190,30 +178,23 @@ mod test {
let expected_output = vec![2, 3, 5, 7, 10, 14, 10, 9, 11, 11, 14, 8, 6, 9, 1];
let nonce1 = generate_random_nonce();
let mut client_message =
client1.create_client_message(&client1_plaintext, &public_key, &nonce1).unwrap();
client1.create_client_message(&client1_plaintext, &public_key, &nonce1)?;
let nonce2 = generate_random_nonce();
let extra_message =
client2.create_client_message(&client2_plaintext, &public_key, &nonce2).unwrap();
client2.create_client_message(&client2_plaintext, &public_key, &nonce2)?;

// Add extra message to the first client message.
kahe.add_ciphertexts_in_place(
client2.kahe.add_ciphertexts_in_place(
&extra_message.kahe_ciphertext,
&mut client_message.kahe_ciphertext,
)?;
vahe.add_ciphertexts_in_place(
client2.vahe.add_ciphertexts_in_place(
&extra_message.ahe_ciphertext,
&mut client_message.ahe_ciphertext,
)?;

// Decrypt client message.
let decryption_request = vahe.get_partial_dec_ciphertext(&client_message.ahe_ciphertext)?;
let rest_of_ciphertext = vahe.get_recover_ciphertext(&client_message.ahe_ciphertext)?;
let partial_decryption = vahe.partial_decrypt(&decryption_request, &sk_share, &mut prng)?;
let decrypted_kahe_key = vahe.recover(&partial_decryption, &rest_of_ciphertext, None)?;
let decrypted_kahe_key = kahe.try_secret_key_from(decrypted_kahe_key)?;
let decrypted_plaintext =
kahe.decrypt(&client_message.kahe_ciphertext, &decrypted_kahe_key)?;

let decrypted_plaintext = testing_decryptor.decrypt(&client_message)?;
verify_that!(decrypted_plaintext.keys().collect::<Vec<_>>(), container_eq([&default_id]))?;
let client_plaintext_length = client1_plaintext.get(default_id.as_str()).unwrap().len();
verify_eq!(
Expand Down