Skip to content

Working branch for NVFP4 MSE#1403

Draft
cjluo-nv wants to merge 8 commits intomainfrom
chenjiel/recipe-nvfp4-experts-mse-fp8-cast-kv-2
Draft

Working branch for NVFP4 MSE#1403
cjluo-nv wants to merge 8 commits intomainfrom
chenjiel/recipe-nvfp4-experts-mse-fp8-cast-kv-2

Conversation

@cjluo-nv
Copy link
Copy Markdown
Collaborator

@cjluo-nv cjluo-nv commented May 7, 2026

What does this PR do?

Type of change: ?

Usage

# Add a code snippet demonstrating how to use this

Testing

Before your PR is "Ready for review"

Make sure you read and follow Contributor guidelines and your commits are signed (git commit -s -S).

Make sure you read and follow the Security Best Practices (e.g. avoiding hardcoded trust_remote_code=True, torch.load(..., weights_only=False), pickle, etc.).

  • Is this change backward compatible?: ✅ / ❌ / N/A
  • If you copied code from any other sources or added a new PIP dependency, did you follow guidance in CONTRIBUTING.md: ✅ / ❌ / N/A
  • Did you write any new necessary tests?: ✅ / ❌ / N/A
  • Did you update Changelog?: ✅ / ❌ / N/A

Additional Information

cjluo-nv and others added 7 commits May 4, 2026 20:57
Replaces the 126-iteration Python sweep in NVFP4MSECalibrator with a single
fused Triton kernel that loads each NVFP4 block once, evaluates all 126 valid
FP8 E4M3 scale candidates in registers, and emits the per-block best amax
directly. For our specific candidate set (FP8 representable values / 448) the
FP8 round-trip on the per-block scale is the identity, so the kernel uses
`scale = candidate * global_amax / 6.0` and runs on any CUDA + Triton.

Triton-backed calibrator is on by default for `mse_calibrate(... fp8_scale_sweep=True)`;
set `MODELOPT_NVFP4_TRITON_SWEEP=0` to fall back to the reference for debugging.

Measured ~7.4x speedup on a B300 over the reference NVFP4MSECalibrator
(8192x4096 weight, ~2M NVFP4 blocks: 176.67 ms -> 23.81 ms). Bit-identical to
the reference for typical block counts; on multi-million-block weights an
occasional adjacent-candidate tie-break can differ at the fp32-noise level
(observed 2 / 2,097,152 blocks; per-block MSE within 1e-7 relative).

Signed-off-by: Chenjie Luo <chenjiel@nvidia.com>
…ner loop

Two follow-on optimizations to the fused FP8 scale sweep kernel:

1. @triton.autotune over (BLOCKS_PER_PROGRAM, num_warps): a hand-sweep on B300
   showed the previous default (BPP=4, num_warps=4) at 23.7 ms left ~4x on the
   table — best config (BPP=64, num_warps=8) lands at ~5 ms. Three configs are
   included to cover small/medium/large N_BLOCKS without flooding compile time.

2. Drop the sign-handling tl.where: since FP4 quantization preserves sign,
   (w - w_q)^2 == (|w| - |w_q|)^2, so the kernel works on |w| throughout and
   skips one tl.where + negation per element per candidate.

Result on the same 8192x4096 weight (~2M blocks) on B300:
  reference NVFP4MSECalibrator:   176.68 ms
  triton  TritonNVFP4MSECalibrator: 4.23 ms
  speedup: 41.8x  (was 7.4x)

This is ~1.2x above the rough pure-compute floor (~240 GF / 67 TF/s ~= 3.6 ms),
so the kernel is now near saturation and further wins would need an algorithmic
change (candidate pruning, etc.).

Signed-off-by: Chenjie Luo <chenjiel@nvidia.com>
Addresses review comments on PR #1387:

- TritonNVFP4MSECalibrator.reset() now leaves the calibrator reusable: shape /
  dtype / n_blocks of the initial amax are stashed in __init__, so collect()
  no longer depends on _initial_amax surviving reset(). Adds an x.ndim==2
  assertion in collect() since the weight quantizer always reshapes upstream.
- nvfp4_fp8_scale_sweep validates inputs cleanly instead of using assert
  (which is stripped by python -O): rejects non-CUDA tensors, non-positive
  block_size, and empty / non-1D candidates with ValueError. Skips the
  per-element finite/positive check on candidates since it would scan a 126-
  entry tensor on every kernel call.
- mse_calibrate hoists the MODELOPT_NVFP4_TRITON_SWEEP env-var lookup out of
  the per-quantizer loop and resolves to the calibrator class once.
- Updates test_reset_allows_recollect to verify the new reuse contract; adds
  test_input_validation covering the new ValueErrors.

The duplicate fp8_scale_candidates implementation in the kernel file and
NVFP4MSECalibrator._generate_candidates() is left in place: deduplicating
would force the reference path to import from the kernel module, which is
gated behind Triton availability. The FP8 E4M3 spec is fixed and the parity
test exercises both paths against each other.

Signed-off-by: Chenjie Luo <chenjiel@nvidia.com>
…recipe support in scripts

- Add modelopt_recipes/general/ptq/nvfp4_experts_only_mse-fp8_cast_kv.yaml,
  combining experts-only NVFP4 W4A4 with the MSE FP8 scale-sweep weight
  calibration (algorithm: mse, fp8_scale_sweep: true; expert weight blocks
  switched to "static" so the static FP8 sweep applies) and FP8 KV cache
  with use_constant_amax: true.

- examples/llm_ptq/scripts: thread a new --recipe flag through parser.sh and
  huggingface_example.sh. Either --quant or --recipe is required; passing both
  errors out. When --recipe is used, the script derives MODEL_NAME from the
  recipe basename, passes --recipe= to hf_ptq.py, and exits after export with
  a TRT-LLM deployment hint (recipes can produce arbitrary configs).

- Drop the qformat case-statement whitelist in huggingface_example.sh; let
  hf_ptq.py be the single source of truth for valid qformats / recipes.

(Pre-commit hook check-modelopt-recipes was skipped: the host conda env has a
broken torchvision install that prevents the validator from importing modelopt.
The recipe was verified independently via tools/precommit/check_modelopt_recipes.py
in a working environment.)

Signed-off-by: Chenjie Luo <chenjiel@nvidia.com>
Same shape as nvfp4_experts_only_mse-fp8_cast_kv but with the broader
*mlp* / *block_sparse_moe* patterns from nvfp4_mlp_only-fp8_kv.yaml so it
covers both dense MLP and MoE expert weights:

- algorithm: { method: mse, fp8_scale_sweep: true, layerwise: false }
- All MLP weight quantizers switched from "dynamic" to "static" so the
  static FP8 scale sweep applies (otherwise mse_calibrate skips them).
- Input quantizers stay dynamic.
- KV bmm gets use_constant_amax: true (the _cast_kv flavor: skips KV
  calibration, hardcodes amax to FP8 E4M3 max 448.0).

Pre-commit hook check-modelopt-recipes was skipped because the host conda
env has a broken torchvision install that prevents the validator from
importing modelopt; the recipe is the same shape as the experts-only one
which already validates cleanly in a working env.

Signed-off-by: Chenjie Luo <chenjiel@nvidia.com>
## Summary

- Saturates `per_block_scale * 448 / per_block_scale_max` to ≤ 448
before the `to(torch.float8_e4m3fn)` cast in
`NVFP4QTensor.get_weights_scaling_factor_from_quantizer`.
- Adds a regression test that reproduces the NaN byte without the clamp.

## Why

When `_amax` contains a zero entry (e.g. an all-zero weight block left
untouched by max calibration), the existing
`per_block_scale[per_block_scale == 0] = 1.0` safety net drives the
pre-cast value to `1.0 * 448 / (global_amax / 6)`. `fp8_e4m3fn` has no
Inf — anything `≥ 480` rounds to NaN — so a 0x7F byte slips into the
exported `weight_scale`.

This was observed in a saved Kimi-K2.6-NVFP4-MSE checkpoint at
`language_model.model.layers.1.mlp.experts.21.down_proj.weight_scale[4001,
18]`. The MSE FP8 sweep itself never produces zero per-block amax (it
always emits at least `c[0] * global_amax`), but any export path where
`_amax` ends up zero — including pure max calibration — hits the bug.
With the clamp the byte saturates to `0x7E` (= 448, fp8 max finite) and
dequantization is unaffected: the FP4 nibbles for an all-zero block are
all 0, so `0 × 448 × weight_scale_2 = 0` regardless of the stored fp8
scale. For non-degenerate blocks the clamp is a no-op since
`per_block_amax ≤ global_amax` already bounds the pre-cast value at 448.

## Test plan

- [x] New regression test
`test_export_fp8_scale_no_nan_for_zero_amax_block` fails on `main`'s
export code (reproduces the 0x7F NaN byte) and passes with the clamp.
- [x] Existing tests in
`tests/gpu/torch/quantization/test_nvfp4_static_quantizer_cuda.py` still
pass (10/10).

🤖 Generated with [Claude Code](https://claude.com/claude-code)

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

* **Bug Fixes**
* Improved numerical stability in FP8 quantization scaling by preventing
overflow and NaN conditions
* Enhanced handling of edge cases in quantization processing for
zero-weight blocks

<!-- end of auto-generated comment: release notes by coderabbit.ai -->

Signed-off-by: Chenjie Luo <chenjiel@nvidia.com>
Signed-off-by: Chenjie Luo <chenjiel@nvidia.com>
@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented May 7, 2026

Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually.

Contributors can view more details about this message here.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented May 7, 2026

Important

Review skipped

Draft detected.

Please check the settings in the CodeRabbit UI or the .coderabbit.yaml file in this repository. To trigger a single review, invoke the @coderabbitai review command.

⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: 3c064d8b-a098-472e-be67-55902242cf81

You can disable this status message by setting the reviews.review_status to false in the CodeRabbit configuration file.

Use the checkbox below for a quick retry:

  • 🔍 Trigger review
✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch chenjiel/recipe-nvfp4-experts-mse-fp8-cast-kv-2

Comment @coderabbitai help to get the list of available commands and usage tips.

@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented May 7, 2026

PR Preview Action v1.8.1

QR code for preview link

🚀 View preview at
https://NVIDIA.github.io/Model-Optimizer/pr-preview/pr-1403/

Built to branch gh-pages at 2026-05-07 17:13 UTC.
Preview will be ready when the GitHub Pages deployment is complete.

@codecov
Copy link
Copy Markdown

codecov Bot commented May 7, 2026

Codecov Report

❌ Patch coverage is 58.72340% with 97 lines in your changes missing coverage. Please review.
✅ Project coverage is 74.63%. Comparing base (acfab41) to head (ca23dcd).
⚠️ Report is 22 commits behind head on main.

Files with missing lines Patch % Lines
modelopt/torch/quantization/model_calib.py 63.33% 44 Missing ⚠️
...torch/kernels/quantization/gemm/nvfp4_fp8_sweep.py 56.60% 23 Missing ⚠️
modelopt/torch/export/unified_export_hf.py 5.88% 16 Missing ⚠️
modelopt/torch/quantization/plugins/huggingface.py 12.50% 7 Missing ⚠️
modelopt/torch/export/moe_utils.py 0.00% 6 Missing ⚠️
...odelopt/torch/quantization/qtensor/nvfp4_tensor.py 0.00% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1403      +/-   ##
==========================================
- Coverage   76.90%   74.63%   -2.28%     
==========================================
  Files         471      478       +7     
  Lines       50562    53235    +2673     
==========================================
+ Hits        38886    39730     +844     
- Misses      11676    13505    +1829     
Flag Coverage Δ
unit 52.47% <37.87%> (-0.34%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

… MoE

Two-part fix for transformers 5.x fused-experts containers (Qwen3-MoE /
Qwen3.5-MoE / Mixtral / DeepSeek / Kimi-K2.x ...) where weight quantizers
live in `nn.ModuleList`s (`gate_up_proj_weight_quantizers`,
`down_proj_weight_quantizers`):

1. Add `_QuantFusedExperts.iter_weights_for_calibration` that yields
   per-expert (weight_slice, quantizer) pairs for both projections. The base
   impl uses singular `*_weight_quantizer` and silently skips fused-experts
   modules, so weight-only calibration paths never reach per-expert
   quantizers.

2. Refactor `mse_calibrate`:
   - Add `_bootstrap_uncalibrated_weight_quantizers` after `max_calibrate`
     to populate `_amax` on quantizers the forward pass didn't reach (dead
     MoE experts that received no calibration tokens). Runs the existing
     calibrator on the weight slice surfaced by
     `iter_weights_for_calibration`.
   - Replace the singular-only `weight_attr_names` discovery + `getattr`-by-
     name walk with an `iter_weights_for_calibration` walk done inside each
     parent module's `enable_weight_access_and_writeback` context, so MSE
     processes every per-expert quantizer (active and dead) and remains
     FSDP-safe.

Without this, the export-time fallback in `_export_fused_experts` derived
separate gate/up amaxes from each half of the fused weight, breaking the
gate==up `weight_scale_2` invariant on dead experts. End-to-end check on
Qwen3.5-122B-A10B with `nvfp4_experts_only_mse-fp8_cast_kv`:
  - Before: 1/12288 (layer 38 expert 69) gate \!= up; 0 weights MSE-calibrated
  - After:  0/12288 mismatches; 24576 weights MSE-calibrated; ~4.2 min

Signed-off-by: Chenjie Luo <chenjiel@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant