Skip to content
Draft
36 changes: 21 additions & 15 deletions examples/llm_ptq/scripts/huggingface_example.sh
Original file line number Diff line number Diff line change
Expand Up @@ -49,18 +49,7 @@ dense | sparsegpt) ;;
;;
esac

#Iterate over list of qformats provided and check if they are valid
IFS=","
for qformat in $QFORMAT; do
case $qformat in
fp8 | fp8_pc_pt | fp8_pb_wo | int8_wo | int8_sq | int4_awq | w4a8_awq | fp16 | bf16 | nvfp4 | nvfp4_awq | nvfp4_mse | w4a8_nvfp4_fp8 | w4a8_mxfp4_fp8 | nvfp4_experts_only | nvfp4_mlp_only | nvfp4_omlp_only | nvfp4_svdquant | mxfp8 | nvfp4_local_hessian) ;;
*)
echo "Unknown quant argument: Expected one of: [fp8, fp8_pc_pt, fp8_pb_wo, int8_wo, int8_sq, int4_awq, w4a8_awq, fp16, bf16, nvfp4, nvfp4_awq, nvfp4_mse, w4a8_nvfp4_fp8, w4a8_mxfp4_fp8, nvfp4_experts_only, nvfp4_mlp_only, nvfp4_omlp_only, nvfp4_svdquant, mxfp8, nvfp4_local_hessian]" >&2
exit 1
;;
esac
done
IFS=" "
# Quant format / recipe validation is delegated to hf_ptq.py.

script_dir="$(dirname "$(readlink -f "$0")")"

Expand All @@ -72,7 +61,14 @@ fi

QFORMAT_MODIFIED="${QFORMAT//,/_}"

MODEL_NAME=$(basename $MODEL_PATH | sed 's/[^0-9a-zA-Z\-]/_/g')_${QFORMAT_MODIFIED}${KV_CACHE_QUANT:+_kv_${KV_CACHE_QUANT}}
# When using --recipe, build the model name from the recipe basename (without
# directory or .yaml suffix) so each recipe gets its own SAVE_PATH.
if [ -n "$RECIPE" ]; then
RECIPE_TAG=$(basename "$RECIPE" .yaml | sed 's/[^0-9a-zA-Z\-]/_/g')
MODEL_NAME=$(basename $MODEL_PATH | sed 's/[^0-9a-zA-Z\-]/_/g')_recipe_${RECIPE_TAG}
else
MODEL_NAME=$(basename $MODEL_PATH | sed 's/[^0-9a-zA-Z\-]/_/g')_${QFORMAT_MODIFIED}${KV_CACHE_QUANT:+_kv_${KV_CACHE_QUANT}}
fi

SAVE_PATH=${ROOT_SAVE_PATH}/saved_models_${MODEL_NAME}

Expand Down Expand Up @@ -177,11 +173,16 @@ if [[ $TASKS =~ "quant" ]] || [[ ! -d "$SAVE_PATH" ]] || [[ ! $(ls -A $SAVE_PATH

if [[ "$MODEL_CONFIG_EXIST" == false ]]; then
echo "Quantizing original model..."
if [ -n "$RECIPE" ]; then
QUANT_SPEC_ARGS="--recipe=$RECIPE"
else
QUANT_SPEC_ARGS="--qformat=${QFORMAT// /,}"
fi
python hf_ptq.py \
--pyt_ckpt_path=$MODEL_PATH \
--export_path=$SAVE_PATH \
--sparsity_fmt=$SPARSITY_FMT \
--qformat="${QFORMAT// /,}" \
$QUANT_SPEC_ARGS \
--calib_size=$CALIB_SIZE \
--batch_size=$CALIB_BATCH_SIZE \
--inference_tensor_parallel=$TP \
Expand All @@ -203,7 +204,7 @@ if [[ $TASKS =~ "quant" ]] || [[ ! -d "$SAVE_PATH" ]] || [[ ! $(ls -A $SAVE_PATH
exit 0
fi

if [[ "$QFORMAT" == *"nvfp4"* ]] || [[ "$KV_CACHE_QUANT" == *"nvfp4"* ]]; then
if [[ "$QFORMAT" == *"nvfp4"* ]] || [[ "$KV_CACHE_QUANT" == *"nvfp4"* ]] || [[ "$RECIPE" == *"nvfp4"* ]]; then
cuda_major=$(nvidia-smi --query-gpu=compute_cap --format=csv,noheader -i 0 | cut -d. -f1)

if [ "$cuda_major" -lt 10 ]; then
Expand All @@ -212,6 +213,11 @@ if [[ $TASKS =~ "quant" ]] || [[ ! -d "$SAVE_PATH" ]] || [[ ! $(ls -A $SAVE_PATH
fi
fi

if [ -n "$RECIPE" ]; then
echo "Recipe $RECIPE used. Please deploy with TensorRT-LLM directly. Checkpoint export_path: $SAVE_PATH"
exit 0
fi

if [[ ! " fp8 nvfp4 bf16 fp16 " =~ " ${QFORMAT} " ]]; then
echo "Quant $QFORMAT specified. Please read TensorRT-LLM quantization support matrix https://nvidia.github.io/TensorRT-LLM/features/quantization.html#quantization-in-tensorrt-llm and use TensorRT-LLM for deployment. Checkpoint export_path: $SAVE_PATH"
exit 0
Expand Down
16 changes: 13 additions & 3 deletions examples/llm_ptq/scripts/parser.sh
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ parse_options() {
# Default values
MODEL_PATH=""
QFORMAT=""
RECIPE=""
KV_CACHE_QUANT=""
TP=1
PP=1
Expand All @@ -37,13 +38,14 @@ parse_options() {
CAST_MXFP4_TO_NVFP4=false

# Parse command-line options
ARGS=$(getopt -o "" -l "model:,quant:,kv_cache_quant:,tp:,pp:,sparsity:,awq_block_size:,calib:,calib_batch_size:,auto_quantize_bits:,output:,batch:,tasks:,lm_eval_tasks:,lm_eval_limit:,simple_eval_tasks:,trust_remote_code,use_seq_device_map,gpu_max_mem_percentage:,kv_cache_free_gpu_memory_fraction:,low_memory_mode,no-verbose,calib_dataset:,calib_seq:,auto_quantize_method:,auto_quantize_score_size:,auto_quantize_checkpoint:,moe_calib_experts_ratio:,cast_mxfp4_to_nvfp4" -n "$0" -- "$@")
ARGS=$(getopt -o "" -l "model:,quant:,recipe:,kv_cache_quant:,tp:,pp:,sparsity:,awq_block_size:,calib:,calib_batch_size:,auto_quantize_bits:,output:,batch:,tasks:,lm_eval_tasks:,lm_eval_limit:,simple_eval_tasks:,trust_remote_code,use_seq_device_map,gpu_max_mem_percentage:,kv_cache_free_gpu_memory_fraction:,low_memory_mode,no-verbose,calib_dataset:,calib_seq:,auto_quantize_method:,auto_quantize_score_size:,auto_quantize_checkpoint:,moe_calib_experts_ratio:,cast_mxfp4_to_nvfp4" -n "$0" -- "$@")

eval set -- "$ARGS"
while true; do
case "$1" in
--model ) MODEL_PATH="$2"; shift 2;;
--quant ) QFORMAT="$2"; shift 2;;
--recipe ) RECIPE="$2"; shift 2;;
--kv_cache_quant ) KV_CACHE_QUANT="$2"; shift 2;;
--tp ) TP="$2"; shift 2;;
--pp ) PP="$2"; shift 2;;
Expand Down Expand Up @@ -99,12 +101,19 @@ parse_options() {
fi

# Verify required options are provided
if [ -z "$MODEL_PATH" ] || [ -z "$QFORMAT" ] || [ -z "$TASKS" ]; then
echo "Usage: $0 --model=<MODEL_PATH> --quant=<QFORMAT> --tasks=<TASK,...>"
if [ -z "$MODEL_PATH" ] || [ -z "$TASKS" ] || ([ -z "$QFORMAT" ] && [ -z "$RECIPE" ]); then
echo "Usage: $0 --model=<MODEL_PATH> (--quant=<QFORMAT> | --recipe=<RECIPE>) --tasks=<TASK,...>"
echo "Optional args: --sparsity=<SPARSITY_FMT> --awq_block_size=<AWQ_BLOCK_SIZE> --calib=<CALIB_SIZE>"
exit 1
fi

# --quant and --recipe are mutually exclusive: --recipe is a full PTQ spec, while
# --quant selects a built-in qformat preset. Pick exactly one.
if [ -n "$QFORMAT" ] && [ -n "$RECIPE" ]; then
echo "Cannot specify both --quant and --recipe; pick one." >&2
exit 1
fi

VALID_TASKS=("quant" "mmlu" "lm_eval" "livecodebench" "simple_eval")

for task in $(echo "$TASKS" | tr ',' ' '); do
Expand Down Expand Up @@ -135,6 +144,7 @@ parse_options() {
echo "================="
echo "model: $MODEL_PATH"
echo "quant: $QFORMAT"
echo "recipe: $RECIPE"
echo "tp (TensorRT-LLM Checkpoint only): $TP"
echo "pp (TensorRT-LLM Checkpoint only): $PP"
echo "sparsity: $SPARSITY_FMT"
Expand Down
16 changes: 15 additions & 1 deletion modelopt/torch/export/moe_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,11 +87,25 @@ def _export_fused_experts(module: nn.Module, dtype: torch.dtype) -> None:
and w_quantizer._amax.dim() >= 1
):
amax = w_quantizer._amax
# Static block-quant calibration (e.g. NVFP4 MSE FP8 sweep)
# produces a per-block _amax with shape (num_blocks_total, ...)
# where num_blocks_total = fused_total * blocks_per_row. That
# shape collapses the row axis we want to slice on. Restore the
# row dimension so the dim-0 slicing below splits gate / up
# correctly. No-op when _amax is already aligned with fused_total.
if amax.numel() != fused_total and amax.numel() % fused_total == 0:
amax = amax.contiguous().view(fused_total, amax.numel() // fused_total)
amax_dim0 = amax.shape[0]
if fused_total % amax_dim0 == 0:
slice_start = fused_start * amax_dim0 // fused_total
slice_end = (fused_start + weight_slice.shape[0]) * amax_dim0 // fused_total
w_quantizer.amax = amax[slice_start:slice_end].contiguous()
sliced = amax[slice_start:slice_end].contiguous()
# The amax setter refuses shape changes once `_amax` exists,
# so drop the existing buffer before re-registering with the
# sliced shape.
if hasattr(w_quantizer, "_amax"):
delattr(w_quantizer, "_amax")
w_quantizer.amax = sliced
else:
warnings.warn(
f"Expert {idx} {proj_name}: fused amax dim0 ({amax_dim0}) does not "
Expand Down
32 changes: 32 additions & 0 deletions modelopt/torch/export/unified_export_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -1130,6 +1130,32 @@ def _unpatch_revert_weight_conversion(patches: list[tuple[Any, Any]]) -> None:
mod.revert_weight_conversion = original


def _sanitize_generation_config_for_save(model: torch.nn.Module) -> None:
"""Coerce ``model.generation_config`` so it passes transformers' strict validation.

Some upstream HF checkpoints ship a ``generation_config.json`` that mixes
``do_sample=False`` with sampling-only attrs (``top_p``, ``top_k``, ...).
Newer transformers raise ``ValueError("GenerationConfig is invalid: ...")``
inside ``save_pretrained``, blocking export. We try a strict validate and
on failure flip ``do_sample`` to ``True`` so the upstream sampling intent
is preserved (rather than silently dropping ``top_p`` etc.). Quietly does
nothing if the model has no generation_config or it's already valid.
"""
gc = getattr(model, "generation_config", None)
if gc is None or not hasattr(gc, "validate"):
return
try:
gc.validate(strict=True)
return
except Exception:
pass
if not getattr(gc, "do_sample", False):
try:
gc.do_sample = True
except Exception:
pass


def export_speculative_decoding(
model: torch.nn.Module,
dtype: torch.dtype | None = None,
Expand Down Expand Up @@ -1211,6 +1237,12 @@ def export_hf_checkpoint(
# modeling_utils does `from core_model_loading import revert_weight_conversion`.
_patches = _patch_revert_weight_conversion()

# Some upstream HF checkpoints ship a generation_config.json that fails
# transformers' strict validation on save (e.g. ``top_p`` set without
# ``do_sample=True`` — newer transformers raises). Flip ``do_sample`` to
# the sampling-attrs intent so save_pretrained can write the file.
_sanitize_generation_config_for_save(model)

try:
model.save_pretrained(
export_dir,
Expand Down
1 change: 1 addition & 0 deletions modelopt/torch/kernels/quantization/gemm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
# fp4_kernel works on any CUDA GPU with triton
from .fp4_kernel import *
from .fp8_kernel import *
from .nvfp4_fp8_sweep import *

# fp4_kernel_hopper requires compute >= 8.9 (uses tl.float8e4nv)
if torch.cuda.get_device_capability() >= (8, 9):
Expand Down
166 changes: 166 additions & 0 deletions modelopt/torch/kernels/quantization/gemm/nvfp4_fp8_sweep.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Fused Triton kernel for the NVFP4 weight-MSE FP8 scale sweep.

Replaces the 126-iteration Python sweep in :class:`NVFP4MSECalibrator` with a single
kernel that, for each NVFP4 block, evaluates all 126 valid FP8 E4M3 scale candidates
and emits the per-block ``best_amax`` directly.

The 126 candidates are constructed as ``valid_fp8_e4m3_value / 448`` (see
:func:`fp8_scale_candidates`). For these specific candidates, the FP8 round-trip on
the per-block scale is the identity, so the kernel can use
``scale = candidate * global_amax / 6.0`` without an explicit FP8 cast — making it
runnable on any CUDA GPU with Triton (no ``tl.float8e4nv`` requirement).

Tile shape (``BLOCKS_PER_PROGRAM``) and ``num_warps`` are autotuned per ``N_BLOCKS``.
"""

import torch
import triton
import triton.language as tl

from .nvfp4_quant import fp4_round_magnitude

__all__ = ["fp8_scale_candidates", "nvfp4_fp8_scale_sweep"]


def fp8_scale_candidates(device: torch.device | str = "cpu") -> torch.Tensor:
"""Return the 126 valid finite positive FP8 E4M3 scale candidates / 448."""
uint8_values = torch.arange(0, 128, dtype=torch.uint8, device=device)
fp8_values = uint8_values.view(torch.float8_e4m3fn).float()
valid_mask = torch.isfinite(fp8_values) & (fp8_values > 0)
return fp8_values[valid_mask] / 448.0


# Selected from a (BLOCKS_PER_PROGRAM, num_warps) sweep on B300:
# BPP=16,nw=2: 6.06 ms BPP=32,nw=4: 6.06 ms BPP=64,nw=8: 5.08 ms
# The smaller-tile entries cover cases where N_BLOCKS is small enough that BPP=64
# would underfill the SMs.
_FP8_SWEEP_AUTOTUNE_CONFIGS = [
triton.Config({"BLOCKS_PER_PROGRAM": 16}, num_warps=2),
triton.Config({"BLOCKS_PER_PROGRAM": 32}, num_warps=4),
triton.Config({"BLOCKS_PER_PROGRAM": 64}, num_warps=8),
]


@triton.autotune(configs=_FP8_SWEEP_AUTOTUNE_CONFIGS, key=["N_BLOCKS"])
@triton.jit
def _fp8_scale_sweep_kernel(
x_ptr, # [N_BLOCKS * BLOCK_SIZE], any float dtype (loaded as fp32)
candidates_ptr, # [NUM_CANDIDATES] fp32
global_amax_ptr, # scalar fp32
best_amax_ptr, # [N_BLOCKS] fp32 output
N_BLOCKS,
BLOCK_SIZE: tl.constexpr,
NUM_CANDIDATES: tl.constexpr,
BLOCKS_PER_PROGRAM: tl.constexpr,
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCKS_PER_PROGRAM
block_idx = block_start + tl.arange(0, BLOCKS_PER_PROGRAM)
block_mask = block_idx < N_BLOCKS

# Load weights for this tile and pre-compute their absolute values once.
# The squared error is sign-invariant since FP4 quant preserves sign:
# (w - w_q)^2 = (|w| - |w_q|)^2 = (|w| - q_mag * scale)^2
# so we never need ``w`` itself again, dropping a tl.where + negation per element.
elem_offs = block_idx[:, None] * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)[None, :]
elem_mask = block_mask[:, None]
w_abs = tl.abs(tl.load(x_ptr + elem_offs, mask=elem_mask, other=0.0).to(tl.float32))

global_amax = tl.load(global_amax_ptr).to(tl.float32)

best_loss = tl.full([BLOCKS_PER_PROGRAM], float("inf"), dtype=tl.float32)
best_idx = tl.zeros([BLOCKS_PER_PROGRAM], dtype=tl.int32)

# Loop over the 126 FP8 candidates (compile-time unrolled).
# Scales are guaranteed positive and finite (constructed from a positive candidate
# times nonneg global_amax), so the degenerate-scale guard from nvfp4_scalar_quant is
# unnecessary apart from the global_amax == 0 case handled below.
for k in tl.static_range(NUM_CANDIDATES):
c = tl.load(candidates_ptr + k).to(tl.float32)
scale = c * global_amax / 6.0
# Avoid divide-by-zero when global_amax == 0; the resulting err == w_abs² is
# the same for every candidate, so any best_idx is fine.
scale_safe = tl.where(scale == 0.0, 1.0, scale)
q_mag = fp4_round_magnitude(w_abs / scale_safe)
diff = w_abs - q_mag * scale
loss = tl.sum(diff * diff, axis=1) # [BLOCKS_PER_PROGRAM]
is_better = loss < best_loss
best_loss = tl.where(is_better, loss, best_loss)
best_idx = tl.where(is_better, k, best_idx)

# Map each block's winning candidate index back to its amax = global_amax * c[best].
best_c = tl.load(candidates_ptr + best_idx, mask=block_mask, other=0.0).to(tl.float32)
best_amax = global_amax * best_c
tl.store(best_amax_ptr + block_idx, best_amax, mask=block_mask)


def nvfp4_fp8_scale_sweep(
x: torch.Tensor,
global_amax: torch.Tensor,
block_size: int = 16,
candidates: torch.Tensor | None = None,
) -> torch.Tensor:
"""Find the per-block FP8 scale that minimizes NVFP4 quantization MSE.

Equivalent to the 126-step sweep in :class:`NVFP4MSECalibrator`, but fused into
a single Triton kernel: every block's weight elements are loaded once, all 126
candidates are evaluated in registers, and the running argmin is kept inline.

Args:
x: Weight tensor on CUDA. Total element count must be divisible by
``block_size``; layout is treated as a flat ``[N_BLOCKS, BLOCK_SIZE]``.
global_amax: Scalar FP32 global amax (``= reduce_amax(per_block_amax)``).
block_size: NVFP4 block size (typically 16).
candidates: Optional precomputed candidate tensor of shape ``[126]`` (must
be the FP8 E4M3 valid values divided by 448). Built lazily if omitted.

Returns:
``best_amax`` of shape ``[N_BLOCKS]``, fp32, on the same device as ``x``.
"""
if not x.is_cuda:
raise ValueError("nvfp4_fp8_scale_sweep requires a CUDA tensor.")
if not isinstance(block_size, int) or block_size <= 0:
raise ValueError(f"block_size must be a positive int, got {block_size!r}.")
if x.numel() % block_size != 0:
raise ValueError(f"x.numel() ({x.numel()}) is not divisible by block_size ({block_size}).")

if candidates is None:
candidates = fp8_scale_candidates(x.device)
candidates = candidates.contiguous().to(device=x.device, dtype=torch.float32)
if candidates.ndim != 1 or candidates.numel() == 0:
raise ValueError(
f"candidates must be a non-empty 1-D tensor; got shape {tuple(candidates.shape)}."
)

n_blocks = x.numel() // block_size
x_flat = x.contiguous().view(-1)
global_amax_f32 = global_amax.detach().to(device=x.device, dtype=torch.float32).reshape(1)
best_amax = torch.empty(n_blocks, dtype=torch.float32, device=x.device)

grid = lambda meta: (triton.cdiv(n_blocks, meta["BLOCKS_PER_PROGRAM"]),)
with torch.cuda.device(x.device):
_fp8_scale_sweep_kernel[grid](
x_flat,
candidates,
global_amax_f32,
best_amax,
n_blocks,
BLOCK_SIZE=block_size,
NUM_CANDIDATES=int(candidates.numel()),
)
return best_amax
Loading
Loading