Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion examples/llm_ptq/hf_ptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -797,6 +797,11 @@ def pre_quantize(
preview_input_ids = next(iter(calib_dataloader))[
"input_features" if model_type == "whisper" else "input_ids"
][0:1]
# Strip leading padding tokens so the preview input shows real content
if model_type != "whisper" and tokenizer is not None and tokenizer.pad_token_id is not None:
first_non_pad = (preview_input_ids[0] != tokenizer.pad_token_id).nonzero(as_tuple=True)[0]
if first_non_pad.numel() > 0:
preview_input_ids = preview_input_ids[:, first_non_pad[0] :]

# Generate preview before quantization
if args.skip_generate:
Expand Down Expand Up @@ -897,7 +902,7 @@ def input_decode(input_ids):
if processor is not None and isinstance(processor, WhisperProcessor):
return first_text_speech_dataset
elif tokenizer is not None:
return tokenizer.batch_decode(input_ids)
return tokenizer.batch_decode(input_ids, skip_special_tokens=True)
else:
raise ValueError("The processor or tokenizer must be set")

Expand Down
85 changes: 78 additions & 7 deletions modelopt/torch/export/moe_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,12 @@ def _export_fused_experts(module: nn.Module, dtype: torch.dtype) -> None:
)
i_quantizer = gate_up_input_q if is_gate_up else down_input_q

# gate/up share a weight quantizer — clone so each gets independent amax.
w_quantizer = copy.deepcopy(w_quantizer_src) if is_gate_up else w_quantizer_src
# gate/up share a quantizer — deepcopy so gate_proj and up_proj get
# independent quantizers that can hold different amax slices.
if is_gate_up:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: is_gate_or_up_proj ?

w_quantizer = copy.deepcopy(w_quantizer_src)
else:
w_quantizer = w_quantizer_src

# For per-channel amax (dim >= 1), proportionally slice dim-0
# to match the split weight.
Expand All @@ -91,7 +95,7 @@ def _export_fused_experts(module: nn.Module, dtype: torch.dtype) -> None:
if fused_total % amax_dim0 == 0:
slice_start = fused_start * amax_dim0 // fused_total
slice_end = (fused_start + weight_slice.shape[0]) * amax_dim0 // fused_total
w_quantizer.amax = amax[slice_start:slice_end].contiguous()
w_quantizer._amax = amax[slice_start:slice_end].contiguous()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why are you checking _amax now instead of amax?

else:
warnings.warn(
Comment thread
Fridah-nv marked this conversation as resolved.
f"Expert {idx} {proj_name}: fused amax dim0 ({amax_dim0}) does not "
Expand All @@ -100,20 +104,73 @@ def _export_fused_experts(module: nn.Module, dtype: torch.dtype) -> None:
stacklevel=2,
)

# If the weight quantizer was never calibrated, compute amax from weights.
# Patch invalid per-block amax entries (NaN/inf/negative/zero/too-small/too-large)
Comment thread
Fridah-nv marked this conversation as resolved.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should throw an error if there are experts with uncalibrated amax and suggest rerunning PTQ with more calibration samples/seq length. This is what we do in the MCore PTQ path -- because null amax in MCore causes a deadlock in distributed sync. For HF PTQ to have parity with MCore PTQ we should do the same (even though there is no dist sync in HF PTQ)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Patching is risky as the patched amax could break the checkpoint. for non-null invalid amax, a warning/error should also be thrown

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this be in the MSE calibrator?

# with weight-derived fallback values.
min_valid_amax = 2e-3 # floor matches FP8 E4M3FN minimum subnormal (2^-9 ≈ 0.00195)
max_valid_amax = 1e6
if (
hasattr(w_quantizer, "_amax")
and w_quantizer._amax is not None
and w_quantizer._amax.numel() > 1
and (getattr(w_quantizer, "block_sizes", None) or {}).get(-1) is not None
):
amax_cpu = w_quantizer._amax
invalid_mask = ~(
torch.isfinite(amax_cpu)
& (amax_cpu >= min_valid_amax)
& (amax_cpu <= max_valid_amax)
)
if invalid_mask.any():
_block_size = (getattr(w_quantizer, "block_sizes", None) or {}).get(-1, 16)
per_block_fallback = (
weight_slice.detach()
Comment thread
Fridah-nv marked this conversation as resolved.
.reshape(-1, _block_size)
.abs()
.amax(dim=1, keepdim=True)
.cpu()
.float()
.clamp(min=2e-3)
.reshape(amax_cpu.shape)
)
amax_cpu[invalid_mask] = per_block_fallback[invalid_mask]
w_quantizer._amax = amax_cpu
Comment thread
coderabbitai[bot] marked this conversation as resolved.

# For uncalibrated experts (amax missing or invalid scalar), fall back to
# per-block amax from weights so the static export path can reshape it correctly.
# Only applies to per-block (NVFP4) quantizers — non-block quantizers have
# no block_sizes and should not be routed to the static NVFP4 export path.
if (
hasattr(w_quantizer, "is_enabled")
and w_quantizer.is_enabled
and (getattr(w_quantizer, "block_sizes", None) or {}).get(-1) is not None
and (
not hasattr(w_quantizer, "_amax")
or w_quantizer._amax is None
or torch.all(w_quantizer._amax == 0)
or (
w_quantizer._amax.numel() == 1
and not (
torch.isfinite(w_quantizer._amax)
and w_quantizer._amax >= min_valid_amax
and w_quantizer._amax <= max_valid_amax
)
)
)
):
w_quantizer.amax = weight_slice.abs().amax().to(torch.float32)
_block_size = (getattr(w_quantizer, "block_sizes", None) or {}).get(-1, 16)
fallback_per_block = (
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the code in this file is too hard to read, too many hardcoded numbers everywhere

weight_slice.detach()
.reshape(-1, _block_size)
.abs()
.amax(dim=1, keepdim=True)
.cpu()
.float()
.clamp(min=2e-3)
.reshape(*weight_slice.shape[:-1], weight_slice.shape[-1] // _block_size)
)
w_quantizer._amax = fallback_per_block
warnings.warn(
f"Expert {idx} {proj_name} weight quantizer was not calibrated "
f"(amax missing or zero). Using weight-derived amax as fallback. "
f"(amax missing or zero). Using weight-derived per-block amax as fallback. "
f"Consider using more calibration data to activate all experts.",
stacklevel=2,
)
Expand All @@ -123,6 +180,20 @@ def _export_fused_experts(module: nn.Module, dtype: torch.dtype) -> None:
wrapper.weight_quantizer = w_quantizer
wrapper.input_quantizer = i_quantizer

# Set global_amax to route to the static NVFP4 export path (reads per-block _amax).
# Always recompute from the current (possibly patched) _amax — a stale zero
# global_amax causes division-by-zero in the per-block scale formula.
# Guard: only per-block (NVFP4) quantizers have block_sizes; skip for others.
wq = wrapper.weight_quantizer
if (
hasattr(wq, "_amax")
and wq._amax is not None
and wq._amax.numel() > 1
and (getattr(wq, "block_sizes", None) or {}).get(-1) is not None
):
wq._amax = wq._amax.to(weight_slice.device)
wq.global_amax = wq._amax.float().amax().clamp(min=2e-3)

_export_quantized_weight(wrapper, dtype)

proj = nn.Module()
Expand Down
67 changes: 62 additions & 5 deletions modelopt/torch/quantization/model_calib.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import warnings
from collections.abc import Callable
from functools import partial
from typing import TypeAlias
from typing import Any, TypeAlias

import torch
import torch.distributed as dist
Expand Down Expand Up @@ -351,7 +351,7 @@ def mse_calibrate(

# Step 2: Replace calibrators with MseCalibrator for enabled quantizers
# and identify weight quantizers
weight_quantizers = []
weight_quantizers: list[tuple[Any, Any, TensorQuantizer]] = []
seen_modules = set()

for name, module in list(model.named_modules()):
Expand Down Expand Up @@ -410,7 +410,12 @@ def mse_calibrate(
quant_func=partial(_mse_quant_func, quantizer=module),
)

# Identify weight quantizers by checking if they have corresponding weight parameters
# Collect weight quantizers (standard + fused-experts per-expert lists).
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is a fused-experts per-expert lists ? that seems contradictory, how can it be fused and per-expert at the same time?

try:
from modelopt.torch.quantization.plugins.huggingface import _QuantFusedExperts as _qfe_cls
except ImportError:
_qfe_cls = None # type: ignore[misc]

name_to_module = dict(model.named_modules())
for parent_module in name_to_module.values():
if parent_module in seen_modules:
Expand All @@ -421,8 +426,56 @@ def mse_calibrate(
if isinstance(weight_quantizer, TensorQuantizer) and weight_quantizer.is_enabled:
if getattr(weight_quantizer, "_calibrator", None) is not None:
weight_quantizers.append((parent_module, weight_name, weight_quantizer))
# Enqueue per-expert quantizers from {param}_weight_quantizers ModuleLists.
if _qfe_cls is not None and isinstance(parent_module, _qfe_cls):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move this to a helper function

for param_name, param in parent_module.named_parameters(recurse=False):
qlist = getattr(parent_module, f"{param_name}_weight_quantizers", None)
if not isinstance(qlist, nn.ModuleList):
continue
if len(qlist) != param.shape[0]:
warnings.warn(
f"Skipping {param_name}_weight_quantizers: list length {len(qlist)} "
f"does not match parameter leading dimension {param.shape[0]}. "
"This may indicate a misconfigured fused-experts module.",
stacklevel=2,
)
continue
for expert_idx, wq in enumerate(qlist):
if isinstance(wq, TensorQuantizer) and wq.is_enabled:
if getattr(wq, "_calibrator", None) is not None:
weight_quantizers.append((parent_module, (param_name, expert_idx), wq))

Comment on lines +429 to +447
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we have a helper method get_weight_quantizers(module) which can support both MoE and regular weight quantizers? This will help avoid the code branching here

seen_modules.add(parent_module)

# Warn about enabled weight quantizers that weren't scheduled for MSE calibration.
picked_ids = {id(wq) for _, _, wq in weight_quantizers}

def _is_active_unpicked(q: Any) -> bool:
return (
isinstance(q, TensorQuantizer)
and q.is_enabled
and getattr(q, "_calibrator", None) is not None
and id(q) not in picked_ids
)

missed: list[str] = []
for mod_name, module in name_to_module.items():
for attr_name, attr in module._modules.items():
if isinstance(attr, TensorQuantizer) and attr_name.endswith("weight_quantizer"):
if _is_active_unpicked(attr):
missed.append(f"{mod_name}.{attr_name}")
elif isinstance(attr, nn.ModuleList) and attr_name.endswith("_weight_quantizers"):
for i, wq in enumerate(attr):
if _is_active_unpicked(wq):
missed.append(f"{mod_name}.{attr_name}[{i}]")
if missed:
warnings.warn(
f"MSE weight calibration: {len(missed)} weight quantizer(s) are enabled but were "
f"not scheduled for calibration and will retain max-calibration amax values. "
f"First {min(5, len(missed))}: {missed[:5]}",
stacklevel=2,
)

# Step 3: Calibrate weight quantizers ONE AT A TIME with immediate amax computation
# This prevents massive memory accumulation seen in large models
for idx, (parent_module, weight_name, weight_quantizer) in enumerate(
Expand All @@ -432,7 +485,11 @@ def mse_calibrate(
weight_quantizer.disable_quant()
weight_quantizer.enable_calib()
with enable_weight_access_and_writeback(parent_module, model, name_to_module):
weight = getattr(parent_module, weight_name)
if isinstance(weight_name, tuple):
param_name, expert_idx = weight_name
weight = getattr(parent_module, param_name)[expert_idx]
else:
weight = getattr(parent_module, weight_name)
weight_quantizer(weight)

# IMMEDIATELY compute amax and reset calibrator to free memory
Expand Down Expand Up @@ -778,7 +835,7 @@ def finish_stats_collection(model: nn.Module, method: str | None = None, **kwarg

cal = getattr(module, "_calibrator", None)
if cal and not getattr(module, "_dynamic", False):
if method in {"entropy"}:
if method == "entropy":
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is this needed?

if cal.compute_amax(method) is not None:
module.load_calib_amax("entropy", **kwargs)
elif cal.compute_amax(**kwargs) is not None:
Expand Down
1 change: 1 addition & 0 deletions modelopt/torch/quantization/model_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -595,6 +595,7 @@ def print_quant_summary(model: nn.Module, output_dir: str | None = None):
lines.append(f"{len(lines)} TensorQuantizers found in model")

if output_dir:
os.makedirs(output_dir, exist_ok=True)
path = os.path.join(output_dir, ".quant_summary.txt")
with open(path, "w", encoding="utf-8") as f:
f.write("\n".join(lines) + "\n")
Expand Down
4 changes: 2 additions & 2 deletions modelopt/torch/quantization/nn/modules/tensor_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1112,7 +1112,7 @@ def forward(self, inputs):

return outputs

def _short_amax(self, fmt=".4f"):
def _short_amax(self, fmt=".2e"):
"""Short description of amax.

Returns:
Expand All @@ -1130,7 +1130,7 @@ def _short_amax(self, fmt=".4f"):
return "meta"
return self._short_tensor(self._amax, fmt)

def _short_tensor(self, tensor: torch.Tensor, fmt=".4f"):
def _short_tensor(self, tensor: torch.Tensor, fmt=".2e"):
"""Short description of tensor."""
if tensor.numel() == 1:
return f"{tensor.item():{fmt}}"
Expand Down
13 changes: 11 additions & 2 deletions modelopt/torch/quantization/qtensor/nvfp4_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,8 +124,11 @@ def get_weights_scaling_factor_from_quantizer(

# Quantize scales to FP8
if not keep_high_precision:
per_block_scale = (per_block_scale * 448.0 / per_block_scale_max).to(
torch.float8_e4m3fn
fp8_e4m3fn_min = 2**-9 # 0.001953125 — smallest positive subnormal
per_block_scale = (
(per_block_scale * 448.0 / per_block_scale_max)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
(per_block_scale * 448.0 / per_block_scale_max)
(per_block_scale.float() * 448.0 / per_block_scale_max)

Copy link
Copy Markdown
Contributor

@jenchen13 jenchen13 May 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we also need a max clamp in line 130 to 448. I saw some nan's in exported MSE no-sweep checkpoints due to overflow

e.g. during PTQ
weight_scale dtype=torch.float8_e4m3fn min=nan max=nan mean=nan
also need a unit test for overflow

.clamp(min=fp8_e4m3fn_min)
.to(torch.float8_e4m3fn)
)
return per_block_scale, weights_scaling_factor_2
else:
Expand Down Expand Up @@ -167,6 +170,12 @@ def get_weights_scaling_factor(
per_block_scale[per_block_scale == 0] = 1.0
# Convert to torch.float8_e4m3fn
if not keep_high_precision:
# Clamp to the minimum positive FP8 E4M3FN subnormal (~0.00195 = 2^-9) before
# casting. Without this, blocks whose scale falls below the FP8 representable
# range silently underflow to 0, causing those blocks to produce zero output at
# inference even when the weights are non-trivial.
fp8_e4m3fn_min = 2**-9 # 0.001953125 — smallest positive subnormal
per_block_scale = per_block_scale.clamp(min=fp8_e4m3fn_min)
per_block_scale = per_block_scale.to(torch.float8_e4m3fn)
Comment on lines +173 to 179
Copy link
Copy Markdown
Contributor

@realAsma realAsma May 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we create a helper method which does the FP8 quantization of per_tensor scale and use that here and here https://github.com/NVIDIA/Model-Optimizer/pull/1382/changes#r3191334011

return per_block_scale, weights_scaling_factor_2

Expand Down
Loading
Loading