Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 20 additions & 33 deletions onnxscript/_internal/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -1206,7 +1206,11 @@ def _translate_loop_stmt(self, loop_stmt: ast.For | ast.While) -> None:
self.fail(loop_stmt, "Unsupported loop bound, it should be 'range(?)'.")
assert not iter.keywords, "Unsupported loop bound."
o_loop_bound = self._translate_expr(iter.args[0], "loop_bound")
onnx_cond_var = ir.Value(name=self.generate_unique_name("cond_in")) # TODO(Rama)
onnx_cond_var = make_value(
self.generate_unique_name("cond_in"),
onnx_types.BOOL,
self._source_of(loop_stmt),
)
i_cond_var = onnx_cond_var
cond_while = None
o_loop_condition = None # No condition for a for loop.
Expand All @@ -1220,8 +1224,12 @@ def _translate_loop_stmt(self, loop_stmt: ast.For | ast.While) -> None:
)
python_loop_var_name = "infinite_loop"
o_loop_bound = None
i_cond_var = ir.Value(name=test.id) # TODO(Rama)
cond_while = ir.Value(name=test.id) # TODO(Rama)
i_cond_var = make_value(
self.generate_unique_name(test.id),
onnx_types.BOOL,
self._source_of(loop_stmt),
)
cond_while = test.id
onnx_cond_var = None
o_loop_condition = self._translate_name_expr(test)
# we need to go through all the instructions to see
Expand Down Expand Up @@ -1254,20 +1262,11 @@ def _translate_loop_stmt(self, loop_stmt: ast.For | ast.While) -> None:
values.SymbolValue(onnx_loop_var, self._source_of(loop_stmt)),
)

self._current_fn.append_parameter(
make_value(
i_cond_var.name,
onnx_types.BOOL,
self._source_of(loop_stmt),
)
)
self._current_fn.append_parameter(i_cond_var)

for pv in loop_state_vars:
onnx_var_name = self.generate_unique_name(pv)
# TODO: retrieve the annotation for variable pv is any is specified.
# typeinfo = self._eval_constant_expr(pv.annotation)
typeinfo = None
parameter = make_value(onnx_var_name, typeinfo, self._source_of(loop_stmt))
parameter = make_value(onnx_var_name, None, self._source_of(loop_stmt))
self._current_fn.append_parameter(parameter)
self._bind(
pv,
Expand Down Expand Up @@ -1306,33 +1305,25 @@ def _translate_loop_stmt(self, loop_stmt: ast.For | ast.While) -> None:
continue
self._translate_stmt(s)

onnx_cond_out_name = self.generate_unique_name("cond_out")

if cond_while is not None:
# Loop while
current_scope = self._current_scope()
if cond_while.name not in current_scope:
if cond_while not in current_scope:
self.fail(
loop_stmt,
f"Unable to find condition variable {cond_while.name} in known "
f"Unable to find condition variable {cond_while} in known "
f"variables {list(current_scope)!r}.",
)
onnx_cond_var = current_scope[cond_while.name].value
onnx_cond_var = current_scope[cond_while].value

self.emit(
[onnx_cond_out_name],
cond_out = self.emit1(
[self.generate_unique_name("cond_out")],
values.Op(self.default_opset, operator_name),
[condition_name or onnx_cond_var],
[],
)
self._current_fn.outputs.append(cond_out)

self._current_fn.outputs.append(
make_value(
onnx_cond_out_name,
onnx_types.BOOL,
self._source_of(loop_stmt),
)
)
for pv in loop_state_vars:
onnx_var = self._py_var_to_onnx_var(pv, self._source_of(loop_stmt))
if onnx_var.name not in self._current_fn.assigned_names:
Expand All @@ -1342,11 +1333,7 @@ def _translate_loop_stmt(self, loop_stmt: ast.For | ast.While) -> None:
# In this case, we create a copy of y, treating the statement as
# shorthand for "x = op.Identity(y)".
onnx_var = self._emit_copy(onnx_var, pv)
# TODO: retrieve variable type for the annotation if any.
typeinfo = None
self._current_fn.outputs.append(
make_value(onnx_var.name, typeinfo, self._source_of(loop_stmt))
)
self._current_fn.outputs.append(onnx_var)
body = self._exit_scope()
inputs = [o_loop_bound, o_loop_condition] + [
self._py_var_to_onnx_var(pv, self._source_of(loop_stmt)) for pv in loop_state_vars
Expand Down
Loading