55import com .google .common .collect .Sets ;
66import de .peeeq .wurstscript .jassIm .*;
77import de .peeeq .wurstscript .translation .imtranslation .*;
8+ import de .peeeq .wurstscript .types .TypesHelper ;
89
910import java .util .*;
1011
@@ -104,7 +105,7 @@ private void inlineCall(ImFunction f, Element parent, int parentI, ImFunctionCal
104105 if (called == f ) {
105106 throw new Error ("cannot inline self." );
106107 }
107- List <ImStmt > stmts = Lists .newArrayList ();
108+ List <ImStmt > prefixStmts = Lists .newArrayList ();
108109 // save arguments to temp vars:
109110 List <ImExpr > args = call .getArguments ().removeAll ();
110111 Map <ImVar , ImVar > varSubtitutions = Maps .newLinkedHashMap ();
@@ -115,7 +116,7 @@ private void inlineCall(ImFunction f, Element parent, int parentI, ImFunctionCal
115116 f .getLocals ().add (tempVar );
116117 varSubtitutions .put (param , tempVar );
117118 // set temp var
118- stmts .add (JassIm .ImSet (arg .attrTrace (), JassIm .ImVarAccess (tempVar ), arg ));
119+ prefixStmts .add (JassIm .ImSet (arg .attrTrace (), JassIm .ImVarAccess (tempVar ), arg ));
119120 }
120121 // add locals
121122 for (ImVar l : called .getLocals ()) {
@@ -124,6 +125,7 @@ private void inlineCall(ImFunction f, Element parent, int parentI, ImFunctionCal
124125 varSubtitutions .put (l , newL );
125126 }
126127 // add body and replace params with tempvars
128+ List <ImStmt > copiedBody = Lists .newArrayList ();
127129 for (int i = 0 ; i < called .getBody ().size (); i ++) {
128130 ImStmt s = called .getBody ().get (i ).copy ();
129131 ImHelper .replaceVar (s , varSubtitutions );
@@ -138,22 +140,48 @@ public void visit(ImFunctionCall called) {
138140 });
139141
140142
141- stmts .add (s );
143+ copiedBody .add (s );
142144 }
143- // handle return
145+
146+ List <ImStmt > stmts = Lists .newArrayList ();
147+ stmts .addAll (prefixStmts );
148+
144149 ImExpr newExpr = null ;
145- if (stmts .size () > 0 ) {
146- ImStmt lastStmt = stmts .get (stmts .size () - 1 );
147- if (lastStmt instanceof ImReturn ) {
148- ImReturn ret = (ImReturn ) lastStmt ;
149- stmts .remove (stmts .size () - 1 );
150- ImExprOpt valOpt = ret .getReturnValue ();
151- if (valOpt instanceof ImExpr ) {
152- ImExpr val = (ImExpr ) valOpt .copy ();
153- ImHelper .replaceVar (val , varSubtitutions );
154- newExpr = ImStatementExpr (ImStmts (stmts ), val );
150+ if (maxOneReturn (called )) {
151+ // Fast path for existing single-return shape.
152+ stmts .addAll (copiedBody );
153+ if (!stmts .isEmpty ()) {
154+ ImStmt lastStmt = stmts .get (stmts .size () - 1 );
155+ if (lastStmt instanceof ImReturn ) {
156+ ImReturn ret = (ImReturn ) lastStmt ;
157+ stmts .remove (stmts .size () - 1 );
158+ ImExprOpt valOpt = ret .getReturnValue ();
159+ if (valOpt instanceof ImExpr ) {
160+ ImExpr val = (ImExpr ) valOpt .copy ();
161+ ImHelper .replaceVar (val , varSubtitutions );
162+ newExpr = ImStatementExpr (ImStmts (stmts ), val );
163+ }
155164 }
156165 }
166+ } else {
167+ // Multi-return path: rewrite returns to done-flag + optional return temp.
168+ ImVar doneVar = JassIm .ImVar (call .attrTrace (), TypesHelper .imBool (), "inlineDone" , false );
169+ f .getLocals ().add (doneVar );
170+ stmts .add (JassIm .ImSet (call .attrTrace (), JassIm .ImVarAccess (doneVar ), JassIm .ImBoolVal (false )));
171+
172+ ImVar retVar = null ;
173+ if (!(called .getReturnType () instanceof ImVoid )) {
174+ retVar = JassIm .ImVar (call .attrTrace (), called .getReturnType ().copy (), "inlineRet" , false );
175+ f .getLocals ().add (retVar );
176+ stmts .add (JassIm .ImSet (call .attrTrace (), JassIm .ImVarAccess (retVar ), ImHelper .defaultValueForComplexType (called .getReturnType ())));
177+ }
178+
179+ ImStmts rewritten = rewriteForEarlyReturns (JassIm .ImStmts (copiedBody ), doneVar , retVar );
180+ stmts .addAll (rewritten .removeAll ());
181+
182+ if (retVar != null ) {
183+ newExpr = ImStatementExpr (ImStmts (stmts ), JassIm .ImVarAccess (retVar ));
184+ }
157185 }
158186 if (newExpr == null ) {
159187 newExpr = ImHelper .statementExprVoid (ImStmts (stmts ));
@@ -162,6 +190,48 @@ public void visit(ImFunctionCall called) {
162190
163191 }
164192
193+ private ImStmts rewriteForEarlyReturns (ImStmts body , ImVar doneVar , ImVar retVar ) {
194+ ImStmts rewritten = JassIm .ImStmts ();
195+ for (ImStmt s : body ) {
196+ ImStmt transformed = rewriteStmtForEarlyReturn (s , doneVar , retVar );
197+ ImExpr notDone = JassIm .ImOperatorCall (de .peeeq .wurstscript .WurstOperator .NOT , JassIm .ImExprs (JassIm .ImVarAccess (doneVar )));
198+ rewritten .add (JassIm .ImIf (s .attrTrace (), notDone , JassIm .ImStmts (transformed ), JassIm .ImStmts ()));
199+ }
200+ return rewritten ;
201+ }
202+
203+ private ImStmt rewriteStmtForEarlyReturn (ImStmt s , ImVar doneVar , ImVar retVar ) {
204+ if (s instanceof ImReturn ) {
205+ ImReturn r = (ImReturn ) s ;
206+ ImStmts b = JassIm .ImStmts ();
207+ if (retVar != null && r .getReturnValue () instanceof ImExpr ) {
208+ ImExpr rv = (ImExpr ) r .getReturnValue ();
209+ rv .setParent (null );
210+ b .add (JassIm .ImSet (r .getTrace (), JassIm .ImVarAccess (retVar ), rv ));
211+ }
212+ b .add (JassIm .ImSet (r .getTrace (), JassIm .ImVarAccess (doneVar ), JassIm .ImBoolVal (true )));
213+ return ImHelper .statementExprVoid (b );
214+ } else if (s instanceof ImIf ) {
215+ ImIf imIf = (ImIf ) s ;
216+ ImStmts thenBlock = rewriteForEarlyReturns (imIf .getThenBlock ().copy (), doneVar , retVar );
217+ ImStmts elseBlock = rewriteForEarlyReturns (imIf .getElseBlock ().copy (), doneVar , retVar );
218+ return JassIm .ImIf (imIf .getTrace (), imIf .getCondition ().copy (), thenBlock , elseBlock );
219+ } else if (s instanceof ImLoop ) {
220+ ImLoop l = (ImLoop ) s ;
221+ ImStmts loopBody = JassIm .ImStmts ();
222+ loopBody .add (JassIm .ImExitwhen (l .getTrace (), JassIm .ImVarAccess (doneVar )));
223+ loopBody .addAll (rewriteForEarlyReturns (l .getBody ().copy (), doneVar , retVar ).removeAll ());
224+ return JassIm .ImLoop (l .getTrace (), loopBody );
225+ } else if (s instanceof ImVarargLoop ) {
226+ ImVarargLoop l = (ImVarargLoop ) s ;
227+ ImStmts loopBody = JassIm .ImStmts ();
228+ loopBody .add (JassIm .ImExitwhen (l .getTrace (), JassIm .ImVarAccess (doneVar )));
229+ loopBody .addAll (rewriteForEarlyReturns (l .getBody ().copy (), doneVar , retVar ).removeAll ());
230+ return JassIm .ImVarargLoop (l .getTrace (), loopBody , l .getLoopVar ());
231+ }
232+ return s ;
233+ }
234+
165235 private void rateInlinableFunctions () {
166236 for (Map .Entry <ImFunction , ImFunction > f : translator .getCalledFunctions ().entries ()) {
167237 incCallCount (f .getKey ());
@@ -288,9 +358,7 @@ private void collectInlinableFunctions() {
288358 // this is only relevant for lua, because in JASS they are eliminated before inlining
289359 continue ;
290360 }
291- if (maxOneReturn (f )) {
292- inlinableFunctions .add (f );
293- }
361+ inlinableFunctions .add (f );
294362 }
295363 }
296364
0 commit comments