Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 15 additions & 8 deletions src/coefficient_sumcheck.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use crate::field::SumcheckField;
use ark_ff::Field;
use ark_poly::univariate::DensePolynomial;

Expand Down Expand Up @@ -46,7 +47,7 @@ pub struct CoefficientSumcheck<F: Field> {
/// }
/// }
/// ```
pub trait RoundPolyEvaluator<F: Field>: Sync {
pub trait RoundPolyEvaluator<F: SumcheckField>: Sync {
/// The degree of the round polynomial (number of coefficients = degree + 1).
fn degree(&self) -> usize;

Expand Down Expand Up @@ -79,7 +80,7 @@ pub trait RoundPolyEvaluator<F: Field>: Sync {
/// SIMD fast path for degree-1 with a single pairwise table.
///
/// Returns `[sum_even, sum_odd - sum_even]` = coefficients of `h(x) = c0 + c1*x`.
fn simd_evaluate_degree1<F: Field>(pw: &[F]) -> Vec<F> {
fn simd_evaluate_degree1<F: SumcheckField>(pw: &[F]) -> Vec<F> {
// Try SIMD dispatch for Goldilocks
#[cfg(all(
feature = "simd",
Expand Down Expand Up @@ -112,7 +113,7 @@ fn simd_evaluate_degree1<F: Field>(pw: &[F]) -> Vec<F> {
all(target_arch = "x86_64", target_feature = "avx512ifma")
)
))]
fn try_simd_evaluate_degree1<F: ark_ff::Field>(pw: &[F]) -> Option<Vec<F>> {
fn try_simd_evaluate_degree1<F: SumcheckField>(pw: &[F]) -> Option<Vec<F>> {
crate::simd_sumcheck::dispatch::try_simd_evaluate_degree1(pw)
}

Expand All @@ -128,7 +129,10 @@ fn try_simd_evaluate_degree1<F: ark_ff::Field>(pw: &[F]) -> Option<Vec<F>> {
all(target_arch = "x86_64", target_feature = "avx512ifma")
)
))]
fn try_simd_fused_reduce_evaluate<F: Field>(pw: &mut Vec<F>, challenge: F) -> Option<Vec<F>> {
fn try_simd_fused_reduce_evaluate<F: SumcheckField>(
pw: &mut Vec<F>,
challenge: F,
) -> Option<Vec<F>> {
crate::simd_sumcheck::dispatch::try_simd_fused_reduce_evaluate_degree1(pw, challenge)
}

Expand All @@ -139,13 +143,16 @@ fn try_simd_fused_reduce_evaluate<F: Field>(pw: &mut Vec<F>, challenge: F) -> Op
all(target_arch = "x86_64", target_feature = "avx512ifma")
)
)))]
fn try_simd_fused_reduce_evaluate<F: Field>(_pw: &mut Vec<F>, _challenge: F) -> Option<Vec<F>> {
fn try_simd_fused_reduce_evaluate<F: SumcheckField>(
_pw: &mut Vec<F>,
_challenge: F,
) -> Option<Vec<F>> {
None
}

/// Parallel evaluate using rayon (for heavy evaluators).
#[cfg(feature = "parallel")]
fn parallel_evaluate<F: Field>(
fn parallel_evaluate<F: SumcheckField>(
evaluator: &impl RoundPolyEvaluator<F>,
tablewise: &[Vec<Vec<F>>],
pairwise: &[Vec<F>],
Expand Down Expand Up @@ -184,7 +191,7 @@ fn parallel_evaluate<F: Field>(

/// Fallback when parallel feature is disabled.
#[cfg(not(feature = "parallel"))]
fn parallel_evaluate<F: Field>(
fn parallel_evaluate<F: SumcheckField>(
evaluator: &impl RoundPolyEvaluator<F>,
tablewise: &[Vec<Vec<F>>],
pairwise: &[Vec<F>],
Expand All @@ -209,7 +216,7 @@ fn parallel_evaluate<F: Field>(
/// Sequential evaluate (for trivial evaluators where rayon overhead dominates).
///
/// Fills `coeffs_out` with accumulated coefficients (zeroes it first).
fn sequential_evaluate_into<F: Field>(
fn sequential_evaluate_into<F: SumcheckField>(
evaluator: &impl RoundPolyEvaluator<F>,
tablewise: &[Vec<Vec<F>>],
pairwise: &[Vec<F>],
Expand Down
32 changes: 19 additions & 13 deletions src/inner_product_sumcheck.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@
//! additional cache-locality gains from reading all four strides
//! simultaneously.

use ark_ff::Field;
use crate::field::SumcheckField;
use alloc::vec::Vec;
#[cfg(feature = "parallel")]
use rayon::join;
#[cfg(feature = "parallel")]
Expand All @@ -29,7 +30,7 @@ use crate::transcript::ProverTranscript;

/// Legacy return type for `inner_product_sumcheck`.
#[derive(Debug, PartialEq)]
pub struct ProductSumcheck<F: Field> {
pub struct ProductSumcheck<F: SumcheckField> {
pub prover_messages: Vec<(F, F)>,
pub verifier_messages: Vec<F>,
pub final_evaluations: (F, F),
Expand All @@ -38,6 +39,7 @@ pub struct ProductSumcheck<F: Field> {
// ─── Workload threshold ─────────────────────────────────────────────────────

/// Target single-thread workload size for `T`. Close to L1 cache.
#[cfg(feature = "parallel")]
const fn workload_size<T: Sized>() -> usize {
#[cfg(all(target_arch = "aarch64", target_os = "macos"))]
const CACHE_SIZE: usize = 1 << 17;
Expand All @@ -63,7 +65,7 @@ const fn workload_size<T: Sized>() -> usize {

// ─── Scalar helpers ─────────────────────────────────────────────────────────

fn dot<F: Field>(a: &[F], b: &[F]) -> F {
fn dot<F: SumcheckField>(a: &[F], b: &[F]) -> F {
debug_assert_eq!(a.len(), b.len());
#[cfg(feature = "parallel")]
if a.len() > workload_size::<F>() {
Expand All @@ -72,7 +74,7 @@ fn dot<F: Field>(a: &[F], b: &[F]) -> F {
a.iter().zip(b).map(|(x, y)| *x * *y).sum()
}

fn scalar_mul<F: Field>(v: &mut [F], w: F) {
fn scalar_mul<F: SumcheckField>(v: &mut [F], w: F) {
for x in v.iter_mut() {
*x *= w;
}
Expand All @@ -83,8 +85,8 @@ fn scalar_mul<F: Field>(v: &mut [F], w: F) {
/// `(c0, c2)` of the round polynomial `q(x) = c0 + c1·x + c2·x²`.
///
/// Vectors `a` and `b` are implicitly zero-extended to the next power of two.
pub fn compute_sumcheck_polynomial<F: Field>(a: &[F], b: &[F]) -> (F, F) {
fn recurse<F: Field>(a0: &[F], a1: &[F], b0: &[F], b1: &[F]) -> (F, F) {
pub fn compute_sumcheck_polynomial<F: SumcheckField>(a: &[F], b: &[F]) -> (F, F) {
fn recurse<F: SumcheckField>(a0: &[F], a1: &[F], b0: &[F], b1: &[F]) -> (F, F) {
debug_assert_eq!(a0.len(), b0.len());
debug_assert_eq!(a1.len(), b1.len());
debug_assert!(a0.len() == a1.len());
Expand Down Expand Up @@ -138,8 +140,8 @@ pub fn compute_sumcheck_polynomial<F: Field>(a: &[F], b: &[F]) -> (F, F) {
///
/// `values` is implicitly zero-padded to the next power of two. On return,
/// the length is a power of two (or zero).
pub fn fold<F: Field>(values: &mut Vec<F>, weight: F) {
fn recurse_both<F: Field>(low: &mut [F], high: &[F], weight: F) {
pub fn fold<F: SumcheckField>(values: &mut Vec<F>, weight: F) {
fn recurse_both<F: SumcheckField>(low: &mut [F], high: &[F], weight: F) {
#[cfg(feature = "parallel")]
if low.len() > workload_size::<F>() {
let split = low.len() / 2;
Expand Down Expand Up @@ -174,7 +176,11 @@ pub fn fold<F: Field>(values: &mut Vec<F>, weight: F) {
}

/// Two-pass fold-then-compute; reference version kept for testing.
pub fn fold_and_compute_polynomial<F: Field>(a: &mut Vec<F>, b: &mut Vec<F>, weight: F) -> (F, F) {
pub fn fold_and_compute_polynomial<F: SumcheckField>(
a: &mut Vec<F>,
b: &mut Vec<F>,
weight: F,
) -> (F, F) {
fold(a, weight);
fold(b, weight);
compute_sumcheck_polynomial(a, b)
Expand All @@ -190,7 +196,7 @@ pub fn fold_and_compute_polynomial<F: Field>(a: &mut Vec<F>, b: &mut Vec<F>, wei
///
/// Falls back to the unfused path for small or non-pow2 inputs so the
/// implicit-zero tail accounting stays identical.
pub fn fused_fold_and_compute_polynomial<F: Field>(
pub fn fused_fold_and_compute_polynomial<F: SumcheckField>(
a: &mut Vec<F>,
b: &mut Vec<F>,
weight: F,
Expand All @@ -202,7 +208,7 @@ pub fn fused_fold_and_compute_polynomial<F: Field>(
}

#[allow(clippy::too_many_arguments)]
fn kernel<F: Field>(
fn kernel<F: SumcheckField>(
a0: &mut [F],
a1: &mut [F],
a2: &[F],
Expand Down Expand Up @@ -304,7 +310,7 @@ pub fn inner_product_sumcheck_partial<F, T, H>(
mut hook: H,
) -> ProductSumcheck<F>
where
F: Field,
F: SumcheckField,
T: ProverTranscript<F>,
H: FnMut(usize, &mut T),
{
Expand Down Expand Up @@ -364,7 +370,7 @@ pub fn inner_product_sumcheck<F, T, H>(
hook: H,
) -> ProductSumcheck<F>
where
F: Field,
F: SumcheckField,
T: ProverTranscript<F>,
H: FnMut(usize, &mut T),
{
Expand Down
4 changes: 0 additions & 4 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,16 +45,12 @@ pub mod transcript;

// ─── Arkworks-dependent modules ─────────────────────────────────────────────

#[cfg(feature = "arkworks")]
mod inner_product_sumcheck;
#[cfg(feature = "arkworks")]
mod multilinear_sumcheck;

#[cfg(feature = "arkworks")]
pub use inner_product_sumcheck::{
inner_product_sumcheck, inner_product_sumcheck_partial, ProductSumcheck,
};
#[cfg(feature = "arkworks")]
pub use multilinear_sumcheck::{
compute_sumcheck_polynomial, fold, fused_fold_and_compute_polynomial, multilinear_sumcheck,
multilinear_sumcheck_partial, Sumcheck,
Expand Down
37 changes: 24 additions & 13 deletions src/multilinear_sumcheck.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
//! 4 reads + 2 writes per quadruple (fused) vs. 6 reads + 2 writes
//! (fold + compute separately) — a ~33% memory-traffic reduction.

use ark_ff::Field;
use crate::field::SumcheckField;
use alloc::vec::Vec;
#[cfg(feature = "parallel")]
use rayon::join;
#[cfg(feature = "parallel")]
Expand All @@ -27,14 +28,15 @@ use crate::transcript::ProverTranscript;

/// Legacy return type for `multilinear_sumcheck`.
#[derive(Debug)]
pub struct Sumcheck<F: Field> {
pub struct Sumcheck<F: SumcheckField> {
pub prover_messages: Vec<(F, F)>,
pub verifier_messages: Vec<F>,
pub final_evaluation: F,
}

// ─── Workload threshold ─────────────────────────────────────────────────────

#[cfg(feature = "parallel")]
const fn workload_size<T: Sized>() -> usize {
#[cfg(all(target_arch = "aarch64", target_os = "macos"))]
const CACHE_SIZE: usize = 1 << 17;
Expand All @@ -60,15 +62,15 @@ const fn workload_size<T: Sized>() -> usize {

// ─── Scalar helpers ─────────────────────────────────────────────────────────

fn sum_slice<F: Field>(v: &[F]) -> F {
fn sum_slice<F: SumcheckField>(v: &[F]) -> F {
#[cfg(feature = "parallel")]
if v.len() > workload_size::<F>() {
return v.par_iter().copied().sum();
}
v.iter().copied().sum()
}

fn scalar_mul<F: Field>(v: &mut [F], w: F) {
fn scalar_mul<F: SumcheckField>(v: &mut [F], w: F) {
for x in v.iter_mut() {
*x *= w;
}
Expand All @@ -81,8 +83,8 @@ fn scalar_mul<F: Field>(v: &mut [F], w: F) {
/// `values` is implicitly zero-extended to the next power of two.
/// - `s0 = Σ v[0..L/2]` (low half, possibly with tail contributions)
/// - `s1 = Σ v[L/2..L]`
pub fn compute_sumcheck_polynomial<F: Field>(values: &[F]) -> (F, F) {
fn recurse<F: Field>(lo: &[F], hi: &[F]) -> (F, F) {
pub fn compute_sumcheck_polynomial<F: SumcheckField>(values: &[F]) -> (F, F) {
fn recurse<F: SumcheckField>(lo: &[F], hi: &[F]) -> (F, F) {
debug_assert_eq!(lo.len(), hi.len());

#[cfg(feature = "parallel")]
Expand Down Expand Up @@ -127,7 +129,7 @@ pub fn compute_sumcheck_polynomial<F: Field>(values: &[F]) -> (F, F) {
///
/// SIMD-accelerated for Goldilocks base field on NEON and AVX-512 IFMA.
/// Falls back to a scalar recursive `rayon::join` fold for other fields.
pub fn fold<F: Field>(values: &mut Vec<F>, weight: F) {
pub fn fold<F: SumcheckField>(values: &mut Vec<F>, weight: F) {
// SIMD fast path for base-field Goldilocks (MSB layout).
#[cfg(all(
feature = "simd",
Expand All @@ -142,7 +144,7 @@ pub fn fold<F: Field>(values: &mut Vec<F>, weight: F) {
return;
}
}
fn recurse_both<F: Field>(low: &mut [F], high: &[F], weight: F) {
fn recurse_both<F: SumcheckField>(low: &mut [F], high: &[F], weight: F) {
#[cfg(feature = "parallel")]
if low.len() > workload_size::<F>() {
let split = low.len() / 2;
Expand Down Expand Up @@ -176,21 +178,30 @@ pub fn fold<F: Field>(values: &mut Vec<F>, weight: F) {
}

/// Two-pass fold-then-compute. Reference only.
pub fn fold_and_compute_polynomial<F: Field>(values: &mut Vec<F>, weight: F) -> (F, F) {
pub fn fold_and_compute_polynomial<F: SumcheckField>(values: &mut Vec<F>, weight: F) -> (F, F) {
fold(values, weight);
compute_sumcheck_polynomial(values)
}

/// Fused fold + compute: folds `values` by `weight` *and* returns the
/// next-round `(s0, s1)` in one sweep over the quadruple
/// `(v[k], v[k+L/4], v[k+L/2], v[k+3L/4])`.
pub fn fused_fold_and_compute_polynomial<F: Field>(values: &mut Vec<F>, weight: F) -> (F, F) {
pub fn fused_fold_and_compute_polynomial<F: SumcheckField>(
values: &mut Vec<F>,
weight: F,
) -> (F, F) {
let l = values.len();
if !l.is_power_of_two() || l < 4 {
return fold_and_compute_polynomial(values, weight);
}

fn kernel<F: Field>(v0: &mut [F], v1: &mut [F], v2: &[F], v3: &[F], weight: F) -> (F, F) {
fn kernel<F: SumcheckField>(
v0: &mut [F],
v1: &mut [F],
v2: &[F],
v3: &[F],
weight: F,
) -> (F, F) {
debug_assert_eq!(v0.len(), v1.len());
debug_assert_eq!(v0.len(), v2.len());
debug_assert_eq!(v0.len(), v3.len());
Expand Down Expand Up @@ -258,7 +269,7 @@ pub fn multilinear_sumcheck_partial<F, T, H>(
mut hook: H,
) -> Sumcheck<F>
where
F: Field,
F: SumcheckField,
T: ProverTranscript<F>,
H: FnMut(usize, &mut T),
{
Expand Down Expand Up @@ -314,7 +325,7 @@ pub fn multilinear_sumcheck<F, T, H>(
hook: H,
) -> Sumcheck<F>
where
F: Field,
F: SumcheckField,
T: ProverTranscript<F>,
H: FnMut(usize, &mut T),
{
Expand Down
Loading
Loading