Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 6 additions & 9 deletions ceno_zkvm/src/scheme/cpu/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -558,7 +558,7 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> TraceCommitter<CpuBa
&self,
traces: BTreeMap<usize, witness::RowMajorMatrix<E::BaseField>>,
) -> (
Vec<MultilinearExtension<'a, E>>,
Vec<ArcMultilinearExtension<'a, E>>,
PCS::CommitmentWithWitness,
PCS::Commitment,
) {
Expand All @@ -576,22 +576,19 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> TraceCommitter<CpuBa
let prover_param = &self.backend.pp;
let pcs_data = PCS::batch_commit(prover_param, traces.into_values().collect_vec()).unwrap();
let commit = PCS::get_pure_commitment(&pcs_data);
let mles = PCS::get_arc_mle_witness_from_commitment(&pcs_data)
.into_par_iter()
.map(|mle| mle.as_ref().clone())
.collect::<Vec<_>>();
let mles = PCS::get_arc_mle_witness_from_commitment(&pcs_data);

(mles, pcs_data, commit)
}

fn extract_witness_mles<'a, 'b>(
&self,
witness_mles: &'b mut Vec<<CpuBackend<E, PCS> as ProverBackend>::MultilinearPoly<'a>>,
witness_mles: &'b mut Vec<Arc<<CpuBackend<E, PCS> as ProverBackend>::MultilinearPoly<'a>>>,
_pcs_data: &'b <CpuBackend<E, PCS> as ProverBackend>::PcsData,
) -> Box<
dyn Iterator<Item = Arc<<CpuBackend<E, PCS> as ProverBackend>::MultilinearPoly<'a>>> + 'b,
> {
let iter = witness_mles.drain(..).map(Arc::new);
let iter = witness_mles.drain(..);
Box::new(iter)
}
}
Expand Down Expand Up @@ -1524,9 +1521,9 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> DeviceTransporter<Cp

fn transport_mles<'a>(
&self,
mles: &[MultilinearExtension<'a, E>],
mles: Vec<MultilinearExtension<'a, E>>,
) -> Vec<ArcMultilinearExtension<'a, E>> {
mles.iter().map(|mle| mle.clone().into()).collect_vec()
mles.into_iter().map(Arc::new).collect_vec()
}
}

Expand Down
10 changes: 5 additions & 5 deletions ceno_zkvm/src/scheme/gpu/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1670,7 +1670,7 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E> + 'static>
&self,
traces: BTreeMap<usize, witness::RowMajorMatrix<E::BaseField>>,
) -> (
Vec<<GpuBackend<E, PCS> as ProverBackend>::MultilinearPoly<'a>>,
Vec<Arc<<GpuBackend<E, PCS> as ProverBackend>::MultilinearPoly<'a>>>,
<GpuBackend<E, PCS> as ProverBackend>::PcsData,
PCS::Commitment,
) {
Expand Down Expand Up @@ -1811,7 +1811,7 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E> + 'static>

fn extract_witness_mles<'a, 'b>(
&self,
_witness_mles: &'b mut Vec<<GpuBackend<E, PCS> as ProverBackend>::MultilinearPoly<'a>>,
_witness_mles: &'b mut Vec<Arc<<GpuBackend<E, PCS> as ProverBackend>::MultilinearPoly<'a>>>,
pcs_data: &'b <GpuBackend<E, PCS> as ProverBackend>::PcsData,
) -> Box<
dyn Iterator<Item = Arc<<GpuBackend<E, PCS> as ProverBackend>::MultilinearPoly<'a>>> + 'b,
Expand Down Expand Up @@ -3625,11 +3625,11 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E> + 'static>

fn transport_mles<'a>(
&self,
mles: &[MultilinearExtension<'a, E>],
mles: Vec<MultilinearExtension<'a, E>>,
) -> Vec<Arc<<GpuBackend<E, PCS> as ProverBackend>::MultilinearPoly<'a>>> {
let cuda_hal = get_cuda_hal().unwrap();
mles.iter()
.map(|mle| Arc::new(MultilinearExtensionGpu::from_ceno(&cuda_hal, mle)))
mles.into_iter()
.map(|mle| Arc::new(MultilinearExtensionGpu::from_ceno(&cuda_hal, &mle)))
.collect_vec()
}
}
Expand Down
6 changes: 3 additions & 3 deletions ceno_zkvm/src/scheme/hal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -117,15 +117,15 @@ pub trait TraceCommitter<PB: ProverBackend> {
&self,
traces: BTreeMap<usize, witness::RowMajorMatrix<<PB::E as ExtensionField>::BaseField>>,
) -> (
Vec<PB::MultilinearPoly<'a>>,
Vec<Arc<PB::MultilinearPoly<'a>>>,
PB::PcsData,
<PB::Pcs as PolynomialCommitmentScheme<PB::E>>::Commitment,
);

/// Return an iterator over witness polynomials so backends can decide how to source them
fn extract_witness_mles<'a, 'b>(
&self,
witness_mles: &'b mut Vec<PB::MultilinearPoly<'a>>,
witness_mles: &'b mut Vec<Arc<PB::MultilinearPoly<'a>>>,
pcs_data: &'b PB::PcsData, // used by GPU backend
) -> Box<dyn Iterator<Item = Arc<PB::MultilinearPoly<'a>>> + 'b>;
}
Expand Down Expand Up @@ -280,7 +280,7 @@ pub trait DeviceTransporter<PB: ProverBackend> {

fn transport_mles<'a>(
&self,
mles: &[MultilinearExtension<'a, PB::E>],
mles: Vec<MultilinearExtension<'a, PB::E>>,
) -> Vec<Arc<PB::MultilinearPoly<'a>>>;
}

Expand Down
9 changes: 5 additions & 4 deletions ceno_zkvm/src/scheme/prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -476,7 +476,7 @@ impl<
// commit to witness traces in batch
#[cfg_attr(not(feature = "gpu"), allow(unused_mut))]
let (witness_mles, witness_data, witin_commit): (
Vec<PB::MultilinearPoly<'_>>,
Vec<Arc<PB::MultilinearPoly<'_>>>,
PB::PcsData,
PCS::Commitment,
) = {
Expand All @@ -490,7 +490,8 @@ impl<
gpu_device,
gpu_witness_traces,
);
let witness_mles = unsafe { std::mem::transmute(gpu_witness_mles) };
drop(gpu_witness_mles);
let witness_mles = Vec::new();
let witness_data = unsafe { std::mem::transmute_copy(&gpu_witness_data) };
std::mem::forget(gpu_witness_data);
(witness_mles, witness_data, witin_commit)
Expand Down Expand Up @@ -894,7 +895,7 @@ impl<
name_and_instances: Vec<(String, [usize; 2])>,
structural_rmms: Vec<witness::RowMajorMatrix<E::BaseField>>,
#[cfg(feature = "gpu")] witness_trace_rows: Vec<Option<usize>>,
#[allow(unused_mut)] mut witness_mles: Vec<PB::MultilinearPoly<'data>>,
#[allow(unused_mut)] mut witness_mles: Vec<Arc<PB::MultilinearPoly<'data>>>,
witness_data: &PB::PcsData,
mut fixed_mles: Vec<Arc<PB::MultilinearPoly<'data>>>,
challenges: [E; 2],
Expand Down Expand Up @@ -967,7 +968,7 @@ impl<
let structural_witness = info_span!("[ceno] transport_structural_witness")
.in_scope(|| {
let structural_mles = structural_rmm.to_mles();
self.device.transport_mles(&structural_mles)
self.device.transport_mles(structural_mles)
});
(witness_mle, structural_witness, None)
};
Expand Down
Loading