Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@ namespace riscv {
inline constexpr size_t RV64_REGISTER_NUM_LIMBS = 8;
inline constexpr size_t RV64_WORD_NUM_LIMBS = 4;
inline constexpr size_t RV64_CELL_BITS = 8;
inline constexpr size_t RV64_U16_LIMB_BITS = 2 * RV64_CELL_BITS;
inline constexpr size_t RV64_PTR_U16_LIMBS = RV64_WORD_NUM_LIMBS / 2;
inline constexpr size_t RV64_PTR_BITS = RV64_U16_LIMB_BITS * RV64_PTR_U16_LIMBS;
} // namespace riscv

namespace program {
Expand Down
26 changes: 25 additions & 1 deletion crates/circuits/primitives/cuda/include/primitives/utils.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#pragma once

#include "primitives/constants.h"

#include <cstdint>
#include <cuda_runtime.h>

Expand All @@ -23,6 +25,28 @@ __device__ __forceinline__ uint32_t u32_from_bytes_le(const uint8_t *b) {
return (uint32_t)b[0] | ((uint32_t)b[1] << 8) | ((uint32_t)b[2] << 16) | ((uint32_t)b[3] << 24);
}

template <typename T, size_t NUM_LIMBS>
__device__ __forceinline__ void u32_to_le_u16_limbs(T (&out)[NUM_LIMBS], uint32_t value) {
static_assert(NUM_LIMBS == 2, "u32_to_le_u16_limbs expects two u16 cells");
out[0] = T(uint16_t(value));
out[1] = T(uint16_t(value >> 16));
}

template <typename T, size_t NUM_LIMBS>
__device__ __forceinline__ void bytes_to_le_u16_limbs(T (&out)[NUM_LIMBS], const uint8_t *bytes) {
#pragma unroll
for (size_t i = 0; i < NUM_LIMBS; i++) {
out[i] = T(u16_from_bytes_le(bytes + 2 * i));
}
}

__device__ __host__ __forceinline__ uint32_t scale_rv64_ptr_high_u16(
uint16_t high_u16,
uint32_t ptr_max_bits
) {
return uint32_t(high_u16) << (riscv::RV64_PTR_BITS - ptr_max_bits);
}

// Convert 4 bytes to a u32 in big endian order
// **SAFETY**: b has to be at least 4 bytes long
__device__ __forceinline__ uint32_t u32_from_bytes_be(const uint8_t *b) {
Expand Down Expand Up @@ -53,4 +77,4 @@ __device__ __host__ __forceinline__ uint64_t rotl64(uint64_t x, uint32_t n) {

__device__ __host__ __forceinline__ uint32_t rotr(uint32_t value, int n) {
return (value >> n) | (value << (32 - n));
}
}
32 changes: 16 additions & 16 deletions crates/circuits/sha2-air/src/air.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ use std::{iter::once, marker::PhantomData};

use ndarray::s;
use openvm_circuit_primitives::{
bitwise_op_lookup::BitwiseOperationLookupBus, encoder::Encoder, utils::select, ColumnsAir,
SubAir,
bitwise_op_lookup::BitwiseOperationLookupBus, encoder::Encoder, utils::select,
var_range::VariableRangeCheckerBus, ColumnsAir, SubAir,
};
use openvm_stark_backend::{
interaction::{BusIndex, InteractionBuilder, PermutationCheckBus},
Expand All @@ -25,6 +25,8 @@ use crate::{
#[derive(Clone, Debug)]
pub struct Sha2BlockHasherSubAir<C: Sha2BlockHasherSubairConfig> {
pub bitwise_lookup_bus: BitwiseOperationLookupBus,
/// Range checker bus for digest-row `final_hash` limbs.
pub range_bus: VariableRangeCheckerBus,
pub row_idx_encoder: Encoder,
/// Internal bus for self-interactions in this AIR.
pub private_bus: PermutationCheckBus,
Expand All @@ -36,9 +38,14 @@ pub struct Sha2BlockHasherSubAir<C: Sha2BlockHasherSubairConfig> {
impl<C: Sha2BlockHasherSubairConfig> ColumnsAir for Sha2BlockHasherSubAir<C> {}

impl<C: Sha2BlockHasherSubairConfig> Sha2BlockHasherSubAir<C> {
pub fn new(bitwise_lookup_bus: BitwiseOperationLookupBus, private_bus_idx: BusIndex) -> Self {
pub fn new(
bitwise_lookup_bus: BitwiseOperationLookupBus,
range_bus: VariableRangeCheckerBus,
private_bus_idx: BusIndex,
) -> Self {
Self {
bitwise_lookup_bus,
range_bus,
row_idx_encoder: Encoder::new(C::ROWS_PER_BLOCK + 1, 2, false), /* + 1 for dummy
* (padding) rows */
private_bus: PermutationCheckBus::new(private_bus_idx),
Expand Down Expand Up @@ -149,7 +156,7 @@ impl<C: Sha2BlockHasherSubairConfig> Sha2BlockHasherSubAir<C> {
) {
// Assert that the previous hash + work vars == final hash.
// That is, `next.prev_hash[i] + local.work_vars[i] == next.final_hash[i]`
// where addition is done modulo 2^32
// where addition is done modulo 2^32.
for i in 0..C::HASH_WORDS {
let mut carry = AB::Expr::ZERO;
for j in 0..C::WORD_U16S {
Expand All @@ -174,25 +181,18 @@ impl<C: Sha2BlockHasherSubairConfig> Sha2BlockHasherSubAir<C> {
1,
)
};
let final_hash_limb = compose::<AB::Expr>(
next.final_hash
.slice(s![i, j * 2..(j + 1) * 2])
.as_slice()
.unwrap(),
8,
);
let final_hash_limb: AB::Expr = next.final_hash[[i, j]].into();

carry = AB::Expr::from(AB::F::from_u32(1 << 16).inverse())
* (next.prev_hash[[i, j]] + work_var_limb + carry - final_hash_limb);
builder
.when(*next.flags.is_digest_row)
.assert_bool(carry.clone());
}
// constrain the final hash limbs two at a time since we can do two checks per
// interaction
for chunk in next.final_hash.row(i).as_slice().unwrap().chunks(2) {
self.bitwise_lookup_bus
.send_range(chunk[0], chunk[1])
// Range-check final-hash limbs for the digest-row addition.
for j in 0..C::WORD_U16S {
self.range_bus
.range_check(next.final_hash[[i, j]], C::WORD_U16_BITS)
.eval(builder, *next.flags.is_digest_row);
}
}
Expand Down
4 changes: 2 additions & 2 deletions crates/circuits/sha2-air/src/columns.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,9 @@ pub struct Sha2DigestCols<
pub hash: Sha2WorkVarsCols<T, WORD_BITS, ROUNDS_PER_ROW, WORD_U16S>,
pub schedule_helper:
Sha2MessageHelperCols<T, WORD_U16S, ROUNDS_PER_ROW, ROUNDS_PER_ROW_MINUS_ONE>,
/// The actual final hash values of the given block
/// The actual final hash values of the given block, as little-endian 16-bit limbs.
/// Note: the above `hash` will be equal to `final_hash` unless we are on the last block
pub final_hash: [[T; WORD_U8S]; HASH_WORDS],
pub final_hash: [[T; WORD_U16S]; HASH_WORDS],
/// The final hash of the previous block
/// Note: will be constrained using interactions with the chip itself
pub prev_hash: [[T; WORD_U16S]; HASH_WORDS],
Expand Down
2 changes: 2 additions & 0 deletions crates/circuits/sha2-air/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ pub trait Sha2BlockHasherSubairConfig: Send + Sync + Clone {

/// Number of 16-bit limbs in a SHA word
const WORD_U16S: usize = Self::WORD_BITS / 16;
/// Bit width of one u16 limb.
const WORD_U16_BITS: usize = Self::WORD_BITS / Self::WORD_U16S;
/// Number of 8-bit limbs in a SHA word
const WORD_U8S: usize = Self::WORD_BITS / 8;
/// Number of cells in a SHA block
Expand Down
26 changes: 12 additions & 14 deletions crates/circuits/sha2-air/src/trace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,16 @@ use std::{marker::PhantomData, ops::Range};

use openvm_circuit_primitives::{
bitwise_op_lookup::SharedBitwiseOperationLookupChip, encoder::Encoder, utils::compose,
var_range::SharedVariableRangeCheckerChip,
};
use openvm_stark_backend::p3_field::PrimeField32;
use sha2::{compress256, compress512, digest::generic_array::GenericArray};

use crate::{
big_sig0, big_sig0_field, big_sig1, big_sig1_field, ch, ch_field, get_flag_pt_array,
le_limbs_into_word, maj, maj_field, set_arrayview_from_u32_slice, small_sig0, small_sig0_field,
small_sig1, small_sig1_field, word_into_bits, word_into_u16_limbs, word_into_u8_limbs,
Sha2BlockHasherSubairConfig, Sha2DigestColsRefMut, Sha2RoundColsRef, Sha2RoundColsRefMut,
Sha2Variant, WrappingAdd,
small_sig1, small_sig1_field, word_into_bits, word_into_u16_limbs, Sha2BlockHasherSubairConfig,
Sha2DigestColsRefMut, Sha2RoundColsRef, Sha2RoundColsRefMut, Sha2Variant, WrappingAdd,
};

/// A helper struct for the SHA-2 trace generation.
Expand Down Expand Up @@ -79,6 +79,7 @@ impl<C: Sha2BlockHasherSubairConfig> Sha2BlockHasherFillerHelper<C> {
trace_start_col: usize,
input: &[C::Word],
bitwise_lookup_chip: SharedBitwiseOperationLookupChip<8>,
range_checker_chip: &SharedVariableRangeCheckerChip,
prev_hash: &[C::Word],
next_block_prev_hash: &[C::Word],
global_block_idx: u32,
Expand Down Expand Up @@ -287,23 +288,20 @@ impl<C: Sha2BlockHasherSubairConfig> Sha2BlockHasherFillerHelper<C> {
let final_hash: Vec<C::Word> = (0..C::HASH_WORDS)
.map(|i| work_vars[i].wrapping_add(prev_hash[i]))
.collect();
let final_hash_limbs: Vec<Vec<u32>> = final_hash
let final_hash_u16_limbs: Vec<Vec<u32>> = final_hash
.iter()
.map(|word| word_into_u8_limbs::<C>(*word))
.map(|word| word_into_u16_limbs::<C>(*word))
.collect();
// need to ensure final hash limbs are bytes, in order for
// prev_hash[i] + work_vars[i] == final_hash[i]
// to be constrained correctly
for word in final_hash_limbs.iter() {
for chunk in word.chunks(2) {
bitwise_lookup_chip.request_range(chunk[0], chunk[1]);
// Range-check the final-hash limbs so the digest-row addition
// `prev_hash + work_vars == final_hash` is constrained in u16 limbs.
for word in final_hash_u16_limbs.iter() {
for &limb in word.iter() {
range_checker_chip.add_count(limb, C::WORD_U16_BITS);
}
}
set_arrayview_from_u32_slice(
&mut cols.final_hash,
final_hash
.iter()
.flat_map(|word| word_into_u8_limbs::<C>(*word)),
final_hash_u16_limbs.iter().flatten().copied(),
);
set_arrayview_from_u32_slice(
&mut cols.prev_hash,
Expand Down
29 changes: 29 additions & 0 deletions crates/circuits/sha2-air/src/utils.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::mem::size_of;

use ndarray::ArrayViewMut;
pub use openvm_circuit_primitives::utils::compose;
use openvm_circuit_primitives::{
Expand Down Expand Up @@ -304,6 +306,7 @@ pub fn constraint_word_addition<AB: AirBuilder, C: Sha2BlockHasherSubairConfig>(
}
}

/// Fill an array view from u32 values.
pub fn set_arrayview_from_u32_slice<F: PrimeField32, D: ndarray::Dimension>(
arrayview: &mut ArrayViewMut<F, D>,
data: impl IntoIterator<Item = u32>,
Expand All @@ -314,6 +317,18 @@ pub fn set_arrayview_from_u32_slice<F: PrimeField32, D: ndarray::Dimension>(
.for_each(|(x, y)| *x = y);
}

/// Fill an array view from u16 values.
pub fn set_arrayview_from_u16_slice<F: PrimeField32, D: ndarray::Dimension>(
arrayview: &mut ArrayViewMut<F, D>,
data: impl IntoIterator<Item = u16>,
) {
arrayview
.iter_mut()
.zip(data.into_iter().map(|x| F::from_u16(x)))
.for_each(|(x, y)| *x = y);
}

/// Fill an array view from u8 values.
pub fn set_arrayview_from_u8_slice<F: PrimeField32, D: ndarray::Dimension>(
arrayview: &mut ArrayViewMut<F, D>,
data: impl IntoIterator<Item = u8>,
Expand All @@ -323,3 +338,17 @@ pub fn set_arrayview_from_u8_slice<F: PrimeField32, D: ndarray::Dimension>(
.zip(data.into_iter().map(|x| F::from_u8(x)))
.for_each(|(x, y)| *x = y);
}

/// Fill an array view by decoding bytes as little-endian u16 cells.
pub fn set_arrayview_from_u16_le_bytes<F: PrimeField32, D: ndarray::Dimension>(
arrayview: &mut ArrayViewMut<F, D>,
bytes: &[u8],
) {
debug_assert_eq!(arrayview.len() * size_of::<u16>(), bytes.len());
arrayview
.iter_mut()
.zip(bytes.chunks_exact(size_of::<u16>()))
.for_each(|(slot, bytes)| {
*slot = F::from_u16(u16::from_le_bytes([bytes[0], bytes[1]]));
});
}
5 changes: 3 additions & 2 deletions crates/vm/src/system/memory/online.rs
Original file line number Diff line number Diff line change
Expand Up @@ -209,9 +209,10 @@ impl<M: LinearMemory> AddressMap<M> {
(addr_space, cell_idx): Address,
len: usize,
) -> &[T] {
debug_assert_eq!(
assert_eq!(
size_of::<T>(),
self.config[addr_space as usize].layout.size()
self.config[addr_space as usize].layout.size(),
"typed slice access must use the AS cell type; use get_u8_slice for raw bytes"
);
let start = (cell_idx as usize) * size_of::<T>();
let mem = self.mem.get_unchecked(addr_space as usize);
Expand Down
8 changes: 5 additions & 3 deletions extensions/keccak256/circuit/cuda/include/keccakf_op.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,11 @@ template <typename T> struct KeccakfOpCols {
T is_valid;
T timestamp;
T rd_ptr;
T buffer_ptr_limbs[RV64_WORD_NUM_LIMBS]; // 4 limbs
T preimage[KECCAK_WIDTH_BYTES]; // 200 bytes
T postimage[KECCAK_WIDTH_BYTES]; // 200 bytes
// Low 32 bits of [rd_ptr:8]_1 as u16 cells.
T buffer_ptr_limbs[RV64_PTR_U16_LIMBS];
// Keccak state as u16 cells.
T preimage[KECCAK_WIDTH_U16S];
T postimage[KECCAK_WIDTH_U16S];
MemoryReadAuxCols<T> rd_aux;
MemoryBaseAuxCols<T> buffer_word_aux[KECCAK_WIDTH_MEM_OPS]; // 25 words
};
Expand Down
6 changes: 4 additions & 2 deletions extensions/keccak256/circuit/cuda/include/xorin.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,11 @@ struct XorinInstructionCols {
T input_reg_ptr;
T len_reg_ptr;
T buffer_ptr;
T buffer_ptr_limbs[riscv::RV64_WORD_NUM_LIMBS];
// Low 32 bits of [buffer_reg_ptr:8]_1 as u16 cells.
T buffer_ptr_limbs[RV64_PTR_U16_LIMBS];
T input_ptr;
T input_ptr_limbs[riscv::RV64_WORD_NUM_LIMBS];
// Low 32 bits of [input_reg_ptr:8]_1 as u16 cells.
T input_ptr_limbs[RV64_PTR_U16_LIMBS];
T len;
T len_limb;
T start_timestamp;
Expand Down
Loading
Loading