Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 10 additions & 29 deletions src/cuzk/msm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,39 +12,17 @@ use crate::cuzk::gpu::{
get_adapter, get_device, read_from_gpu,
};
use crate::cuzk::shader_manager::ShaderManager;
use crate::cuzk::utils::compute_p;
use crate::cuzk::utils::to_biguint_le;
use crate::{points_to_bytes, scalars_to_bytes};

use super::utils::bytes_to_field;
use super::utils::calc_bitwidth;
use super::utils::{MiscParams, compute_misc_params};
use ff::Field;

/// Calculate the number of words in the field characteristic
pub fn calc_num_words(word_size: usize) -> usize {
let p_bit_length = calc_bitwidth(&P);
let mut num_words = p_bit_length / word_size;
while num_words * word_size < p_bit_length {
num_words += 1;
}
num_words
}
use ff::{Field, PrimeField};

/// 13-bit limbs.
pub const WORD_SIZE: usize = 13;

/// Field characteristic
pub static P: Lazy<BigUint> = Lazy::new(|| {
BigUint::from_str_radix(
"21888242871839275222246405745257275088696311157297823662689037894645226208583",
10,
)
.expect("Invalid modulus")
});

/// Miscellaneous parameters
pub static PARAMS: Lazy<MiscParams> = Lazy::new(|| compute_misc_params(&P, WORD_SIZE));

fn pad_scalars<C: CurveAffine>(scalars: &[C::Scalar]) -> Vec<C::Scalar> {
let n = scalars.len();
let l = n.next_power_of_two();
Expand Down Expand Up @@ -73,19 +51,21 @@ fn pad_points<C: CurveAffine>(points: &[C]) -> Vec<C> {
* 2022: https://eprint.iacr.org/2022/1321.pdf
*/
pub async fn compute_msm<C: CurveAffine>(points: &[C], scalars: &[C::Scalar]) -> C::Curve {
let p = compute_p::<C>();
let params = compute_misc_params(&p, WORD_SIZE);
let padded_scalars = pad_scalars::<C>(scalars);
let padded_points = pad_points::<C>(points);
let input_size = padded_scalars.len();
let chunk_size = if input_size >= 65536 { 16 } else { 4 };
let num_columns = 1 << chunk_size;
let num_rows = input_size.div_ceil(num_columns);
let num_subtasks = 256_usize.div_ceil(chunk_size);
let num_words = PARAMS.num_words;
let num_words = params.num_words;

let point_bytes = points_to_bytes(&padded_points);
let scalar_bytes = scalars_to_bytes(&padded_scalars);

let shader_manager = ShaderManager::new(WORD_SIZE, chunk_size, input_size);
let shader_manager = ShaderManager::new(WORD_SIZE, chunk_size, input_size, &params);

let adapter = get_adapter().await;
let (device, queue) = get_device(&adapter).await;
Expand Down Expand Up @@ -350,12 +330,13 @@ pub async fn compute_msm<C: CurveAffine>(points: &[C], scalars: &[C::Scalar]) ->
device.destroy();

let mut points = vec![];
let r_inv = params.clone().rinv;

let g_points_x = bytemuck::cast_slice::<u8, u32>(&data[0])
.chunks(num_words)
.map(|x| {
let x_biguint_montgomery = to_biguint_le(x, num_words, WORD_SIZE as u32);
let x_biguint = x_biguint_montgomery * &PARAMS.rinv % P.clone();
let x_biguint = x_biguint_montgomery * &r_inv % p.clone();

bytes_to_field(&x_biguint.to_bytes_le())
})
Expand All @@ -364,7 +345,7 @@ pub async fn compute_msm<C: CurveAffine>(points: &[C], scalars: &[C::Scalar]) ->
.chunks(num_words)
.map(|y| {
let y_biguint_montgomery = to_biguint_le(y, num_words, WORD_SIZE as u32);
let y_biguint = y_biguint_montgomery * &PARAMS.rinv % P.clone();
let y_biguint = y_biguint_montgomery * &r_inv % p.clone();

bytes_to_field(&y_biguint.to_bytes_le())
})
Expand All @@ -373,7 +354,7 @@ pub async fn compute_msm<C: CurveAffine>(points: &[C], scalars: &[C::Scalar]) ->
.chunks(num_words)
.map(|z| {
let z_biguint_montgomery = to_biguint_le(z, num_words, WORD_SIZE as u32);
let z_biguint = z_biguint_montgomery * &PARAMS.rinv % P.clone();
let z_biguint = z_biguint_montgomery * &r_inv % p.clone();

bytes_to_field(&z_biguint.to_bytes_le())
})
Expand Down
25 changes: 12 additions & 13 deletions src/cuzk/shader_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,9 @@ pub static TEST_FIELD_SHADER: Lazy<String> =
pub static TEST_POINT_SHADER: Lazy<String> =
Lazy::new(|| include_str!("wgsl/test/test_point.wgsl").to_string());

use crate::cuzk::utils::{calc_bitwidth, gen_mu_limbs, gen_one_limbs, gen_p_limbs, gen_rinv_limbs};
use crate::cuzk::utils::{calc_bitwidth, gen_mu_limbs, gen_one_limbs, gen_p_limbs, gen_rinv_limbs, MiscParams};

use super::{
msm::{P, PARAMS},
utils::{gen_p_limbs_plus_one, gen_r_limbs, gen_zero_limbs},
};

Expand All @@ -71,13 +70,13 @@ pub struct ShaderManager {

impl ShaderManager {
/// Create a new shader manager
pub fn new(word_size: usize, chunk_size: usize, input_size: usize) -> Self {
let p_bit_length = calc_bitwidth(&P);
let num_words = PARAMS.num_words;
let r = PARAMS.r.clone();
let rinv = PARAMS.rinv.clone();
println!("P: {P:?}");
println!("P limbs: {}", gen_p_limbs(&P, num_words, word_size));
pub fn new(word_size: usize, chunk_size: usize, input_size: usize, params: &MiscParams) -> Self {
let p_bit_length = calc_bitwidth(&params.p);
let num_words = params.num_words;
let r = params.r.clone();
let rinv = params.rinv.clone();
println!("P: {:?}", params.p);
println!("P limbs: {}", gen_p_limbs(&params.p, num_words, word_size));
println!("W_MASK: {:?}", (1 << word_size) - 1);
println!("R limbs: {}", gen_r_limbs(&r, num_words, word_size));
Self {
Expand All @@ -86,15 +85,15 @@ impl ShaderManager {
input_size,
num_words,
index_shift: 1 << (chunk_size - 1),
p_limbs: gen_p_limbs(&P, num_words, word_size),
p_limbs_plus_one: gen_p_limbs_plus_one(&P, num_words, word_size),
p_limbs: gen_p_limbs(&params.p, num_words, word_size),
p_limbs_plus_one: gen_p_limbs_plus_one(&params.p, num_words, word_size),
zero_limbs: gen_zero_limbs(num_words),
one_limbs: gen_one_limbs(num_words),
slack: num_words * word_size - p_bit_length,
w_mask: (1 << word_size) - 1,
n0: PARAMS.n0,
n0: params.n0,
r_limbs: gen_r_limbs(&r, num_words, word_size),
mu_limbs: gen_mu_limbs(&P, num_words, word_size),
mu_limbs: gen_mu_limbs(&params.p, num_words, word_size),
rinv_limbs: gen_rinv_limbs(&rinv, num_words, word_size),
}
}
Expand Down
70 changes: 37 additions & 33 deletions src/cuzk/test/utils.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use ff::PrimeField;
use group::{prime::PrimeCurveAffine, Group};
use halo2curves::bn256::{Fr, G1, G1Affine};
use halo2curves::CurveAffine;

use crate::cuzk::utils::to_words_le_from_field;

Expand Down Expand Up @@ -30,15 +30,15 @@ pub fn get_element(arr: &[i32], id: i32) -> i32 {
}
}

pub fn get_point_element(arr: &[G1Affine], id: i32) -> G1Affine {
pub fn get_point_element<C: CurveAffine>(arr: &[C], id: i32) -> C {
if id < 0 {
if (arr.len() as i32 + id) < 0 {
return G1Affine::identity();
return C::identity();
}
arr[arr.len() + id as usize]
} else {
if id >= arr.len() as i32 {
return G1Affine::identity();
return C::identity();
}
arr[id as usize]
}
Expand Down Expand Up @@ -137,18 +137,22 @@ pub fn decompose_scalars_signed<F: PrimeField>(
signed_slices[i] = limbs[i] as i32 + carry;
if signed_slices[i] >= l / 2 {
signed_slices[i] = -(l - signed_slices[i]);
if signed_slices[i] == -0 {
signed_slices[i] = 0;
}
// if signed_slices[i] == 0 {
// signed_slices[i] = 0;
// }
carry = 1;
} else {
carry = 0;
}
}

// We do not need to handle the case where the final carry equals 1, as the highest word of the field modulus (0x12ab) is smaller than 2^{16-1}
if carry == 1 {
panic!("final carry is 1");
// TODO: Review this
// panic!("final carry is 1");
println!("Carrying 1");
println!("Scalar: {:?}", scalar);
println!("Limbs: {:?}", limbs);
signed_slices.push(carry);
}
as_limbs.push(signed_slices.iter().map(|x| x + shift).collect());
}
Expand All @@ -163,19 +167,19 @@ pub fn decompose_scalars_signed<F: PrimeField>(
/**
* Perform SMVP with signed bucket indices
*/
pub fn cpu_smvp_signed(
pub fn cpu_smvp_signed<C: CurveAffine>(
subtask_idx: usize,
input_size: usize,
num_columns: usize,
chunk_size: usize,
all_csc_col_ptr: &[i32],
all_csc_val_idxs: &[i32],
points: &[G1Affine],
) -> Vec<G1> {
points: &[C],
) -> Vec<C::Curve> {
let l = 1 << chunk_size;
let h = l / 2;
let zero = G1::identity();
let mut buckets: Vec<G1> = vec![zero; num_columns / 2];
let zero = C::Curve::identity();
let mut buckets: Vec<C::Curve> = vec![zero; num_columns / 2];

let rp_offset = subtask_idx * (num_columns + 1);

Expand All @@ -197,7 +201,7 @@ pub fn cpu_smvp_signed(
let idx = subtask_idx as i32 * input_size as i32 + k;
let val = get_element(all_csc_val_idxs, idx);
let point = get_point_element(points, val);
sum += G1::from(point);
sum += C::Curve::from(point);
}

let bucket_idx;
Expand All @@ -219,23 +223,23 @@ pub fn cpu_smvp_signed(
}

/// Serial bucket reduction
pub fn serial_bucket_reduction(buckets: &[G1]) -> G1 {
pub fn serial_bucket_reduction<C: CurveAffine>(buckets: &[C::Curve]) -> C::Curve {
let mut indices = vec![];
for i in 1..buckets.len() {
indices.push(i);
}
indices.push(0);

let mut bucket_sum = G1::identity();
let mut bucket_sum = C::Curve::identity();
for i in 1..buckets.len() + 1 {
let b = buckets[indices[i - 1]] * Fr::from(i as u64);
let b = buckets[indices[i - 1]] * C::Scalar::from(i as u64);
bucket_sum += b;
}
bucket_sum
}

/// Perform running sum in the classic fashion - one siumulated thread only
pub fn running_sum_bucket_reduction(buckets: &[G1]) -> G1 {
pub fn running_sum_bucket_reduction<C: CurveAffine>(buckets: &[C::Curve]) -> C::Curve {
let n = buckets.len();
let mut m = buckets[0];
let mut g = m;
Expand All @@ -252,9 +256,9 @@ pub fn running_sum_bucket_reduction(buckets: &[G1]) -> G1 {

/// Perform running sum with simulated parallelism. It is up to the caller
/// to add the resulting points.
pub fn parallel_bucket_reduction(buckets: &[G1], num_threads: usize) -> Vec<G1> {
pub fn parallel_bucket_reduction<C: CurveAffine>(buckets: &[C::Curve], num_threads: usize) -> Vec<C::Curve> {
let buckets_per_thread = buckets.len() / num_threads;
let mut bucket_sums: Vec<G1> = vec![];
let mut bucket_sums: Vec<C::Curve> = vec![];

for thread_id in 0..num_threads {
let idx = if thread_id == 0 {
Expand All @@ -275,7 +279,7 @@ pub fn parallel_bucket_reduction(buckets: &[G1], num_threads: usize) -> Vec<G1>

let s = buckets_per_thread * (num_threads - thread_id - 1);
if s > 0 {
g += m * Fr::from(s as u64);
g += m * C::Scalar::from(s as u64);
}

bucket_sums.push(g);
Expand All @@ -284,13 +288,13 @@ pub fn parallel_bucket_reduction(buckets: &[G1], num_threads: usize) -> Vec<G1>
}

/// The first part of the parallel bucket reduction algo
pub fn parallel_bucket_reduction_1(
buckets: &[G1],
pub fn parallel_bucket_reduction_1<C: CurveAffine>(
buckets: &[C::Curve],
num_threads: usize,
) -> (Vec<G1>, Vec<G1>) {
) -> (Vec<C::Curve>, Vec<C::Curve>) {
let buckets_per_thread = buckets.len() / num_threads;
let mut g_points: Vec<G1> = vec![];
let mut m_points: Vec<G1> = vec![];
let mut g_points: Vec<C::Curve> = vec![];
let mut m_points: Vec<C::Curve> = vec![];

for thread_id in 0..num_threads {
let idx = if thread_id == 0 {
Expand All @@ -316,21 +320,21 @@ pub fn parallel_bucket_reduction_1(
}

/// The second part of the parallel bucket reduction algo
pub fn parallel_bucket_reduction_2(
g_points: Vec<G1>,
m_points: Vec<G1>,
pub fn parallel_bucket_reduction_2<C: CurveAffine>(
g_points: Vec<C::Curve>,
m_points: Vec<C::Curve>,
num_buckets: usize,
num_threads: usize,
) -> Vec<G1> {
) -> Vec<C::Curve> {
let buckets_per_thread = num_buckets / num_threads;
let mut result: Vec<G1> = vec![];
let mut result: Vec<C::Curve> = vec![];

for thread_id in 0..num_threads {
let mut g = g_points[thread_id];
let m = m_points[thread_id];
let s = buckets_per_thread * (num_threads - thread_id - 1);
if s > 0 {
g += m * Fr::from(s as u64);
g += m * C::Scalar::from(s as u64);
}
result.push(g);
}
Expand Down
Loading