Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 20 additions & 18 deletions modelopt/torch/export/unified_export_megatron.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,16 +419,25 @@ def _get_state_dict(self):
if hasattr(model, "output_layer") and not model.share_embeddings_and_output_weights:
self.rules["output_layer"](model.output_layer)

def _get_fused_norm_weight(self, module):
"""Return ``module.layer_norm_weight`` when TE fuses the norm into a linear layer.

Returns ``None`` when the ``"fused_norm"`` rule is absent or the module has no
``layer_norm_weight`` attribute (or its value is ``None``).
"""
if "fused_norm" not in self.rules:
return None
return getattr(module, "layer_norm_weight", None)

def _get_transformer_layer_state_dict(self, layer, layer_id):
if not isinstance(layer.input_layernorm, IdentityOp):
self.rules["input_layernorm"](layer.input_layernorm, layer_id)
elif (
hasattr(layer.self_attention, "linear_qkv")
and hasattr(layer.self_attention.linear_qkv, "layer_norm_weight")
and layer.self_attention.linear_qkv.layer_norm_weight is not None
and "fused_norm" in self.rules
):
self.rules["fused_norm"](layer.self_attention.linear_qkv.layer_norm_weight, layer_id)
norm_weight := self._get_fused_norm_weight(
getattr(layer.self_attention, "linear_qkv", None)
)
) is not None:
self.rules["fused_norm"](norm_weight, layer_id)

if not isinstance(layer.self_attention, IdentityOp):
if "MLASelfAttention" in str(type(layer.self_attention)):
Expand Down Expand Up @@ -470,12 +479,10 @@ def _get_transformer_layer_state_dict(self, layer, layer_id):
elif (
not isinstance(layer.mlp, IdentityOp)
and "MoE" not in str(type(layer.mlp))
and hasattr(layer.mlp, "linear_fc1")
and hasattr(layer.mlp.linear_fc1, "layer_norm_weight")
and layer.mlp.linear_fc1.layer_norm_weight is not None
and "fused_norm" in self.rules
and (norm_weight := self._get_fused_norm_weight(getattr(layer.mlp, "linear_fc1", None)))
is not None
):
self.rules["fused_norm"](layer.mlp.linear_fc1.layer_norm_weight, layer_id)
self.rules["fused_norm"](norm_weight, layer_id)

if not isinstance(layer.mlp, IdentityOp):
if "MoE" in str(type(layer.mlp)):
Expand Down Expand Up @@ -555,14 +562,9 @@ def _get_mtp_state_dict(self) -> dict[str, torch.Tensor]:
def _get_mamba_layer_state_dict(self, layer, layer_id):
if not isinstance(layer.norm, IdentityOp):
self.rules["norm"](layer.norm, layer_id)
elif (
isinstance(layer.norm, IdentityOp)
and hasattr(layer.mixer.in_proj, "layer_norm_weight")
and layer.mixer.in_proj.layer_norm_weight is not None
and "fused_norm" in self.rules
):
elif (norm_weight := self._get_fused_norm_weight(layer.mixer.in_proj)) is not None:
# TE spec: norm is fused into in_proj (QuantTELayerNormColumnParallelLinear).
self.rules["fused_norm"](layer.mixer.in_proj.layer_norm_weight, layer_id)
self.rules["fused_norm"](norm_weight, layer_id)

self.rules["mixer_norm"](layer.mixer.norm, layer_id)
self.rules["A_log"](layer.mixer.A_log, layer_id)
Expand Down
Loading