Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 40 additions & 15 deletions coremltools/converters/mil/frontend/torch/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1794,21 +1794,37 @@ def _parse_positional_args(context, node) -> Tuple[Var]:

x, beta, threshold = _parse_positional_args(context, node)

# Use numerically stable decomposition to prevent fp16 overflow on ANE.
# The native softplus MIL op computes log(1 + exp(x)), where exp(x)
# overflows in fp16 for x > ~10.4 on Apple Neural Engine (issue #2687).
#
# Stable form: softplus(x) = max(x, 0) + log(1 + exp(-|x|))
# Since -|x| <= 0, exp(-|x|) is always in (0, 1], so no overflow occurs.
# This matches the value_inference in coremltools' own softplus MIL op.
if beta == 1:
# this is the special case that Core ML softplus handles
res = mb.softplus(x=x, name=node.name)
abs_x = mb.abs(x=x)
neg_abs_x = mb.mul(x=-1.0, y=abs_x)
exp_val = mb.exp(x=neg_abs_x)
log_val = mb.log(x=mb.add(x=1.0, y=exp_val))
max_val = mb.maximum(x=x, y=0.0)
sp = mb.add(x=max_val, y=log_val)
else:
if x.rank == 4:
# can use Core ML softplus_parametric
C = x.shape[1]
alpha_br = np.repeat(1.0 / beta, C).astype("float32")
beta_br = np.repeat(beta, C).astype("float32")
res = mb.softplus_parametric(x=x, alpha=alpha_br, beta=beta_br, name=node.name)
else:
# have to generally decompose
beta_mul_x = mb.mul(x=beta, y=x)
softplus = mb.softplus(x=beta_mul_x)
res = mb.real_div(x=softplus, y=beta, name=node.name)
# For non-unit beta: softplus(x) = (1/beta) * softplus(beta * x)
# Apply the same stable decomposition to (beta * x).
beta_mul_x = mb.mul(x=beta, y=x)
abs_bx = mb.abs(x=beta_mul_x)
neg_abs_bx = mb.mul(x=-1.0, y=abs_bx)
exp_val = mb.exp(x=neg_abs_bx)
log_val = mb.log(x=mb.add(x=1.0, y=exp_val))
max_val = mb.maximum(x=beta_mul_x, y=0.0)
stable_sp = mb.add(x=max_val, y=log_val)
sp = mb.real_div(x=stable_sp, y=beta)

# Apply PyTorch's threshold: for beta * x > threshold, softplus(x) ≈ x,
# so return x directly. This matches PyTorch's exact behavior.
beta_x = mb.mul(x=beta, y=x)
cond = mb.greater(x=beta_x, y=threshold)
res = mb.select(cond=cond, a=x, b=sp, name=node.name)
context.add(res)


Expand All @@ -1817,8 +1833,17 @@ def mish(context, node):
inputs = _get_inputs(context, node, expected=1)
x = inputs[0]

softplus = mb.softplus(x=x)
tanh = mb.tanh(x=softplus)
# Mish(x) = x * tanh(softplus(x))
# Use numerically stable softplus to prevent fp16 overflow on ANE.
# See softplus converter above and issue #2687.
abs_x = mb.abs(x=x)
neg_abs_x = mb.mul(x=-1.0, y=abs_x)
exp_val = mb.exp(x=neg_abs_x)
log_val = mb.log(x=mb.add(x=1.0, y=exp_val))
max_val = mb.maximum(x=x, y=0.0)
stable_softplus = mb.add(x=max_val, y=log_val)

tanh = mb.tanh(x=stable_softplus)
res = mb.mul(x=x, y=tanh, name=node.name)
context.add(res)

Expand Down
47 changes: 36 additions & 11 deletions coremltools/converters/mil/frontend/torch/test/test_torch_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -6721,17 +6721,9 @@ def test_softplus(
# executorch decomposes softplus to very basic log and exp
target_op = "exp"
else:
if beta is None or beta == 1:
# this is the special case that Core ML softplus handles
target_op = "softplus"
else:
if rank == 4:
# can use Core ML softplus_parametric
target_op = "softplus_parametric"
else:
# have to generally decompose to
# `x -> beta * x -> softplus(beta * x) -> softplus(beta * x) / beta`
target_op = "softplus"
# The converter now wraps softplus in a select for fp16 threshold safety,
# so we skip the target_op check for non-executorch frontends.
target_op = None

self.run_compare_torch(
input_shape,
Expand All @@ -6743,6 +6735,39 @@ def test_softplus(
target_op=target_op,
)

@pytest.mark.parametrize(
"compute_unit, backend, frontend",
itertools.product(compute_units, backends, frontends),
)
@pytest.mark.skipif(
_macos_version() <= (10, 15),
reason="Parametric SoftPlus segfaults on macOS 10.15 and below.",
)
def test_softplus_fp16_threshold(self, compute_unit, backend, frontend):
"""Test that softplus handles large input values correctly.

This is a regression test for issue #2687: softplus exhibits a hard
output collapse on ANE in fp16 at x ≈ 10.4 due to exp() overflow.
The fix applies PyTorch's threshold parameter so that for large x,
softplus(x) ≈ x, preventing the fp16 overflow.
"""
# Use input values that span the critical fp16 range: values near and
# above the overflow point (~11 in fp16 for exp())
x = torch.tensor([
[-5.0, 0.0, 5.0, 10.0, 10.4, 11.0, 15.0, 20.0, 25.0, 50.0]
])
model = nn.Softplus()
model.eval()

TorchBaseTest.run_compare_torch(
x,
model,
input_as_shape=False,
frontend=frontend,
backend=backend,
compute_unit=compute_unit,
)

@pytest.mark.parametrize(
"compute_unit, backend, frontend, shape",
itertools.product(compute_units, backends, frontends, COMMON_SHAPES_ALL),
Expand Down