-
Notifications
You must be signed in to change notification settings - Fork 390
fixes for fused moe (qwen3.6, GLM5.1 + MSE calibration #1382
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -76,8 +76,12 @@ def _export_fused_experts(module: nn.Module, dtype: torch.dtype) -> None: | |
| ) | ||
| i_quantizer = gate_up_input_q if is_gate_up else down_input_q | ||
|
|
||
| # gate/up share a weight quantizer — clone so each gets independent amax. | ||
| w_quantizer = copy.deepcopy(w_quantizer_src) if is_gate_up else w_quantizer_src | ||
| # gate/up share a quantizer — deepcopy so gate_proj and up_proj get | ||
| # independent quantizers that can hold different amax slices. | ||
| if is_gate_up: | ||
| w_quantizer = copy.deepcopy(w_quantizer_src) | ||
| else: | ||
| w_quantizer = w_quantizer_src | ||
|
|
||
| # For per-channel amax (dim >= 1), proportionally slice dim-0 | ||
| # to match the split weight. | ||
|
|
@@ -91,7 +95,7 @@ def _export_fused_experts(module: nn.Module, dtype: torch.dtype) -> None: | |
| if fused_total % amax_dim0 == 0: | ||
| slice_start = fused_start * amax_dim0 // fused_total | ||
| slice_end = (fused_start + weight_slice.shape[0]) * amax_dim0 // fused_total | ||
| w_quantizer.amax = amax[slice_start:slice_end].contiguous() | ||
| w_quantizer._amax = amax[slice_start:slice_end].contiguous() | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why are you checking |
||
| else: | ||
| warnings.warn( | ||
|
Fridah-nv marked this conversation as resolved.
|
||
| f"Expert {idx} {proj_name}: fused amax dim0 ({amax_dim0}) does not " | ||
|
|
@@ -100,20 +104,73 @@ def _export_fused_experts(module: nn.Module, dtype: torch.dtype) -> None: | |
| stacklevel=2, | ||
| ) | ||
|
|
||
| # If the weight quantizer was never calibrated, compute amax from weights. | ||
| # Patch invalid per-block amax entries (NaN/inf/negative/zero/too-small/too-large) | ||
|
Fridah-nv marked this conversation as resolved.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should throw an error if there are experts with uncalibrated amax and suggest rerunning PTQ with more calibration samples/seq length. This is what we do in the MCore PTQ path -- because null amax in MCore causes a deadlock in distributed sync. For HF PTQ to have parity with MCore PTQ we should do the same (even though there is no dist sync in HF PTQ)
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Patching is risky as the patched amax could break the checkpoint. for non-null invalid amax, a warning/error should also be thrown
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should this be in the MSE calibrator? |
||
| # with weight-derived fallback values. | ||
| min_valid_amax = 2e-3 # floor matches FP8 E4M3FN minimum subnormal (2^-9 ≈ 0.00195) | ||
| max_valid_amax = 1e6 | ||
| if ( | ||
| hasattr(w_quantizer, "_amax") | ||
| and w_quantizer._amax is not None | ||
| and w_quantizer._amax.numel() > 1 | ||
| and (getattr(w_quantizer, "block_sizes", None) or {}).get(-1) is not None | ||
| ): | ||
| amax_cpu = w_quantizer._amax | ||
| invalid_mask = ~( | ||
| torch.isfinite(amax_cpu) | ||
| & (amax_cpu >= min_valid_amax) | ||
| & (amax_cpu <= max_valid_amax) | ||
| ) | ||
| if invalid_mask.any(): | ||
| _block_size = (getattr(w_quantizer, "block_sizes", None) or {}).get(-1, 16) | ||
| per_block_fallback = ( | ||
| weight_slice.detach() | ||
|
Fridah-nv marked this conversation as resolved.
|
||
| .reshape(-1, _block_size) | ||
| .abs() | ||
| .amax(dim=1, keepdim=True) | ||
| .cpu() | ||
| .float() | ||
| .clamp(min=2e-3) | ||
| .reshape(amax_cpu.shape) | ||
| ) | ||
| amax_cpu[invalid_mask] = per_block_fallback[invalid_mask] | ||
| w_quantizer._amax = amax_cpu | ||
|
coderabbitai[bot] marked this conversation as resolved.
|
||
|
|
||
| # For uncalibrated experts (amax missing or invalid scalar), fall back to | ||
| # per-block amax from weights so the static export path can reshape it correctly. | ||
| # Only applies to per-block (NVFP4) quantizers — non-block quantizers have | ||
| # no block_sizes and should not be routed to the static NVFP4 export path. | ||
| if ( | ||
| hasattr(w_quantizer, "is_enabled") | ||
| and w_quantizer.is_enabled | ||
| and (getattr(w_quantizer, "block_sizes", None) or {}).get(-1) is not None | ||
| and ( | ||
| not hasattr(w_quantizer, "_amax") | ||
| or w_quantizer._amax is None | ||
| or torch.all(w_quantizer._amax == 0) | ||
| or ( | ||
| w_quantizer._amax.numel() == 1 | ||
| and not ( | ||
| torch.isfinite(w_quantizer._amax) | ||
| and w_quantizer._amax >= min_valid_amax | ||
| and w_quantizer._amax <= max_valid_amax | ||
| ) | ||
| ) | ||
| ) | ||
| ): | ||
| w_quantizer.amax = weight_slice.abs().amax().to(torch.float32) | ||
| _block_size = (getattr(w_quantizer, "block_sizes", None) or {}).get(-1, 16) | ||
| fallback_per_block = ( | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the code in this file is too hard to read, too many hardcoded numbers everywhere |
||
| weight_slice.detach() | ||
| .reshape(-1, _block_size) | ||
| .abs() | ||
| .amax(dim=1, keepdim=True) | ||
| .cpu() | ||
| .float() | ||
| .clamp(min=2e-3) | ||
| .reshape(*weight_slice.shape[:-1], weight_slice.shape[-1] // _block_size) | ||
| ) | ||
| w_quantizer._amax = fallback_per_block | ||
| warnings.warn( | ||
| f"Expert {idx} {proj_name} weight quantizer was not calibrated " | ||
| f"(amax missing or zero). Using weight-derived amax as fallback. " | ||
| f"(amax missing or zero). Using weight-derived per-block amax as fallback. " | ||
| f"Consider using more calibration data to activate all experts.", | ||
| stacklevel=2, | ||
| ) | ||
|
|
@@ -123,6 +180,20 @@ def _export_fused_experts(module: nn.Module, dtype: torch.dtype) -> None: | |
| wrapper.weight_quantizer = w_quantizer | ||
| wrapper.input_quantizer = i_quantizer | ||
|
|
||
| # Set global_amax to route to the static NVFP4 export path (reads per-block _amax). | ||
| # Always recompute from the current (possibly patched) _amax — a stale zero | ||
| # global_amax causes division-by-zero in the per-block scale formula. | ||
| # Guard: only per-block (NVFP4) quantizers have block_sizes; skip for others. | ||
| wq = wrapper.weight_quantizer | ||
| if ( | ||
| hasattr(wq, "_amax") | ||
| and wq._amax is not None | ||
| and wq._amax.numel() > 1 | ||
| and (getattr(wq, "block_sizes", None) or {}).get(-1) is not None | ||
| ): | ||
| wq._amax = wq._amax.to(weight_slice.device) | ||
| wq.global_amax = wq._amax.float().amax().clamp(min=2e-3) | ||
|
|
||
| _export_quantized_weight(wrapper, dtype) | ||
|
|
||
| proj = nn.Module() | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -20,7 +20,7 @@ | |
| import warnings | ||
| from collections.abc import Callable | ||
| from functools import partial | ||
| from typing import TypeAlias | ||
| from typing import Any, TypeAlias | ||
|
|
||
| import torch | ||
| import torch.distributed as dist | ||
|
|
@@ -351,7 +351,7 @@ def mse_calibrate( | |
|
|
||
| # Step 2: Replace calibrators with MseCalibrator for enabled quantizers | ||
| # and identify weight quantizers | ||
| weight_quantizers = [] | ||
| weight_quantizers: list[tuple[Any, Any, TensorQuantizer]] = [] | ||
| seen_modules = set() | ||
|
|
||
| for name, module in list(model.named_modules()): | ||
|
|
@@ -410,7 +410,12 @@ def mse_calibrate( | |
| quant_func=partial(_mse_quant_func, quantizer=module), | ||
| ) | ||
|
|
||
| # Identify weight quantizers by checking if they have corresponding weight parameters | ||
| # Collect weight quantizers (standard + fused-experts per-expert lists). | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what is a |
||
| try: | ||
| from modelopt.torch.quantization.plugins.huggingface import _QuantFusedExperts as _qfe_cls | ||
| except ImportError: | ||
| _qfe_cls = None # type: ignore[misc] | ||
|
|
||
| name_to_module = dict(model.named_modules()) | ||
| for parent_module in name_to_module.values(): | ||
| if parent_module in seen_modules: | ||
|
|
@@ -421,8 +426,56 @@ def mse_calibrate( | |
| if isinstance(weight_quantizer, TensorQuantizer) and weight_quantizer.is_enabled: | ||
| if getattr(weight_quantizer, "_calibrator", None) is not None: | ||
| weight_quantizers.append((parent_module, weight_name, weight_quantizer)) | ||
| # Enqueue per-expert quantizers from {param}_weight_quantizers ModuleLists. | ||
| if _qfe_cls is not None and isinstance(parent_module, _qfe_cls): | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. move this to a helper function |
||
| for param_name, param in parent_module.named_parameters(recurse=False): | ||
| qlist = getattr(parent_module, f"{param_name}_weight_quantizers", None) | ||
| if not isinstance(qlist, nn.ModuleList): | ||
| continue | ||
| if len(qlist) != param.shape[0]: | ||
| warnings.warn( | ||
| f"Skipping {param_name}_weight_quantizers: list length {len(qlist)} " | ||
| f"does not match parameter leading dimension {param.shape[0]}. " | ||
| "This may indicate a misconfigured fused-experts module.", | ||
| stacklevel=2, | ||
| ) | ||
| continue | ||
| for expert_idx, wq in enumerate(qlist): | ||
| if isinstance(wq, TensorQuantizer) and wq.is_enabled: | ||
| if getattr(wq, "_calibrator", None) is not None: | ||
| weight_quantizers.append((parent_module, (param_name, expert_idx), wq)) | ||
|
|
||
|
Comment on lines
+429
to
+447
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can we have a helper method |
||
| seen_modules.add(parent_module) | ||
|
|
||
| # Warn about enabled weight quantizers that weren't scheduled for MSE calibration. | ||
| picked_ids = {id(wq) for _, _, wq in weight_quantizers} | ||
|
|
||
| def _is_active_unpicked(q: Any) -> bool: | ||
| return ( | ||
| isinstance(q, TensorQuantizer) | ||
| and q.is_enabled | ||
| and getattr(q, "_calibrator", None) is not None | ||
| and id(q) not in picked_ids | ||
| ) | ||
|
|
||
| missed: list[str] = [] | ||
| for mod_name, module in name_to_module.items(): | ||
| for attr_name, attr in module._modules.items(): | ||
| if isinstance(attr, TensorQuantizer) and attr_name.endswith("weight_quantizer"): | ||
| if _is_active_unpicked(attr): | ||
| missed.append(f"{mod_name}.{attr_name}") | ||
| elif isinstance(attr, nn.ModuleList) and attr_name.endswith("_weight_quantizers"): | ||
| for i, wq in enumerate(attr): | ||
| if _is_active_unpicked(wq): | ||
| missed.append(f"{mod_name}.{attr_name}[{i}]") | ||
| if missed: | ||
| warnings.warn( | ||
| f"MSE weight calibration: {len(missed)} weight quantizer(s) are enabled but were " | ||
| f"not scheduled for calibration and will retain max-calibration amax values. " | ||
| f"First {min(5, len(missed))}: {missed[:5]}", | ||
| stacklevel=2, | ||
| ) | ||
|
|
||
| # Step 3: Calibrate weight quantizers ONE AT A TIME with immediate amax computation | ||
| # This prevents massive memory accumulation seen in large models | ||
| for idx, (parent_module, weight_name, weight_quantizer) in enumerate( | ||
|
|
@@ -432,7 +485,11 @@ def mse_calibrate( | |
| weight_quantizer.disable_quant() | ||
| weight_quantizer.enable_calib() | ||
| with enable_weight_access_and_writeback(parent_module, model, name_to_module): | ||
| weight = getattr(parent_module, weight_name) | ||
| if isinstance(weight_name, tuple): | ||
| param_name, expert_idx = weight_name | ||
| weight = getattr(parent_module, param_name)[expert_idx] | ||
| else: | ||
| weight = getattr(parent_module, weight_name) | ||
| weight_quantizer(weight) | ||
|
|
||
| # IMMEDIATELY compute amax and reset calibrator to free memory | ||
|
|
@@ -778,7 +835,7 @@ def finish_stats_collection(model: nn.Module, method: str | None = None, **kwarg | |
|
|
||
| cal = getattr(module, "_calibrator", None) | ||
| if cal and not getattr(module, "_dynamic", False): | ||
| if method in {"entropy"}: | ||
| if method == "entropy": | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why is this needed? |
||
| if cal.compute_amax(method) is not None: | ||
| module.load_calib_amax("entropy", **kwargs) | ||
| elif cal.compute_amax(**kwargs) is not None: | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -124,8 +124,11 @@ def get_weights_scaling_factor_from_quantizer( | |||||
|
|
||||||
| # Quantize scales to FP8 | ||||||
| if not keep_high_precision: | ||||||
| per_block_scale = (per_block_scale * 448.0 / per_block_scale_max).to( | ||||||
| torch.float8_e4m3fn | ||||||
| fp8_e4m3fn_min = 2**-9 # 0.001953125 — smallest positive subnormal | ||||||
| per_block_scale = ( | ||||||
| (per_block_scale * 448.0 / per_block_scale_max) | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we also need a max clamp in line 130 to 448. I saw some nan's in exported MSE no-sweep checkpoints due to overflow e.g. during PTQ |
||||||
| .clamp(min=fp8_e4m3fn_min) | ||||||
| .to(torch.float8_e4m3fn) | ||||||
| ) | ||||||
| return per_block_scale, weights_scaling_factor_2 | ||||||
| else: | ||||||
|
|
@@ -167,6 +170,12 @@ def get_weights_scaling_factor( | |||||
| per_block_scale[per_block_scale == 0] = 1.0 | ||||||
| # Convert to torch.float8_e4m3fn | ||||||
| if not keep_high_precision: | ||||||
| # Clamp to the minimum positive FP8 E4M3FN subnormal (~0.00195 = 2^-9) before | ||||||
| # casting. Without this, blocks whose scale falls below the FP8 representable | ||||||
| # range silently underflow to 0, causing those blocks to produce zero output at | ||||||
| # inference even when the weights are non-trivial. | ||||||
| fp8_e4m3fn_min = 2**-9 # 0.001953125 — smallest positive subnormal | ||||||
| per_block_scale = per_block_scale.clamp(min=fp8_e4m3fn_min) | ||||||
| per_block_scale = per_block_scale.to(torch.float8_e4m3fn) | ||||||
|
Comment on lines
+173
to
179
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we create a helper method which does the FP8 quantization of per_tensor scale and use that here and here https://github.com/NVIDIA/Model-Optimizer/pull/1382/changes#r3191334011 |
||||||
| return per_block_scale, weights_scaling_factor_2 | ||||||
|
|
||||||
|
|
||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit:
is_gate_or_up_proj?