-
Notifications
You must be signed in to change notification settings - Fork 390
[Quantization] Fused Triton kernel for NVFP4 FP8 scale sweep search #1387
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
38148d0
7a132d7
4644ed1
c4be1bb
c10f784
8f04a9a
2bc8a54
c42de7f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,31 @@ | ||
| # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| """Single source of truth for the NVFP4 FP8 scale-candidate set. | ||
|
|
||
| Pure PyTorch, no Triton dependency, so it can be imported from both the kernel | ||
| wrapper (which is triton-gated) and the reference Python sweep in the | ||
| :class:`NVFP4MSECalibrator` (which must work without triton too). | ||
| """ | ||
|
|
||
| import torch | ||
|
|
||
|
|
||
| def fp8_scale_candidates(device: torch.device | str = "cpu") -> torch.Tensor: | ||
| """Return the 126 valid finite positive FP8 E4M3 scale candidates / 448.""" | ||
| uint8_values = torch.arange(0, 128, dtype=torch.uint8, device=device) | ||
| fp8_values = uint8_values.view(torch.float8_e4m3fn).float() | ||
| valid_mask = torch.isfinite(fp8_values) & (fp8_values > 0) | ||
| return fp8_values[valid_mask] / 448.0 |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,150 @@ | ||
| # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| """Fused Triton kernel for the NVFP4 weight-MSE FP8 scale sweep. | ||
|
|
||
| Replaces the 126-iteration Python sweep in :class:`NVFP4MSECalibrator` with a single | ||
| kernel that, for each NVFP4 block, evaluates all 126 valid FP8 E4M3 scale candidates | ||
| and emits the per-block ``best_amax`` directly. | ||
|
|
||
| The 126 candidates are constructed as ``valid_fp8_e4m3_value / 448`` (see | ||
| :func:`fp8_scale_candidates`). For these specific candidates, the FP8 round-trip on | ||
| the per-block scale is the identity, so the kernel can use | ||
| ``scale = candidate * global_amax / 6.0`` without an explicit FP8 cast — making it | ||
| runnable on any CUDA GPU with Triton (no ``tl.float8e4nv`` requirement). | ||
|
|
||
| Tile shape (``BLOCKS_PER_PROGRAM``) and ``num_warps`` are autotuned per ``N_BLOCKS``. | ||
| """ | ||
|
|
||
| import torch | ||
| import triton | ||
| import triton.language as tl | ||
|
|
||
| from ._fp8_scale_candidates import fp8_scale_candidates | ||
| from .nvfp4_quant import fp4_round_magnitude | ||
|
|
||
| __all__ = ["fp8_scale_candidates", "nvfp4_fp8_scale_sweep"] | ||
|
|
||
|
|
||
| # Selected from a (BLOCKS_PER_PROGRAM, num_warps) sweep on B300: | ||
| # BPP=16,nw=2: 6.06 ms BPP=32,nw=4: 6.06 ms BPP=64,nw=8: 5.08 ms | ||
| # The smaller-tile entries cover cases where N_BLOCKS is small enough that BPP=64 | ||
| # would underfill the SMs. | ||
| _FP8_SWEEP_AUTOTUNE_CONFIGS = [ | ||
| triton.Config({"BLOCKS_PER_PROGRAM": 16}, num_warps=2), | ||
| triton.Config({"BLOCKS_PER_PROGRAM": 32}, num_warps=4), | ||
| triton.Config({"BLOCKS_PER_PROGRAM": 64}, num_warps=8), | ||
| ] | ||
|
|
||
|
|
||
| @triton.autotune(configs=_FP8_SWEEP_AUTOTUNE_CONFIGS, key=["N_BLOCKS"]) | ||
| @triton.jit | ||
| def _fp8_scale_sweep_kernel( | ||
| x_ptr, # [N_BLOCKS * BLOCK_SIZE], any float dtype (loaded as fp32) | ||
| candidates_ptr, # [NUM_CANDIDATES] fp32 | ||
| global_amax_ptr, # scalar fp32 | ||
| best_amax_ptr, # [N_BLOCKS] fp32 output | ||
| N_BLOCKS, | ||
| BLOCK_SIZE: tl.constexpr, | ||
| NUM_CANDIDATES: tl.constexpr, | ||
| BLOCKS_PER_PROGRAM: tl.constexpr, | ||
| ): | ||
| pid = tl.program_id(axis=0) | ||
| block_start = pid * BLOCKS_PER_PROGRAM | ||
| block_idx = block_start + tl.arange(0, BLOCKS_PER_PROGRAM) | ||
| block_mask = block_idx < N_BLOCKS | ||
|
|
||
| # Load weights for this tile and pre-compute their absolute values once. | ||
| # The squared error is sign-invariant since FP4 quant preserves sign: | ||
| # (w - w_q)^2 = (|w| - |w_q|)^2 = (|w| - q_mag * scale)^2 | ||
| # so we never need ``w`` itself again, dropping a tl.where + negation per element. | ||
| elem_offs = block_idx[:, None] * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)[None, :] | ||
| elem_mask = block_mask[:, None] | ||
| w_abs = tl.abs(tl.load(x_ptr + elem_offs, mask=elem_mask, other=0.0).to(tl.float32)) | ||
|
|
||
| global_amax = tl.load(global_amax_ptr).to(tl.float32) | ||
|
|
||
| best_loss = tl.full([BLOCKS_PER_PROGRAM], float("inf"), dtype=tl.float32) | ||
| best_idx = tl.zeros([BLOCKS_PER_PROGRAM], dtype=tl.int32) | ||
|
|
||
| # Loop over the 126 FP8 candidates (compile-time unrolled). | ||
| # Scales are guaranteed positive and finite (constructed from a positive candidate | ||
| # times nonneg global_amax), so the degenerate-scale guard from nvfp4_scalar_quant is | ||
| # unnecessary apart from the global_amax == 0 case handled below. | ||
| for k in tl.static_range(NUM_CANDIDATES): | ||
| c = tl.load(candidates_ptr + k).to(tl.float32) | ||
| scale = c * global_amax / 6.0 | ||
| # Avoid divide-by-zero when global_amax == 0; in that case w_abs is also zero | ||
| # (global_amax = max|w|), so the loss is zero for every candidate either way. | ||
| scale_safe = tl.where(scale == 0.0, 1.0, scale) | ||
| q_mag = fp4_round_magnitude(w_abs / scale_safe) | ||
| diff = w_abs - q_mag * scale_safe | ||
| loss = tl.sum(diff * diff, axis=1) # [BLOCKS_PER_PROGRAM] | ||
| is_better = loss < best_loss | ||
| best_loss = tl.where(is_better, loss, best_loss) | ||
| best_idx = tl.where(is_better, k, best_idx) | ||
|
|
||
| # Map each block's winning candidate index back to its amax = global_amax * c[best]. | ||
| best_c = tl.load(candidates_ptr + best_idx, mask=block_mask, other=0.0).to(tl.float32) | ||
| best_amax = global_amax * best_c | ||
| tl.store(best_amax_ptr + block_idx, best_amax, mask=block_mask) | ||
|
|
||
|
|
||
| def nvfp4_fp8_scale_sweep( | ||
| x: torch.Tensor, | ||
| global_amax: torch.Tensor, | ||
| block_size: int = 16, | ||
| ) -> torch.Tensor: | ||
| """Find the per-block FP8 scale that minimizes NVFP4 quantization MSE. | ||
|
|
||
| Equivalent to the 126-step sweep in :class:`NVFP4MSECalibrator`, but fused into | ||
| a single Triton kernel: every block's weight elements are loaded once, all 126 | ||
| candidates are evaluated in registers, and the running argmin is kept inline. | ||
|
|
||
| Args: | ||
| x: Weight tensor on CUDA. Total element count must be divisible by | ||
| ``block_size``; layout is treated as a flat ``[N_BLOCKS, BLOCK_SIZE]``. | ||
| global_amax: Scalar FP32 global amax (``= reduce_amax(per_block_amax)``). | ||
| block_size: NVFP4 block size (typically 16). | ||
|
|
||
| Returns: | ||
| ``best_amax`` of shape ``[N_BLOCKS]``, fp32, on the same device as ``x``. | ||
| """ | ||
| if not x.is_cuda: | ||
| raise ValueError("nvfp4_fp8_scale_sweep requires a CUDA tensor.") | ||
| if not isinstance(block_size, int) or block_size <= 0: | ||
| raise ValueError(f"block_size must be a positive int, got {block_size!r}.") | ||
| if x.numel() % block_size != 0: | ||
| raise ValueError(f"x.numel() ({x.numel()}) is not divisible by block_size ({block_size}).") | ||
|
coderabbitai[bot] marked this conversation as resolved.
|
||
|
|
||
| candidates = fp8_scale_candidates(x.device).to(dtype=torch.float32) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [IMPORTANT Performance]
The allocation is small (504 bytes) and the cost is negligible compared to the kernel itself, so this is not a blocking issue. However, if you want to eliminate it, a module-level cache (e.g., Not urgent given the 42x speedup context, but would be a clean follow-up. |
||
|
|
||
| n_blocks = x.numel() // block_size | ||
| x_flat = x.contiguous().view(-1) | ||
| global_amax_f32 = global_amax.detach().to(device=x.device, dtype=torch.float32).reshape(1) | ||
| best_amax = torch.empty(n_blocks, dtype=torch.float32, device=x.device) | ||
|
|
||
| grid = lambda meta: (triton.cdiv(n_blocks, meta["BLOCKS_PER_PROGRAM"]),) | ||
| with torch.cuda.device(x.device): | ||
| _fp8_scale_sweep_kernel[grid]( | ||
| x_flat, | ||
| candidates, | ||
| global_amax_f32, | ||
| best_amax, | ||
| n_blocks, | ||
| BLOCK_SIZE=block_size, | ||
| NUM_CANDIDATES=int(candidates.numel()), | ||
| ) | ||
| return best_amax | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -16,6 +16,7 @@ | |
| """Calibrator that returns the MSE amax of all collected tensors.""" | ||
|
|
||
| import math | ||
| import os | ||
| from collections.abc import Callable | ||
|
|
||
| import torch | ||
|
|
@@ -172,7 +173,15 @@ def compute_amax(self, verbose: bool = False): | |
|
|
||
|
|
||
| class NVFP4MSECalibrator(MseCalibrator): | ||
| """Per-block FP8 scale sweep calibrator for NVFP4 static quantization.""" | ||
| """Per-block FP8 scale sweep calibrator for NVFP4 static quantization. | ||
|
|
||
| Uses a fused Triton kernel as an internal fast path on the first ``collect`` call | ||
| when (a) ``error_func is None``, (b) the input tensor is on CUDA in the standard | ||
| blocked ``[n_blocks, block_size]`` layout, and (c) Triton + the kernel package are | ||
| importable. Falls back to the reference 126-step Python sweep otherwise (custom | ||
| ``error_func`` users, multi-``collect`` activation flows, CPU inputs, or when the | ||
| fast path is disabled via ``MODELOPT_NVFP4_TRITON_SWEEP=0``). | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
|
|
@@ -185,16 +194,86 @@ def __init__( | |
| """Initialize NVFP4 MSE calibrator with per-block and global amax.""" | ||
| super().__init__(amax=amax, axis=axis, quant_func=quant_func, error_func=error_func) | ||
| self._global_amax = global_amax | ||
| # Set by the Triton fast path on its (one-shot) collect; consumed by compute_amax. | ||
| self._best_amax_fast: torch.Tensor | None = None | ||
|
|
||
| def _compute_candidate_amax(self, candidates: torch.Tensor) -> torch.Tensor: | ||
| if candidates.ndim != 0: # Called during final compute amax | ||
| candidates = candidates.view_as(self._initial_amax) | ||
| return torch.ones_like(self._initial_amax) * self._global_amax * candidates | ||
|
|
||
| def _generate_candidates(self, device: torch.device) -> torch.Tensor: | ||
| """Generate 126 valid FP8 E4M3 scale candidates.""" | ||
| uint8_values = torch.arange(0, 128, dtype=torch.uint8, device=device) | ||
| fp8_values = uint8_values.view(torch.float8_e4m3fn).float() | ||
| valid_mask = torch.isfinite(fp8_values) & (fp8_values > 0) | ||
| fp8_values = fp8_values[valid_mask] | ||
| return fp8_values / 448.0 | ||
| """Generate the 126 valid FP8 E4M3 scale candidates.""" | ||
| from modelopt.torch.kernels.quantization.gemm._fp8_scale_candidates import ( | ||
| fp8_scale_candidates, | ||
| ) | ||
|
|
||
| return fp8_scale_candidates(device) | ||
|
|
||
| def _can_use_triton_fast_path(self, x: torch.Tensor) -> bool: | ||
| """Whether the Triton fast path is usable for this ``collect`` input. | ||
|
|
||
| The kernel produces the final per-block amax in one shot, so it's only usable | ||
| when the caller wants the standard squared-error sweep on a single CUDA tensor | ||
| whose layout already matches the per-block amax. | ||
| """ | ||
| if self._error_func is not None: | ||
| return False | ||
| if not x.is_cuda: | ||
| return False | ||
| if os.environ.get("MODELOPT_NVFP4_TRITON_SWEEP", "1") == "0": | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @jenchen13 see here -> |
||
| return False | ||
| if self._initial_amax is None: | ||
| return False | ||
| if x.ndim != 2 or x.shape[0] != int(self._initial_amax.numel()): | ||
| return False | ||
| try: | ||
| from modelopt.torch.kernels.quantization.gemm import nvfp4_fp8_scale_sweep # noqa: F401 | ||
| except ImportError: | ||
| return False | ||
| return True | ||
|
|
||
| @torch.no_grad() | ||
| def collect(self, x: torch.Tensor): | ||
| """Collect input statistics. Uses the Triton fast path when eligible.""" | ||
| if self._best_amax_fast is not None: | ||
| raise RuntimeError( | ||
| "NVFP4MSECalibrator: the Triton fast path produced a final amax on a " | ||
| "previous collect() call; multi-collect after the fast path is not " | ||
| "supported. Call reset() to start a fresh cycle, set " | ||
| "MODELOPT_NVFP4_TRITON_SWEEP=0, or pass a non-None error_func to force " | ||
| "the reference path for activation-style accumulation." | ||
| ) | ||
| # Fast path is eligible only on the first call, before the reference accumulator | ||
| # has produced any state. | ||
| if self._losses_sum is None and self._can_use_triton_fast_path(x): | ||
| from modelopt.torch.kernels.quantization.gemm import nvfp4_fp8_scale_sweep | ||
|
|
||
| best_flat = nvfp4_fp8_scale_sweep(x.detach(), self._global_amax, block_size=x.shape[-1]) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [SUGGESTION] Minor: The reference path in |
||
| # Match the original shape/dtype of the initial amax so downstream | ||
| # load_calib_amax behaves identically to the reference path. | ||
| self._best_amax_fast = best_flat.reshape(self._initial_amax.shape).to( | ||
| self._initial_amax.dtype | ||
| ) | ||
| return | ||
| super().collect(x) | ||
|
|
||
| @torch.no_grad() | ||
| def compute_amax(self, verbose: bool = False): | ||
| """Return the per-block amax — from the fast path if it ran, else from the reference sweep.""" | ||
| if self._best_amax_fast is not None: | ||
| return self._best_amax_fast | ||
| return super().compute_amax(verbose=verbose) | ||
|
|
||
| def reset(self): | ||
| """Reset per-cycle state. Keep ``_initial_amax`` so the calibrator stays reusable. | ||
|
|
||
| ``MseCalibrator.reset()`` intentionally drops ``_initial_amax`` to free memory in | ||
| the multi-step search, but the NVFP4 per-block amax is shape ``[num_blocks]`` — | ||
| small enough to keep so a follow-up ``collect()`` can run again on the same | ||
| calibrator instance. | ||
| """ | ||
| self._best_amax_fast = None | ||
| self._losses_sum = None | ||
| self._candidates = None | ||
| self._amax = None | ||
Uh oh!
There was an error while loading. Please reload this page.