Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 46 additions & 0 deletions modelopt/torch/export/unified_export_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -641,6 +641,9 @@ def _process_quantized_modules(
):
sub_module.unpack_weight()
if get_quantization_format(sub_module) != QUANTIZATION_NONE:
# Skip QuantMoELinear - it's handled separately in _reconstruct_step3p5_moe_linear
if type(sub_module).__name__ == "QuantMoELinear":
continue
if is_quantlinear(sub_module):
try:
with fsdp2_aware_weight_update(model, sub_module, reshard=False):
Expand Down Expand Up @@ -670,6 +673,46 @@ def _process_quantized_modules(
_export_quantized_weight(sub_module, dtype, weight_name)


def _reconstruct_step3p5_moe_linear(model: nn.Module) -> None:
"""Reconstruct QuantMoELinear per-expert weights back to original 3D MoELinear format.

After _process_quantized_modules, each expert's nn.Linear inside QuantMoELinear has:
- weight: fp4-quantized tensor [out_features, in_features]
- weight_scale, weight_scale_2: per-block / global scales
- input_scale: activation scale (if calibrated)

This stacks them back into the original MoELinear layout so the exported state_dict
uses the original key names (e.g. moe.up_proj.weight with shape [N, out, in]).

Note: QuantMoELinear is the dynamically generated class name (Quant + MoELinear),
not _QuantMoELinear which is the implementation class.
"""
for name, module in model.named_modules():
# Match QuantMoELinear (dynamically generated name) not _QuantMoELinear (implementation class)
if type(module).__name__ != "QuantMoELinear":
continue

n = module.num_experts
experts = module.experts

# Reconstruct 3D weight: [num_experts, out_features, in_features]
module.weight = nn.Parameter(
torch.stack([experts[i].weight.data for i in range(n)]),
requires_grad=False,
)

# Stack per-expert scales back under the original attribute names
for attr in ("weight_scale", "weight_scale_2", "input_scale"):
if hasattr(experts[0], attr):
module.register_buffer(
attr,
torch.stack([getattr(experts[i], attr) for i in range(n)]),
)

# Remove expanded experts — the reconstructed 3D tensors replace them
del module.experts


def _export_transformers_checkpoint(
model: nn.Module, dtype: torch.dtype | None = None, is_modelopt_qlora: bool = False, **kwargs
) -> tuple[dict[str, Any], dict[str, Any]]:
Expand Down Expand Up @@ -791,6 +834,9 @@ def _export_transformers_checkpoint(
# Process all quantized modules and export weights
_process_quantized_modules(model, dtype, is_modelopt_qlora)

# Reconstruct Step3p5 MoELinear: per-expert _QuantLinear weights → original 3D format
_reconstruct_step3p5_moe_linear(model)

if accelerator is not None:
# Gather state_dict from all ranks
quantized_state_dict = accelerator.get_state_dict(model)
Expand Down
7 changes: 7 additions & 0 deletions modelopt/torch/quantization/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -641,6 +641,13 @@
"*mlp*input_quantizer": _nvfp4_quantizer,
"*block_sparse_moe*weight_quantizer": _nvfp4_quantizer,
"*block_sparse_moe*input_quantizer": _nvfp4_quantizer,
# Step3p5 MoE experts: MoELinear lives at *.moe.{up,gate,down}_proj
"*moe*weight_quantizer": _nvfp4_quantizer,
"*moe*input_quantizer": _nvfp4_quantizer,
# disable *mode.gate.* for router
"*moe.gate.*": {"enable": False},
# Disable share_expert (dense MLP alongside MoE, not in MLP-only quant scope)
"*share_expert*": {"enable": False},
**_default_disabled_quantizer_cfg,
}

Expand Down
64 changes: 64 additions & 0 deletions modelopt/torch/quantization/plugins/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -1468,10 +1468,74 @@ def _is_param_grad_enabled_for_auto_quantize(pname, model):
is_homogeneous_hf_model, get_homogeneous_hf_decoder_layers
)


class _QuantMoELinear(QuantModule):
"""Quantization wrapper for Step3p5 MoELinear modules (fused expert weights).

MoELinear has weight shape [num_experts, out_features, in_features] with
forward(x, expert_id). We expand it into per-expert nn.Linear modules so
each expert gets its own weight_quantizer and input_quantizer, calibrated
only on tokens actually routed to that expert.

On export, _reconstruct_step3p5_moe_linear() stacks the per-expert quantized
weights and scales back into the original 3D format.
"""

def _setup(self):
from accelerate import init_empty_weights

dtype, device = self.weight.dtype, self.weight.device

with init_empty_weights():
experts = nn.ModuleList(
[
nn.Linear(self.in_features, self.out_features, bias=False)
for _ in range(self.num_experts)
]
)

for i in range(self.num_experts):
experts[i].to_empty(device=device)
with torch.no_grad():
experts[i].weight.data = self.weight[i].detach().to(dtype=dtype, device=device)

delattr(self, "weight")
self.experts = experts

def forward(self, x, expert_id):
# experts[expert_id] is a _QuantLinear after quantization wrapping,
# providing per-expert input_quantizer and weight_quantizer.
# Cast input to match expert weight dtype before linear operation,
# then cast output to float32 to match original MoELinear forward behavior.
expert = self.experts[expert_id]
x = x.to(expert.weight.dtype)
return expert(x).float()


def register_step3p5_moe_on_the_fly(model):
"""Register Step3p5 MoELinear for quantization.

Step3p5 uses a custom MoELinear class (loaded via trust_remote_code) with
weight shape [num_experts, out_features, in_features] and forward(x, expert_id).
We detect it by model class name, then grab the type from the first MoE layer.
"""
if type(model).__name__ not in ("Step3p5ForCausalLM", "Step3p5Model"):
return
for module in model.modules():
if type(module).__name__ == "Step3p5MoEMLP":
moe_linear_type = type(module.up_proj)
if QuantModuleRegistry.get(moe_linear_type) is None:
QuantModuleRegistry.register({moe_linear_type: f"hf.{moe_linear_type.__name__}"})(
_QuantMoELinear
)
break


CUSTOM_MODEL_PLUGINS.update(
[
register_falcon_linears_on_the_fly,
register_dbrx_moe_on_the_fly,
register_step3p5_moe_on_the_fly,
register_sparse_moe_on_the_fly,
register_hf_attentions_on_the_fly,
convert_hf_parallel_linears_on_the_fly,
Expand Down
Loading