Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/multilinear_product/provers/blendy/prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use std::collections::BTreeSet;

use crate::{
messages::VerifierMessages,
multilinear::ReduceMode,
multilinear_product::{BlendyProductProver, BlendyProductProverConfig, TimeProductProver},
order_strategy::MSBOrder,
prover::Prover,
Expand Down Expand Up @@ -43,6 +44,7 @@ impl<F: Field, S: Stream<F>> Prover<F> for BlendyProductProver<F, S> {
streams: None,
num_variables: num_variables - last_round + 1,
inverse_four: F::from(4_u32).inverse().unwrap(),
reduce_mode: ReduceMode::Variablewise,
};

let stream_iterators = prover_config
Expand Down
7 changes: 5 additions & 2 deletions src/multilinear_product/provers/time/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::marker::PhantomData;

use ark_ff::Field;

use crate::{prover::ProductProverConfig, streams::Stream};
use crate::{multilinear::ReduceMode, prover::ProductProverConfig, streams::Stream};

pub struct TimeProductProverConfig<F, S>
where
Expand All @@ -11,6 +11,7 @@ where
{
pub num_variables: usize,
pub streams: Vec<S>,
pub reduce_mode: ReduceMode,
_f: PhantomData<F>,
}

Expand All @@ -19,10 +20,11 @@ where
F: Field,
S: Stream<F>,
{
pub fn new(num_variables: usize, streams: Vec<S>) -> Self {
pub fn new(num_variables: usize, streams: Vec<S>, reduce_mode: ReduceMode) -> Self {
Self {
num_variables,
streams,
reduce_mode,
_f: PhantomData::<F>,
}
}
Expand All @@ -33,6 +35,7 @@ impl<F: Field, S: Stream<F>> ProductProverConfig<F, S> for TimeProductProverConf
Self {
num_variables,
streams,
reduce_mode: ReduceMode::Variablewise,
_f: PhantomData::<F>,
}
}
Expand Down
275 changes: 66 additions & 209 deletions src/multilinear_product/provers/time/core.rs
Original file line number Diff line number Diff line change
@@ -1,18 +1,25 @@
use ark_ff::Field;
use ark_std::vec::Vec;
#[cfg(feature = "parallel")]
use rayon::iter::{
IndexedParallelIterator, IntoParallelIterator, IntoParallelRefMutIterator, ParallelIterator,
};

use crate::streams::Stream;
use crate::{
multilinear::{
reductions::{pairwise, variablewise},
ReduceMode,
},
multilinear_product::provers::time::reductions::{
pairwise::{pairwise_product_evaluate, pairwise_product_evaluate_from_stream},
variablewise::{variablewise_product_evaluate, variablewise_product_evaluate_from_stream},
},
streams::Stream,
};

pub struct TimeProductProver<F: Field, S: Stream<F>> {
pub current_round: usize,
pub evaluations: Vec<Option<Vec<F>>>,
pub streams: Option<Vec<S>>,
pub num_variables: usize,
pub inverse_four: F,
pub reduce_mode: ReduceMode,
}

impl<F: Field, S: Stream<F>> TimeProductProver<F, S> {
Expand All @@ -27,217 +34,67 @@ impl<F: Field, S: Stream<F>> TimeProductProver<F, S> {
* from the streams (instead of the tables), which reduces max memory usage by 1/2
*/
pub fn vsbw_evaluate(&self) -> (F, F, F) {
// Initialize accumulators
let mut sum_half = F::ZERO;
let mut j_prime_table: ((F, F), (F, F)) = ((F::ZERO, F::ZERO), (F::ZERO, F::ZERO));

// Calculate the bitmask for the number of free variables
let bitmask: usize = 1 << (self.num_free_variables() - 1);

// Determine the length of evaluations to iterate through
let evaluations_len = match &self.evaluations[0] {
Some(evaluations) => evaluations.len(),
None => match &self.streams {
Some(streams) => 2usize.pow(streams[0].num_variables() as u32),
None => panic!("Both streams and evaluations cannot be None"),
match &self.evaluations[0] {
None => match self.reduce_mode {
ReduceMode::Variablewise => variablewise_product_evaluate_from_stream(
&self.streams.clone().unwrap(),
self.inverse_four,
),
ReduceMode::Pairwise => {
pairwise_product_evaluate_from_stream(&self.streams.clone().unwrap())
}
},
};

#[cfg(feature = "parallel")]
{
let p_evals = self.evaluations[0].as_deref();
let q_evals = self.evaluations[1].as_deref();
let streams = self.streams.as_ref();

let (acc00, acc01, acc10, acc11) = (0..evaluations_len / 2)
.into_par_iter()
// each worker gets its own local (j00, j01, j10, j11)
.fold(
|| (F::ZERO, F::ZERO, F::ZERO, F::ZERO),
|(mut j00, mut j01, mut j10, mut j11), i| {
// Load p,q at bit = 0 and 1. We only branch on the “source” (vec vs stream).
let (p0, p1) = if let Some(pe) = p_evals {
(pe[i], pe[i | bitmask])
} else {
let s =
&streams.expect("Both streams and evaluations cannot be None")[0];
(s.evaluation(i), s.evaluation(i | bitmask))
};

let (q0, q1) = if let Some(qe) = q_evals {
(qe[i], qe[i | bitmask])
} else {
let s =
&streams.expect("Both streams and evaluations cannot be None")[1];
(s.evaluation(i), s.evaluation(i | bitmask))
};

// Directly accumulate the 2x2 contributions (no temp x/y tables needed)
j00 += p0 * q0; // (0,0)
j11 += p1 * q1; // (1,1)
j01 += p0 * q1; // (0,1)
j10 += p1 * q0; // (1,0)

(j00, j01, j10, j11)
},
)
// combine thread-local partials
.reduce(
|| (F::ZERO, F::ZERO, F::ZERO, F::ZERO),
|(a00, a01, a10, a11), (b00, b01, b10, b11)| {
(a00 + b00, a01 + b01, a10 + b10, a11 + b11)
},
);
j_prime_table.0 .0 += acc00;
j_prime_table.0 .1 += acc01;
j_prime_table.1 .0 += acc10;
j_prime_table.1 .1 += acc11;
}

#[cfg(not(feature = "parallel"))]
// Iterate through evaluations
for i in 0..(evaluations_len / 2) {
// these must be zeroed out
let mut x_table: (F, F) = (F::ZERO, F::ZERO);
let mut y_table: (F, F) = (F::ZERO, F::ZERO);

// get all the values
let p_zero = match &self.evaluations[0] {
None => match &self.streams {
Some(streams) => streams[0].evaluation(i),
None => panic!("Both streams and evaluations cannot be None"),
},
Some(evaluations_p) => evaluations_p[i],
};
let q_zero = match &self.evaluations[1] {
None => match &self.streams {
Some(streams) => streams[1].evaluation(i),
None => panic!("Both streams and evaluations cannot be None"),
},
Some(evaluations_q) => evaluations_q[i],
};
let p_one = match &self.evaluations[0] {
None => match &self.streams {
Some(streams) => streams[0].evaluation(i | bitmask),
None => panic!("Both streams and evaluations cannot be None"),
},
Some(evaluations_p) => evaluations_p[i | bitmask],
};
let q_one = match &self.evaluations[1] {
None => match &self.streams {
Some(streams) => streams[1].evaluation(i | bitmask),
None => panic!("Both streams and evaluations cannot be None"),
},
Some(evaluations_q) => evaluations_q[i | bitmask],
};

// update tables
x_table.0 += p_zero;
y_table.0 += q_zero;
y_table.1 += q_one;
x_table.1 += p_one;

// update j_prime
j_prime_table.0 .0 = j_prime_table.0 .0 + x_table.0 * y_table.0;
j_prime_table.1 .1 = j_prime_table.1 .1 + x_table.1 * y_table.1;
j_prime_table.0 .1 = j_prime_table.0 .1 + x_table.0 * y_table.1;
j_prime_table.1 .0 = j_prime_table.1 .0 + x_table.1 * y_table.0;
Some(_evals) => {
let evals: Vec<Vec<F>> = self
.evaluations
.iter()
.filter_map(|opt| opt.clone()) // keep only Some(&Vec<F>)
.collect();
let evals_slice: &[Vec<F>] = &evals;
match self.reduce_mode {
ReduceMode::Variablewise => {
variablewise_product_evaluate(evals_slice, self.inverse_four)
}
ReduceMode::Pairwise => pairwise_product_evaluate(evals_slice),
}
}
}

// update
let sum_0 = j_prime_table.0 .0;
let sum_1 = j_prime_table.1 .1;
sum_half +=
j_prime_table.0 .0 + j_prime_table.1 .1 + j_prime_table.0 .1 + j_prime_table.1 .0;
sum_half *= self.inverse_four;

(sum_0, sum_1, sum_half)
}
pub fn vsbw_reduce_evaluations(&mut self, verifier_message: F, verifier_message_hat: F) {
for i in 0..self.evaluations.len() {
// Clone or initialize the evaluations vector
let mut evaluations = match &self.evaluations[i] {
Some(evaluations) => evaluations.clone(),
None => match &self.streams {
Some(streams) => vec![
F::ZERO;
2usize.pow(streams[i].num_variables().try_into().unwrap())
/ 2
],
None => panic!("Both streams and evaluations cannot be None"),
},
};

// Determine the length of evaluations to iterate through
let evaluations_len = match &self.evaluations[i] {
Some(evaluations) => evaluations.len() / 2,
None => evaluations.len(),
};

// Calculate what bit needs to be set to index the second half of the last round's evaluations
let setbit: usize = 1 << self.num_free_variables();

// Iterate through pairs of evaluations
for i0 in 0..evaluations_len {
let i1 = i0 | setbit;

// Get point evaluations for indices i0 and i1
let point_evaluation_i0 = match &self.evaluations[i] {
None => match &self.streams {
Some(streams) => streams[i].evaluation(i0),
None => panic!("Both streams and evaluations cannot be None"),
},
Some(evaluations) => evaluations[i0],
};
let point_evaluation_i1 = match &self.evaluations[i] {
None => match &self.streams {
Some(streams) => streams[i].evaluation(i1),
None => panic!("Both streams and evaluations cannot be None"),
},
Some(evaluations) => evaluations[i1],
};
// Update the i0-th evaluation based on the reduction operation
evaluations[i0] = point_evaluation_i0 * verifier_message_hat
+ point_evaluation_i1 * verifier_message;
}

#[cfg(feature = "parallel")]
let vm = verifier_message;
#[cfg(feature = "parallel")]
let vmh = verifier_message_hat;
#[cfg(feature = "parallel")]
match (&self.evaluations[i], &self.streams) {
// Read from slice
(Some(src), _) => {
evaluations.par_iter_mut().enumerate().for_each(
|(i0, out): (usize, &mut F)| {
let i1 = i0 | setbit;
let p0 = src[i0];
let p1 = src[i1];
*out = p0 * vmh + p1 * vm; // <- write through &mut
},
);
match &self.evaluations[0] {
None => {
let len = self.streams.clone().unwrap().len();
for i in 0..len {
self.evaluations[i] = Some(vec![]);
match self.reduce_mode {
ReduceMode::Variablewise => variablewise::reduce_evaluations_from_stream(
&self.streams.as_mut().unwrap()[i],
self.evaluations[i].as_mut().unwrap(),
verifier_message,
verifier_message_hat,
),
ReduceMode::Pairwise => pairwise::reduce_evaluations_from_stream(
&self.streams.as_mut().unwrap()[i],
self.evaluations[i].as_mut().unwrap(),
verifier_message,
),
}
}
// Read from stream
(None, Some(streams)) => {
let s = &streams[i];
evaluations.par_iter_mut().enumerate().for_each(
|(i0, out): (usize, &mut F)| {
let i1 = i0 | setbit;
let p0 = s.evaluation(i0);
let p1 = s.evaluation(i1);
*out = p0 * vmh + p1 * vm; // <- write through &mut
},
);
}
Some(_a) => {
for table in &mut self.evaluations {
match self.reduce_mode {
ReduceMode::Variablewise => variablewise::reduce_evaluations(
table.as_mut().unwrap(),
verifier_message,
verifier_message_hat,
),
ReduceMode::Pairwise => {
pairwise::reduce_evaluations(table.as_mut().unwrap(), verifier_message)
}
}
}
(None, None) => panic!("Both streams and evaluations cannot be None"),
}

// Truncate the evaluations vector to the correct length
evaluations.truncate(evaluations_len);

// Update the internal state with the new evaluations vector
self.evaluations[i] = Some(evaluations.clone());
}
}
}
1 change: 1 addition & 0 deletions src/multilinear_product/provers/time/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
mod config;
mod core;
mod prover;
pub mod reductions;

pub use config::TimeProductProverConfig;
pub use core::TimeProductProver;
1 change: 1 addition & 0 deletions src/multilinear_product/provers/time/prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ impl<F: Field, S: Stream<F>> Prover<F> for TimeProductProver<F, S> {
streams: Some(prover_config.streams),
num_variables,
inverse_four: F::from(4_u32).inverse().unwrap(),
reduce_mode: prover_config.reduce_mode,
}
}

Expand Down
2 changes: 2 additions & 0 deletions src/multilinear_product/provers/time/reductions/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
pub mod pairwise;
pub mod variablewise;
Loading
Loading