Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion modelopt/torch/export/moe_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,11 +110,19 @@ def _export_fused_experts(module: nn.Module, dtype: torch.dtype) -> None:
and w_quantizer._amax.dim() >= 1
):
amax = w_quantizer._amax
# Per-block _amax (NVFP4 static) collapses the row axis we want
# to slice on; restore it so dim-0 slicing splits gate/up.
if amax.numel() != fused_total and amax.numel() % fused_total == 0:
amax = amax.contiguous().view(fused_total, amax.numel() // fused_total)
amax_dim0 = amax.shape[0]
if fused_total % amax_dim0 == 0:
slice_start = fused_start * amax_dim0 // fused_total
slice_end = (fused_start + weight_slice.shape[0]) * amax_dim0 // fused_total
w_quantizer.amax = amax[slice_start:slice_end].contiguous()
sliced = amax[slice_start:slice_end].contiguous()
# The amax setter refuses shape changes; drop _amax first.
if hasattr(w_quantizer, "_amax"):
delattr(w_quantizer, "_amax")
w_quantizer.amax = sliced
else:
warnings.warn(
f"Expert {idx} {proj_name}: fused amax dim0 ({amax_dim0}) does not "
Expand Down
15 changes: 15 additions & 0 deletions modelopt/torch/export/unified_export_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -1134,6 +1134,19 @@ def _unpatch_revert_weight_conversion(patches: list[tuple[Any, Any]]) -> None:
mod.revert_weight_conversion = original


def _sanitize_generation_config_for_save(model: torch.nn.Module) -> None:
"""Force ``do_sample=True`` when generation_config has ``top_k``/``top_p`` set.

Newer transformers reject ``do_sample=False`` mixed with sampling attrs in
``save_pretrained``'s strict validate.
"""
gc = getattr(model, "generation_config", None)
if gc is None:
return
if getattr(gc, "top_k", None) is not None or getattr(gc, "top_p", None) is not None:
gc.do_sample = True

Comment thread
coderabbitai[bot] marked this conversation as resolved.

def export_speculative_decoding(
model: torch.nn.Module,
dtype: torch.dtype | None = None,
Expand Down Expand Up @@ -1228,6 +1241,8 @@ def export_hf_checkpoint(
# modeling_utils does `from core_model_loading import revert_weight_conversion`.
_patches = _patch_revert_weight_conversion()

_sanitize_generation_config_for_save(model)

try:
model.save_pretrained(
export_dir,
Expand Down
205 changes: 142 additions & 63 deletions modelopt/torch/quantization/model_calib.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@
promote_nvfp4_static_quantizers,
quantizer_attr_names,
reduce_amax,
weight_attr_names,
)
from .utils.calib_utils import _GPTQ_HELPER_REGISTRY, GPTQHelper

Expand All @@ -64,8 +63,100 @@
"max_calibrate",
"smoothquant",
"svdquant",
"sync_grouped_weight_global_amax",
]


# Sibling groups that share an FP8 scale-of-scales: members feed the same input
# (Q/K/V) or get fused at deployment (gate/up), so divergent global_amax would
# split their FP8 grids.
_GROUPED_WEIGHT_QUANTIZER_PATTERNS: tuple[tuple[str, ...], ...] = (
("q_proj", "k_proj", "v_proj"),
("gate_proj", "up_proj"), # Llama/Qwen/Mistral
("w1", "w3"), # Mixtral
)


def _collect_grouped_linears(model: nn.Module) -> list[list[nn.Module]]:
"""Collect sibling groups (Q/K/V, gate/up) with calibrated NVFP4-static weight quantizers."""
groups: list[list[nn.Module]] = []
wq_attr = quantizer_attr_names("weight").weight_quantizer
for parent in model.modules():
for sibling_names in _GROUPED_WEIGHT_QUANTIZER_PATTERNS:
members = []
for n in sibling_names:
child = getattr(parent, n, None)
wq = getattr(child, wq_attr, None) if child is not None else None
if (
isinstance(wq, TensorQuantizer)
and not wq._disabled
and wq.is_nvfp4_static
and getattr(wq, "_amax", None) is not None
):
members.append(child)
if len(members) >= 2:
groups.append(members)
return groups


@torch.no_grad()
def _bootstrap_uncalibrated_weight_quantizers(model: nn.Module) -> int:
"""Populate ``_amax`` from weights for quantizers the forward pass didn't reach.

Dead MoE experts that received no tokens are otherwise skipped by
``mse_calibrate``, leaving export to derive separate per-half amax for
gate/up and break the gate==up ``weight_scale_2`` invariant.
"""
n = 0
for module in model.modules():
if not isinstance(module, QuantModule):
continue
for weight, q in module.iter_weights_for_calibration():
if not isinstance(q, TensorQuantizer) or q._disabled or q._dynamic:
continue
if q._calibrator is None:
continue
if getattr(q, "_amax", None) is not None and not torch.all(q._amax == 0):
continue
q.disable_quant()
q.enable_calib()
q(weight)
if q._calibrator.compute_amax() is not None:
q.load_calib_amax()
q.enable_quant()
q.disable_calib()
if hasattr(q._calibrator, "reset"):
q._calibrator.reset()
n += 1
return n
Comment on lines +102 to +131
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Run dead-expert bootstrap under enable_weight_access_and_writeback().

This helper reads weight slices and calibrates them before entering any weight-access context. On FSDP/HF-TP/offloaded modules that can either calibrate only the local shard or hit an access failure that gets swallowed by the blanket except, leaving the “dead” expert unbootstrapped. That recreates the missing-_amax path this PR is trying to eliminate in distributed/export flows.

Suggested adjustment
 `@torch.no_grad`()
 def _bootstrap_uncalibrated_weight_quantizers(model: nn.Module) -> int:
     """Run a max-style amax collection on weight quantizers whose ``_amax`` is missing."""
     n = 0
+    name_to_module = dict(model.named_modules())
     for module in model.modules():
         if not isinstance(module, QuantModule):
             continue
-        try:
-            pairs = list(module.iter_weights_for_calibration())
-        except Exception:
-            continue
-        for weight, q in pairs:
-            if not isinstance(q, TensorQuantizer) or q._disabled or q._dynamic:
-                continue
-            if q._calibrator is None:
-                continue
-            if hasattr(q, "_amax") and q._amax is not None and not torch.all(q._amax == 0):
-                continue
-            q.disable_quant()
-            q.enable_calib()
-            q(weight)
-            if q._calibrator.compute_amax() is not None:
-                q.load_calib_amax()
-            q.enable_quant()
-            q.disable_calib()
-            if hasattr(q._calibrator, "reset"):
-                q._calibrator.reset()
-            n += 1
+        with enable_weight_access_and_writeback(module, model, name_to_module):
+            for weight, q in module.iter_weights_for_calibration():
+                if not isinstance(q, TensorQuantizer) or q._disabled or q._dynamic:
+                    continue
+                if q._calibrator is None:
+                    continue
+                if hasattr(q, "_amax") and q._amax is not None and not torch.all(q._amax == 0):
+                    continue
+                q.disable_quant()
+                q.enable_calib()
+                q(weight)
+                if q._calibrator.compute_amax() is not None:
+                    q.load_calib_amax()
+                q.enable_quant()
+                q.disable_calib()
+                if hasattr(q._calibrator, "reset"):
+                    q._calibrator.reset()
+                n += 1
     return n
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@modelopt/torch/quantization/model_calib.py` around lines 121 - 162, The
bootstrap loop in _bootstrap_uncalibrated_weight_quantizers must run weight
reads inside enable_weight_access_and_writeback() so FSDP/HF-TP/offload sharded
modules perform proper local access/writeback instead of triggering an access
failure swallowed by the blanket except; wrap the per-module calibration work
(the call to module.iter_weights_for_calibration() and the q(weight) calibration
call inside the loop) with with enable_weight_access_and_writeback(module):,
remove or narrow the broad try/except that currently hides access errors so
genuine access failures surface, and keep the rest of the logic
(q.disable_quant(), q.enable_calib(), q(weight), q.load_calib_amax(),
q.enable_quant(), q.disable_calib(), q._calibrator.reset()) unchanged.



@torch.no_grad()
def sync_grouped_weight_global_amax(model: nn.Module) -> int:
"""Unify NVFP4 ``global_amax`` across Q/K/V and gate/up sibling weight quantizers.

Reuses ``preprocess_linear_fusion`` (which performs the same unification at
export time) to keep the FP8 scale-of-scales consistent across siblings
during MSE / local-Hessian search. Must run after ``max_calibrate``.
"""
# Inline: quant_utils imports enable_stats_collection/finish_stats_collection/svd
# from this module, so top-level would deadlock the cycle.
from modelopt.torch.export.quant_utils import preprocess_linear_fusion

wq_attr = quantizer_attr_names("weight").weight_quantizer
n_groups = 0
for group in _collect_grouped_linears(model):
for child in group:
wq = getattr(child, wq_attr)
if not isinstance(wq, NVFP4StaticQuantizer):
NVFP4StaticQuantizer.from_tensor_quantizer(
wq, global_amax=reduce_amax(wq._amax, axis=None)
)
preprocess_linear_fusion(group)
n_groups += 1
return n_groups


CalibratorFactory: TypeAlias = Callable[
[torch.Tensor, int | tuple | list | None, Callable[..., torch.Tensor]], _Calibrator
]
Expand Down Expand Up @@ -346,32 +437,25 @@ def mse_calibrate(
See :class:`MseCalibConfig <modelopt.torch.quantization.config.MseCalibConfig>` for
details on the remaining arguments.
"""
# Step 1: First get initial amax using max calibration
# Step 1: max calibration; then populate _amax for dead experts so step 3
# doesn't skip them, and unify NVFP4 global_amax across Q/K/V and gate/up
# siblings so MSE searches against a consistent FP8 grid.
max_calibrate(model, forward_loop, distributed_sync)
_bootstrap_uncalibrated_weight_quantizers(model)
sync_grouped_weight_global_amax(model)

# Step 2: Replace calibrators with MseCalibrator for enabled quantizers
# and identify weight quantizers
weight_quantizers = []
seen_modules = set()

# Step 2: replace calibrators with MseCalibrator for enabled quantizers.
for name, module in list(model.named_modules()):
if isinstance(module, TensorQuantizer) and not module._disabled:
if module._calibrator is not None and not module._dynamic and hasattr(module, "_amax"):
# Get the initial amax from max calibration
initial_amax = module._amax.clone().detach()
is_nvfp4_static = module.is_nvfp4_static

is_nvfp4_static = (
module.is_static_block_quant
and module._num_bits == (2, 1)
and module._block_sizes is not None
and module._block_sizes.get("scale_bits") == (4, 3)
)

if is_nvfp4_static:
# Compute and set global_amax
# sync_grouped_weight_global_amax may have already promoted +
# unified global_amax across the sibling group; only promote
# standalone (non-grouped) NVFP4-static quantizers here.
if is_nvfp4_static and not isinstance(module, NVFP4StaticQuantizer):
global_amax = reduce_amax(initial_amax, axis=None)

# Convert to NVFP4StaticQuantizer in-place
NVFP4StaticQuantizer.from_tensor_quantizer(module, global_amax=global_amax)

if fp8_scale_sweep:
Expand Down Expand Up @@ -412,52 +496,50 @@ def mse_calibrate(
quant_func=partial(_mse_quant_func, quantizer=module),
)

# Identify weight quantizers by checking if they have corresponding weight parameters
# Step 3: calibrate weight quantizers via iter_weights_for_calibration.
# The fused-experts override yields one pair per expert per projection, so
# every per-expert quantizer is MSE-calibrated (not just routed ones).
name_to_module = dict(model.named_modules())
seen_modules: set[int] = set()
pbar = tqdm(desc="MSE weight calibration")
n_calibrated = 0
for parent_module in name_to_module.values():
if parent_module in seen_modules:
if id(parent_module) in seen_modules or not isinstance(parent_module, QuantModule):
continue
for weight_name in weight_attr_names(parent_module):
weight_quantizer_name = quantizer_attr_names(weight_name).weight_quantizer
weight_quantizer = getattr(parent_module, weight_quantizer_name, None)
if isinstance(weight_quantizer, TensorQuantizer) and weight_quantizer.is_enabled:
if getattr(weight_quantizer, "_calibrator", None) is not None:
weight_quantizers.append((parent_module, weight_name, weight_quantizer))
seen_modules.add(parent_module)

# Step 3: Calibrate weight quantizers ONE AT A TIME with immediate amax computation
# This prevents massive memory accumulation seen in large models
for idx, (parent_module, weight_name, weight_quantizer) in enumerate(
tqdm(weight_quantizers, desc="MSE weight calibration")
):
# Enable calibration mode for the weight quantizer
weight_quantizer.disable_quant()
weight_quantizer.enable_calib()
seen_modules.add(id(parent_module))
with enable_weight_access_and_writeback(parent_module, model, name_to_module):
weight = getattr(parent_module, weight_name)
weight_quantizer(weight)
for weight, weight_quantizer in parent_module.iter_weights_for_calibration():
if not (
isinstance(weight_quantizer, TensorQuantizer)
and weight_quantizer.is_enabled
and getattr(weight_quantizer, "_calibrator", None) is not None
):
continue
weight_quantizer.disable_quant()
weight_quantizer.enable_calib()
weight_quantizer(weight)

# IMMEDIATELY compute amax and reset calibrator to free memory
cal = getattr(weight_quantizer, "_calibrator", None)
if cal is not None and cal.compute_amax() is not None:
weight_quantizer.load_calib_amax()
cal = weight_quantizer._calibrator
if cal.compute_amax() is not None:
weight_quantizer.load_calib_amax()

weight_quantizer.enable_quant()
weight_quantizer.disable_calib()
weight_quantizer.enable_quant()
weight_quantizer.disable_calib()

# Synchronize ALL CUDA devices before resetting to ensure all async operations complete
# This is critical for multi-GPU setups where tensors may be on different devices
if torch.cuda.is_available():
for dev_id in range(torch.cuda.device_count()):
torch.cuda.synchronize(torch.device(f"cuda:{dev_id}"))
if torch.cuda.is_available():
for dev_id in range(torch.cuda.device_count()):
torch.cuda.synchronize(torch.device(f"cuda:{dev_id}"))

if cal is not None and hasattr(cal, "reset"):
cal.reset()
if hasattr(cal, "reset"):
cal.reset()

if (idx + 1) % 10 == 0 and torch.cuda.is_available():
for dev_id in range(torch.cuda.device_count()):
torch.cuda.synchronize(torch.device(f"cuda:{dev_id}"))
torch.cuda.empty_cache()
pbar.update(1)
n_calibrated += 1
if n_calibrated % 10 == 0 and torch.cuda.is_available():
for dev_id in range(torch.cuda.device_count()):
torch.cuda.synchronize(torch.device(f"cuda:{dev_id}"))
torch.cuda.empty_cache()
pbar.close()

if torch.cuda.is_available():
for dev_id in range(torch.cuda.device_count()):
Expand Down Expand Up @@ -612,6 +694,8 @@ def forward(self, input, *args, **kwargs):
print_rank_0("local_hessian: Running max calibration for all quantizers...")
max_calibrate(model, forward_loop, distributed_sync)

sync_grouped_weight_global_amax(model)

# Setup helpers for all quantized linear modules
name_to_module = dict(model.named_modules())
weight_quantizers_info = []
Expand Down Expand Up @@ -666,14 +750,9 @@ def quant_func(x, amax, quantizer=weight_quantizer):

return xq

is_nvfp4_static = (
weight_quantizer.is_static_block_quant
and weight_quantizer._num_bits == (2, 1)
and weight_quantizer._block_sizes is not None
and weight_quantizer._block_sizes.get("scale_bits") == (4, 3)
)
is_nvfp4_static = weight_quantizer.is_nvfp4_static

if is_nvfp4_static:
if is_nvfp4_static and not isinstance(weight_quantizer, NVFP4StaticQuantizer):
global_amax = reduce_amax(initial_amax, axis=None)
NVFP4StaticQuantizer.from_tensor_quantizer(weight_quantizer, global_amax=global_amax)

Expand Down
10 changes: 10 additions & 0 deletions modelopt/torch/quantization/nn/modules/tensor_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,6 +514,16 @@ def is_mx_format(self):
and self.block_sizes.get("scale_bits", None) == (8, 0)
)

@property
def is_nvfp4_static(self):
"""True for E2M1 weights + E4M3 per-block scales in static layout (format-only check)."""
return (
self.is_static_block_quant
and self._num_bits == (2, 1)
and self._block_sizes is not None
and self._block_sizes.get("scale_bits") == (4, 3)
)

def is_mxfp(self, bits):
"""Check if is MXFP4/MXFP6/MXFP8."""
if bits == 4:
Expand Down
18 changes: 18 additions & 0 deletions modelopt/torch/quantization/plugins/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -900,6 +900,24 @@ def forward(self, *args, **kwargs):
self._down_proj_linear = False
return super().forward(*args, **kwargs)

def iter_weights_for_calibration(self):
"""Yield ``(weight_slice, quantizer)`` per-expert pairs.

The base impl uses singular ``*_weight_quantizer`` and skips fused-
experts modules, so weight-only calibration never reaches per-expert
quantizers without this override.
"""
for weight_name, quantizers_name in (
("gate_up_proj", "gate_up_proj_weight_quantizers"),
("down_proj", "down_proj_weight_quantizers"),
):
weight = getattr(self, weight_name, None)
quantizers = getattr(self, quantizers_name, None)
if weight is None or quantizers is None:
continue
for idx, q in enumerate(quantizers):
yield weight[idx], q

def fold_weight(self, keep_attrs: bool = False):
"""Fold per-expert weight quantizers into the fused 3-D weights.

Expand Down
8 changes: 1 addition & 7 deletions modelopt/torch/quantization/utils/core_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -957,13 +957,7 @@ def promote_nvfp4_static_quantizers(model: nn.Module) -> int:
for _name, module in list(model.named_modules()):
if isinstance(module, TensorQuantizer) and not module._disabled:
if module._calibrator is not None and not module._dynamic and hasattr(module, "_amax"):
is_nvfp4_static = (
module.is_static_block_quant
and module._num_bits == (2, 1)
and module._block_sizes is not None
and module._block_sizes.get("scale_bits") == (4, 3)
)
if is_nvfp4_static:
if module.is_nvfp4_static:
initial_amax = module._amax.clone().detach()
global_amax = reduce_amax(initial_amax, axis=None)
NVFP4StaticQuantizer.from_tensor_quantizer(module, global_amax=global_amax)
Expand Down
Loading