Skip to content
1 change: 1 addition & 0 deletions modelopt/torch/kernels/quantization/gemm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
# fp4_kernel works on any CUDA GPU with triton
from .fp4_kernel import *
from .fp8_kernel import *
from .nvfp4_fp8_sweep import *

# fp4_kernel_hopper requires compute >= 8.9 (uses tl.float8e4nv)
if torch.cuda.get_device_capability() >= (8, 9):
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Single source of truth for the NVFP4 FP8 scale-candidate set.

Pure PyTorch, no Triton dependency, so it can be imported from both the kernel
wrapper (which is triton-gated) and the reference Python sweep in the
:class:`NVFP4MSECalibrator` (which must work without triton too).
"""

import torch


def fp8_scale_candidates(device: torch.device | str = "cpu") -> torch.Tensor:
"""Return the 126 valid finite positive FP8 E4M3 scale candidates / 448."""
uint8_values = torch.arange(0, 128, dtype=torch.uint8, device=device)
fp8_values = uint8_values.view(torch.float8_e4m3fn).float()
valid_mask = torch.isfinite(fp8_values) & (fp8_values > 0)
return fp8_values[valid_mask] / 448.0
150 changes: 150 additions & 0 deletions modelopt/torch/kernels/quantization/gemm/nvfp4_fp8_sweep.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Fused Triton kernel for the NVFP4 weight-MSE FP8 scale sweep.

Replaces the 126-iteration Python sweep in :class:`NVFP4MSECalibrator` with a single
kernel that, for each NVFP4 block, evaluates all 126 valid FP8 E4M3 scale candidates
and emits the per-block ``best_amax`` directly.

The 126 candidates are constructed as ``valid_fp8_e4m3_value / 448`` (see
:func:`fp8_scale_candidates`). For these specific candidates, the FP8 round-trip on
the per-block scale is the identity, so the kernel can use
``scale = candidate * global_amax / 6.0`` without an explicit FP8 cast — making it
runnable on any CUDA GPU with Triton (no ``tl.float8e4nv`` requirement).

Tile shape (``BLOCKS_PER_PROGRAM``) and ``num_warps`` are autotuned per ``N_BLOCKS``.
"""

import torch
import triton
import triton.language as tl

from ._fp8_scale_candidates import fp8_scale_candidates
from .nvfp4_quant import fp4_round_magnitude

__all__ = ["fp8_scale_candidates", "nvfp4_fp8_scale_sweep"]


# Selected from a (BLOCKS_PER_PROGRAM, num_warps) sweep on B300:
# BPP=16,nw=2: 6.06 ms BPP=32,nw=4: 6.06 ms BPP=64,nw=8: 5.08 ms
# The smaller-tile entries cover cases where N_BLOCKS is small enough that BPP=64
# would underfill the SMs.
_FP8_SWEEP_AUTOTUNE_CONFIGS = [
triton.Config({"BLOCKS_PER_PROGRAM": 16}, num_warps=2),
triton.Config({"BLOCKS_PER_PROGRAM": 32}, num_warps=4),
triton.Config({"BLOCKS_PER_PROGRAM": 64}, num_warps=8),
]


@triton.autotune(configs=_FP8_SWEEP_AUTOTUNE_CONFIGS, key=["N_BLOCKS"])
@triton.jit
def _fp8_scale_sweep_kernel(
x_ptr, # [N_BLOCKS * BLOCK_SIZE], any float dtype (loaded as fp32)
candidates_ptr, # [NUM_CANDIDATES] fp32
global_amax_ptr, # scalar fp32
best_amax_ptr, # [N_BLOCKS] fp32 output
N_BLOCKS,
BLOCK_SIZE: tl.constexpr,
NUM_CANDIDATES: tl.constexpr,
BLOCKS_PER_PROGRAM: tl.constexpr,
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCKS_PER_PROGRAM
block_idx = block_start + tl.arange(0, BLOCKS_PER_PROGRAM)
block_mask = block_idx < N_BLOCKS

# Load weights for this tile and pre-compute their absolute values once.
# The squared error is sign-invariant since FP4 quant preserves sign:
# (w - w_q)^2 = (|w| - |w_q|)^2 = (|w| - q_mag * scale)^2
# so we never need ``w`` itself again, dropping a tl.where + negation per element.
elem_offs = block_idx[:, None] * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)[None, :]
elem_mask = block_mask[:, None]
w_abs = tl.abs(tl.load(x_ptr + elem_offs, mask=elem_mask, other=0.0).to(tl.float32))

global_amax = tl.load(global_amax_ptr).to(tl.float32)

best_loss = tl.full([BLOCKS_PER_PROGRAM], float("inf"), dtype=tl.float32)
best_idx = tl.zeros([BLOCKS_PER_PROGRAM], dtype=tl.int32)

# Loop over the 126 FP8 candidates (compile-time unrolled).
# Scales are guaranteed positive and finite (constructed from a positive candidate
# times nonneg global_amax), so the degenerate-scale guard from nvfp4_scalar_quant is
# unnecessary apart from the global_amax == 0 case handled below.
for k in tl.static_range(NUM_CANDIDATES):
c = tl.load(candidates_ptr + k).to(tl.float32)
scale = c * global_amax / 6.0
# Avoid divide-by-zero when global_amax == 0; in that case w_abs is also zero
# (global_amax = max|w|), so the loss is zero for every candidate either way.
scale_safe = tl.where(scale == 0.0, 1.0, scale)
q_mag = fp4_round_magnitude(w_abs / scale_safe)
diff = w_abs - q_mag * scale_safe
loss = tl.sum(diff * diff, axis=1) # [BLOCKS_PER_PROGRAM]
is_better = loss < best_loss
Comment thread
realAsma marked this conversation as resolved.
best_loss = tl.where(is_better, loss, best_loss)
best_idx = tl.where(is_better, k, best_idx)

# Map each block's winning candidate index back to its amax = global_amax * c[best].
best_c = tl.load(candidates_ptr + best_idx, mask=block_mask, other=0.0).to(tl.float32)
best_amax = global_amax * best_c
tl.store(best_amax_ptr + block_idx, best_amax, mask=block_mask)


def nvfp4_fp8_scale_sweep(
x: torch.Tensor,
global_amax: torch.Tensor,
block_size: int = 16,
) -> torch.Tensor:
"""Find the per-block FP8 scale that minimizes NVFP4 quantization MSE.

Equivalent to the 126-step sweep in :class:`NVFP4MSECalibrator`, but fused into
a single Triton kernel: every block's weight elements are loaded once, all 126
candidates are evaluated in registers, and the running argmin is kept inline.

Args:
x: Weight tensor on CUDA. Total element count must be divisible by
``block_size``; layout is treated as a flat ``[N_BLOCKS, BLOCK_SIZE]``.
global_amax: Scalar FP32 global amax (``= reduce_amax(per_block_amax)``).
block_size: NVFP4 block size (typically 16).

Returns:
``best_amax`` of shape ``[N_BLOCKS]``, fp32, on the same device as ``x``.
"""
if not x.is_cuda:
raise ValueError("nvfp4_fp8_scale_sweep requires a CUDA tensor.")
if not isinstance(block_size, int) or block_size <= 0:
raise ValueError(f"block_size must be a positive int, got {block_size!r}.")
if x.numel() % block_size != 0:
raise ValueError(f"x.numel() ({x.numel()}) is not divisible by block_size ({block_size}).")
Comment thread
coderabbitai[bot] marked this conversation as resolved.

candidates = fp8_scale_candidates(x.device).to(dtype=torch.float32)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[IMPORTANT Performance]

fp8_scale_candidates(x.device) is called on every invocation of nvfp4_fp8_scale_sweep, allocating a new 126-element tensor each time. In the weight calibration loop (mse_calibrate step 3), this function is called once per weight quantizer — potentially hundreds of times for large models.

The allocation is small (504 bytes) and the cost is negligible compared to the kernel itself, so this is not a blocking issue. However, if you want to eliminate it, a module-level cache (e.g., functools.lru_cache keyed on device, or a _candidates_cache: dict[torch.device, torch.Tensor]) would ensure the tensor is computed once per device per process.

Not urgent given the 42x speedup context, but would be a clean follow-up.


n_blocks = x.numel() // block_size
x_flat = x.contiguous().view(-1)
global_amax_f32 = global_amax.detach().to(device=x.device, dtype=torch.float32).reshape(1)
best_amax = torch.empty(n_blocks, dtype=torch.float32, device=x.device)

grid = lambda meta: (triton.cdiv(n_blocks, meta["BLOCKS_PER_PROGRAM"]),)
with torch.cuda.device(x.device):
_fp8_scale_sweep_kernel[grid](
x_flat,
candidates,
global_amax_f32,
best_amax,
n_blocks,
BLOCK_SIZE=block_size,
NUM_CANDIDATES=int(candidates.numel()),
)
return best_amax
93 changes: 86 additions & 7 deletions modelopt/torch/quantization/calib/mse.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
"""Calibrator that returns the MSE amax of all collected tensors."""

import math
import os
from collections.abc import Callable

import torch
Expand Down Expand Up @@ -172,7 +173,15 @@ def compute_amax(self, verbose: bool = False):


class NVFP4MSECalibrator(MseCalibrator):
"""Per-block FP8 scale sweep calibrator for NVFP4 static quantization."""
"""Per-block FP8 scale sweep calibrator for NVFP4 static quantization.

Uses a fused Triton kernel as an internal fast path on the first ``collect`` call
when (a) ``error_func is None``, (b) the input tensor is on CUDA in the standard
blocked ``[n_blocks, block_size]`` layout, and (c) Triton + the kernel package are
importable. Falls back to the reference 126-step Python sweep otherwise (custom
``error_func`` users, multi-``collect`` activation flows, CPU inputs, or when the
fast path is disabled via ``MODELOPT_NVFP4_TRITON_SWEEP=0``).
"""

def __init__(
self,
Expand All @@ -185,16 +194,86 @@ def __init__(
"""Initialize NVFP4 MSE calibrator with per-block and global amax."""
super().__init__(amax=amax, axis=axis, quant_func=quant_func, error_func=error_func)
self._global_amax = global_amax
# Set by the Triton fast path on its (one-shot) collect; consumed by compute_amax.
self._best_amax_fast: torch.Tensor | None = None

def _compute_candidate_amax(self, candidates: torch.Tensor) -> torch.Tensor:
if candidates.ndim != 0: # Called during final compute amax
candidates = candidates.view_as(self._initial_amax)
return torch.ones_like(self._initial_amax) * self._global_amax * candidates

def _generate_candidates(self, device: torch.device) -> torch.Tensor:
"""Generate 126 valid FP8 E4M3 scale candidates."""
uint8_values = torch.arange(0, 128, dtype=torch.uint8, device=device)
fp8_values = uint8_values.view(torch.float8_e4m3fn).float()
valid_mask = torch.isfinite(fp8_values) & (fp8_values > 0)
fp8_values = fp8_values[valid_mask]
return fp8_values / 448.0
"""Generate the 126 valid FP8 E4M3 scale candidates."""
from modelopt.torch.kernels.quantization.gemm._fp8_scale_candidates import (
fp8_scale_candidates,
)

return fp8_scale_candidates(device)

def _can_use_triton_fast_path(self, x: torch.Tensor) -> bool:
"""Whether the Triton fast path is usable for this ``collect`` input.

The kernel produces the final per-block amax in one shot, so it's only usable
when the caller wants the standard squared-error sweep on a single CUDA tensor
whose layout already matches the per-block amax.
"""
if self._error_func is not None:
return False
if not x.is_cuda:
return False
if os.environ.get("MODELOPT_NVFP4_TRITON_SWEEP", "1") == "0":
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jenchen13 see here ->

return False
if self._initial_amax is None:
return False
if x.ndim != 2 or x.shape[0] != int(self._initial_amax.numel()):
return False
try:
from modelopt.torch.kernels.quantization.gemm import nvfp4_fp8_scale_sweep # noqa: F401
except ImportError:
return False
return True

@torch.no_grad()
def collect(self, x: torch.Tensor):
"""Collect input statistics. Uses the Triton fast path when eligible."""
if self._best_amax_fast is not None:
raise RuntimeError(
"NVFP4MSECalibrator: the Triton fast path produced a final amax on a "
"previous collect() call; multi-collect after the fast path is not "
"supported. Call reset() to start a fresh cycle, set "
"MODELOPT_NVFP4_TRITON_SWEEP=0, or pass a non-None error_func to force "
"the reference path for activation-style accumulation."
)
# Fast path is eligible only on the first call, before the reference accumulator
# has produced any state.
if self._losses_sum is None and self._can_use_triton_fast_path(x):
from modelopt.torch.kernels.quantization.gemm import nvfp4_fp8_scale_sweep

best_flat = nvfp4_fp8_scale_sweep(x.detach(), self._global_amax, block_size=x.shape[-1])
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[SUGGESTION]

Minor: The reference path in MseCalibrator.collect (line 93) does x = x.detach().to(dtype=torch.float32) before computing losses, promoting the input to fp32 on the host side. Here the Triton path passes x.detach() without explicit fp32 promotion. This is fine because the kernel loads with .to(tl.float32) internally — but it means the kernel receives, e.g., a bf16 tensor and Triton must handle the load-conversion. Just noting for future readers that the fp32 semantics are maintained inside the kernel, not at the call site.

# Match the original shape/dtype of the initial amax so downstream
# load_calib_amax behaves identically to the reference path.
self._best_amax_fast = best_flat.reshape(self._initial_amax.shape).to(
self._initial_amax.dtype
)
return
super().collect(x)

@torch.no_grad()
def compute_amax(self, verbose: bool = False):
"""Return the per-block amax — from the fast path if it ran, else from the reference sweep."""
if self._best_amax_fast is not None:
return self._best_amax_fast
return super().compute_amax(verbose=verbose)

def reset(self):
"""Reset per-cycle state. Keep ``_initial_amax`` so the calibrator stays reusable.

``MseCalibrator.reset()`` intentionally drops ``_initial_amax`` to free memory in
the multi-step search, but the NVFP4 per-block amax is shape ``[num_blocks]`` —
small enough to keep so a follow-up ``collect()`` can run again on the same
calibrator instance.
"""
self._best_amax_fast = None
self._losses_sum = None
self._candidates = None
self._amax = None
4 changes: 3 additions & 1 deletion modelopt/torch/quantization/model_calib.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,9 @@ def mse_calibrate(
continue

if fp8_scale_sweep and is_nvfp4_static:
# Replace calibrator with NVFP4MSECalibrator
# NVFP4MSECalibrator internally selects a fused Triton kernel for
# the standard squared-error sweep; set MODELOPT_NVFP4_TRITON_SWEEP=0
# to force the reference Python sweep for debugging.
module._calibrator = NVFP4MSECalibrator(
amax=initial_amax,
axis=module._calibrator._axis,
Expand Down
Loading
Loading