Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 83 additions & 5 deletions coremltools/converters/mil/frontend/torch/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -2228,9 +2228,48 @@ def _parse_keyword_args(context, node, dtype) -> Var:
res = mb.cumsum(x=x, axis=dim, name=node.name)
else:
assert node.kind in ("logcumsumexp", "_logcumsumexp")
exp = mb.exp(x=x)
cumsumexp = mb.cumsum(x=exp, axis=dim)
res = mb.log(x=cumsumexp, name=node.name)
# Use numerically stable decomposition to prevent fp16 overflow
# on ANE. The naive log(cumsum(exp(x))) computes exp(x) on raw
# input, which overflows in fp16 for x > ~11.09 (since
# exp(11.09) ~ 65,504 = fp16 max).
#
# Stable form:
# logcumsumexp(x) = max(x) + log(cumsum(exp(x - max(x))))
#
# We use the global max over the entire axis rather than a
# running (cumulative) max because MIL does not provide a
# cummax op. The global max is always >= the running max at
# every position, so exp(x_i - global_max) <= exp(x_i -
# running_max_i) <= 1 for all i. This guarantees no overflow.
# The trade-off is slightly more underflow for early positions
# when a much larger value appears later, but this does not
# affect correctness -- those contributions are genuinely
# negligible. This is the same max-shift pattern used in the
# logsumexp stable decomposition.

# Step 1: global max along the cumsum axis (keep dims for
# broadcasting)
x_max = mb.reduce_max(x=x, axes=[dim], keep_dims=True,
name=node.name + "_max")

# Step 2: x - max(x) (broadcast subtraction)
x_shifted = mb.sub(x=x, y=x_max,
name=node.name + "_shifted")

# Step 3: exp(x - max(x)) -- all values in (0, 1], safe in fp16
x_exp = mb.exp(x=x_shifted,
name=node.name + "_exp")

# Step 4: cumulative sum of shifted exponentials
x_cumsum = mb.cumsum(x=x_exp, axis=dim,
name=node.name + "_cumsum")

# Step 5: log(cumsum(...))
x_log = mb.log(x=x_cumsum,
name=node.name + "_log")

# Step 6: add back the global max
res = mb.add(x=x_log, y=x_max, name=node.name)

context.add(res)

Expand Down Expand Up @@ -5915,8 +5954,47 @@ def _parse_positional_args(context, node) -> Tuple[Var]:

x, axis = _parse_positional_args(context, node)

res = mb.softmax(x=x, axis=axis, name=node.name + "_softmax")
res = mb.log(x=res, name=node.name)
# Use numerically stable decomposition to prevent fp16 underflow
# on ANE. The naive log(softmax(x)) first computes softmax
# probabilities, which underflow to 0 in fp16 for non-dominant
# classes (any probability below ~6e-5). Then log(0) produces
# -inf, silently corrupting the output.
#
# Stable form:
# log_softmax(x) = x - max(x) - log(sum(exp(x - max(x))))
#
# By subtracting max(x) first, all exp() arguments are <= 0, so
# exp() values are in (0, 1] -- no overflow. The log of the sum
# is computed directly, avoiding the underflow-prone intermediate
# softmax probabilities.
#
# This matches PyTorch's own fused log_softmax implementation and
# the approach used in coremltools' TensorFlow frontend for
# cross-entropy loss (_softmax_cross_entropy_with_logits).

# Step 1: max(x) along softmax axis (keep dims for broadcasting)
x_max = mb.reduce_max(x=x, axes=[axis], keep_dims=True,
name=node.name + "_max")

# Step 2: x - max(x) (broadcast subtraction)
x_shifted = mb.sub(x=x, y=x_max,
name=node.name + "_shifted")

# Step 3: exp(x - max(x)) -- all values in (0, 1], safe in fp16
x_exp = mb.exp(x=x_shifted,
name=node.name + "_exp")

# Step 4: sum(exp(x - max(x))) along softmax axis
x_sum = mb.reduce_sum(x=x_exp, axes=[axis], keep_dims=True,
name=node.name + "_sum")

# Step 5: log(sum(...))
x_log_sum = mb.log(x=x_sum,
name=node.name + "_log_sum")

# Step 6: (x - max(x)) - log(sum(exp(x - max(x))))
res = mb.sub(x=x_shifted, y=x_log_sum, name=node.name)

context.add(res)


Expand Down
71 changes: 71 additions & 0 deletions coremltools/converters/mil/frontend/torch/test/test_torch_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -6534,6 +6534,47 @@ def test_softmax(self, compute_unit, backend, frontend, shape):
shape, model, frontend=frontend, backend=backend, compute_unit=compute_unit
)

@pytest.mark.parametrize(
"compute_unit, backend, frontend, shape",
itertools.product(compute_units, backends, frontends, COMMON_SHAPES_ALL),
)
def test_log_softmax(self, compute_unit, backend, frontend, shape):
model = ModuleWrapper(function=torch.nn.functional.log_softmax, kwargs={"dim": -1})
self.run_compare_torch(
shape, model, frontend=frontend, backend=backend, compute_unit=compute_unit
)

@pytest.mark.parametrize(
"compute_unit, backend",
itertools.product(compute_units, backends),
)
def test_log_softmax_fp16_no_neg_inf(self, compute_unit, backend):
"""Regression test: log_softmax must not produce -inf for
non-dominant classes in fp16. The naive log(softmax(x))
underflows when softmax probabilities fall below fp16 minimum
(~6e-5), causing log(0) = -inf on Apple Neural Engine."""

class LogSoftmaxModel(nn.Module):
def forward(self, x):
return torch.nn.functional.log_softmax(x, dim=-1)

model = LogSoftmaxModel().eval()

# Input with one dominant class and moderate spread — this is
# the exact scenario that breaks with naive log(softmax(x))
# in fp16 because the non-dominant softmax outputs underflow
# to 0 before log is applied.
x = torch.tensor([[0.0, 0.0, 0.0, 50.0, 0.0, 0.0, 0.0, 0.0]])

self.run_compare_torch(
x.shape,
model,
backend=backend,
compute_unit=compute_unit,
input_as_shape=False,
torch_input_values=[x],
)

@pytest.mark.parametrize(
"compute_unit, backend, frontend, range_val",
itertools.product(
Expand Down Expand Up @@ -12459,6 +12500,36 @@ def test_logcumsumexp(self, compute_unit, backend, frontend, axis):
input_shape, model, frontend=frontend, backend=backend, compute_unit=compute_unit
)

@pytest.mark.parametrize(
"compute_unit, backend",
itertools.product(compute_units, backends),
)
def test_logcumsumexp_fp16_large_input(self, compute_unit, backend):
"""Regression test: logcumsumexp must not overflow to inf for
inputs above ~11.09 in fp16. The naive log(cumsum(exp(x)))
computes exp(x) on raw input, which overflows at the fp16
maximum (65,504) on Apple Neural Engine."""

class LogCumSumExpModel(nn.Module):
def forward(self, x):
return torch.logcumsumexp(x, dim=-1)

model = LogCumSumExpModel().eval()

# Input with values in the range that triggers fp16 overflow
# in the naive exp(x) computation. exp(12) ≈ 162,755 which
# exceeds fp16 max of 65,504.
x = torch.tensor([[1.0, 5.0, 10.0, 12.0, 15.0, 20.0, 50.0]])

self.run_compare_torch(
x.shape,
model,
backend=backend,
compute_unit=compute_unit,
input_as_shape=False,
torch_input_values=[x],
)


class TestHannWindow(TorchBaseTest):
@pytest.mark.parametrize(
Expand Down