Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 5 additions & 6 deletions shell_wrapper/kahe.cc
Original file line number Diff line number Diff line change
Expand Up @@ -322,12 +322,6 @@ FfiStatus PackMessagesRaw(rust::Slice<const uint64_t> messages,
return MakeFfiStatus(absl::InvalidArgumentError(
secure_aggregation::kNullPointerErrorMessage));
}

// Allocate the vector for output packed values if needed.
if (packed_values->ptr == nullptr) {
packed_values->ptr =
std::make_unique<std::vector<secure_aggregation::BigInteger>>();
}
auto curr_packed_values =
rlwe::PackMessagesFlat<secure_aggregation::Integer,
secure_aggregation::BigInteger>(
Expand All @@ -339,6 +333,11 @@ FfiStatus PackMessagesRaw(rust::Slice<const uint64_t> messages,
}
// Pad with zeros if needed.
curr_packed_values.resize(num_packed_values, 0);
// Allocate the vector for output packed values if needed.
if (packed_values->ptr == nullptr) {
packed_values->ptr =
std::make_unique<std::vector<secure_aggregation::BigInteger>>();
}
// Append the packed values to the end of the output vector.
packed_values->ptr->insert(packed_values->ptr->end(),
curr_packed_values.begin(),
Expand Down
18 changes: 13 additions & 5 deletions shell_wrapper/kahe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ use single_thread_hkdf::{SeedWrapper, SingleThreadHkdfWrapper};
use status::rust_status_from_cpp;
use std::collections::HashMap;
use std::marker::PhantomData;
use std::mem;
use std::mem::MaybeUninit;

#[derive(Debug, PartialEq, Clone)]
Expand All @@ -37,7 +38,7 @@ mod ffi {
}

pub struct BigIntVectorWrapper {
pub ptr: UniquePtr<CxxVector<BigInteger>>,
ptr: UniquePtr<CxxVector<BigInteger>>,
}

unsafe extern "C++" {
Expand Down Expand Up @@ -190,7 +191,10 @@ pub fn encrypt(
params: &KahePublicParametersWrapper,
prng: &mut SingleThreadHkdfWrapper,
) -> Result<RnsPolynomialVec, status::StatusError> {
let mut packed_values = MaybeUninit::<BigIntVectorWrapper>::zeroed();
// SAFETY: this initializes `packed_values` with packed_values.ptr == nullptr. The following
// loop ensures that we either return an error (and drop `packed_values`), or make
// packed_values.ptr point to a valid C++ vector.
let mut packed_values: BigIntVectorWrapper = unsafe { mem::zeroed() };
// SAFETY: No lifetime constraints (`PackMessagesRaw` may create a new vector of BigIntegers
// wrapped by `packed_values` which does not keep any reference to the inputs).
// `PackMessagesRaw` only appends to the C++ vector wrapped by `packed_values`,
Expand All @@ -206,18 +210,19 @@ pub fn encrypt(
packed_vector_config.base,
packed_vector_config.dimension,
packed_vector_config.num_packed_coeffs,
packed_values.as_mut_ptr(),
&mut packed_values,
)
})?;
}
// SAFETY: `packed_values` is safely initialized if we get to this point.

let mut out = MaybeUninit::<RnsPolynomialVec>::zeroed();
// SAFETY: No lifetime constraints (`Encrypt` creates a new vector of polynomials wrapped by
// `out` which does not keep any reference to the inputs). `Encrypt` reads the C++ vector
// wrapped by `packed_values`, updates the states wrapped by `prng`, and writes into the C++
// vector wrapped by `out`.
rust_status_from_cpp(unsafe {
ffi::Encrypt(&packed_values.assume_init(), secret_key, params, prng, out.as_mut_ptr())
ffi::Encrypt(&packed_values, secret_key, params, prng, out.as_mut_ptr())
})?;
// SAFETY: `out` is safely initialized if we get to this point.
Ok(unsafe { out.assume_init() })
Expand All @@ -239,6 +244,9 @@ pub fn decrypt(
ffi::Decrypt(ciphertext, secret_key, params, packed_values.as_mut_ptr())
})?;

// SAFETY: `packed_values` is safely initialized if we get to this point.
let mut packed_values = unsafe { packed_values.assume_init() };

let mut output_vectors = HashMap::<String, Vec<u64>>::new();
// Assume the packed values are stored in the same order as the configs.
for (id, packed_vector_config) in packed_vector_configs.iter() {
Expand All @@ -253,7 +261,7 @@ pub fn decrypt(
packed_vector_config.base,
packed_vector_config.dimension,
packed_vector_config.num_packed_coeffs,
packed_values.assume_init_mut(),
&mut packed_values,
&mut unpacked_values,
)
})?;
Expand Down
82 changes: 81 additions & 1 deletion willow/src/api/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ rust_library(
"@protobuf//rust:protobuf",
"//shell_wrapper:status",
"//willow/proto/willow:aggregation_config_rust_proto",
"//willow/src/shell:single_thread_hkdf",
"//willow/src/traits:proto_serialization_traits",
],
)
Expand All @@ -58,7 +59,6 @@ cc_library(
"@abseil-cpp//absl/status",
"@abseil-cpp//absl/status:statusor",
"@abseil-cpp//absl/strings",
"@cxx.rs//:cxx",
"@cxx.rs//:core",
"//willow/proto/willow:aggregation_config_cc_proto",
"//willow/proto/willow:server_accumulator_cc_proto",
Expand Down Expand Up @@ -111,3 +111,83 @@ rust_library(
"//willow/src/willow_v1:willow_v1_verifier",
],
)

rust_library(
name = "client",
srcs = ["client.rs"],
deps = [
":aggregation_config",
"@protobuf//rust:protobuf",
"@cxx.rs//:cxx",
"//shell_wrapper:status",
"//willow/proto/shell:shell_ciphertexts_rust_proto",
"//willow/proto/willow:aggregation_config_rust_proto",
"//willow/proto/willow:messages_rust_proto",
"//willow/src/shell:ahe_shell",
"//willow/src/shell:kahe_shell",
"//willow/src/shell:parameters_shell",
"//willow/src/shell:single_thread_hkdf",
"//willow/src/shell:vahe_shell",
"//willow/src/traits:ahe_traits",
"//willow/src/traits:client_traits",
"//willow/src/traits:kahe_traits",
"//willow/src/traits:messages",
"//willow/src/traits:prng_traits",
"//willow/src/traits:proto_serialization_traits",
"//willow/src/traits:vahe_traits",
"//willow/src/willow_v1:willow_v1_client",
"//willow/src/willow_v1:willow_v1_server",
],
)

rust_cxx_bridge(
name = "client_cxx",
src = "client.rs",
deps = [
":client",
":encoded_data",
],
)

cc_library(
name = "encoded_data",
srcs = ["encoded_data.cc"],
hdrs = ["encoded_data.h"],
deps = [
"@cxx.rs//:core",
"//willow/src/input_encoding:codec",
],
)

cc_library(
name = "client_cc",
srcs = ["client.cc"],
hdrs = ["client.h"],
deps = [
":client_cxx",
"@abseil-cpp//absl/status",
"@abseil-cpp//absl/status:statusor",
"@cxx.rs//:core",
"//willow/proto/shell:shell_ciphertexts_cc_proto",
"//willow/proto/willow:aggregation_config_cc_proto",
"//willow/proto/willow:messages_cc_proto",
"//willow/proto/willow:server_accumulator_cc_proto",
"//willow/src/input_encoding:codec",
],
)

cc_test(
name = "client_test_cc",
srcs = ["client_test.cc"],
deps = [
":client_cc",
":client_cxx",
"@googletest//:gtest_main",
"@cxx.rs//:core",
"//shell_wrapper:status_matchers",
"//willow/proto/willow:aggregation_config_cc_proto",
"//willow/proto/willow:input_spec_cc_proto",
"//willow/src/input_encoding:codec",
"//willow/src/testing_utils:shell_testing_decryptor_cc",
],
)
13 changes: 13 additions & 0 deletions willow/src/api/aggregation_config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,19 @@ impl ToProto for AggregationConfig {
}
}

impl AggregationConfig {
/// Computes context bytes by hashing the session ID in the config.
pub fn compute_context_bytes(&self) -> Result<Vec<u8>, StatusError> {
let context_seed = single_thread_hkdf::compute_hkdf(
self.session_id.as_bytes(),
b"",
b"AggregationConfig.context_string",
single_thread_hkdf::seed_length(),
)?;
Ok(context_seed.as_bytes().to_vec())
}
}

#[cfg(test)]
mod tests {
use crate::AggregationConfig;
Expand Down
78 changes: 78 additions & 0 deletions willow/src/api/client.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
// Copyright 2025 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "willow/src/api/client.h"

#include <cstdint>
#include <memory>
#include <string>
#include <utility>

#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "include/cxx.h"
#include "willow/proto/shell/ciphertexts.pb.h"
#include "willow/proto/willow/aggregation_config.pb.h"
#include "willow/proto/willow/server_accumulator.pb.h"
#include "willow/src/api/client.rs.h"
#include "willow/src/input_encoding/codec.h"

namespace secure_aggregation {

absl::StatusOr<willow::ClientMessage> GenerateClientContribution(
const willow::AggregationConfigProto& aggregation_config,
const willow::EncodedData& encoded_data,
const willow::ShellAhePublicKey& key, const std::string& nonce) {
// Initialize client.
std::string config_str = aggregation_config.SerializeAsString();
auto config_ptr = std::make_unique<std::string>(std::move(config_str));
secure_aggregation::WillowShellClient* client_ptr = nullptr;
std::unique_ptr<std::string> status_message;
int status_code =
initialize_client(std::move(config_ptr), &client_ptr, &status_message);
if (status_code != 0) {
return absl::Status(absl::StatusCode(status_code), *status_message);
}
auto client = client_into_box(client_ptr);

// Prepare arguments.
EncodedDataWrapper encoded_data_wrapper(encoded_data);
std::string key_str = key.SerializeAsString();
auto key_ptr = std::make_unique<std::string>(std::move(key_str));
rust::Slice<const uint8_t> nonce_slice{
reinterpret_cast<const uint8_t*>(nonce.data()), nonce.size()};
rust::Vec<uint8_t> result_bytes;
std::unique_ptr<std::string> status_message_gen;

// Encrypt data.
int status_code_gen =
generate_contribution(client, encoded_data_wrapper, std::move(key_ptr),
nonce_slice, &result_bytes, &status_message_gen);
if (status_code_gen != 0) {
return absl::Status(absl::StatusCode(status_code_gen), *status_message_gen);
}

// Parse string to ClientMessage.
willow::ClientMessage client_message;
std::string result_str(reinterpret_cast<const char*>(result_bytes.data()),
result_bytes.size());
if (!client_message.ParseFromString(result_str)) {
return absl::InternalError(
"Failed to parse ClientMessage from Rust output.");
}

return client_message;
}

} // namespace secure_aggregation
40 changes: 40 additions & 0 deletions willow/src/api/client.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef SECURE_AGGREGATION_WILLOW_SRC_API_CLIENT_H_
#define SECURE_AGGREGATION_WILLOW_SRC_API_CLIENT_H_

#include <string>

#include "absl/status/statusor.h"
#include "willow/proto/shell/ciphertexts.pb.h"
#include "willow/proto/willow/aggregation_config.pb.h"
#include "willow/proto/willow/messages.pb.h"
#include "willow/proto/willow/server_accumulator.pb.h"
#include "willow/src/input_encoding/codec.h"

namespace secure_aggregation {

// Generates a client contribution by encrypting the encoded data with the
// provided AHE public key.
absl::StatusOr<willow::ClientMessage> GenerateClientContribution(
const willow::AggregationConfigProto& aggregation_config,
const willow::EncodedData& encoded_data,
const willow::ShellAhePublicKey& key, const std::string& nonce);

} // namespace secure_aggregation

#endif // SECURE_AGGREGATION_WILLOW_SRC_API_CLIENT_H_
Loading