Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions extensions/native/circuit/cuda/include/native/sumcheck.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,14 @@ using namespace native;
template <typename T> struct HeaderSpecificCols {
T pc;
T registers[5];
MemoryReadAuxCols<T> read_records[7];
MemoryReadAuxCols<T> read_records[8];
MemoryWriteAuxCols<T, EXT_DEG> write_records;
};

template <typename T> struct ProdSpecificCols {
T data_ptr;
T p[EXT_DEG * 2];
MemoryReadAuxCols<T> read_records[2];
MemoryReadAuxCols<T> read_records[1];
T p_evals[EXT_DEG];
MemoryWriteAuxCols<T, EXT_DEG> write_record;
T eval_rlc[EXT_DEG];
Expand All @@ -24,7 +24,7 @@ template <typename T> struct ProdSpecificCols {
template <typename T> struct LogupSpecificCols {
T data_ptr;
T pq[EXT_DEG * 4];
MemoryReadAuxCols<T> read_records[2];
MemoryReadAuxCols<T> read_records[1];
T p_evals[EXT_DEG];
T q_evals[EXT_DEG];
MemoryWriteAuxCols<T, EXT_DEG> write_records[2];
Expand Down
26 changes: 8 additions & 18 deletions extensions/native/circuit/cuda/src/sumcheck.cu
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ __device__ void fill_sumcheck_specific(RowSlice row, MemoryAuxColsFactory &mem_h
uint32_t start_timestamp = row[COL_INDEX(NativeSumcheckCols, start_timestamp)].asUInt32();

if (row[COL_INDEX(NativeSumcheckCols, header_row)] == Fp::one()) {
for (uint32_t i = 0; i < 7; ++i) {
for (uint32_t i = 0; i < 8; ++i) {
mem_fill_base(
mem_helper,
start_timestamp + i,
Expand All @@ -25,43 +25,33 @@ __device__ void fill_sumcheck_specific(RowSlice row, MemoryAuxColsFactory &mem_h
specific.slice_from(COL_INDEX(HeaderSpecificCols, write_records.base))
);
} else if (row[COL_INDEX(NativeSumcheckCols, prod_row)] == Fp::one()) {
mem_fill_base(
mem_helper,
start_timestamp,
specific.slice_from(COL_INDEX(ProdSpecificCols, read_records[0].base))
);
if (row[COL_INDEX(NativeSumcheckCols, within_round_limit)] == Fp::one()) {
mem_fill_base(
mem_helper,
start_timestamp + 1,
specific.slice_from(COL_INDEX(ProdSpecificCols, read_records[1].base))
start_timestamp,
specific.slice_from(COL_INDEX(ProdSpecificCols, read_records[0].base))
);
mem_fill_base(
mem_helper,
start_timestamp + 2,
start_timestamp + 1,
specific.slice_from(COL_INDEX(ProdSpecificCols, write_record.base))
);
}
} else if (row[COL_INDEX(NativeSumcheckCols, logup_row)] == Fp::one()) {
mem_fill_base(
mem_helper,
start_timestamp,
specific.slice_from(COL_INDEX(LogupSpecificCols, read_records[0].base))
);
if (row[COL_INDEX(NativeSumcheckCols, within_round_limit)] == Fp::one()) {
mem_fill_base(
mem_helper,
start_timestamp + 1,
specific.slice_from(COL_INDEX(LogupSpecificCols, read_records[1].base))
start_timestamp,
specific.slice_from(COL_INDEX(LogupSpecificCols, read_records[0].base))
);
mem_fill_base(
mem_helper,
start_timestamp + 2,
start_timestamp + 1,
specific.slice_from(COL_INDEX(LogupSpecificCols, write_records[0].base))
);
mem_fill_base(
mem_helper,
start_timestamp + 3,
start_timestamp + 2,
specific.slice_from(COL_INDEX(LogupSpecificCols, write_records[1].base))
);
}
Expand Down
65 changes: 26 additions & 39 deletions extensions/native/circuit/src/sumcheck/air.rs
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,9 @@ impl<AB: InteractionBuilder> Air<AB> for NativeSumcheckAir {
alpha,
next.alpha,
);
builder
.when(next.prod_row + next.logup_row)
.assert_eq(max_round, next.max_round);
builder
.when(next.prod_row + next.logup_row)
.assert_eq(prod_nested_len, next.prod_nested_len);
Expand Down Expand Up @@ -223,21 +226,21 @@ impl<AB: InteractionBuilder> Air<AB> for NativeSumcheckAir {
.when(next.prod_row + next.logup_row)
.assert_eq(
next.start_timestamp,
start_timestamp + AB::F::from_canonical_usize(7),
start_timestamp + AB::F::from_canonical_usize(8),
);
builder
.when(prod_row)
.when(next.prod_row + next.logup_row)
.assert_eq(
next.start_timestamp,
start_timestamp + AB::F::ONE + within_round_limit * AB::F::TWO,
start_timestamp + within_round_limit * AB::F::TWO,
);
builder
.when(logup_row)
.when(next.prod_row + next.logup_row)
.assert_eq(
next.start_timestamp,
start_timestamp + AB::F::ONE + within_round_limit * AB::F::from_canonical_usize(3),
start_timestamp + within_round_limit * AB::F::from_canonical_usize(3),
);

// Termination condition
Expand Down Expand Up @@ -330,6 +333,19 @@ impl<AB: InteractionBuilder> Air<AB> for NativeSumcheckAir {
)
.eval(builder, header_row);

// Read max_round
self.memory_bridge
.read(
MemoryAddress::new(
native_as,
register_ptrs[0] + AB::F::from_canonical_usize(CONTEXT_ARR_BASE_LEN),
),
[max_round],
first_timestamp + AB::F::from_canonical_usize(7),
&header_row_specific.read_records[7],
)
.eval(builder, header_row);

// Write final result
self.memory_bridge
.write(
Expand All @@ -348,20 +364,6 @@ impl<AB: InteractionBuilder> Air<AB> for NativeSumcheckAir {
let next_prod_row_specific: &ProdSpecificCols<AB::Var> =
next.specific[..ProdSpecificCols::<AB::Var>::width()].borrow();

self.memory_bridge
.read(
MemoryAddress::new(
native_as,
register_ptrs[0]
+ AB::F::from_canonical_usize(CONTEXT_ARR_BASE_LEN)
+ (curr_prod_n - AB::F::ONE),
), // curr_prod_n starts at 1.
[max_round],
start_timestamp,
&prod_row_specific.read_records[0],
)
.eval(builder, prod_row);

// prod_row * within_round_limit =
// prod_in_round_evaluation + prod_next_round_evaluation
builder
Expand All @@ -385,8 +387,8 @@ impl<AB: InteractionBuilder> Air<AB> for NativeSumcheckAir {
.read(
MemoryAddress::new(native_as, register_ptrs[2] + prod_row_specific.data_ptr),
prod_row_specific.p,
start_timestamp + AB::F::ONE,
&prod_row_specific.read_records[1],
start_timestamp,
&prod_row_specific.read_records[0],
)
.eval(builder, prod_row * within_round_limit);

Expand All @@ -402,7 +404,7 @@ impl<AB: InteractionBuilder> Air<AB> for NativeSumcheckAir {
register_ptrs[4] + curr_prod_n * AB::F::from_canonical_usize(EXT_DEG),
),
prod_row_specific.p_evals,
start_timestamp + AB::F::TWO,
start_timestamp + AB::F::ONE,
&prod_row_specific.write_record,
)
.eval(builder, prod_row * within_round_limit);
Expand Down Expand Up @@ -449,21 +451,6 @@ impl<AB: InteractionBuilder> Air<AB> for NativeSumcheckAir {
let next_logup_row_specfic: &LogupSpecificCols<AB::Var> =
next.specific[..LogupSpecificCols::<AB::Var>::width()].borrow();

self.memory_bridge
.read(
MemoryAddress::new(
native_as,
register_ptrs[0]
+ AB::F::from_canonical_usize(EXT_DEG * 2)
+ num_prod_spec
+ (curr_logup_n - AB::F::ONE),
), // curr_logup_n starts at 1.
[max_round],
start_timestamp,
&logup_row_specific.read_records[0],
)
.eval(builder, logup_row);

// logup_row * within_round_limit =
// logup_in_round_evaluation + logup_next_round_evaluation
builder
Expand All @@ -488,8 +475,8 @@ impl<AB: InteractionBuilder> Air<AB> for NativeSumcheckAir {
.read(
MemoryAddress::new(native_as, register_ptrs[3] + logup_row_specific.data_ptr),
logup_row_specific.pq,
start_timestamp + AB::F::ONE,
&logup_row_specific.read_records[1],
start_timestamp,
&logup_row_specific.read_records[0],
)
.eval(builder, logup_row * within_round_limit);

Expand All @@ -513,7 +500,7 @@ impl<AB: InteractionBuilder> Air<AB> for NativeSumcheckAir {
+ (num_prod_spec + curr_logup_n) * AB::F::from_canonical_usize(EXT_DEG),
),
logup_row_specific.p_evals,
start_timestamp + AB::F::TWO,
start_timestamp + AB::F::ONE,
&logup_row_specific.write_records[0],
)
.eval(builder, logup_row * within_round_limit);
Expand All @@ -528,7 +515,7 @@ impl<AB: InteractionBuilder> Air<AB> for NativeSumcheckAir {
* AB::F::from_canonical_usize(EXT_DEG),
),
logup_row_specific.q_evals,
start_timestamp + AB::F::from_canonical_usize(3),
start_timestamp + AB::F::TWO,
&logup_row_specific.write_records[1],
)
.eval(builder, logup_row * within_round_limit);
Expand Down
Loading
Loading