Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
186 changes: 163 additions & 23 deletions ceno_zkvm/src/tables/shard_ram.rs
Original file line number Diff line number Diff line change
Expand Up @@ -133,24 +133,29 @@ impl ShardRamRecord {
hasher.permute(input.clone())[0..SEPTIC_EXTENSION_DEGREE].into();
if let Some(p) = SepticPoint::from_x(x) {
let y6 = (p.y.0)[SEPTIC_EXTENSION_DEGREE - 1].to_canonical_u64();
let is_y_in_2nd_half = y6 >= (prime / 2);

// we negate y if needed
// to ensure read => y in [0, p/2) and write => y in [p/2, p)
let negate = match (self.is_to_write_set, is_y_in_2nd_half) {
(true, false) => true, // write, y in [0, p/2)
(false, true) => true, // read, y in [p/2, p)
_ => false,
};
// y6 == 0 is the 2-torsion exception: both signs collapse, so
// the in-circuit y-sign binding cannot distinguish read/write.
// Reject and try a new nonce. Probability per record ~ 1/p ≈ 2^-31.
if y6 != 0 {
let is_y_in_2nd_half = y6 >= (prime / 2);

// negate y to enforce host convention:
// read (is_to_write_set = 0) => y6 in [1, (p-1)/2]
// write (is_to_write_set = 1) => y6 in [(p+1)/2, p-1]
let negate = match (self.is_to_write_set, is_y_in_2nd_half) {
(true, false) => true, // write, y in [1, p/2), flip to upper half
(false, true) => true, // read, y in [p/2, p), flip to lower half
_ => false,
};

let point = if negate { -p } else { p };
let point = if negate { -p } else { p };

return ECPoint { nonce, point };
} else {
// try again with different nonce
nonce += 1;
input[6] = E::BaseField::from_canonical_u32(nonce);
return ECPoint { nonce, point };
}
}
// try again with different nonce
nonce += 1;
input[6] = E::BaseField::from_canonical_u32(nonce);
}
}
}
Expand Down Expand Up @@ -180,6 +185,9 @@ pub struct ShardRamConfig<E: ExtensionField> {
pub(crate) x: Vec<WitIn>,
pub(crate) y: Vec<WitIn>,
pub(crate) slope: Vec<WitIn>,
// 4-byte decomposition of `y6_lo`, the helper used to bind the sign of
// `y[SEPTIC_EXTENSION_DEGREE - 1]` to `is_global_write`. See `configure`.
pub(crate) y6_lo_bytes: [WitIn; 4],
pub(crate) perm_config: Poseidon2Config<E, 16, 7, 1, 4, 13>,
}

Expand Down Expand Up @@ -273,12 +281,50 @@ impl<E: ExtensionField> ShardRamConfig<E> {
cb.require_equal(|| "x = poseidon2's output", xi.expr(), hasher_output)?;
}

// both (x, y) and (x, -y) are valid ec points
// if is_global_write = 1, then y should be in [0, p/2)
// if is_global_write = 0, then y should be in [p/2, p)

// TODO: enforce 0 <= y < p/2 if is_global_write = 1
// enforce p/2 <= y < p if is_global_write = 0
// Both `(x, y)` and `(x, -y)` are valid ec points, so without an extra
// constraint a malicious prover could flip every leaf and still satisfy
// the EC sum against an updated `shard_rw_sum` (issue #1338). Bind the
// sign of `y[6]` to `is_global_write` using the host-side convention:
// read (is_global_write = 0) => y[6] in [1, (p-1)/2]
// write (is_global_write = 1) => y[6] in [(p+1)/2, p-1]
// We witness `y6_lo` ∈ [0, (p-1)/2) via four byte limbs with the top
// byte < 60 (BabyBear: (p-1)/2 = 60 · 2^24 exactly), then:
// read branch: y[6] = y6_lo + 1
// write branch: y[6] = p - 1 - y6_lo (equiv. y[6] + y6_lo + 1 ≡ 0)
// y[6] = 0 is the unique exception, blocked in both branches and
// rejected on the prover side by `ShardRamRecord::to_ec_point`.
debug_assert_eq!(
<E::BaseField as SmallField>::MODULUS_U64,
0x7800_0001,
"shard_ram y-sign constraint hardcodes BabyBear's (p-1)/2 = 60 * 2^24"
);
let y6_lo_bytes: [WitIn; 4] =
std::array::from_fn(|i| cb.create_witin(|| format!("y6_lo_b{i}")));
for (i, w) in y6_lo_bytes.iter().enumerate().take(3) {
cb.assert_byte(|| format!("y6_lo_b{i} byte"), w.expr())?;
}
// `lookup_ltu_byte` asserts (b3, 60) are bytes and that b3 < 60.
cb.lookup_ltu_byte(
y6_lo_bytes[3].expr(),
E::BaseField::from_canonical_u64(60).expr(),
Expression::ONE,
)?;
let y6_lo = y6_lo_bytes[0].expr()
+ y6_lo_bytes[1].expr() * E::BaseField::from_canonical_u64(1 << 8).expr()
+ y6_lo_bytes[2].expr() * E::BaseField::from_canonical_u64(1 << 16).expr()
+ y6_lo_bytes[3].expr() * E::BaseField::from_canonical_u64(1 << 24).expr();
let y6 = y[SEPTIC_EXTENSION_DEGREE - 1].expr();
// condition_require_equal: target = if cond then true_expr else false_expr
cb.condition_require_equal(
|| "y6 binds to is_global_write",
is_global_write.expr(),
y6,
// write: y[6] = p - 1 - y6_lo
E::BaseField::from_canonical_u64(<E::BaseField as SmallField>::MODULUS_U64 - 1).expr()
- y6_lo.clone(),
// read: y[6] = y6_lo + 1
y6_lo + Expression::ONE,
)?;

Ok(ShardRamConfig {
x,
Expand All @@ -292,6 +338,7 @@ impl<E: ExtensionField> ShardRamConfig<E> {
local_clk,
nonce,
is_global_write,
y6_lo_bytes,
perm_config,
})
}
Expand All @@ -315,7 +362,7 @@ impl<E: ExtensionField> ShardRamCircuit<E> {
fn assign_instance(
config: &ShardRamConfig<E>,
instance: &mut [E::BaseField],
_lk_multiplicity: &mut LkMultiplicity,
lk_multiplicity: &mut LkMultiplicity,
input: &ShardRamInput<E>,
) -> Result<(), crate::error::ZKVMError> {
// assign basic fields
Expand Down Expand Up @@ -350,6 +397,27 @@ impl<E: ExtensionField> ShardRamCircuit<E> {
instance[witin.id as usize] = *fe;
});

// Bind y[6]'s sign to is_global_write via `y6_lo` byte decomposition.
// See ShardRamConfig::configure. `to_ec_point` guarantees y6 != 0 and
// the matching half-of-field convention, so the subtraction below
// never underflows.
let prime = <E::BaseField as SmallField>::MODULUS_U64;
let y6_u64 = point.y.0[SEPTIC_EXTENSION_DEGREE - 1].to_canonical_u64();
let y6_lo_u64 = if record.is_to_write_set {
prime - 1 - y6_u64
} else {
y6_u64 - 1
};
for i in 0..4 {
let b = (y6_lo_u64 >> (8 * i)) & 0xff;
set_val!(instance, config.y6_lo_bytes[i], b);
}
for i in 0..3 {
let b = (y6_lo_u64 >> (8 * i)) & 0xff;
lk_multiplicity.assert_const_range(b, 8);
}
lk_multiplicity.lookup_ltu_byte((y6_lo_u64 >> 24) & 0xff, 60);

let ram_type = E::BaseField::from_canonical_u32(record.ram_type as u32);
let mut input = [E::BaseField::ZERO; 16];

Expand Down Expand Up @@ -725,7 +793,7 @@ impl<E: ExtensionField> ShardRamCircuit<E> {
#[cfg(test)]
mod tests {
use either::Either;
use ff_ext::{BabyBearExt4, FromUniformBytes, PoseidonField};
use ff_ext::{BabyBearExt4, FromUniformBytes, PoseidonField, SmallField};
use itertools::Itertools;
use mpcs::{BasefoldDefault, PolynomialCommitmentScheme, SecurityLevel};
use p3::babybear::BabyBear;
Expand Down Expand Up @@ -938,4 +1006,76 @@ mod tests {
.create_chip_proof(&mut task, &mut transcript)
.unwrap();
}

/// Regression test for issue #1338.
///
/// The host convention selected by `to_ec_point` is the one the in-circuit
/// y-sign constraint must enforce. We check:
/// 1. Honest read/write points land in the half of the field expected by
/// the circuit (`y6 != 0`, lower half for reads, upper half for writes).
/// 2. The "negate every leaf" attack from issue #1338, when fed through the
/// same `assign_instance` formula, produces `y6_lo` values whose top
/// byte is `>= 60`. The circuit's `lookup_ltu_byte(b3, 60, 1)` is
/// therefore unsatisfiable on the attacker's witness.
#[test]
fn test_shard_ram_y_sign_attack_breaks_byte_bound() {
let perm = <F as PoseidonField>::get_default_perm();
let prime = <F as SmallField>::MODULUS_U64;
let half_p_minus_1 = (prime - 1) / 2; // 60 * 2^24 for BabyBear

let write_record = ShardRamRecord {
addr: 0xdeadbeef,
ram_type: RAMType::Memory,
value: 0x1234_5678,
shard: 1,
local_clk: 7,
global_clk: 13,
is_to_write_set: true,
};
let read_record = ShardRamRecord {
addr: 0xfeedface,
ram_type: RAMType::Memory,
value: 0xcafe_babe,
shard: 2,
local_clk: 0,
global_clk: 19,
is_to_write_set: false,
};

for record in &[write_record, read_record] {
let point = record.to_ec_point::<E, Perm>(&perm).point;
let y6 = point.y.0[SEPTIC_EXTENSION_DEGREE - 1].to_canonical_u64();
// (1) Honest convention.
assert_ne!(y6, 0, "to_ec_point must reject the 2-torsion case y6=0");
if record.is_to_write_set {
assert!(
y6 > half_p_minus_1 && y6 < prime,
"honest write must have y6 in upper half, got {y6:#x}"
);
} else {
assert!(
(1..=half_p_minus_1).contains(&y6),
"honest read must have y6 in lower half, got {y6:#x}"
);
}

// (2) Attacker negates every leaf. Replay `assign_instance`'s
// y6_lo formula with the negated y6; the resulting top byte must
// land outside [0, 60), so `lookup_ltu_byte(b3, 60, 1)` fails.
let negated_y6 = prime - y6;
let attacked_y6_lo = if record.is_to_write_set {
prime - 1 - negated_y6
} else {
negated_y6 - 1
};
let b3 = (attacked_y6_lo >> 24) & 0xff;
assert!(
b3 >= 60,
"attack went undetected: is_to_write_set={}, negated y6={:#x}, attacked y6_lo={:#x} ⇒ b3={b3}",
record.is_to_write_set,
negated_y6,
attacked_y6_lo,
);
}
}
}
Loading