Skip to content

Fix log_softmax fp16 underflow and logcumsumexp fp16 overflow on ANE via stable decomposition#2727

Open
Ashutosh0x wants to merge 1 commit into
apple:mainfrom
Ashutosh0x:fix/log-softmax-fp16-underflow
Open

Fix log_softmax fp16 underflow and logcumsumexp fp16 overflow on ANE via stable decomposition#2727
Ashutosh0x wants to merge 1 commit into
apple:mainfrom
Ashutosh0x:fix/log-softmax-fp16-underflow

Conversation

@Ashutosh0x
Copy link
Copy Markdown

@Ashutosh0x Ashutosh0x commented May 29, 2026

Summary

Fix two fp16 numerical stability bugs in the PyTorch frontend converter that cause silent output corruption on Apple Neural Engine:

  1. log_softmax: Produces -inf for non-dominant classes when softmax probabilities underflow to 0 in fp16
  2. logcumsumexp: Overflows to inf for inputs > ~11.09 because exp() is applied to raw input without stabilization

Both fixes use the standard max-shift decomposition -- the same pattern applied in PR #2725 (softplus/mish) and PR #2726 (logsumexp).

Changes

converters/mil/frontend/torch/ops.py

log_softmax (line 5904): Replace naive log(softmax(x)) with:
log_softmax(x) = x - max(x) - log(sum(exp(x - max(x))))
This avoids computing tiny intermediate softmax probabilities that underflow to 0 in fp16.

logcumsumexp (line 2230): Replace naive log(cumsum(exp(x))) with:
logcumsumexp(x) = max(x) + log(cumsum(exp(x - max(x))))
By subtracting the global max first, all exp() arguments are <= 0, keeping values in (0, 1].

Note on global max for logcumsumexp: The global max is used rather than a running (cumulative) max because MIL does not provide a cummax op. The global max is always >= the running max at every position, so exp(x_i - global_max) <= 1 for all i, guaranteeing no overflow. The trade-off is slightly more underflow for early positions when a much larger value appears later, but this does not affect correctness -- those contributions are genuinely negligible. A future optimization could introduce a running max if a cummax MIL op becomes available.

converters/mil/frontend/torch/test/test_torch_ops.py

  • Added test_log_softmax -- standard parametrized test across all shapes/backends/frontends
  • Added test_log_softmax_fp16_no_neg_inf -- regression test with dominant-class input (x=50)
  • Added test_logcumsumexp_fp16_large_input -- regression test with fp16-overflow inputs (x up to 50)

Fixes

Related

Pattern

This is part of a systematic effort to fix all exp()-based operations in the converter that overflow in fp16 on ANE. The root cause is always the same: exp(x) without first bounding x overflows at the fp16 maximum (65,504, reached at x ~ 11.09). The fix is always the same: subtract the maximum before computing exp(), then add it back after log().

Op Overflow threshold Fix PR
softplus x > 10.4 max(x,0) + log(1+exp(-abs(x))) #2725
mish x > 10.4 (via softplus) same as softplus #2725
logsumexp x > 7.63 (C=32) max + log(sum(exp(x-max))) #2726
log_softmax probabilities < 6e-5 x - max - log(sum(exp(x-max))) this PR
logcumsumexp x > 11.09 max + log(cumsum(exp(x-max))) this PR

…via stable decomposition

log_softmax: The naive log(softmax(x)) produces -inf for non-dominant
classes in fp16 because softmax outputs underflow to 0, then log(0) = -inf.
The stable form x - max(x) - log(sum(exp(x - max(x)))) avoids computing
tiny intermediate probabilities directly.

logcumsumexp: The naive log(cumsum(exp(x))) overflows in fp16 for x > ~11.09
since exp(11.09) exceeds fp16 max (65,504). The stable form shifts by the
global maximum first so all exp() arguments are <= 0, keeping values in (0,1].

Both fixes follow the same max-shift pattern used in the logsumexp stable
decomposition (PR apple#2726) and the softplus stable decomposition (PR apple#2725).

Added regression tests with extreme fp16 inputs for both ops.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

1 participant