1010import liquidjava .rj_language .opt .derivation_node .ValDerivationNode ;
1111import liquidjava .rj_language .opt .derivation_node .VarDerivationNode ;
1212
13+ import java .util .HashMap ;
1314import java .util .Map ;
1415
1516public class ConstantPropagation {
@@ -19,55 +20,71 @@ public class ConstantPropagation {
1920 * VariableResolver to extract variable equalities from the expression first. Returns a derivation node representing
2021 * the propagation steps taken.
2122 */
22- public static ValDerivationNode propagate (Expression exp ) {
23+ public static ValDerivationNode propagate (Expression exp , ValDerivationNode previousOrigin ) {
2324 Map <String , Expression > substitutions = VariableResolver .resolve (exp );
24- return propagateRecursive (exp , substitutions );
25+
26+ // map of variable origins from the previous derivation tree
27+ Map <String , DerivationNode > varOrigins = new HashMap <>();
28+ if (previousOrigin != null ) {
29+ extractVarOrigins (previousOrigin , varOrigins );
30+ }
31+ return propagateRecursive (exp , substitutions , varOrigins );
2532 }
2633
2734 /**
2835 * Recursively performs constant propagation on an expression (e.g. x + y && x == 1 && y == 2 => 1 + 2)
2936 */
30- private static ValDerivationNode propagateRecursive (Expression exp , Map <String , Expression > subs ) {
37+ private static ValDerivationNode propagateRecursive (Expression exp , Map <String , Expression > subs ,
38+ Map <String , DerivationNode > varOrigins ) {
3139
3240 // substitute variable
3341 if (exp instanceof Var var ) {
3442 String name = var .getName ();
3543 Expression value = subs .get (name );
3644 // substitution
37- if (value != null )
38- return new ValDerivationNode (value .clone (), new VarDerivationNode (name ));
45+ if (value != null ) {
46+ // check if this variable has an origin from a previous pass
47+ DerivationNode previousOrigin = varOrigins .get (name );
48+
49+ // preserve origin if value came from previous derivation
50+ DerivationNode origin = previousOrigin != null ? new VarDerivationNode (name , previousOrigin )
51+ : new VarDerivationNode (name );
52+ return new ValDerivationNode (value .clone (), origin );
53+ }
3954
4055 // no substitution
4156 return new ValDerivationNode (var , null );
4257 }
4358
4459 // lift unary origin
4560 if (exp instanceof UnaryExpression unary ) {
46- ValDerivationNode operand = propagateRecursive (unary .getChildren ().get (0 ), subs );
47- unary .setChild (0 , operand .getValue ());
61+ ValDerivationNode operand = propagateRecursive (unary .getChildren ().get (0 ), subs , varOrigins );
62+ UnaryExpression cloned = (UnaryExpression ) unary .clone ();
63+ cloned .setChild (0 , operand .getValue ());
4864
49- DerivationNode origin = operand .getOrigin () != null ? new UnaryDerivationNode ( operand , unary . getOp ())
50- : null ;
51- return new ValDerivationNode (unary , origin );
65+ return operand .getOrigin () != null
66+ ? new ValDerivationNode ( cloned , new UnaryDerivationNode ( operand , cloned . getOp ()))
67+ : new ValDerivationNode (cloned , null );
5268 }
5369
5470 // lift binary origin
5571 if (exp instanceof BinaryExpression binary ) {
56- ValDerivationNode left = propagateRecursive (binary .getFirstOperand (), subs );
57- ValDerivationNode right = propagateRecursive (binary .getSecondOperand (), subs );
58- binary .setChild (0 , left .getValue ());
59- binary .setChild (1 , right .getValue ());
60-
61- DerivationNode origin = (left .getOrigin () != null || right .getOrigin () != null )
62- ? new BinaryDerivationNode (left , right , binary .getOperator ()) : null ;
63- return new ValDerivationNode (binary , origin );
72+ ValDerivationNode left = propagateRecursive (binary .getFirstOperand (), subs , varOrigins );
73+ ValDerivationNode right = propagateRecursive (binary .getSecondOperand (), subs , varOrigins );
74+ BinaryExpression cloned = (BinaryExpression ) binary .clone ();
75+ cloned .setChild (0 , left .getValue ());
76+ cloned .setChild (1 , right .getValue ());
77+
78+ return (left .getOrigin () != null || right .getOrigin () != null )
79+ ? new ValDerivationNode (cloned , new BinaryDerivationNode (left , right , cloned .getOperator ()))
80+ : new ValDerivationNode (cloned , null );
6481 }
6582
6683 // recursively propagate children
6784 if (exp .hasChildren ()) {
6885 Expression propagated = exp .clone ();
6986 for (int i = 0 ; i < exp .getChildren ().size (); i ++) {
70- ValDerivationNode child = propagateRecursive (exp .getChildren ().get (i ), subs );
87+ ValDerivationNode child = propagateRecursive (exp .getChildren ().get (i ), subs , varOrigins );
7188 propagated .setChild (i , child .getValue ());
7289 }
7390 return new ValDerivationNode (propagated , null );
@@ -76,4 +93,49 @@ private static ValDerivationNode propagateRecursive(Expression exp, Map<String,
7693 // no propagation
7794 return new ValDerivationNode (exp , null );
7895 }
96+
97+ /**
98+ * Extracts the derivation nodes for variable values from the derivation tree This is so done so when we find "var
99+ * == value" in the tree, we store the derivation of the value So it can be preserved when var is substituted in
100+ * subsequent passes
101+ */
102+ private static void extractVarOrigins (ValDerivationNode node , Map <String , DerivationNode > varOrigins ) {
103+ if (node == null )
104+ return ;
105+
106+ Expression value = node .getValue ();
107+ DerivationNode origin = node .getOrigin ();
108+
109+ // check for equality expressions
110+ if (value instanceof BinaryExpression binExp && "==" .equals (binExp .getOperator ())
111+ && origin instanceof BinaryDerivationNode binOrigin ) {
112+ Expression left = binExp .getFirstOperand ();
113+ Expression right = binExp .getSecondOperand ();
114+
115+ // extract variable name and value derivation from either side
116+ String varName = null ;
117+ ValDerivationNode valueDerivation = null ;
118+
119+ if (left instanceof Var var && right .isLiteral ()) {
120+ varName = var .getName ();
121+ valueDerivation = binOrigin .getRight ();
122+ } else if (right instanceof Var var && left .isLiteral ()) {
123+ varName = var .getName ();
124+ valueDerivation = binOrigin .getLeft ();
125+ }
126+ if (varName != null && valueDerivation != null && valueDerivation .getOrigin () != null ) {
127+ varOrigins .put (varName , valueDerivation .getOrigin ());
128+ }
129+ }
130+
131+ // recursively process the origin tree
132+ if (origin instanceof BinaryDerivationNode binOrigin ) {
133+ extractVarOrigins (binOrigin .getLeft (), varOrigins );
134+ extractVarOrigins (binOrigin .getRight (), varOrigins );
135+ } else if (origin instanceof UnaryDerivationNode unaryOrigin ) {
136+ extractVarOrigins (unaryOrigin .getOperand (), varOrigins );
137+ } else if (origin instanceof ValDerivationNode valOrigin ) {
138+ extractVarOrigins (valOrigin , varOrigins );
139+ }
140+ }
79141}
0 commit comments