Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
243 changes: 23 additions & 220 deletions crates/lean_compiler/src/a_simplify_lang/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -321,8 +321,7 @@ pub fn simplify_program(mut program: Program) -> Result<SimpleProgram, String> {
program.functions.remove(&name);
}

let mut mutable_loop_counter = Counter::new();
transform_mutable_in_loops_in_program(&mut program, &mut mutable_loop_counter)?;
check_no_loop_carried_mutables(&program)?;

let mut new_functions = BTreeMap::new();
let mut counters = Counters::default();
Expand Down Expand Up @@ -956,28 +955,6 @@ fn substitute_const_vars_in_expr(expr: &mut Expression, const_var_exprs: &BTreeM
changed
}

// ============================================================================
// TRANSFORMATION: Mutable variables in non-unrolled loops
// ============================================================================
//
// This transformation handles mutable variables that are modified inside
// non-unrolled loops by using buffers to store intermediate values.
//
// For a loop like:
// for i in start..end { x += i; }
//
// We transform it to:
// size = end - start;
// x_buff = Array(size + 1);
// x_buff[0] = x;
// for i in start..end {
// buff_idx = i - start;
// mut x_body = x_buff[buff_idx];
// x_body += i;
// x_buff[buff_idx + 1] = x_body;
// }
// x = x_buff[size];

/// Finds mutable variables that are:
/// 1. Defined OUTSIDE this block (external)
/// 2. Re-assigned INSIDE this block
Expand Down Expand Up @@ -1052,228 +1029,54 @@ fn find_assigned_external_vars_helper(
}
}

fn transform_mutable_in_loops_in_program(program: &mut Program, counter: &mut Counter) -> Result<(), String> {
for func in program.functions.values_mut() {
transform_mutable_in_loops_in_lines(&mut func.body, &program.const_arrays, counter, &BTreeSet::new())?;
/// Reject any `range` / `parallel_range` loop that reassigns a mutable variable
/// defined in an enclosing scope ("loop-carried mutable").
fn check_no_loop_carried_mutables(program: &Program) -> Result<(), String> {
for func in program.functions.values() {
check_loop_carried_mutables_in_lines(&func.body, &program.const_arrays, &BTreeSet::new())?;
}
Ok(())
}

fn transform_mutable_in_loops_in_lines(
lines: &mut Vec<Line>,
fn check_loop_carried_mutables_in_lines(
lines: &[Line],
const_arrays: &BTreeMap<String, ConstArrayValue>,
counter: &mut Counter,
outer_mut_vars: &BTreeSet<Var>,
) -> Result<(), String> {
let mut local_mut_vars = outer_mut_vars.clone();
let mut i = 0;
while i < lines.len() {
match &mut lines[i] {
Line::ForLoop { body, loop_kind, .. } if loop_kind.is_unroll() => {
transform_mutable_in_loops_in_lines(body, const_arrays, counter, &local_mut_vars)?;
i += 1;
}
for line in lines {
match line {
Line::ForLoop {
iterator,
start,
end,
body,
loop_kind: loop_kind @ (LoopKind::Range | LoopKind::ParallelRange),
loop_kind: LoopKind::Range | LoopKind::ParallelRange,
location,
..
} => {
let loop_kind = loop_kind.clone();
transform_mutable_in_loops_in_lines(body, const_arrays, counter, &local_mut_vars)?;
check_loop_carried_mutables_in_lines(body, const_arrays, &local_mut_vars)?;
let modified_vars = find_modified_external_vars(body, const_arrays, &local_mut_vars);

if modified_vars.is_empty() {
// No mutable variables modified, no transformation needed
i += 1;
continue;
}

if loop_kind.is_parallel() {
if !modified_vars.is_empty() {
return Err(format!(
"parallel loop at {location} carries mutable variable(s) {modified_vars:?} across iterations; use a sequential `range` loop"
"loop at {location} reassigns enclosing-scope mutable(s) {modified_vars:?}; \
loop-carried mutables are unsupported: use an explicit buffer (see zkDSL.md, \"For loops\")"
));
}

let suffix = counter.get_next();

// Generate the transformed code
let mut new_lines = Vec::new();

let location = *location;

// Create size variable: @loop_size_{suffix} = end - start
let size_var = format!("@loop_size_{suffix}");

new_lines.push(Line::Statement {
targets: vec![AssignmentTarget::Var {
var: size_var.clone(),
is_mutable: false,
}],
value: Expression::MathExpr(MathOperation::Sub, vec![end.clone(), start.clone()]),
location,
});

let mut var_to_buff: BTreeMap<Var, (Var, Var)> = BTreeMap::new(); // var -> (buff_name, body_name)

for var in &modified_vars {
let buff_name = format!("@loop_buff_{var}_{suffix}");
let body_name = format!("@loop_body_{var}_{suffix}");

// buff = Array(size + 1)
new_lines.push(Line::Statement {
targets: vec![AssignmentTarget::Var {
var: buff_name.clone(),
is_mutable: false,
}],
value: Expression::FunctionCall {
function_name: "Array".to_string(),
args: vec![Expression::MathExpr(
// TODO opti in case there is only one mutated var
MathOperation::Add,
vec![Expression::var(size_var.clone()), Expression::one()],
)],
location,
},
location,
});

// buff[0] = var (current value)
new_lines.push(Line::Statement {
targets: vec![AssignmentTarget::ArrayAccess {
array: buff_name.clone().into(),
index: Box::new(Expression::zero()),
}],
value: Expression::var(var.clone()),
location,
});

var_to_buff.insert(var.clone(), (buff_name, body_name));
}

// Transform the loop body
let iterator = iterator.clone();
let mut new_body = Vec::new();

// buff_idx = i - start (or just i when start is zero)
let buff_idx_var = format!("@loop_buff_idx_{suffix}");

new_body.push(Line::Statement {
targets: vec![AssignmentTarget::Var {
var: buff_idx_var.clone(),
is_mutable: false,
}],
value: Expression::MathExpr(
MathOperation::Sub,
vec![Expression::var(iterator.clone()), start.clone()],
),
location,
});

// For each modified variable: mut body_var = buff[buff_idx]
for (var, (buff_name, body_name)) in &var_to_buff {
new_body.push(Line::Statement {
targets: vec![AssignmentTarget::Var {
var: body_name.clone(),
is_mutable: true,
}],
value: Expression::ArrayAccess {
array: buff_name.clone().into(),
index: vec![Expression::Value(
VarOrConstMallocAccess::Var(buff_idx_var.clone()).into(),
)],
},
location,
});

// Replace all references to var with body_name in the original body
transform_vars_in_lines(body, &|v: &Var| {
if v == var {
VarTransform::Rename(body_name.clone())
} else {
VarTransform::Keep
}
});
}

// Add the original body (now modified to use body_vars)
new_body.append(body);

// next_idx = buff_idx + 1
let next_idx_var = format!("@loop_next_idx_{suffix}");
new_body.push(Line::Statement {
targets: vec![AssignmentTarget::Var {
var: next_idx_var.clone(),
is_mutable: false,
}],
value: Expression::MathExpr(
MathOperation::Add,
vec![Expression::var(buff_idx_var.clone()), Expression::one()],
),
location,
});

// For each modified variable: buff[next_idx] = body_var
for (buff_name, body_name) in var_to_buff.values() {
new_body.push(Line::Statement {
targets: vec![AssignmentTarget::ArrayAccess {
array: buff_name.clone().into(),
index: Expression::var(next_idx_var.clone()).into(),
}],
value: Expression::var(body_name.clone()),
location,
});
}

// Create the new loop
new_lines.push(Line::ForLoop {
iterator: iterator.clone(),
start: start.clone(),
end: end.clone(),
body: new_body,
loop_kind,
location,
});

// After the loop: var = buff[size]
for (var, (buff_name, _body_name)) in &var_to_buff {
new_lines.push(Line::Statement {
targets: vec![AssignmentTarget::Var {
var: var.clone(),
is_mutable: false,
}],
value: Expression::ArrayAccess {
array: buff_name.clone().into(),
index: vec![Expression::var(size_var.clone())],
},
location,
});
}

// Replace the original loop with the new lines
let num_new = new_lines.len();
lines.splice(i..=i, new_lines);
i += num_new;
}
line @ (Line::IfCondition { .. } | Line::Match { .. }) => {
for block in line.nested_blocks_mut() {
transform_mutable_in_loops_in_lines(block, const_arrays, counter, &local_mut_vars)?;
Line::ForLoop { body, .. } => {
check_loop_carried_mutables_in_lines(body, const_arrays, &local_mut_vars)?;
}
Line::IfCondition { .. } | Line::Match { .. } => {
for block in line.nested_blocks() {
check_loop_carried_mutables_in_lines(block, const_arrays, &local_mut_vars)?;
}
i += 1;
}
Line::Statement { targets, .. } => {
for target in targets {
if let AssignmentTarget::Var { var, is_mutable: true } = target {
local_mut_vars.insert(var.clone());
}
}
i += 1;
}
_ => {
i += 1;
}
_ => {}
}
}
Ok(())
Expand Down
11 changes: 11 additions & 0 deletions crates/lean_compiler/tests/test_data/error_100.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from snark_lib import *

# Error: a Mut carried across a `parallel_range` loop is rejected, same as `range`.


def main():
acc: Mut = 0
for i in parallel_range(0, 4):
acc = acc + i
assert acc == 6
return
13 changes: 13 additions & 0 deletions crates/lean_compiler/tests/test_data/error_101.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from snark_lib import *

# Error: the enclosing Mut `c` is reassigned inside an `if` nested in a `range`
# loop — detection must look inside nested blocks, not just the loop's top level.


def main():
c: Mut = 0
for i in range(0, 5):
if i == 2:
c = c + 1
assert c == 1
return
12 changes: 12 additions & 0 deletions crates/lean_compiler/tests/test_data/error_102.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from snark_lib import *

# Error: `counter` (enclosing Mut) is reassigned inside a nested `range` loop.


def main():
counter: Mut = 0
for i in range(0, 3):
for j in range(0, 2):
counter = counter + 1
assert counter == 6
return
12 changes: 12 additions & 0 deletions crates/lean_compiler/tests/test_data/error_99.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from snark_lib import *

# Error: `total` (a Mut from the enclosing scope) is reassigned inside a `range`
# loop. Loop-carried mutables are not supported; use an explicit buffer instead.


def main():
total: Mut = 0
for i in range(0, 5):
total = total + i
assert total == 10
return
20 changes: 0 additions & 20 deletions crates/lean_compiler/tests/test_data/program_130.py

This file was deleted.

11 changes: 0 additions & 11 deletions crates/lean_compiler/tests/test_data/program_131.py

This file was deleted.

11 changes: 0 additions & 11 deletions crates/lean_compiler/tests/test_data/program_132.py

This file was deleted.

Loading
Loading