Skip to content

Commit b592ef5

Browse files
committed
Fix Losing Previous Simplification Nodes on Multiple Passes
Fixed by tracking previous origins in constant propagation
1 parent cba9719 commit b592ef5

File tree

3 files changed

+92
-23
lines changed

3 files changed

+92
-23
lines changed

liquidjava-verifier/src/main/java/liquidjava/rj_language/opt/ConstantPropagation.java

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -19,15 +19,16 @@ public class ConstantPropagation {
1919
* VariableResolver to extract variable equalities from the expression first. Returns a derivation node representing
2020
* the propagation steps taken.
2121
*/
22-
public static ValDerivationNode propagate(Expression exp) {
22+
public static ValDerivationNode propagate(Expression exp, DerivationNode defaultOrigin) {
2323
Map<String, Expression> substitutions = VariableResolver.resolve(exp);
24-
return propagateRecursive(exp, substitutions);
24+
return propagateRecursive(exp, substitutions, defaultOrigin);
2525
}
2626

2727
/**
2828
* Recursively performs constant propagation on an expression (e.g. x + y && x == 1 && y == 2 => 1 + 2)
2929
*/
30-
private static ValDerivationNode propagateRecursive(Expression exp, Map<String, Expression> subs) {
30+
private static ValDerivationNode propagateRecursive(Expression exp, Map<String, Expression> subs,
31+
DerivationNode defaultOrigin) {
3132

3233
// substitute variable
3334
if (exp instanceof Var var) {
@@ -38,44 +39,44 @@ private static ValDerivationNode propagateRecursive(Expression exp, Map<String,
3839
return new ValDerivationNode(value.clone(), new VarDerivationNode(name));
3940

4041
// no substitution
41-
return new ValDerivationNode(var, null);
42+
return new ValDerivationNode(var, defaultOrigin);
4243
}
4344

4445
// lift unary origin
4546
if (exp instanceof UnaryExpression unary) {
46-
ValDerivationNode operand = propagateRecursive(unary.getChildren().get(0), subs);
47+
ValDerivationNode operand = propagateRecursive(unary.getChildren().get(0), subs, defaultOrigin);
4748
UnaryExpression cloned = (UnaryExpression) unary.clone();
4849
cloned.setChild(0, operand.getValue());
4950

5051
DerivationNode origin = operand.getOrigin() != null ? new UnaryDerivationNode(operand, cloned.getOp())
51-
: null;
52+
: defaultOrigin;
5253
return new ValDerivationNode(cloned, origin);
5354
}
5455

5556
// lift binary origin
5657
if (exp instanceof BinaryExpression binary) {
57-
ValDerivationNode left = propagateRecursive(binary.getFirstOperand(), subs);
58-
ValDerivationNode right = propagateRecursive(binary.getSecondOperand(), subs);
58+
ValDerivationNode left = propagateRecursive(binary.getFirstOperand(), subs, defaultOrigin);
59+
ValDerivationNode right = propagateRecursive(binary.getSecondOperand(), subs, defaultOrigin);
5960
BinaryExpression cloned = (BinaryExpression) binary.clone();
6061
cloned.setChild(0, left.getValue());
6162
cloned.setChild(1, right.getValue());
6263

6364
DerivationNode origin = (left.getOrigin() != null || right.getOrigin() != null)
64-
? new BinaryDerivationNode(left, right, cloned.getOperator()) : null;
65+
? new BinaryDerivationNode(left, right, cloned.getOperator()) : defaultOrigin;
6566
return new ValDerivationNode(cloned, origin);
6667
}
6768

6869
// recursively propagate children
6970
if (exp.hasChildren()) {
7071
Expression propagated = exp.clone();
7172
for (int i = 0; i < exp.getChildren().size(); i++) {
72-
ValDerivationNode child = propagateRecursive(exp.getChildren().get(i), subs);
73+
ValDerivationNode child = propagateRecursive(exp.getChildren().get(i), subs, defaultOrigin);
7374
propagated.setChild(i, child.getValue());
7475
}
75-
return new ValDerivationNode(propagated, null);
76+
return new ValDerivationNode(propagated, defaultOrigin);
7677
}
7778

7879
// no propagation
79-
return new ValDerivationNode(exp, null);
80+
return new ValDerivationNode(exp, defaultOrigin);
8081
}
8182
}

liquidjava-verifier/src/main/java/liquidjava/rj_language/opt/ExpressionSimplifier.java

Lines changed: 40 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import liquidjava.rj_language.ast.LiteralBoolean;
66
import liquidjava.rj_language.opt.derivation_node.BinaryDerivationNode;
77
import liquidjava.rj_language.opt.derivation_node.DerivationNode;
8+
import liquidjava.rj_language.opt.derivation_node.UnaryDerivationNode;
89
import liquidjava.rj_language.opt.derivation_node.ValDerivationNode;
910

1011
public class ExpressionSimplifier {
@@ -15,7 +16,7 @@ public class ExpressionSimplifier {
1516
*/
1617
public static ValDerivationNode simplify(Expression exp) {
1718
ValDerivationNode fixedPoint = simplifyToFixedPoint(null, null, exp);
18-
return simplifyDerivationTree(fixedPoint);
19+
return simplifyValDerivationNode(fixedPoint);
1920
}
2021

2122
/**
@@ -25,7 +26,7 @@ public static ValDerivationNode simplify(Expression exp) {
2526
private static ValDerivationNode simplifyToFixedPoint(ValDerivationNode current, ValDerivationNode previous,
2627
Expression prevExp) {
2728
// apply propagation and folding
28-
ValDerivationNode prop = ConstantPropagation.propagate(prevExp);
29+
ValDerivationNode prop = ConstantPropagation.propagate(prevExp, current);
2930
ValDerivationNode fold = ConstantFolding.fold(prop);
3031
Expression currExp = fold.getValue();
3132

@@ -41,7 +42,7 @@ private static ValDerivationNode simplifyToFixedPoint(ValDerivationNode current,
4142
/**
4243
* Recursively simplifies the derivation tree by removing redundant conjuncts
4344
*/
44-
private static ValDerivationNode simplifyDerivationTree(ValDerivationNode node) {
45+
private static ValDerivationNode simplifyValDerivationNode(ValDerivationNode node) {
4546
Expression value = node.getValue();
4647
DerivationNode origin = node.getOrigin();
4748

@@ -50,13 +51,12 @@ private static ValDerivationNode simplifyDerivationTree(ValDerivationNode node)
5051
ValDerivationNode leftSimplified;
5152
ValDerivationNode rightSimplified;
5253

53-
// simplify children
5454
if (origin instanceof BinaryDerivationNode binOrigin) {
55-
leftSimplified = simplifyDerivationTree(binOrigin.getLeft());
56-
rightSimplified = simplifyDerivationTree(binOrigin.getRight());
55+
leftSimplified = simplifyValDerivationNode(binOrigin.getLeft());
56+
rightSimplified = simplifyValDerivationNode(binOrigin.getRight());
5757
} else {
58-
leftSimplified = simplifyDerivationTree(new ValDerivationNode(binExp.getFirstOperand(), null));
59-
rightSimplified = simplifyDerivationTree(new ValDerivationNode(binExp.getSecondOperand(), null));
58+
leftSimplified = simplifyValDerivationNode(new ValDerivationNode(binExp.getFirstOperand(), null));
59+
rightSimplified = simplifyValDerivationNode(new ValDerivationNode(binExp.getSecondOperand(), null));
6060
}
6161

6262
// check if either side is redundant
@@ -65,7 +65,7 @@ private static ValDerivationNode simplifyDerivationTree(ValDerivationNode node)
6565
if (isRedundant(rightSimplified.getValue()))
6666
return leftSimplified;
6767

68-
// check if children are equal (x && x => x)
68+
// collapse identical sides (x && x => x)
6969
if (leftSimplified.getValue().toString().equals(rightSimplified.getValue().toString())) {
7070
return leftSimplified;
7171
}
@@ -75,10 +75,41 @@ private static ValDerivationNode simplifyDerivationTree(ValDerivationNode node)
7575
DerivationNode newOrigin = new BinaryDerivationNode(leftSimplified, rightSimplified, "&&");
7676
return new ValDerivationNode(newValue, newOrigin);
7777
}
78+
79+
// simplify origin
80+
DerivationNode simplifiedOrigin = simplifyDerivationNode(origin);
81+
if (simplifiedOrigin != origin) {
82+
return new ValDerivationNode(value, simplifiedOrigin);
83+
}
84+
7885
// no simplification
7986
return node;
8087
}
8188

89+
private static DerivationNode simplifyDerivationNode(DerivationNode node) {
90+
if (node == null)
91+
return null;
92+
if (node instanceof ValDerivationNode val) {
93+
return simplifyValDerivationNode(val);
94+
}
95+
if (node instanceof BinaryDerivationNode binary) {
96+
ValDerivationNode left = simplifyValDerivationNode(binary.getLeft());
97+
ValDerivationNode right = simplifyValDerivationNode(binary.getRight());
98+
if (left != binary.getLeft() || right != binary.getRight()) {
99+
return new BinaryDerivationNode(left, right, binary.getOp());
100+
}
101+
return binary;
102+
}
103+
if (node instanceof UnaryDerivationNode unary) {
104+
ValDerivationNode operand = simplifyValDerivationNode(unary.getOperand());
105+
if (operand != unary.getOperand()) {
106+
return new UnaryDerivationNode(operand, unary.getOp());
107+
}
108+
return unary;
109+
}
110+
return node;
111+
}
112+
82113
/**
83114
* Checks if an expression is redundant (e.g. true or x == x)
84115
*/

liquidjava-verifier/src/test/java/liquidjava/rj_language/opt/ExpressionSimplifierTest.java

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -308,7 +308,6 @@ void testComplexArithmeticWithMultipleOperations() {
308308
void testFixedPointSimplification() {
309309
// Given: x == -y && y == a / b && a == 6 && b == 3
310310
// Expected: x == -2
311-
312311
Expression varX = new Var("x");
313312
Expression varY = new Var("y");
314313
Expression varA = new Var("a");
@@ -332,6 +331,44 @@ void testFixedPointSimplification() {
332331
// Then
333332
assertNotNull(result, "Result should not be null");
334333
assertEquals("x == -2", result.getValue().toString(), "Expected result to be x == -2");
334+
335+
// Compare derivation tree structure
336+
337+
// Origin of y (value 2) - right operand of result
338+
ValDerivationNode originY = new ValDerivationNode(new LiteralInt(2), new VarDerivationNode("y"));
339+
UnaryDerivationNode originNeg2 = new UnaryDerivationNode(originY, "-");
340+
ValDerivationNode rightNode = new ValDerivationNode(new LiteralInt(-2), originNeg2);
341+
342+
// Origin of x - left operand of result
343+
// 6 (from a) / 3 (from b) -> 2
344+
ValDerivationNode val6 = new ValDerivationNode(new LiteralInt(6), new VarDerivationNode("a"));
345+
ValDerivationNode val3 = new ValDerivationNode(new LiteralInt(3), new VarDerivationNode("b"));
346+
BinaryDerivationNode divOp = new BinaryDerivationNode(val6, val3, "/");
347+
ValDerivationNode val2FromDiv = new ValDerivationNode(new LiteralInt(2), divOp);
348+
349+
// y == 2 (from y == 6 / 3)
350+
ValDerivationNode valYNode = new ValDerivationNode(new Var("y"), null);
351+
BinaryDerivationNode eqYOp = new BinaryDerivationNode(valYNode, val2FromDiv, "==");
352+
ValDerivationNode yEq2 = new ValDerivationNode(new BinaryExpression(new Var("y"), "==", new LiteralInt(2)),
353+
eqYOp);
354+
355+
// x == -y
356+
ValDerivationNode xEqNegY = new ValDerivationNode(
357+
new BinaryExpression(new Var("x"), "==", new UnaryExpression("-", new Var("y"))), null);
358+
359+
// x == -y && y == 2
360+
BinaryDerivationNode andOp1 = new BinaryDerivationNode(xEqNegY, yEq2, "&&");
361+
ValDerivationNode xEqNegYAndYEq2 = new ValDerivationNode(
362+
new BinaryExpression(xEqNegY.getValue(), "&&", yEq2.getValue()), andOp1);
363+
364+
// Left node x has origin pointing to the previous simplification's tree
365+
ValDerivationNode leftNode = new ValDerivationNode(new Var("x"), xEqNegYAndYEq2);
366+
367+
// Root equality
368+
BinaryDerivationNode rootOrigin = new BinaryDerivationNode(leftNode, rightNode, "==");
369+
ValDerivationNode expected = new ValDerivationNode(result.getValue(), rootOrigin);
370+
371+
assertDerivationEquals(expected, result, "Derivation tree structure");
335372
}
336373

337374
@Test
@@ -447,7 +484,7 @@ void testSameExpressionTwiceShouldSimplifyToSingle() {
447484
assertEquals("a + b == 1", result.getValue().toString(),
448485
"Same expression twice should be simplified to a single equality");
449486
}
450-
487+
451488
@Test
452489
void testCircularDependencyShouldNotSimplify() {
453490
// Given: x == y && y == x

0 commit comments

Comments
 (0)