Skip to content

Fix softplus and mish fp16 overflow on ANE via stable decomposition#2725

Open
Ashutosh0x wants to merge 1 commit into
apple:mainfrom
Ashutosh0x:fix/softplus-fp16-stable-decomposition-2687
Open

Fix softplus and mish fp16 overflow on ANE via stable decomposition#2725
Ashutosh0x wants to merge 1 commit into
apple:mainfrom
Ashutosh0x:fix/softplus-fp16-stable-decomposition-2687

Conversation

@Ashutosh0x
Copy link
Copy Markdown

@Ashutosh0x Ashutosh0x commented May 28, 2026

Problem

The native softplus MIL op computes log(1 + exp(x)), where exp(x) overflows in fp16 for x > ~10.4 on Apple Neural Engine, causing a hard, single-step output collapse to 0. This also affects nn.Mish (x * tanh(softplus(x))). CPU and GPU compute units are unaffected.

Additionally, PyTorch's threshold parameter (default 20) was being ignored by the converter.

Discovered while debugging fp16 precision in a KataGo-style network's Mish activations (see #2687).

Solution

Replace the native softplus op with the numerically stable equivalent:

softplus(x) = max(x, 0) + log(1 + exp(-|x|))

Since -|x| <= 0, exp(-|x|) is always in (0, 1], so no overflow can occur in any precision. This formula is already used by coremltools' own softplus MIL op value_inference.

Also apply PyTorch's threshold parameter: for beta * x > threshold, return x directly, matching PyTorch's exact semantics.

Changes

  • ops.py — softplus converter: Replaced mb.softplus() with the stable decomposition (abs -> mul(-1) -> exp -> add(1) -> log + maximum(x, 0) -> add). Applied same approach for all beta values (unit and non-unit). Added threshold via mb.select.
  • ops.py — mish converter: Applied the same stable softplus decomposition (Mish calls softplus internally).
  • test_torch_ops.py: Updated test_softplus to account for the new graph structure. Added test_softplus_fp16_threshold regression test with input values spanning the critical fp16 range: [-5, 0, 5, 10, 10.4, 11, 15, 20, 25, 50].

Testing

  • All existing test_softplus parametrized test cases remain (shapes, ranks, beta/threshold combinations, deployment targets)
  • New test_softplus_fp16_threshold specifically validates correctness at and beyond the ANE fp16 overflow point

Fixes #2687
Fixes #2359

…pple#2687)

The native softplus MIL op computes log(1 + exp(x)), where exp(x) overflows in fp16 for x > ~10.4 on Apple Neural Engine, causing a hard output collapse to 0. This also affects nn.Mish (x * tanh(softplus(x))).

Replace the native softplus op with the numerically stable equivalent: softplus(x) = max(x, 0) + log(1 + exp(-|x|)). Since -|x| <= 0, exp(-|x|) is always in (0,1], so no overflow can occur in any precision. This matches the value_inference formula already used in coremltools' own softplus MIL op definition.

Also apply PyTorch's threshold parameter (default 20) which was previously ignored: for beta*x > threshold, return x directly.

Changes: - Decompose softplus to stable form in PyTorch converter (ops.py) - Apply same fix to mish converter which calls softplus internally - Add test_softplus_fp16_threshold regression test with large inputs - Update test_softplus to account for new graph structure
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

1 participant