Skip to content

Commit f7266a4

Browse files
committed
support inlining functions with multiple returns
1 parent 75dd08c commit f7266a4

2 files changed

Lines changed: 107 additions & 17 deletions

File tree

de.peeeq.wurstscript/src/main/java/de/peeeq/wurstscript/translation/imoptimizer/ImInliner.java

Lines changed: 85 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import com.google.common.collect.Sets;
66
import de.peeeq.wurstscript.jassIm.*;
77
import de.peeeq.wurstscript.translation.imtranslation.*;
8+
import de.peeeq.wurstscript.types.TypesHelper;
89

910
import java.util.*;
1011

@@ -104,7 +105,7 @@ private void inlineCall(ImFunction f, Element parent, int parentI, ImFunctionCal
104105
if (called == f) {
105106
throw new Error("cannot inline self.");
106107
}
107-
List<ImStmt> stmts = Lists.newArrayList();
108+
List<ImStmt> prefixStmts = Lists.newArrayList();
108109
// save arguments to temp vars:
109110
List<ImExpr> args = call.getArguments().removeAll();
110111
Map<ImVar, ImVar> varSubtitutions = Maps.newLinkedHashMap();
@@ -115,7 +116,7 @@ private void inlineCall(ImFunction f, Element parent, int parentI, ImFunctionCal
115116
f.getLocals().add(tempVar);
116117
varSubtitutions.put(param, tempVar);
117118
// set temp var
118-
stmts.add(JassIm.ImSet(arg.attrTrace(), JassIm.ImVarAccess(tempVar), arg));
119+
prefixStmts.add(JassIm.ImSet(arg.attrTrace(), JassIm.ImVarAccess(tempVar), arg));
119120
}
120121
// add locals
121122
for (ImVar l : called.getLocals()) {
@@ -124,6 +125,7 @@ private void inlineCall(ImFunction f, Element parent, int parentI, ImFunctionCal
124125
varSubtitutions.put(l, newL);
125126
}
126127
// add body and replace params with tempvars
128+
List<ImStmt> copiedBody = Lists.newArrayList();
127129
for (int i = 0; i < called.getBody().size(); i++) {
128130
ImStmt s = called.getBody().get(i).copy();
129131
ImHelper.replaceVar(s, varSubtitutions);
@@ -138,22 +140,48 @@ public void visit(ImFunctionCall called) {
138140
});
139141

140142

141-
stmts.add(s);
143+
copiedBody.add(s);
142144
}
143-
// handle return
145+
146+
List<ImStmt> stmts = Lists.newArrayList();
147+
stmts.addAll(prefixStmts);
148+
144149
ImExpr newExpr = null;
145-
if (stmts.size() > 0) {
146-
ImStmt lastStmt = stmts.get(stmts.size() - 1);
147-
if (lastStmt instanceof ImReturn) {
148-
ImReturn ret = (ImReturn) lastStmt;
149-
stmts.remove(stmts.size() - 1);
150-
ImExprOpt valOpt = ret.getReturnValue();
151-
if (valOpt instanceof ImExpr) {
152-
ImExpr val = (ImExpr) valOpt.copy();
153-
ImHelper.replaceVar(val, varSubtitutions);
154-
newExpr = ImStatementExpr(ImStmts(stmts), val);
150+
if (maxOneReturn(called)) {
151+
// Fast path for existing single-return shape.
152+
stmts.addAll(copiedBody);
153+
if (!stmts.isEmpty()) {
154+
ImStmt lastStmt = stmts.get(stmts.size() - 1);
155+
if (lastStmt instanceof ImReturn) {
156+
ImReturn ret = (ImReturn) lastStmt;
157+
stmts.remove(stmts.size() - 1);
158+
ImExprOpt valOpt = ret.getReturnValue();
159+
if (valOpt instanceof ImExpr) {
160+
ImExpr val = (ImExpr) valOpt.copy();
161+
ImHelper.replaceVar(val, varSubtitutions);
162+
newExpr = ImStatementExpr(ImStmts(stmts), val);
163+
}
155164
}
156165
}
166+
} else {
167+
// Multi-return path: rewrite returns to done-flag + optional return temp.
168+
ImVar doneVar = JassIm.ImVar(call.attrTrace(), TypesHelper.imBool(), "inlineDone", false);
169+
f.getLocals().add(doneVar);
170+
stmts.add(JassIm.ImSet(call.attrTrace(), JassIm.ImVarAccess(doneVar), JassIm.ImBoolVal(false)));
171+
172+
ImVar retVar = null;
173+
if (!(called.getReturnType() instanceof ImVoid)) {
174+
retVar = JassIm.ImVar(call.attrTrace(), called.getReturnType().copy(), "inlineRet", false);
175+
f.getLocals().add(retVar);
176+
stmts.add(JassIm.ImSet(call.attrTrace(), JassIm.ImVarAccess(retVar), ImHelper.defaultValueForComplexType(called.getReturnType())));
177+
}
178+
179+
ImStmts rewritten = rewriteForEarlyReturns(JassIm.ImStmts(copiedBody), doneVar, retVar);
180+
stmts.addAll(rewritten.removeAll());
181+
182+
if (retVar != null) {
183+
newExpr = ImStatementExpr(ImStmts(stmts), JassIm.ImVarAccess(retVar));
184+
}
157185
}
158186
if (newExpr == null) {
159187
newExpr = ImHelper.statementExprVoid(ImStmts(stmts));
@@ -162,6 +190,48 @@ public void visit(ImFunctionCall called) {
162190

163191
}
164192

193+
private ImStmts rewriteForEarlyReturns(ImStmts body, ImVar doneVar, ImVar retVar) {
194+
ImStmts rewritten = JassIm.ImStmts();
195+
for (ImStmt s : body) {
196+
ImStmt transformed = rewriteStmtForEarlyReturn(s, doneVar, retVar);
197+
ImExpr notDone = JassIm.ImOperatorCall(de.peeeq.wurstscript.WurstOperator.NOT, JassIm.ImExprs(JassIm.ImVarAccess(doneVar)));
198+
rewritten.add(JassIm.ImIf(s.attrTrace(), notDone, JassIm.ImStmts(transformed), JassIm.ImStmts()));
199+
}
200+
return rewritten;
201+
}
202+
203+
private ImStmt rewriteStmtForEarlyReturn(ImStmt s, ImVar doneVar, ImVar retVar) {
204+
if (s instanceof ImReturn) {
205+
ImReturn r = (ImReturn) s;
206+
ImStmts b = JassIm.ImStmts();
207+
if (retVar != null && r.getReturnValue() instanceof ImExpr) {
208+
ImExpr rv = (ImExpr) r.getReturnValue();
209+
rv.setParent(null);
210+
b.add(JassIm.ImSet(r.getTrace(), JassIm.ImVarAccess(retVar), rv));
211+
}
212+
b.add(JassIm.ImSet(r.getTrace(), JassIm.ImVarAccess(doneVar), JassIm.ImBoolVal(true)));
213+
return ImHelper.statementExprVoid(b);
214+
} else if (s instanceof ImIf) {
215+
ImIf imIf = (ImIf) s;
216+
ImStmts thenBlock = rewriteForEarlyReturns(imIf.getThenBlock().copy(), doneVar, retVar);
217+
ImStmts elseBlock = rewriteForEarlyReturns(imIf.getElseBlock().copy(), doneVar, retVar);
218+
return JassIm.ImIf(imIf.getTrace(), imIf.getCondition().copy(), thenBlock, elseBlock);
219+
} else if (s instanceof ImLoop) {
220+
ImLoop l = (ImLoop) s;
221+
ImStmts loopBody = JassIm.ImStmts();
222+
loopBody.add(JassIm.ImExitwhen(l.getTrace(), JassIm.ImVarAccess(doneVar)));
223+
loopBody.addAll(rewriteForEarlyReturns(l.getBody().copy(), doneVar, retVar).removeAll());
224+
return JassIm.ImLoop(l.getTrace(), loopBody);
225+
} else if (s instanceof ImVarargLoop) {
226+
ImVarargLoop l = (ImVarargLoop) s;
227+
ImStmts loopBody = JassIm.ImStmts();
228+
loopBody.add(JassIm.ImExitwhen(l.getTrace(), JassIm.ImVarAccess(doneVar)));
229+
loopBody.addAll(rewriteForEarlyReturns(l.getBody().copy(), doneVar, retVar).removeAll());
230+
return JassIm.ImVarargLoop(l.getTrace(), loopBody, l.getLoopVar());
231+
}
232+
return s;
233+
}
234+
165235
private void rateInlinableFunctions() {
166236
for (Map.Entry<ImFunction, ImFunction> f : translator.getCalledFunctions().entries()) {
167237
incCallCount(f.getKey());
@@ -288,9 +358,7 @@ private void collectInlinableFunctions() {
288358
// this is only relevant for lua, because in JASS they are eliminated before inlining
289359
continue;
290360
}
291-
if (maxOneReturn(f)) {
292-
inlinableFunctions.add(f);
293-
}
361+
inlinableFunctions.add(f);
294362
}
295363
}
296364

de.peeeq.wurstscript/src/test/java/tests/wurstscript/tests/OptimizerTests.java

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -903,6 +903,28 @@ public void testInlineAnnotation() throws IOException {
903903
assertTrue(inlined.contains("function noot"));
904904
}
905905

906+
@Test
907+
public void inlinerSupportsMultiReturn() throws IOException {
908+
testAssertOkLines(true,
909+
"package test",
910+
"native testSuccess()",
911+
"function absLike(int x) returns int",
912+
" if x >= 0",
913+
" return x",
914+
" return 0 - x",
915+
"init",
916+
" let a = absLike(-4)",
917+
" let b = absLike(3)",
918+
" if a == 4 and b == 3",
919+
" testSuccess()",
920+
"endpackage"
921+
);
922+
923+
String inlined = Files.toString(new File("test-output/OptimizerTests_inlinerSupportsMultiReturn_inl.j"), Charsets.UTF_8);
924+
assertFalse(inlined.contains("call absLike"),
925+
"Expected multi-return function calls to be inlined in _inl output.");
926+
}
927+
906928

907929
@Test
908930
public void moveTowardsBug() { // see #737

0 commit comments

Comments
 (0)