Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions coremltools/converters/mil/mil/passes/defs/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,7 @@ def transform_op(self, op) -> None:
before_op=op,
)
if param_target_dtype == "fp16":
self._check_overflow_to_inf(x, var)
self._check_underflow_to_zero(x, var)
Block._copy_metadata(var, x)

Expand Down Expand Up @@ -503,6 +504,28 @@ def _check_underflow_to_zero(self, new_var, var):
else:
new_var._sym_val.val = new_val.reshape(new_var.val.shape)

def _check_overflow_to_inf(self, new_var, var):
# Saturate finite fp32 values that overflowed to ±inf during the fp16 cast.
# Motivating case: attention masks that use torch.finfo(torch.float32).min
# (~-3.4e38) as the masked-out sentinel. Letting that become fp16 -inf
# produces NaN in softmax for fully-masked rows (e.g. Gemma-4).
# Genuine ±inf in the original fp32 value is preserved.
if new_var.val is None or var.val is None:
return
original_val = np.asarray(var.val).flatten()
new_val = np.asarray(new_var.val).flatten().astype(np.float16, copy=True)
if original_val.shape != new_val.shape:
return
overflow_mask = np.isfinite(original_val) & np.isinf(new_val)
if not np.any(overflow_mask):
return
new_val[overflow_mask & (original_val > 0)] = np.float16(65504.0)
new_val[overflow_mask & (original_val < 0)] = np.float16(-65504.0)
if np.isscalar(new_var.val):
new_var._sym_val.val = new_val[0]
else:
new_var._sym_val.val = new_val.reshape(new_var.val.shape)


@register_pass(namespace="common")
class add_fp16_cast(FP16ComputePrecision):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2287,6 +2287,59 @@ def prog(x):

assert get_op_types_in_program(prog) == ["clip"]

def test_fp16_overflow_finfo_min_const_does_not_nan(self):
"""
Regression test for models that build attention masks using
``torch.finfo(torch.float32).min`` (``-3.4028e+38``) as the masked-out
sentinel (e.g. Gemma-4). That value is finite but exceeds fp16's
representable range. A plain cast yields fp16 ``-inf``, which makes
``softmax(-inf - (-inf)) = NaN`` for fully-masked rows. The cast-time
saturation should replace the cast result with ``-65504`` so ``exp``
underflows to ``0``, which is the intended masking semantics.
"""

SHAPE = (4,)
FP32_MIN = float(np.finfo(np.float32).min)

@mb.program(input_specs=[mb.TensorSpec(shape=SHAPE)])
def prog(x):
mask = mb.const(val=np.full(SHAPE, FP32_MIN, dtype=np.float32))
y = mb.add(x=x, y=mask)
return y

apply_pass_and_basic_check(prog, "common::add_fp16_cast")

# The original fp32 const must remain intact — we must not mutate
# values that are only used through the cast.
fp32_consts = [
op for op in prog.functions["main"].operations
if op.op_type == "const"
and op.outputs[0].is_tensor_or_scalar_of(dtype="fp32")
and op.val.val is not None
and np.asarray(op.val.val).shape == SHAPE
]
assert fp32_consts, "Expected the original fp32 mask const to still exist"
assert np.isclose(np.asarray(fp32_consts[0].val.val).min(), FP32_MIN)

# The inserted fp16 cast must have had its output value saturated to
# ``-65504`` instead of overflowing to ``-inf``.
fp16_cast_vars = [
op.outputs[0]
for op in prog.functions["main"].operations
if op.op_type == "cast"
and op.outputs[0].is_tensor_or_scalar_of(dtype="fp16")
and op.outputs[0].val is not None
]
saturated = [
var for var in fp16_cast_vars
if np.isclose(np.asarray(var.val).min(), -65504.0)
]
assert saturated, (
"Expected the fp16 cast of finfo(fp32).min to be saturated to -65504, "
f"but found fp16 cast outputs: {[var.val for var in fp16_cast_vars]}"
)
assert np.isfinite(np.asarray(saturated[0].val)).all()

def test_divide_by_zero_operation(self):
"""
Input graph:
Expand Down