Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion ceno_recursion/src/zkvm_verifier/verifier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -307,9 +307,21 @@ pub fn verify_zkvm_proof<C: Config<F = F>>(
let num_lks: Var<C::N> =
builder.eval(C::N::from_canonical_usize(chip_vk.get_cs().num_lks()));

// Chips with EC-sum ops carry an extra hypercube variable; the
// prover fills it with EC-tree internal nodes that are inactive
// via `selector_zero = 0` and thus collapse to dummy lookup
// queries. Mirror the native verifier's adjustment here so the
// dummy multiplicity matches the prover.
let ecc_row_factor: usize = if circuit_vk.get_cs().has_ecc_ops() {
2
} else {
1
};
// each padding instance contribute to (2^rotation_vars) dummy lookup padding
let next_pow2_instance: Var<C::N> =
let next_pow2_chip_rows: Var<C::N> =
pow_2(builder, chip_proof.log2_num_instances.get_var());
let next_pow2_instance: Var<C::N> =
builder.eval(next_pow2_chip_rows * C::N::from_canonical_usize(ecc_row_factor));
let num_padded_instance: Var<C::N> =
builder.eval(next_pow2_instance - chip_proof.sum_num_instances.clone());
let rotation_var: Var<C::N> = builder.constant(C::N::from_canonical_usize(
Expand Down
57 changes: 41 additions & 16 deletions ceno_zkvm/src/e2e.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1497,6 +1497,30 @@ pub fn generate_witness<'a, E: ExtensionField>(
&mut zkvm_witness,
)
}).unwrap();

// Assign continuation circuits (LocalFinal + ShardRam) before
// `finalize_lk_multiplicities`: ShardRam's per-row y6_lo byte /
// LTU lookups must land in `combined_lk_mlt` so the U8 / LTU
// table `mlt` columns balance the logup grand product. LocalFinal
// does not consume `combined_lk_mlt`, so running it pre-finalize
// is safe — `assign_table_circuit` tolerates a not-yet-finalized
// multiplicity by passing an empty slice.
info_span!("assign_continuation").in_scope(|| {
system_config
.mmu_config
.assign_continuation_circuit(
&system_config.zkvm_cs,
&shard_ctx,
&mut zkvm_witness,
&pi,
&emul_result.final_mem_state.reg,
&emul_result.final_mem_state.mem,
&emul_result.final_mem_state.hints,
&emul_result.final_mem_state.stack,
&emul_result.final_mem_state.heap,
)
}).unwrap();

info_span!("finalize_lk_multiplicities").in_scope(|| {
zkvm_witness.finalize_lk_multiplicities();
});
Expand Down Expand Up @@ -1535,6 +1559,23 @@ pub fn generate_witness<'a, E: ExtensionField>(
&mut cpu_witness,
)
.unwrap();
// Mirror the main path so `combined_lk_mlt` comparison stays
// meaningful: continuation pushes ShardRamCircuit's per-row
// y6_lo lookups into `lk_mlts` before finalize.
system_config
.mmu_config
.assign_continuation_circuit(
&system_config.zkvm_cs,
&cpu_shard_ctx,
&mut cpu_witness,
&pi,
&emul_result.final_mem_state.reg,
&emul_result.final_mem_state.mem,
&emul_result.final_mem_state.hints,
&emul_result.final_mem_state.stack,
&emul_result.final_mem_state.heap,
)
.unwrap();
cpu_witness.finalize_lk_multiplicities();

#[cfg(feature = "gpu")]
Expand Down Expand Up @@ -1626,22 +1667,6 @@ pub fn generate_witness<'a, E: ExtensionField>(
)
}).unwrap();

info_span!("assign_continuation").in_scope(|| {
system_config
.mmu_config
.assign_continuation_circuit(
&system_config.zkvm_cs,
&shard_ctx,
&mut zkvm_witness,
&pi,
&emul_result.final_mem_state.reg,
&emul_result.final_mem_state.mem,
&emul_result.final_mem_state.hints,
&emul_result.final_mem_state.stack,
&emul_result.final_mem_state.heap,
)
}).unwrap();

info_span!("assign_program_table").in_scope(|| {
zkvm_witness
.assign_table_circuit::<ProgramTableCircuit<E>>(
Expand Down
86 changes: 81 additions & 5 deletions ceno_zkvm/src/instructions/gpu/chips/shard_ram.rs
Original file line number Diff line number Diff line change
Expand Up @@ -198,13 +198,18 @@ pub fn gpu_batch_continuation_ec_on_device(
}

/// Try to run ShardRamCircuit assign_instances on GPU.
/// Returns `Ok(None)` if GPU is unavailable or disabled.
/// Returns `Ok(None)` if GPU is unavailable or disabled. On success the
/// y6_lo byte / LTU lookup multiplicity is derived from `steps` and pushed
/// into `lk_multiplicity` so the caller sees the same per-row contribution
/// the CPU `assign_instance` path would have made.
pub(crate) fn try_gpu_assign_shard_ram<E: ExtensionField>(
config: &ShardRamConfig<E>,
num_witin: usize,
num_structural_witin: usize,
lk_multiplicity: &mut crate::witness::LkMultiplicity,
steps: &[crate::tables::ShardRamInput<E>],
) -> Result<Option<crate::tables::RMMCollections<E::BaseField>>, ZKVMError> {
use crate::scheme::constants::SEPTIC_EXTENSION_DEGREE;
use ceno_gpu::{
Buffer, CudaHal,
bb31::CudaHalBB31,
Expand Down Expand Up @@ -496,6 +501,23 @@ pub(crate) fn try_gpu_assign_shard_ram<E: ExtensionField>(
);
}

// The GPU witness kernel above writes the row data but does not run
// the per-row `assign_instance` CPU path that pushes the y6_lo byte /
// LTU lookup multiplicity. Derive the same contribution from `steps`
// here so the caller's `lk_multiplicity` mirrors the CPU branch and
// `combined_lk_mlt` balances the U8 / LTU table `mlt` columns. Source
// of truth for the queries is `ShardRamConfig::configure`.
for step in steps {
let y6_lo = crate::tables::y6_lo_value::<E>(
step.ec_point.point.y.0[SEPTIC_EXTENSION_DEGREE - 1],
step.record.is_to_write_set,
);
for i in 0..3 {
lk_multiplicity.assert_const_range((y6_lo >> (8 * i)) & 0xff, 8);
}
lk_multiplicity.lookup_ltu_byte((y6_lo >> 24) & 0xff, 60);
}

Ok(Some([raw_witin, raw_structural_witin]))
}

Expand Down Expand Up @@ -749,7 +771,10 @@ pub(crate) fn try_gpu_assign_shard_ram_from_device<E: ExtensionField>(
}

/// Full GPU pipeline for assign_shared_circuit: device-resident EC merge + partition + assign.
/// Returns `Ok(None)` if GPU is unavailable, `Ok(Some(inputs))` on success.
/// Returns `Ok(None)` if GPU is unavailable, `Ok(Some((inputs, lk_mlt)))` on
/// success — `lk_mlt` carries the y6_lo byte / LTU lookup multiplicity that
/// `ShardRamConfig::configure` consumes (mirrors the per-row CPU push in
/// `ShardRamCircuit::assign_instance`).
#[allow(clippy::type_complexity)]
pub(crate) fn try_gpu_assign_shared_circuit<E: ExtensionField>(
shard_ctx: &crate::e2e::ShardContext,
Expand All @@ -762,16 +787,24 @@ pub(crate) fn try_gpu_assign_shared_circuit<E: ExtensionField>(
num_witin: usize,
num_structural_witin: usize,
max_chunk: usize,
) -> Result<Option<Vec<crate::structs::ChipInput<E>>>, ZKVMError> {
) -> Result<
Option<(
Vec<crate::structs::ChipInput<E>>,
gkr_iop::utils::lk_multiplicity::Multiplicity<u64>,
)>,
ZKVMError,
> {
use crate::{
instructions::gpu::{
chips::shard_ram::gpu_batch_continuation_ec_on_device,
dispatch::take_shared_device_buffers,
},
structs::{ChipInput, ZKVMWitnesses},
tables::{ShardRamCircuit, ShardRamRecord, TableCircuit},
witness::LkMultiplicity,
};
use ceno_gpu::Buffer;
use ff_ext::SmallField;
use gkr_iop::gpu::get_cuda_hal;
use rayon::prelude::*;
use tracing::info_span;
Expand Down Expand Up @@ -943,8 +976,51 @@ pub(crate) fn try_gpu_assign_shared_circuit<E: ExtensionField>(
total_records - num_writes,
);

// 7. GPU assign_instances from device buffer (chunked by max_cross_shard)
let record_u32s = std::mem::size_of::<ceno_gpu::common::witgen::types::GpuShardRamRecord>() / 4;
// GpuShardRamRecord (#[repr(C)]) layout — derived from shard_ram_record_to_gpu
// above: 4xu32 leader (addr, ram_type, value, _pad0), 3xu64
// (shard, local_clk, global_clk), 2xu32 (is_to_write_set, nonce),
// [u32; 7] point_x, [u32; 7] point_y. Total = 26 u32s.
debug_assert_eq!(record_u32s, 26, "GpuShardRamRecord layout changed");
const IS_TO_WRITE_SET_U32_OFFSET: usize = 10;
const POINT_Y6_U32_OFFSET: usize = 25;

// 6.5. Derive ShardRam's per-row y6_lo byte / LTU lookup multiplicity
// from the partitioned device buffer. Mirrors the per-row CPU push in
// `ShardRamCircuit::assign_instance`; the constraint these queries serve
// lives in `ShardRamConfig::configure` (y6_lo bytes + lookup_ltu_byte).
let lk_mlt = info_span!("gpu_shard_ram_derive_lk_mlt", n = total_records).in_scope(
|| -> Result<gkr_iop::utils::lk_multiplicity::Multiplicity<u64>, ZKVMError> {
if total_records == 0 {
return Ok(gkr_iop::utils::lk_multiplicity::Multiplicity::default());
}
let host_data: Vec<u32> = partitioned_buf.to_vec().map_err(|e| {
ZKVMError::InvalidWitness(
format!("[GPU full pipeline] partitioned_buf D2H: {e}").into(),
)
})?;
debug_assert_eq!(host_data.len(), total_records * record_u32s);
let prime = <E::BaseField as SmallField>::MODULUS_U64;
let lk_multiplicity = LkMultiplicity::default();
host_data.par_chunks_exact(record_u32s).for_each(|rec| {
let mut local = lk_multiplicity.clone();
let is_to_write_set = rec[IS_TO_WRITE_SET_U32_OFFSET] != 0;
let y6 = rec[POINT_Y6_U32_OFFSET] as u64;
let y6_lo = if is_to_write_set {
prime - 1 - y6
} else {
y6 - 1
};
for i in 0..3 {
local.assert_const_range((y6_lo >> (8 * i)) & 0xff, 8);
}
local.lookup_ltu_byte((y6_lo >> 24) & 0xff, 60);
});
Ok(lk_multiplicity.into_finalize_result())
},
)?;

// 7. GPU assign_instances from device buffer (chunked by max_cross_shard)

let circuit_inputs =
info_span!("shard_ram_assign_from_device", n = total_records).in_scope(|| {
Expand Down Expand Up @@ -993,7 +1069,7 @@ pub(crate) fn try_gpu_assign_shared_circuit<E: ExtensionField>(
total_records,
);

Ok(Some(circuit_inputs))
Ok(Some((circuit_inputs, lk_mlt)))
}

#[cfg(test)]
Expand Down
10 changes: 10 additions & 0 deletions ceno_zkvm/src/instructions/riscv/rv32im/mmu.rs
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,16 @@ impl<E: ExtensionField> MmuConfig<E> {
Ok(())
}

/// Assign LocalFinalCircuit and ShardRamCircuit witnesses. Must run
/// *before* `ZKVMWitnesses::finalize_lk_multiplicities`:
/// - `ShardRamCircuit` accumulates its per-row y6_lo byte / LTU lookups
/// into `lk_mlts` via `assign_shared_circuit` (which threads a shared
/// `LkMultiplicity` through `assign_instances_with_lk_multiplicities`),
/// so they land in `combined_lk_mlt` and balance the U8 / LTU table
/// `mlt` columns.
/// - `LocalFinalCircuit` does not consume `combined_lk_mlt`; the regular
/// `assign_table_circuit` entry tolerates a not-yet-finalized
/// multiplicity by passing an empty slice.
#[allow(clippy::too_many_arguments)]
pub fn assign_continuation_circuit(
&self,
Expand Down
14 changes: 13 additions & 1 deletion ceno_zkvm/src/scheme/verifier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -698,8 +698,20 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>>
// compute logup_sum padding
// getting the number of dummy padding item that we used in this opcode circuit
let num_lks = circuit_vk.get_cs().num_lks();
// Chips with EC-sum ops carry an extra hypercube variable (one extra
// log2 row dimension) that the prover fills with EC-tree internal
// nodes; those rows are not "active" instances and their lookup
// queries collapse to the dummy_table_item via `selector_zero = 0`.
// Mirror that here so the verifier subtracts the right number of
// dummy queries.
let ecc_row_factor = if circuit_vk.get_cs().has_ecc_ops() {
2
} else {
1
};
let padded_rows = next_pow2_instance_padding(num_instance) * ecc_row_factor;
// each padding instance contribute to (2^rotation_vars) dummy lookup padding
let num_padded_instance = (next_pow2_instance_padding(num_instance) - num_instance)
let num_padded_instance = (padded_rows - num_instance)
* (1 << circuit_vk.get_cs().rotation_vars().unwrap_or(0));
// each instance contribute to (2^rotation_vars - rotated) dummy lookup padding
let num_instance_non_selected = num_instance
Expand Down
42 changes: 35 additions & 7 deletions ceno_zkvm/src/structs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use crate::{
ECPoint, MemFinalRecord, RMMCollections, ShardRamCircuit, ShardRamInput, ShardRamRecord,
TableCircuit,
},
witness::LkMultiplicity,
};
use ceno_emul::{Addr, CENO_PLATFORM, Platform, RegIdx, StepIndex, StepRecord, WordAddr};
use ff_ext::{ExtensionField, PoseidonField};
Expand Down Expand Up @@ -471,13 +472,15 @@ impl<E: ExtensionField> ZKVMWitnesses<E> {
config: &TC::TableConfig,
input: &TC::WitnessInput<'_>,
) -> Result<(), ZKVMError> {
assert!(self.combined_lk_mlt.is_some());
let cs = cs.get_cs(&TC::name()).unwrap();
let empty_mlt: Vec<FxHashMap<u64, usize>> = Vec::new();
// Scope the immutable borrow of `self.combined_lk_mlt` so the
// `self.witnesses.insert` mutable borrow below is legal.
let witness = TC::assign_instances(
config,
cs.zkvm_v1_css.num_witin as usize,
cs.zkvm_v1_css.num_structural_witin as usize,
self.combined_lk_mlt.as_ref().unwrap(),
self.combined_lk_mlt.as_ref().unwrap_or(&empty_mlt),
input,
)?;
let witness_instances = witness[0].num_instances();
Expand Down Expand Up @@ -645,19 +648,25 @@ impl<E: ExtensionField> ZKVMWitnesses<E> {
}
}

assert!(self.combined_lk_mlt.is_some());
assert!(self.combined_lk_mlt.is_none());
let cs = cs.get_cs(&ShardRamCircuit::<E>::name()).unwrap();
let n_global = global_input.len();
// `ShardRamCircuit::assign_instances` ignores the `multiplicity`
// argument (its lookup contribution is derived externally above), so
// an empty slice is sufficient here and matches the pre-finalize
// ordering: `combined_lk_mlt` is intentionally `None` at this point.
let lk_multiplicity = LkMultiplicity::default();
let circuit_inputs =
info_span!("shard_ram_assign_instances", n = n_global).in_scope(|| {
global_input
.par_chunks(shard_ctx.max_num_cross_shard_accesses)
.map(|shard_accesses| {
let witness = ShardRamCircuit::assign_instances(
let mut lk_multiplicity = lk_multiplicity.clone();
let witness = ShardRamCircuit::assign_instances_with_lk_multiplicities(
config,
cs.zkvm_v1_css.num_witin as usize,
cs.zkvm_v1_css.num_structural_witin as usize,
self.combined_lk_mlt.as_ref().unwrap(),
&mut lk_multiplicity,
shard_accesses,
)?;
let num_reads = shard_accesses
Expand All @@ -674,6 +683,15 @@ impl<E: ExtensionField> ZKVMWitnesses<E> {
})
.collect::<Result<Vec<_>, ZKVMError>>()
})?;

assert!(
self.lk_mlts
.insert(
ShardRamCircuit::<E>::name(),
lk_multiplicity.into_finalize_result()
)
.is_none()
);
// set num_read, num_write as separate instance
assert!(
self.witnesses
Expand All @@ -687,6 +705,11 @@ impl<E: ExtensionField> ZKVMWitnesses<E> {
/// Full GPU pipeline for assign_shared_circuit: keep data on device, minimal CPU roundtrips.
///
/// Returns Ok(true) if successful, Ok(false) if unavailable (no shared device buffers).
/// On success, inserts both `ChipInput` and `ShardRamCircuit`'s derived
/// lookup multiplicity (for the y6_lo byte / LTU queries) into
/// `self.witnesses` / `self.lk_mlts` so the subsequent
/// `finalize_lk_multiplicities` folds the contribution into
/// `combined_lk_mlt` — matching the CPU shortcut's invariant.
#[cfg(feature = "gpu")]
fn try_assign_shared_circuit_gpu(
&mut self,
Expand All @@ -695,7 +718,7 @@ impl<E: ExtensionField> ZKVMWitnesses<E> {
final_mem: &[(&'static str, Option<Range<Addr>>, &[MemFinalRecord])],
config: &<ShardRamCircuit<E> as TableCircuit<E>>::TableConfig,
) -> Result<bool, ZKVMError> {
assert!(self.combined_lk_mlt.is_some());
assert!(self.combined_lk_mlt.is_none());
let cs_inner = cs.get_cs(&ShardRamCircuit::<E>::name()).unwrap();
let num_witin = cs_inner.zkvm_v1_css.num_witin as usize;
let num_structural_witin = cs_inner.zkvm_v1_css.num_structural_witin as usize;
Expand All @@ -708,12 +731,17 @@ impl<E: ExtensionField> ZKVMWitnesses<E> {
num_structural_witin,
shard_ctx.max_num_cross_shard_accesses,
)? {
Some(circuit_inputs) => {
Some((circuit_inputs, lk_mlt)) => {
assert!(
self.witnesses
.insert(ShardRamCircuit::<E>::name(), circuit_inputs)
.is_none()
);
assert!(
self.lk_mlts
.insert(ShardRamCircuit::<E>::name(), lk_mlt)
.is_none()
);
Ok(true)
}
None => Ok(false),
Expand Down
Loading
Loading