Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 52 additions & 1 deletion coremltools/converters/mil/frontend/torch/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -2186,7 +2186,58 @@ def _translate_torch_args(dim, keepdim) -> Tuple[Var]:
reduction_op_type = "reduce_sum"
else:
assert node_kind == "logsumexp"
reduction_op_type = "reduce_log_sum_exp"
# Use numerically stable decomposition to prevent fp16 overflow
# on ANE. The native reduce_log_sum_exp MIL op computes
# log(Σ exp(x_i)), where exp(x_i) overflows in fp16 when
# x_i > log(65504/C) (e.g., ~7.63 for C=32 channels) — issue #2690.
#
# Stable form: logsumexp(x) = max(x) + log(Σ exp(x - max(x)))
# By subtracting max first, all exp() arguments are <= 0, so
# exp() values are in (0, 1] — no overflow in any precision.
# This matches the value_inference in coremltools' own
# reduce_log_sum_exp MIL op definition.
if (
builtin_to_string(x.dtype)
not in SSAOpRegistry._get_core_op_cls("reduce_log_sum_exp").supported_dtypes()
):
x = mb.cast(x=x, dtype="fp32")

reduce_kwargs = {"x": x}
if axes is not None:
reduce_kwargs["axes"] = axes
reduce_kwargs["keep_dims"] = True

# Step 1: max(x) along reduction axes (keep dims for broadcasting)
x_max = mb.reduce_max(**reduce_kwargs)

# Step 2: x - max(x) (broadcast subtraction)
x_shifted = mb.sub(x=x, y=x_max)

# Step 3: exp(x - max(x)) — all values in (0, 1], safe in fp16
x_exp = mb.exp(x=x_shifted)

# Step 4: sum(exp(x - max(x))) along axes
sum_kwargs = {"x": x_exp}
if axes is not None:
sum_kwargs["axes"] = axes
if keep_dims is not None:
sum_kwargs["keep_dims"] = keep_dims
x_sum = mb.reduce_sum(**sum_kwargs)

# Step 5: log(sum) + max
x_log = mb.log(x=x_sum)

# Squeeze max if keep_dims is False to match output shape
if not keep_dims:
max_kwargs = {"x": x}
if axes is not None:
max_kwargs["axes"] = axes
x_max = mb.reduce_max(**max_kwargs)

res = mb.add(x=x_log, y=x_max, name=node.name)
context.add(res)
return

if (
builtin_to_string(x.dtype)
not in SSAOpRegistry._get_core_op_cls(reduction_op_type).supported_dtypes()
Expand Down
28 changes: 28 additions & 0 deletions coremltools/converters/mil/frontend/torch/test/test_torch_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -12227,6 +12227,34 @@ def test_logsumexp(self, compute_unit, backend, frontend, shape, dim):
compute_unit=compute_unit,
)

@pytest.mark.parametrize(
"compute_unit, backend, frontend",
itertools.product(compute_units, backends, frontends),
)
def test_logsumexp_fp16_overflow(self, compute_unit, backend, frontend):
"""Test that logsumexp handles large input values correctly.

This is a regression test for issue #2690: channel-reduce logsumexp
exhibits a hard output collapse on ANE in fp16 at x ≈ 7.63 due to
exp() overflow in the sum. The fix decomposes logsumexp into the
numerically stable form: max(x) + log(Σ exp(x - max(x))).
"""
# Use input values that span the critical fp16 range for a C=32
# channel reduction. The overflow point is log(65504/C) ≈ 7.63.
x = torch.full((1, 32, 4, 4), 8.0) # All channels at 8.0 > 7.63
model = ModuleWrapper(
function=torch.logsumexp,
kwargs={"dim": 1, "keepdim": True},
)
TorchBaseTest.run_compare_torch(
x,
model,
input_as_shape=False,
frontend=frontend,
backend=backend,
compute_unit=compute_unit,
)

@pytest.mark.parametrize(
"compute_unit, backend, frontend, keepdim",
itertools.product(compute_units, backends, frontends, (True, False)),
Expand Down