Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 21 additions & 15 deletions examples/llm_ptq/scripts/huggingface_example.sh
Original file line number Diff line number Diff line change
Expand Up @@ -49,18 +49,7 @@ dense | sparsegpt) ;;
;;
esac

#Iterate over list of qformats provided and check if they are valid
IFS=","
for qformat in $QFORMAT; do
case $qformat in
fp8 | fp8_pc_pt | fp8_pb_wo | int8_wo | int8_sq | int4_awq | w4a8_awq | fp16 | bf16 | nvfp4 | nvfp4_awq | nvfp4_mse | w4a8_nvfp4_fp8 | w4a8_mxfp4_fp8 | nvfp4_experts_only | nvfp4_mlp_only | nvfp4_omlp_only | nvfp4_svdquant | mxfp8 | nvfp4_local_hessian) ;;
*)
echo "Unknown quant argument: Expected one of: [fp8, fp8_pc_pt, fp8_pb_wo, int8_wo, int8_sq, int4_awq, w4a8_awq, fp16, bf16, nvfp4, nvfp4_awq, nvfp4_mse, w4a8_nvfp4_fp8, w4a8_mxfp4_fp8, nvfp4_experts_only, nvfp4_mlp_only, nvfp4_omlp_only, nvfp4_svdquant, mxfp8, nvfp4_local_hessian]" >&2
exit 1
;;
esac
done
IFS=" "
# Quant format / recipe validation is delegated to hf_ptq.py.

script_dir="$(dirname "$(readlink -f "$0")")"

Expand All @@ -72,7 +61,14 @@ fi

QFORMAT_MODIFIED="${QFORMAT//,/_}"

MODEL_NAME=$(basename $MODEL_PATH | sed 's/[^0-9a-zA-Z\-]/_/g')_${QFORMAT_MODIFIED}${KV_CACHE_QUANT:+_kv_${KV_CACHE_QUANT}}
# When using --recipe, build the model name from the recipe basename (without
# directory or .yaml suffix) so each recipe gets its own SAVE_PATH.
if [ -n "$RECIPE" ]; then
RECIPE_TAG=$(basename "$RECIPE" .yaml | sed 's/[^0-9a-zA-Z\-]/_/g')
MODEL_NAME=$(basename $MODEL_PATH | sed 's/[^0-9a-zA-Z\-]/_/g')_recipe_${RECIPE_TAG}
else
MODEL_NAME=$(basename $MODEL_PATH | sed 's/[^0-9a-zA-Z\-]/_/g')_${QFORMAT_MODIFIED}${KV_CACHE_QUANT:+_kv_${KV_CACHE_QUANT}}
fi

SAVE_PATH=${ROOT_SAVE_PATH}/saved_models_${MODEL_NAME}

Expand Down Expand Up @@ -177,11 +173,16 @@ if [[ $TASKS =~ "quant" ]] || [[ ! -d "$SAVE_PATH" ]] || [[ ! $(ls -A $SAVE_PATH

if [[ "$MODEL_CONFIG_EXIST" == false ]]; then
echo "Quantizing original model..."
if [ -n "$RECIPE" ]; then
QUANT_SPEC_ARGS="--recipe=$RECIPE"
else
QUANT_SPEC_ARGS="--qformat=${QFORMAT// /,}"
fi
python hf_ptq.py \
--pyt_ckpt_path=$MODEL_PATH \
--export_path=$SAVE_PATH \
--sparsity_fmt=$SPARSITY_FMT \
--qformat="${QFORMAT// /,}" \
$QUANT_SPEC_ARGS \
--calib_size=$CALIB_SIZE \
--batch_size=$CALIB_BATCH_SIZE \
--inference_tensor_parallel=$TP \
Expand All @@ -203,7 +204,7 @@ if [[ $TASKS =~ "quant" ]] || [[ ! -d "$SAVE_PATH" ]] || [[ ! $(ls -A $SAVE_PATH
exit 0
fi

if [[ "$QFORMAT" == *"nvfp4"* ]] || [[ "$KV_CACHE_QUANT" == *"nvfp4"* ]]; then
if [[ "$QFORMAT" == *"nvfp4"* ]] || [[ "$KV_CACHE_QUANT" == *"nvfp4"* ]] || [[ "$RECIPE" == *"nvfp4"* ]]; then
cuda_major=$(nvidia-smi --query-gpu=compute_cap --format=csv,noheader -i 0 | cut -d. -f1)

if [ "$cuda_major" -lt 10 ]; then
Expand All @@ -212,6 +213,11 @@ if [[ $TASKS =~ "quant" ]] || [[ ! -d "$SAVE_PATH" ]] || [[ ! $(ls -A $SAVE_PATH
fi
fi

if [ -n "$RECIPE" ]; then
echo "Recipe $RECIPE used. Please deploy with TensorRT-LLM directly. Checkpoint export_path: $SAVE_PATH"
exit 0
fi

if [[ ! " fp8 nvfp4 bf16 fp16 " =~ " ${QFORMAT} " ]]; then
echo "Quant $QFORMAT specified. Please read TensorRT-LLM quantization support matrix https://nvidia.github.io/TensorRT-LLM/features/quantization.html#quantization-in-tensorrt-llm and use TensorRT-LLM for deployment. Checkpoint export_path: $SAVE_PATH"
exit 0
Expand Down
16 changes: 13 additions & 3 deletions examples/llm_ptq/scripts/parser.sh
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ parse_options() {
# Default values
MODEL_PATH=""
QFORMAT=""
RECIPE=""
KV_CACHE_QUANT=""
TP=1
PP=1
Expand All @@ -37,13 +38,14 @@ parse_options() {
CAST_MXFP4_TO_NVFP4=false

# Parse command-line options
ARGS=$(getopt -o "" -l "model:,quant:,kv_cache_quant:,tp:,pp:,sparsity:,awq_block_size:,calib:,calib_batch_size:,auto_quantize_bits:,output:,batch:,tasks:,lm_eval_tasks:,lm_eval_limit:,simple_eval_tasks:,trust_remote_code,use_seq_device_map,gpu_max_mem_percentage:,kv_cache_free_gpu_memory_fraction:,low_memory_mode,no-verbose,calib_dataset:,calib_seq:,auto_quantize_method:,auto_quantize_score_size:,auto_quantize_checkpoint:,moe_calib_experts_ratio:,cast_mxfp4_to_nvfp4" -n "$0" -- "$@")
ARGS=$(getopt -o "" -l "model:,quant:,recipe:,kv_cache_quant:,tp:,pp:,sparsity:,awq_block_size:,calib:,calib_batch_size:,auto_quantize_bits:,output:,batch:,tasks:,lm_eval_tasks:,lm_eval_limit:,simple_eval_tasks:,trust_remote_code,use_seq_device_map,gpu_max_mem_percentage:,kv_cache_free_gpu_memory_fraction:,low_memory_mode,no-verbose,calib_dataset:,calib_seq:,auto_quantize_method:,auto_quantize_score_size:,auto_quantize_checkpoint:,moe_calib_experts_ratio:,cast_mxfp4_to_nvfp4" -n "$0" -- "$@")

eval set -- "$ARGS"
while true; do
case "$1" in
--model ) MODEL_PATH="$2"; shift 2;;
--quant ) QFORMAT="$2"; shift 2;;
--recipe ) RECIPE="$2"; shift 2;;
--kv_cache_quant ) KV_CACHE_QUANT="$2"; shift 2;;
--tp ) TP="$2"; shift 2;;
--pp ) PP="$2"; shift 2;;
Expand Down Expand Up @@ -99,12 +101,19 @@ parse_options() {
fi

# Verify required options are provided
if [ -z "$MODEL_PATH" ] || [ -z "$QFORMAT" ] || [ -z "$TASKS" ]; then
echo "Usage: $0 --model=<MODEL_PATH> --quant=<QFORMAT> --tasks=<TASK,...>"
if [ -z "$MODEL_PATH" ] || [ -z "$TASKS" ] || ([ -z "$QFORMAT" ] && [ -z "$RECIPE" ]); then
echo "Usage: $0 --model=<MODEL_PATH> (--quant=<QFORMAT> | --recipe=<RECIPE>) --tasks=<TASK,...>"
echo "Optional args: --sparsity=<SPARSITY_FMT> --awq_block_size=<AWQ_BLOCK_SIZE> --calib=<CALIB_SIZE>"
exit 1
fi

# --quant and --recipe are mutually exclusive: --recipe is a full PTQ spec, while
# --quant selects a built-in qformat preset. Pick exactly one.
if [ -n "$QFORMAT" ] && [ -n "$RECIPE" ]; then
echo "Cannot specify both --quant and --recipe; pick one." >&2
exit 1
fi

VALID_TASKS=("quant" "mmlu" "lm_eval" "livecodebench" "simple_eval")

for task in $(echo "$TASKS" | tr ',' ' '); do
Expand Down Expand Up @@ -135,6 +144,7 @@ parse_options() {
echo "================="
echo "model: $MODEL_PATH"
echo "quant: $QFORMAT"
echo "recipe: $RECIPE"
echo "tp (TensorRT-LLM Checkpoint only): $TP"
echo "pp (TensorRT-LLM Checkpoint only): $PP"
echo "sparsity: $SPARSITY_FMT"
Expand Down
1 change: 1 addition & 0 deletions modelopt/torch/kernels/quantization/gemm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
# fp4_kernel works on any CUDA GPU with triton
from .fp4_kernel import *
from .fp8_kernel import *
from .nvfp4_fp8_sweep import *

# fp4_kernel_hopper requires compute >= 8.9 (uses tl.float8e4nv)
if torch.cuda.get_device_capability() >= (8, 9):
Expand Down
166 changes: 166 additions & 0 deletions modelopt/torch/kernels/quantization/gemm/nvfp4_fp8_sweep.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Fused Triton kernel for the NVFP4 weight-MSE FP8 scale sweep.

Replaces the 126-iteration Python sweep in :class:`NVFP4MSECalibrator` with a single
kernel that, for each NVFP4 block, evaluates all 126 valid FP8 E4M3 scale candidates
and emits the per-block ``best_amax`` directly.

The 126 candidates are constructed as ``valid_fp8_e4m3_value / 448`` (see
:func:`fp8_scale_candidates`). For these specific candidates, the FP8 round-trip on
the per-block scale is the identity, so the kernel can use
``scale = candidate * global_amax / 6.0`` without an explicit FP8 cast — making it
runnable on any CUDA GPU with Triton (no ``tl.float8e4nv`` requirement).

Tile shape (``BLOCKS_PER_PROGRAM``) and ``num_warps`` are autotuned per ``N_BLOCKS``.
"""

import torch
import triton
import triton.language as tl

from .nvfp4_quant import fp4_round_magnitude

__all__ = ["fp8_scale_candidates", "nvfp4_fp8_scale_sweep"]


def fp8_scale_candidates(device: torch.device | str = "cpu") -> torch.Tensor:
"""Return the 126 valid finite positive FP8 E4M3 scale candidates / 448."""
uint8_values = torch.arange(0, 128, dtype=torch.uint8, device=device)
fp8_values = uint8_values.view(torch.float8_e4m3fn).float()
valid_mask = torch.isfinite(fp8_values) & (fp8_values > 0)
return fp8_values[valid_mask] / 448.0


# Selected from a (BLOCKS_PER_PROGRAM, num_warps) sweep on B300:
# BPP=16,nw=2: 6.06 ms BPP=32,nw=4: 6.06 ms BPP=64,nw=8: 5.08 ms
# The smaller-tile entries cover cases where N_BLOCKS is small enough that BPP=64
# would underfill the SMs.
_FP8_SWEEP_AUTOTUNE_CONFIGS = [
triton.Config({"BLOCKS_PER_PROGRAM": 16}, num_warps=2),
triton.Config({"BLOCKS_PER_PROGRAM": 32}, num_warps=4),
triton.Config({"BLOCKS_PER_PROGRAM": 64}, num_warps=8),
]


@triton.autotune(configs=_FP8_SWEEP_AUTOTUNE_CONFIGS, key=["N_BLOCKS"])
@triton.jit
def _fp8_scale_sweep_kernel(
x_ptr, # [N_BLOCKS * BLOCK_SIZE], any float dtype (loaded as fp32)
candidates_ptr, # [NUM_CANDIDATES] fp32
global_amax_ptr, # scalar fp32
best_amax_ptr, # [N_BLOCKS] fp32 output
N_BLOCKS,
BLOCK_SIZE: tl.constexpr,
NUM_CANDIDATES: tl.constexpr,
BLOCKS_PER_PROGRAM: tl.constexpr,
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCKS_PER_PROGRAM
block_idx = block_start + tl.arange(0, BLOCKS_PER_PROGRAM)
block_mask = block_idx < N_BLOCKS

# Load weights for this tile and pre-compute their absolute values once.
# The squared error is sign-invariant since FP4 quant preserves sign:
# (w - w_q)^2 = (|w| - |w_q|)^2 = (|w| - q_mag * scale)^2
# so we never need ``w`` itself again, dropping a tl.where + negation per element.
elem_offs = block_idx[:, None] * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)[None, :]
elem_mask = block_mask[:, None]
w_abs = tl.abs(tl.load(x_ptr + elem_offs, mask=elem_mask, other=0.0).to(tl.float32))

global_amax = tl.load(global_amax_ptr).to(tl.float32)

best_loss = tl.full([BLOCKS_PER_PROGRAM], float("inf"), dtype=tl.float32)
best_idx = tl.zeros([BLOCKS_PER_PROGRAM], dtype=tl.int32)

# Loop over the 126 FP8 candidates (compile-time unrolled).
# Scales are guaranteed positive and finite (constructed from a positive candidate
# times nonneg global_amax), so the degenerate-scale guard from nvfp4_scalar_quant is
# unnecessary apart from the global_amax == 0 case handled below.
for k in tl.static_range(NUM_CANDIDATES):
c = tl.load(candidates_ptr + k).to(tl.float32)
scale = c * global_amax / 6.0
# Avoid divide-by-zero when global_amax == 0; the resulting err == w_abs² is
# the same for every candidate, so any best_idx is fine.
scale_safe = tl.where(scale == 0.0, 1.0, scale)
q_mag = fp4_round_magnitude(w_abs / scale_safe)
diff = w_abs - q_mag * scale
loss = tl.sum(diff * diff, axis=1) # [BLOCKS_PER_PROGRAM]
is_better = loss < best_loss
best_loss = tl.where(is_better, loss, best_loss)
best_idx = tl.where(is_better, k, best_idx)

# Map each block's winning candidate index back to its amax = global_amax * c[best].
best_c = tl.load(candidates_ptr + best_idx, mask=block_mask, other=0.0).to(tl.float32)
best_amax = global_amax * best_c
tl.store(best_amax_ptr + block_idx, best_amax, mask=block_mask)


def nvfp4_fp8_scale_sweep(
x: torch.Tensor,
global_amax: torch.Tensor,
block_size: int = 16,
candidates: torch.Tensor | None = None,
) -> torch.Tensor:
"""Find the per-block FP8 scale that minimizes NVFP4 quantization MSE.

Equivalent to the 126-step sweep in :class:`NVFP4MSECalibrator`, but fused into
a single Triton kernel: every block's weight elements are loaded once, all 126
candidates are evaluated in registers, and the running argmin is kept inline.

Args:
x: Weight tensor on CUDA. Total element count must be divisible by
``block_size``; layout is treated as a flat ``[N_BLOCKS, BLOCK_SIZE]``.
global_amax: Scalar FP32 global amax (``= reduce_amax(per_block_amax)``).
block_size: NVFP4 block size (typically 16).
candidates: Optional precomputed candidate tensor of shape ``[126]`` (must
be the FP8 E4M3 valid values divided by 448). Built lazily if omitted.

Returns:
``best_amax`` of shape ``[N_BLOCKS]``, fp32, on the same device as ``x``.
"""
if not x.is_cuda:
raise ValueError("nvfp4_fp8_scale_sweep requires a CUDA tensor.")
if not isinstance(block_size, int) or block_size <= 0:
raise ValueError(f"block_size must be a positive int, got {block_size!r}.")
if x.numel() % block_size != 0:
raise ValueError(f"x.numel() ({x.numel()}) is not divisible by block_size ({block_size}).")

if candidates is None:
candidates = fp8_scale_candidates(x.device)
candidates = candidates.contiguous().to(device=x.device, dtype=torch.float32)
if candidates.ndim != 1 or candidates.numel() == 0:
raise ValueError(
f"candidates must be a non-empty 1-D tensor; got shape {tuple(candidates.shape)}."
)

n_blocks = x.numel() // block_size
x_flat = x.contiguous().view(-1)
global_amax_f32 = global_amax.detach().to(device=x.device, dtype=torch.float32).reshape(1)
best_amax = torch.empty(n_blocks, dtype=torch.float32, device=x.device)

grid = lambda meta: (triton.cdiv(n_blocks, meta["BLOCKS_PER_PROGRAM"]),)
with torch.cuda.device(x.device):
_fp8_scale_sweep_kernel[grid](
x_flat,
candidates,
global_amax_f32,
best_amax,
n_blocks,
BLOCK_SIZE=block_size,
NUM_CANDIDATES=int(candidates.numel()),
)
return best_amax
Loading
Loading