Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 59 additions & 35 deletions coremltools/converters/mil/frontend/torch/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -7668,57 +7668,69 @@ def _parse_positional_args(context, node) -> Tuple[Var]:
x = inputs[0]

if context.frontend == TorchFrontend.TORCHSCRIPT:
# torch.jit.trace does not distinguish by name `std` and `std.dim`,
# instead by nargs 2 or 4
# torch.jit.trace does not distinguish by name `var` and `var.dim`,
# instead by nargs 2 or 4. It also collapses the `unbiased` and
# `correction` overloads onto the same signature, so the second
# (nargs == 2) or third (nargs == 4) positional argument may carry
# either a bool `unbiased` or a Scalar `correction`. We treat it as
# `correction` since that subsumes `unbiased` (unbiased=False is
# correction=0, unbiased=True is correction=1).
keepdim = False
dim = None
if len(inputs) == 2:
unbiased = inputs[1]
correction = inputs[1]
if len(inputs) == 4:
dim = inputs[1]
unbiased = inputs[2]
correction = inputs[2]
keepdim = inputs[3]
else:
if node.kind == "var":
unbiased = inputs[1] if nargs > 1 else True
correction = inputs[1] if nargs > 1 else True
dim = None
keepdim = False
else:
assert node.kind == "var.dim"
assert nargs > 1
dim = inputs[1]
unbiased = inputs[2] if nargs > 2 else True
correction = inputs[2] if nargs > 2 else True
keepdim = inputs[3] if nargs > 3 else False

return x, dim, unbiased, keepdim
return x, dim, correction, keepdim

def _parse_keyword_args(context, node, dim, unbiased, keepdim) -> Tuple[Var]:
def _parse_keyword_args(context, node, dim, correction, keepdim) -> Tuple[Var]:
dim = _get_kwinputs(context, node, "dim", default=[dim])[0]
unbiased = _get_kwinputs(context, node, "unbiased", default=[unbiased])[0]
unbiased = _get_kwinputs(context, node, "unbiased", default=[None])[0]
if unbiased is not None:
correction = unbiased
correction = _get_kwinputs(context, node, "correction", default=[correction])[0]
keepdim = _get_kwinputs(context, node, "keepdim", default=[keepdim])[0]
return dim, unbiased, keepdim
return dim, correction, keepdim

def _translate_torch_args(dim, unbiased, keepdim) -> Tuple[Var]:
def _translate_torch_args(dim, correction, keepdim) -> Tuple[Var]:
if isinstance(dim, Var):
dim = dim.val
try:
dim = (int(dim),)
except:
pass

if isinstance(unbiased, Var):
unbiased = unbiased.val
if isinstance(correction, Var):
correction = correction.val
# `unbiased` is the bool special case of `correction`: a true/false flag
# maps to a divisor of N-1 / N, i.e. correction 1 / 0. A missing value
# defaults to torch's unbiased estimator (correction 1).
correction = 1 if correction is None else int(correction)

if isinstance(keepdim, Var):
keepdim = keepdim.val

return dim, unbiased, keepdim
return dim, correction, keepdim

x, dim, unbiased, keepdim = _parse_positional_args(context, node)
dim, unbiased, keepdim = _parse_keyword_args(context, node, dim, unbiased, keepdim)
axes, unbiased, keep_dims = _translate_torch_args(dim, unbiased, keepdim)
x, dim, correction, keepdim = _parse_positional_args(context, node)
dim, correction, keepdim = _parse_keyword_args(context, node, dim, correction, keepdim)
axes, correction, keep_dims = _translate_torch_args(dim, correction, keepdim)

y = _var(x, axes=axes, keep_dims=keep_dims, unbiased=unbiased)
y = _var(x, axes=axes, keep_dims=keep_dims, correction=correction)
context.add(y, node.name)


Expand Down Expand Up @@ -7783,56 +7795,68 @@ def _parse_positional_args(context, node) -> Tuple[Var]:

if context.frontend == TorchFrontend.TORCHSCRIPT:
# torch.jit.trace does not distinguish by name `std` and `std.dim`,
# instead by nargs 2 or 4
# instead by nargs 2 or 4. It also collapses the `unbiased` and
# `correction` overloads onto the same signature, so the second
# (nargs == 2) or third (nargs == 4) positional argument may carry
# either a bool `unbiased` or a Scalar `correction`. We treat it as
# `correction` since that subsumes `unbiased` (unbiased=False is
# correction=0, unbiased=True is correction=1).
keepdim = False
dim = None
if len(inputs) == 2:
unbiased = inputs[1]
correction = inputs[1]
if len(inputs) == 4:
dim = inputs[1]
unbiased = inputs[2]
correction = inputs[2]
keepdim = inputs[3]
else:
if node.kind == "std":
unbiased = inputs[1] if nargs > 1 else True
correction = inputs[1] if nargs > 1 else True
dim = None
keepdim = False
else:
assert node.kind == "std.dim"
assert nargs > 1
dim = inputs[1]
unbiased = inputs[2] if nargs > 2 else True
correction = inputs[2] if nargs > 2 else True
keepdim = inputs[3] if nargs > 3 else False

return x, dim, unbiased, keepdim
return x, dim, correction, keepdim

def _parse_keyword_args(context, node, dim, unbiased, keepdim) -> Tuple[Var]:
def _parse_keyword_args(context, node, dim, correction, keepdim) -> Tuple[Var]:
dim = _get_kwinputs(context, node, "dim", default=[dim])[0]
unbiased = _get_kwinputs(context, node, "unbiased", default=[unbiased])[0]
unbiased = _get_kwinputs(context, node, "unbiased", default=[None])[0]
if unbiased is not None:
correction = unbiased
correction = _get_kwinputs(context, node, "correction", default=[correction])[0]
keepdim = _get_kwinputs(context, node, "keepdim", default=[keepdim])[0]
return dim, unbiased, keepdim
return dim, correction, keepdim

def _translate_torch_args(dim, unbiased, keepdim) -> Tuple[Var]:
def _translate_torch_args(dim, correction, keepdim) -> Tuple[Var]:
if isinstance(dim, Var):
dim = dim.val
try:
dim = (int(dim),)
except:
pass

if isinstance(unbiased, Var):
unbiased = unbiased.val
if isinstance(correction, Var):
correction = correction.val
# `unbiased` is the bool special case of `correction`: a true/false flag
# maps to a divisor of N-1 / N, i.e. correction 1 / 0. A missing value
# defaults to torch's unbiased estimator (correction 1).
correction = 1 if correction is None else int(correction)

if isinstance(keepdim, Var):
keepdim = keepdim.val

return dim, unbiased, keepdim
return dim, correction, keepdim

x, dim, unbiased, keepdim = _parse_positional_args(context, node)
dim, unbiased, keepdim = _parse_keyword_args(context, node, dim, unbiased, keepdim)
axes, unbiased, keep_dims = _translate_torch_args(dim, unbiased, keepdim)
x, dim, correction, keepdim = _parse_positional_args(context, node)
dim, correction, keepdim = _parse_keyword_args(context, node, dim, correction, keepdim)
axes, correction, keep_dims = _translate_torch_args(dim, correction, keepdim)

variance = _var(x, axes=axes, keep_dims=keep_dims, unbiased=unbiased)
variance = _var(x, axes=axes, keep_dims=keep_dims, correction=correction)
standard_deviation = mb.sqrt(x=variance)
context.add(standard_deviation, node.name)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7956,7 +7956,7 @@ def test_var_std_4_inputs(
backends,
frontends,
["var", "std"],
[0, 1],
[0, 1, 2],
[[0, 2], [1], [2]],
[True, False],
),
Expand Down