55import liquidjava .rj_language .ast .LiteralBoolean ;
66import liquidjava .rj_language .opt .derivation_node .BinaryDerivationNode ;
77import liquidjava .rj_language .opt .derivation_node .DerivationNode ;
8+ import liquidjava .rj_language .opt .derivation_node .UnaryDerivationNode ;
89import liquidjava .rj_language .opt .derivation_node .ValDerivationNode ;
910
1011public class ExpressionSimplifier {
@@ -15,7 +16,7 @@ public class ExpressionSimplifier {
1516 */
1617 public static ValDerivationNode simplify (Expression exp ) {
1718 ValDerivationNode fixedPoint = simplifyToFixedPoint (null , null , exp );
18- return simplifyDerivationTree (fixedPoint );
19+ return simplifyValDerivationNode (fixedPoint );
1920 }
2021
2122 /**
@@ -25,7 +26,7 @@ public static ValDerivationNode simplify(Expression exp) {
2526 private static ValDerivationNode simplifyToFixedPoint (ValDerivationNode current , ValDerivationNode previous ,
2627 Expression prevExp ) {
2728 // apply propagation and folding
28- ValDerivationNode prop = ConstantPropagation .propagate (prevExp );
29+ ValDerivationNode prop = ConstantPropagation .propagate (prevExp , current );
2930 ValDerivationNode fold = ConstantFolding .fold (prop );
3031 Expression currExp = fold .getValue ();
3132
@@ -41,7 +42,7 @@ private static ValDerivationNode simplifyToFixedPoint(ValDerivationNode current,
4142 /**
4243 * Recursively simplifies the derivation tree by removing redundant conjuncts
4344 */
44- private static ValDerivationNode simplifyDerivationTree (ValDerivationNode node ) {
45+ private static ValDerivationNode simplifyValDerivationNode (ValDerivationNode node ) {
4546 Expression value = node .getValue ();
4647 DerivationNode origin = node .getOrigin ();
4748
@@ -50,13 +51,12 @@ private static ValDerivationNode simplifyDerivationTree(ValDerivationNode node)
5051 ValDerivationNode leftSimplified ;
5152 ValDerivationNode rightSimplified ;
5253
53- // simplify children
5454 if (origin instanceof BinaryDerivationNode binOrigin ) {
55- leftSimplified = simplifyDerivationTree (binOrigin .getLeft ());
56- rightSimplified = simplifyDerivationTree (binOrigin .getRight ());
55+ leftSimplified = simplifyValDerivationNode (binOrigin .getLeft ());
56+ rightSimplified = simplifyValDerivationNode (binOrigin .getRight ());
5757 } else {
58- leftSimplified = simplifyDerivationTree (new ValDerivationNode (binExp .getFirstOperand (), null ));
59- rightSimplified = simplifyDerivationTree (new ValDerivationNode (binExp .getSecondOperand (), null ));
58+ leftSimplified = simplifyValDerivationNode (new ValDerivationNode (binExp .getFirstOperand (), null ));
59+ rightSimplified = simplifyValDerivationNode (new ValDerivationNode (binExp .getSecondOperand (), null ));
6060 }
6161
6262 // check if either side is redundant
@@ -65,7 +65,7 @@ private static ValDerivationNode simplifyDerivationTree(ValDerivationNode node)
6565 if (isRedundant (rightSimplified .getValue ()))
6666 return leftSimplified ;
6767
68- // check if children are equal (x && x => x)
68+ // collapse identical sides (x && x => x)
6969 if (leftSimplified .getValue ().toString ().equals (rightSimplified .getValue ().toString ())) {
7070 return leftSimplified ;
7171 }
@@ -75,10 +75,41 @@ private static ValDerivationNode simplifyDerivationTree(ValDerivationNode node)
7575 DerivationNode newOrigin = new BinaryDerivationNode (leftSimplified , rightSimplified , "&&" );
7676 return new ValDerivationNode (newValue , newOrigin );
7777 }
78+
79+ // simplify origin
80+ DerivationNode simplifiedOrigin = simplifyDerivationNode (origin );
81+ if (simplifiedOrigin != origin ) {
82+ return new ValDerivationNode (value , simplifiedOrigin );
83+ }
84+
7885 // no simplification
7986 return node ;
8087 }
8188
89+ private static DerivationNode simplifyDerivationNode (DerivationNode node ) {
90+ if (node == null )
91+ return null ;
92+ if (node instanceof ValDerivationNode val ) {
93+ return simplifyValDerivationNode (val );
94+ }
95+ if (node instanceof BinaryDerivationNode binary ) {
96+ ValDerivationNode left = simplifyValDerivationNode (binary .getLeft ());
97+ ValDerivationNode right = simplifyValDerivationNode (binary .getRight ());
98+ if (left != binary .getLeft () || right != binary .getRight ()) {
99+ return new BinaryDerivationNode (left , right , binary .getOp ());
100+ }
101+ return binary ;
102+ }
103+ if (node instanceof UnaryDerivationNode unary ) {
104+ ValDerivationNode operand = simplifyValDerivationNode (unary .getOperand ());
105+ if (operand != unary .getOperand ()) {
106+ return new UnaryDerivationNode (operand , unary .getOp ());
107+ }
108+ return unary ;
109+ }
110+ return node ;
111+ }
112+
82113 /**
83114 * Checks if an expression is redundant (e.g. true or x == x)
84115 */
0 commit comments