Skip to content

Commit 8685353

Browse files
authored
Simplification Fixes (#125)
1 parent 6e9d4fe commit 8685353

File tree

10 files changed

+499
-77
lines changed

10 files changed

+499
-77
lines changed

liquidjava-verifier/pom.xml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
<groupId>io.github.liquid-java</groupId>
1313
<artifactId>liquidjava-verifier</artifactId>
14-
<version>0.0.4</version>
14+
<version>0.0.8</version>
1515
<name>liquidjava-verifier</name>
1616
<description>LiquidJava Verifier</description>
1717
<url>https://github.com/liquid-java/liquidjava</url>

liquidjava-verifier/src/main/java/liquidjava/diagnostics/errors/RefinementError.java

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,19 +12,18 @@
1212
*/
1313
public class RefinementError extends LJError {
1414

15-
private final String expected;
15+
private final ValDerivationNode expected;
1616
private final ValDerivationNode found;
1717

18-
public RefinementError(SourcePosition position, Expression expected, ValDerivationNode found,
18+
public RefinementError(SourcePosition position, ValDerivationNode expected, ValDerivationNode found,
1919
TranslationTable translationTable) {
20-
super("Refinement Error",
21-
String.format("%s is not a subtype of %s", found.getValue(), expected.toSimplifiedString()), position,
22-
translationTable);
23-
this.expected = expected.toSimplifiedString();
20+
super("Refinement Error", String.format("%s is not a subtype of %s", found.getValue(), expected.getValue()),
21+
position, translationTable);
22+
this.expected = expected;
2423
this.found = found;
2524
}
2625

27-
public String getExpected() {
26+
public ValDerivationNode getExpected() {
2827
return expected;
2928
}
3029

liquidjava-verifier/src/main/java/liquidjava/processor/refinement_checker/VCChecker.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,8 @@ public void processSubtyping(Predicate expectedType, List<GhostState> list, CtEl
5252
}
5353
boolean isSubtype = smtChecks(expected, premises, element.getPosition());
5454
if (!isSubtype)
55-
throw new RefinementError(element.getPosition(), expectedType.getExpression(),
56-
premisesBeforeChange.simplify(), map);
55+
throw new RefinementError(element.getPosition(), expectedType.simplify(), premisesBeforeChange.simplify(),
56+
map);
5757
}
5858

5959
/**
@@ -263,7 +263,7 @@ protected void throwRefinementError(SourcePosition position, Predicate expected,
263263
gatherVariables(found, lrv, mainVars);
264264
TranslationTable map = new TranslationTable();
265265
Predicate premises = joinPredicates(expected, mainVars, lrv, map).toConjunctions();
266-
throw new RefinementError(position, expected.getExpression(), premises.simplify(), map);
266+
throw new RefinementError(position, expected.simplify(), premises.simplify(), map);
267267
}
268268

269269
protected void throwStateRefinementError(SourcePosition position, Predicate found, Predicate expected)

liquidjava-verifier/src/main/java/liquidjava/rj_language/ast/Expression.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -271,7 +271,7 @@ public void validateGhostInvocations(Context ctx, Factory f) throws LJError {
271271
if (this instanceof FunctionInvocation fi) {
272272
// get all ghosts with the matching name
273273
List<GhostFunction> candidates = ctx.getGhosts().stream().filter(g -> g.matches(fi.name)).toList();
274-
if (candidates.isEmpty())
274+
if (candidates.isEmpty())
275275
return; // not found error is thrown elsewhere
276276

277277
// find matching overload

liquidjava-verifier/src/main/java/liquidjava/rj_language/opt/ConstantPropagation.java

Lines changed: 81 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import liquidjava.rj_language.opt.derivation_node.ValDerivationNode;
1111
import liquidjava.rj_language.opt.derivation_node.VarDerivationNode;
1212

13+
import java.util.HashMap;
1314
import java.util.Map;
1415

1516
public class ConstantPropagation {
@@ -19,55 +20,71 @@ public class ConstantPropagation {
1920
* VariableResolver to extract variable equalities from the expression first. Returns a derivation node representing
2021
* the propagation steps taken.
2122
*/
22-
public static ValDerivationNode propagate(Expression exp) {
23+
public static ValDerivationNode propagate(Expression exp, ValDerivationNode previousOrigin) {
2324
Map<String, Expression> substitutions = VariableResolver.resolve(exp);
24-
return propagateRecursive(exp, substitutions);
25+
26+
// map of variable origins from the previous derivation tree
27+
Map<String, DerivationNode> varOrigins = new HashMap<>();
28+
if (previousOrigin != null) {
29+
extractVarOrigins(previousOrigin, varOrigins);
30+
}
31+
return propagateRecursive(exp, substitutions, varOrigins);
2532
}
2633

2734
/**
2835
* Recursively performs constant propagation on an expression (e.g. x + y && x == 1 && y == 2 => 1 + 2)
2936
*/
30-
private static ValDerivationNode propagateRecursive(Expression exp, Map<String, Expression> subs) {
37+
private static ValDerivationNode propagateRecursive(Expression exp, Map<String, Expression> subs,
38+
Map<String, DerivationNode> varOrigins) {
3139

3240
// substitute variable
3341
if (exp instanceof Var var) {
3442
String name = var.getName();
3543
Expression value = subs.get(name);
3644
// substitution
37-
if (value != null)
38-
return new ValDerivationNode(value.clone(), new VarDerivationNode(name));
45+
if (value != null) {
46+
// check if this variable has an origin from a previous pass
47+
DerivationNode previousOrigin = varOrigins.get(name);
48+
49+
// preserve origin if value came from previous derivation
50+
DerivationNode origin = previousOrigin != null ? new VarDerivationNode(name, previousOrigin)
51+
: new VarDerivationNode(name);
52+
return new ValDerivationNode(value.clone(), origin);
53+
}
3954

4055
// no substitution
4156
return new ValDerivationNode(var, null);
4257
}
4358

4459
// lift unary origin
4560
if (exp instanceof UnaryExpression unary) {
46-
ValDerivationNode operand = propagateRecursive(unary.getChildren().get(0), subs);
47-
unary.setChild(0, operand.getValue());
61+
ValDerivationNode operand = propagateRecursive(unary.getChildren().get(0), subs, varOrigins);
62+
UnaryExpression cloned = (UnaryExpression) unary.clone();
63+
cloned.setChild(0, operand.getValue());
4864

49-
DerivationNode origin = operand.getOrigin() != null ? new UnaryDerivationNode(operand, unary.getOp())
50-
: null;
51-
return new ValDerivationNode(unary, origin);
65+
return operand.getOrigin() != null
66+
? new ValDerivationNode(cloned, new UnaryDerivationNode(operand, cloned.getOp()))
67+
: new ValDerivationNode(cloned, null);
5268
}
5369

5470
// lift binary origin
5571
if (exp instanceof BinaryExpression binary) {
56-
ValDerivationNode left = propagateRecursive(binary.getFirstOperand(), subs);
57-
ValDerivationNode right = propagateRecursive(binary.getSecondOperand(), subs);
58-
binary.setChild(0, left.getValue());
59-
binary.setChild(1, right.getValue());
60-
61-
DerivationNode origin = (left.getOrigin() != null || right.getOrigin() != null)
62-
? new BinaryDerivationNode(left, right, binary.getOperator()) : null;
63-
return new ValDerivationNode(binary, origin);
72+
ValDerivationNode left = propagateRecursive(binary.getFirstOperand(), subs, varOrigins);
73+
ValDerivationNode right = propagateRecursive(binary.getSecondOperand(), subs, varOrigins);
74+
BinaryExpression cloned = (BinaryExpression) binary.clone();
75+
cloned.setChild(0, left.getValue());
76+
cloned.setChild(1, right.getValue());
77+
78+
return (left.getOrigin() != null || right.getOrigin() != null)
79+
? new ValDerivationNode(cloned, new BinaryDerivationNode(left, right, cloned.getOperator()))
80+
: new ValDerivationNode(cloned, null);
6481
}
6582

6683
// recursively propagate children
6784
if (exp.hasChildren()) {
6885
Expression propagated = exp.clone();
6986
for (int i = 0; i < exp.getChildren().size(); i++) {
70-
ValDerivationNode child = propagateRecursive(exp.getChildren().get(i), subs);
87+
ValDerivationNode child = propagateRecursive(exp.getChildren().get(i), subs, varOrigins);
7188
propagated.setChild(i, child.getValue());
7289
}
7390
return new ValDerivationNode(propagated, null);
@@ -76,4 +93,49 @@ private static ValDerivationNode propagateRecursive(Expression exp, Map<String,
7693
// no propagation
7794
return new ValDerivationNode(exp, null);
7895
}
96+
97+
/**
98+
* Extracts the derivation nodes for variable values from the derivation tree This is so done so when we find "var
99+
* == value" in the tree, we store the derivation of the value So it can be preserved when var is substituted in
100+
* subsequent passes
101+
*/
102+
private static void extractVarOrigins(ValDerivationNode node, Map<String, DerivationNode> varOrigins) {
103+
if (node == null)
104+
return;
105+
106+
Expression value = node.getValue();
107+
DerivationNode origin = node.getOrigin();
108+
109+
// check for equality expressions
110+
if (value instanceof BinaryExpression binExp && "==".equals(binExp.getOperator())
111+
&& origin instanceof BinaryDerivationNode binOrigin) {
112+
Expression left = binExp.getFirstOperand();
113+
Expression right = binExp.getSecondOperand();
114+
115+
// extract variable name and value derivation from either side
116+
String varName = null;
117+
ValDerivationNode valueDerivation = null;
118+
119+
if (left instanceof Var var && right.isLiteral()) {
120+
varName = var.getName();
121+
valueDerivation = binOrigin.getRight();
122+
} else if (right instanceof Var var && left.isLiteral()) {
123+
varName = var.getName();
124+
valueDerivation = binOrigin.getLeft();
125+
}
126+
if (varName != null && valueDerivation != null && valueDerivation.getOrigin() != null) {
127+
varOrigins.put(varName, valueDerivation.getOrigin());
128+
}
129+
}
130+
131+
// recursively process the origin tree
132+
if (origin instanceof BinaryDerivationNode binOrigin) {
133+
extractVarOrigins(binOrigin.getLeft(), varOrigins);
134+
extractVarOrigins(binOrigin.getRight(), varOrigins);
135+
} else if (origin instanceof UnaryDerivationNode unaryOrigin) {
136+
extractVarOrigins(unaryOrigin.getOperand(), varOrigins);
137+
} else if (origin instanceof ValDerivationNode valOrigin) {
138+
extractVarOrigins(valOrigin, varOrigins);
139+
}
140+
}
79141
}

liquidjava-verifier/src/main/java/liquidjava/rj_language/opt/ExpressionSimplifier.java

Lines changed: 66 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -14,41 +14,88 @@ public class ExpressionSimplifier {
1414
* Returns a derivation node representing the tree of simplifications applied
1515
*/
1616
public static ValDerivationNode simplify(Expression exp) {
17-
ValDerivationNode prop = ConstantPropagation.propagate(exp);
17+
ValDerivationNode fixedPoint = simplifyToFixedPoint(null, exp);
18+
return simplifyValDerivationNode(fixedPoint);
19+
}
20+
21+
/**
22+
* Recursively applies propagation and folding until the expression stops changing (fixed point) Stops early if the
23+
* expression simplifies to 'true', which means we've simplified too much
24+
*/
25+
private static ValDerivationNode simplifyToFixedPoint(ValDerivationNode current, Expression prevExp) {
26+
// apply propagation and folding
27+
ValDerivationNode prop = ConstantPropagation.propagate(prevExp, current);
1828
ValDerivationNode fold = ConstantFolding.fold(prop);
19-
return simplifyDerivationTree(fold);
29+
ValDerivationNode simplified = simplifyValDerivationNode(fold);
30+
Expression currExp = simplified.getValue();
31+
32+
// fixed point reached
33+
if (current != null && currExp.equals(current.getValue())) {
34+
return current;
35+
}
36+
37+
// continue simplifying
38+
return simplifyToFixedPoint(simplified, simplified.getValue());
2039
}
2140

2241
/**
2342
* Recursively simplifies the derivation tree by removing redundant conjuncts
2443
*/
25-
private static ValDerivationNode simplifyDerivationTree(ValDerivationNode node) {
44+
private static ValDerivationNode simplifyValDerivationNode(ValDerivationNode node) {
2645
Expression value = node.getValue();
2746
DerivationNode origin = node.getOrigin();
2847

2948
// binary expression with &&
30-
if (value instanceof BinaryExpression binExp) {
31-
if ("&&".equals(binExp.getOperator()) && origin instanceof BinaryDerivationNode binOrigin) {
32-
// recursively simplify children
33-
ValDerivationNode leftSimplified = simplifyDerivationTree(binOrigin.getLeft());
34-
ValDerivationNode rightSimplified = simplifyDerivationTree(binOrigin.getRight());
35-
36-
// check if either side is redundant
37-
if (isRedundant(leftSimplified.getValue()))
38-
return rightSimplified;
39-
if (isRedundant(rightSimplified.getValue()))
40-
return leftSimplified;
41-
42-
// return the conjunction with simplified children
43-
Expression newValue = new BinaryExpression(leftSimplified.getValue(), "&&", rightSimplified.getValue());
44-
DerivationNode newOrigin = new BinaryDerivationNode(leftSimplified, rightSimplified, "&&");
45-
return new ValDerivationNode(newValue, newOrigin);
49+
if (value instanceof BinaryExpression binExp && "&&".equals(binExp.getOperator())) {
50+
ValDerivationNode leftSimplified;
51+
ValDerivationNode rightSimplified;
52+
53+
if (origin instanceof BinaryDerivationNode binOrigin) {
54+
leftSimplified = simplifyValDerivationNode(binOrigin.getLeft());
55+
rightSimplified = simplifyValDerivationNode(binOrigin.getRight());
56+
} else {
57+
leftSimplified = simplifyValDerivationNode(new ValDerivationNode(binExp.getFirstOperand(), null));
58+
rightSimplified = simplifyValDerivationNode(new ValDerivationNode(binExp.getSecondOperand(), null));
4659
}
60+
61+
// check if either side is redundant
62+
if (isRedundant(leftSimplified.getValue()))
63+
return rightSimplified;
64+
if (isRedundant(rightSimplified.getValue()))
65+
return leftSimplified;
66+
67+
// collapse identical sides (x && x => x)
68+
if (leftSimplified.getValue().equals(rightSimplified.getValue())) {
69+
return leftSimplified;
70+
}
71+
72+
// collapse symmetric equalities (e.g. x == y && y == x => x == y)
73+
if (isSymmetricEquality(leftSimplified.getValue(), rightSimplified.getValue())) {
74+
return leftSimplified;
75+
}
76+
77+
// return the conjunction with simplified children
78+
Expression newValue = new BinaryExpression(leftSimplified.getValue(), "&&", rightSimplified.getValue());
79+
DerivationNode newOrigin = new BinaryDerivationNode(leftSimplified, rightSimplified, "&&");
80+
return new ValDerivationNode(newValue, newOrigin);
4781
}
4882
// no simplification
4983
return node;
5084
}
5185

86+
private static boolean isSymmetricEquality(Expression left, Expression right) {
87+
if (left instanceof BinaryExpression b1 && "==".equals(b1.getOperator()) && right instanceof BinaryExpression b2
88+
&& "==".equals(b2.getOperator())) {
89+
90+
Expression l1 = b1.getFirstOperand();
91+
Expression r1 = b1.getSecondOperand();
92+
Expression l2 = b2.getFirstOperand();
93+
Expression r2 = b2.getSecondOperand();
94+
return l1.equals(r2) && r1.equals(l2);
95+
}
96+
return false;
97+
}
98+
5299
/**
53100
* Checks if an expression is redundant (e.g. true or x == x)
54101
*/

0 commit comments

Comments
 (0)