Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions modelopt/torch/export/plugins/mcore_custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,15 @@ def __init__(self, target_name_or_prefix: str = "", func_kwargs: dict[str, Any]
func_kwargs=func_kwargs,
)

class SelfAttentionScaling(CustomModuleMapping):
"""A custom module mapping that scales self attention."""
def __init__(self, target_name_or_prefix: str = "", func_kwargs: dict[str, Any] = {}):
"""Create a custom module mapping that scales self attention."""
super().__init__(
func_name="self_attention_scaling",
target_name_or_prefix=target_name_or_prefix,
func_kwargs=func_kwargs,
)

class GatedMLPSlicing(CustomModuleMapping):
"""A custom module mapping that slices gate_proj and up_proj."""
Expand Down
22 changes: 22 additions & 0 deletions modelopt/torch/export/plugins/mcore_nemotron.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
NameRemapping,
QKVMerging,
QKVSlicing,
SelfAttentionScaling,
)

# Example on adding a new CausalLM.
Expand Down Expand Up @@ -81,8 +82,19 @@
"shared_experts.linear_fc2": NameRemapping(
"backbone.layers.{}.mixer.shared_experts.down_proj.", ROW_TP
),
# Latent MoE
"fc1_latent_proj": NameRemapping("backbone.layers.{}.mixer.fc1_latent_proj.", REPLICATE),
"fc2_latent_proj": NameRemapping("backbone.layers.{}.mixer.fc2_latent_proj.", REPLICATE),
# MTP
"mtp.enorm": NameRemapping("mtp.layers.{}.enorm.", REPLICATE),
"mtp.hnorm": NameRemapping("mtp.layers.{}.hnorm.", REPLICATE),
"mtp.eh_proj": NameRemapping("mtp.layers.{}.eh_proj.", REPLICATE),
"mtp.final_layernorm": NameRemapping("mtp.layers.{}.final_layernorm.", REPLICATE),


}

# TODO ADD MTP export

nemotron_h_causal_lm_export: dict[str, CustomModuleMapping] = {
"word_embeddings": NameRemapping("backbone.embeddings."),
Expand All @@ -101,6 +113,7 @@
"input_layernorm": NameRemapping("backbone.layers.{}.norm."),
"linear_qkv": QKVSlicing("backbone.layers.{}.mixer."),
"linear_proj": NameRemapping("backbone.layers.{}.mixer.o_proj."),
"core_attention": SelfAttentionScaling("backbone.layers.{}.mixer."),
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

doublecheck that this is only needed for export

# MLP
"pre_mlp_layernorm": NameRemapping("backbone.layers.{}.norm."),
"linear_fc1": NameRemapping("backbone.layers.{}.mixer.up_proj."),
Expand All @@ -115,4 +128,13 @@
"shared_experts.linear_fc2": NameRemapping(
"backbone.layers.{}.mixer.shared_experts.down_proj."
),
# Latent MoE
"fc1_latent_proj": NameRemapping("backbone.layers.{}.mixer.fc1_latent_proj."),
"fc2_latent_proj": NameRemapping("backbone.layers.{}.mixer.fc2_latent_proj."),
# MTP
"mtp.enorm": NameRemapping("mtp.layers.{}.enorm."),
"mtp.hnorm": NameRemapping("mtp.layers.{}.hnorm."),
"mtp.eh_proj": NameRemapping("mtp.layers.{}.eh_proj."),
"mtp.final_layernorm": NameRemapping("mtp.layers.{}.final_layernorm."),

}
333 changes: 187 additions & 146 deletions modelopt/torch/export/plugins/megatron_importer.py

Large diffs are not rendered by default.

56 changes: 25 additions & 31 deletions modelopt/torch/export/quant_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,7 @@ def get_prequant_scaling_factor(module: nn.Module) -> torch.Tensor:
if prequant_scaling_factor is not None:
assert torch.all(prequant_scaling_factor > 0), (
f"prequant scaling factor {prequant_scaling_factor} not positive."
)
)
return prequant_scaling_factor


Expand All @@ -344,32 +344,22 @@ def get_kv_cache_bias(kv_module: nn.Module) -> list[torch.Tensor]:
kv_bias.append(getattr(quantizer_module, "_bias_value", None))
return kv_bias


def get_kv_cache_scaling_factor(kv_module: nn.Module) -> list[torch.Tensor]:
"""Returns the kv_cache scaling factor if output quantizer is set. Else returns None by default."""
if not hasattr(kv_module, "k_bmm_quantizer") or not hasattr(kv_module, "v_bmm_quantizer"):
def get_kv_cache_scaling_factor(self_attention_module: nn.Module) -> torch.Tensor:
"""
Returns the k and v BMM scaling factors if BMM quantizers are set in the self attention module.
Else returns None by default.
"""
if not hasattr(self_attention_module, "k_bmm_quantizer") or not hasattr(self_attention_module, "v_bmm_quantizer"):
return [None, None]

scaling_factors = [
get_scaling_factor(getattr(kv_module, quantizer))
get_scaling_factor(getattr(self_attention_module, quantizer))
for quantizer in ("k_bmm_quantizer", "v_bmm_quantizer")
]

# For FP8, we recommend default kv cache scaling factor to be 1.
if get_kv_cache_dtype(kv_module) == KV_CACHE_FP8:
for i, factor in enumerate(scaling_factors):
if factor.item() > 0.5:
warn(
f"Warning: Large KV activation detected: {factor.item()}, "
"Quantized KV cache may lead to higher accuracy drop."
)
scaling_factors[i] = torch.max(
factor, torch.tensor([1.0], dtype=torch.float, device=factor.device)
)

return scaling_factors



def get_kv_cache_dtype(modules: list[nn.Module] | nn.Module) -> str | None:
"""Returns the kv_cache dtype.

Expand Down Expand Up @@ -397,6 +387,22 @@ def get_kv_cache_dtype(modules: list[nn.Module] | nn.Module) -> str | None:
num_bits_list.append(quantizer_attr.num_bits)
is_affine &= hasattr(quantizer_attr, "_bias_value")

return _compute_kv_cache_dtype(num_bits_list)

def _compute_kv_cache_dtype(num_bits_list: list[int]) -> str | None:
"""Returns the kv_cache dtype.

If num_bits of output_quantizer is (4, 3) then returns FP8; if it is 8, returns int8,
otherwise returns None.

Args:
modules: The module or list of modules to inspect.

Returns:
The kv_cache dtype.
"""
is_affine = True

if (4, 3) in num_bits_list:
return KV_CACHE_FP8
elif 8 in num_bits_list:
Expand Down Expand Up @@ -920,18 +926,6 @@ def postprocess_state_dict(

value = value.float() / maxbound

# Warn if scale exceeds threshold
if quantization == KV_CACHE_FP8 and value.item() > 0.5:
logger.warning(
"Large KV activations detected. Quantized KV cache may lead to higher accuracy drop. "
"Setting KV cache scaling factor to at least 1."
)

# Ensure scale is at least 1 for KV_CACHE_FP8
# We export real value for KV_CACHE_NVFP4
if quantization == KV_CACHE_FP8:
value.clamp_(min=1.0)

post_state_dict[prefix + new_suffix] = value
break

Expand Down
84 changes: 36 additions & 48 deletions modelopt/torch/export/unified_export_megatron.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,9 @@
from .plugins.megatron_importer import GPTModelImporter
from .quant_utils import (
get_activation_scaling_factor,
get_kv_cache_scaling_factor,
get_kv_cache_dtype,
get_quant_config,
get_quantization_format,
get_scaling_factor,
get_weight_block_size,
Expand Down Expand Up @@ -86,33 +88,6 @@
]


# This path uses output_quantizer for KV cache quantization.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I hope to see more proof before we remove this.

Also TRTLLM right now by default still uses 1 as the kv cache scale [ignoring the values we set here.]

# The function below is the old version of get_kv_cache_scaling_factor which is now refactored to handle bmm_quantizer.
def get_kv_cache_scaling_factor(kv_module: nn.Module) -> torch.Tensor:
"""Returns the kv_cache scaling factor if output quantizer is set. Else returns None by default."""
scaling_factor = (
get_scaling_factor(kv_module.output_quantizer)
if hasattr(kv_module, "output_quantizer")
else None
)

if not scaling_factor:
return None

# For FP8, we recommend default kv cache scaling factor to be 1.
if get_kv_cache_dtype(kv_module) == KV_CACHE_FP8:
if scaling_factor.item() > 0.5:
warn(
f"!!!!Large KV activations detected: {scaling_factor.item()}, "
"Quantized KV cache may lead to higher accuracy drop.\n!!!!"
)
scaling_factor = torch.max(
scaling_factor,
torch.tensor([1.0], dtype=torch.float, device=scaling_factor.device),
)
return scaling_factor


class GPTModelExporter:
"""Megatron Core GPTModel Exporter.

Expand Down Expand Up @@ -281,11 +256,6 @@ def save_pretrained(
elif quantization_format == QUANTIZATION_NVFP4:
quantization = "NVFP4"

kv_cache_quantization = None
kv_cache_dtype = get_kv_cache_dtype(self.model)
if kv_cache_dtype in (KV_CACHE_FP8, KV_CACHE_NVFP4):
# FP8 KV Cache is supported in VLLM; NVFP4 supported in TRTLLM
kv_cache_quantization = kv_cache_dtype
# We use the last PP rank and the 1st EP rank to write the config because
# medusa_heads and eagle_module only exist in the last stage.
if is_last_stage_main_rank:
Expand Down Expand Up @@ -320,17 +290,22 @@ def save_pretrained(
pass

if is_last_stage_main_rank and quantization is not None:
# TODO refactor to use mte.quant_utils.get_quant_config
# except layer names are different in MCore and HF
hf_quant_config = {
"producer": {
"name": "modelopt",
"version": __version__,
},
"quantization": {
"quant_algo": quantization,
"kv_cache_quant_algo": kv_cache_quantization,
"exclude_modules": ["lm_head"],
"exclude_modules": ["lm_head"], # TODO update this dynamically
},
}
if quantization == "NVFP4": # update block size
hf_quant_config["quantization"]["group_size"] = 16
if hasattr(self, "kv_cache_dtype"):
hf_quant_config["quantization"]["kv_cache_quant_algo"] = self.kv_cache_dtype
with open(save_directory + "/hf_quant_config.json", "w") as f:
json.dump(hf_quant_config, f, indent=4)

Expand Down Expand Up @@ -473,6 +448,7 @@ def _custom_mapping_to_lambda(mapping):
method_map = {
"name_remapping": self._name_remapping,
"qkv_slicing": self._qkv_slicing,
"self_attention_scaling": self._self_attention_scaling,
"gated_mlp_slicing": self._gated_mlp_slicing,
"pack_name_remapping": self._pack_name_remapping,
"pack_name_remapping_gpt_oss": self._pack_name_remapping_gpt_oss,
Expand Down Expand Up @@ -541,12 +517,8 @@ def _get_quantized_state(
# TODO (chenhany): support AWQ with pre_quant_scale
if hasattr(module.input_quantizer, "_pre_quant_scale"):
raise ValueError("Detect pre_quant_scale! SmoothQuant/AWQ are not yet supported!")

if hasattr(module, "output_quantizer"):
output_scale = get_kv_cache_scaling_factor(module)
if output_scale is not None:
name_to_value["output_scale"] = output_scale



return name_to_value, qformat, block_size

def _get_quantization_format(self, module: torch.nn.Module):
Expand Down Expand Up @@ -674,9 +646,7 @@ def _qkv_slicing(
q_proj_name="q_proj",
k_proj_name="k_proj",
v_proj_name="v_proj",
k_scale_name="k_scale",
v_scale_name="v_scale",
):
):
name_to_value, qformat, block_size = self._get_quantized_state(module, self.dtype)

q_proj_prefix = prefix + q_proj_name + "."
Expand Down Expand Up @@ -756,7 +726,7 @@ def _qkv_slicing(
quantized_weight = to_quantized_weight(
weight,
scale,
qformat,
qformat,
weight_scale_2,
block_size,
)
Expand All @@ -774,10 +744,7 @@ def _qkv_slicing(
q_proj_key = q_proj_prefix + key
k_proj_key = k_proj_prefix + key
v_proj_key = v_proj_prefix + key
if key == "output_scale":
self._state_dict[prefix + k_scale_name] = val.detach().clone()
self._state_dict[prefix + v_scale_name] = val.detach().clone()
elif key == "bias":
if key == "bias":
# Slice bias similar to weight
bias = val.detach().clone()
bias = bias.reshape([qkv_total_dim, head_size])
Expand All @@ -790,6 +757,21 @@ def _qkv_slicing(
self._state_dict[k_proj_key] = val.detach().clone()
self._state_dict[v_proj_key] = val.detach().clone()

def _self_attention_scaling(self, module, prefix, k_scale_name="k_scale", v_scale_name="v_scale"):
"""KV cache scaling for CoreAttention module."""
k_scale_key = prefix + k_scale_name
v_scale_key = prefix + v_scale_name
if hasattr(module, "k_bmm_quantizer") and hasattr(module, "v_bmm_quantizer"):
kv_scales = get_kv_cache_scaling_factor(module)
if all(s is not None for s in kv_scales):
self._state_dict[k_scale_key] = kv_scales[0]
self._state_dict[v_scale_key] = kv_scales[1]

kv_cache_dtype = get_kv_cache_dtype(module)
if kv_cache_dtype in (KV_CACHE_FP8, KV_CACHE_NVFP4):
# FP8 KV Cache is supported in VLLM; NVFP4 supported in TRTLLM
self.kv_cache_dtype = kv_cache_dtype

def _pack_name_remapping(self, module, prefix, layer_type=None):
"""Pack name remapping into one tensor."""
weight_list = []
Expand Down Expand Up @@ -1149,6 +1131,8 @@ def _get_state_dict(self):
self.rules["q_layernorm"](layer.self_attention.q_layernorm, layer_id)
self.rules["k_layernorm"](layer.self_attention.k_layernorm, layer_id)
self.rules["linear_qkv"](layer.self_attention.linear_qkv, layer_id)
if hasattr(layer.self_attention, "core_attention"):
self.rules["core_attention"](layer.self_attention.core_attention, layer_id)
self.rules["linear_proj"](layer.self_attention.linear_proj, layer_id)
if (
getattr(layer.self_attention.core_attention, "softmax_offset", None)
Expand All @@ -1166,6 +1150,10 @@ def _get_state_dict(self):
self.rules["router"](
layer.mlp.router, layer_id, dtype=self.moe_router_dtype
)
if hasattr(layer.mlp, "fc1_latent_proj") and layer.mlp.fc1_latent_proj is not None:
self.rules["fc1_latent_proj"](layer.mlp.fc1_latent_proj, layer_id)
if hasattr(layer.mlp, "fc2_latent_proj") and layer.mlp.fc2_latent_proj is not None:
self.rules["fc2_latent_proj"](layer.mlp.fc2_latent_proj, layer_id)
if (
hasattr(layer.mlp, "shared_experts")
and layer.mlp.shared_experts is not None
Expand Down
18 changes: 9 additions & 9 deletions modelopt/torch/quantization/model_calib.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,11 @@ def max_calibrate(model: nn.Module, forward_loop: ForwardLoop | None = None, dis
forward_loop(model)
finish_stats_collection(model)

# Step 1: Sync amax across local experts in a SequentialMLP
for name, module in model.named_modules():
if hasattr(module, "sync_moe_local_experts_amax"):
module.sync_moe_local_experts_amax()

if not distributed_sync:
return

Expand All @@ -95,13 +100,14 @@ def sync_quantizer_amax_across_dp_ep(quantizer, parallel_state):
quantizer.sync_amax_across_distributed_group(parallel_state.expert_model_parallel_group)
# TODO: create sync_bias_across_distributed_group

# Step 1:Sync amax across data parallelism

# Step 2:Sync amax across data parallelism
for name, module in model.named_modules():
if isinstance(module, QuantModule):
for child in module.children():
if isinstance(child, (TensorQuantizer, SequentialQuantizer)):
sync_quantizer_amax_across_dp_ep(child, module.parallel_state)
# TP sync:
# Step 3: TP sync
# Objective: the quantization parameters when TP = 8 then changed to TP=4 then back to TP=8 should be the same

# ColumnParallel: X @ [A_1, A_2] (weights split along Cout)
Expand Down Expand Up @@ -156,7 +162,6 @@ def sync_quantizer_amax_across_tp(
axes_for_sync=[None, -1],
parallel_state=module.parallel_state,
)

sync_quantizer_amax_across_tp(
module.weight_quantizer,
name,
Expand All @@ -182,10 +187,7 @@ def sync_quantizer_amax_across_tp(
parallel_state=module.parallel_state,
)

# MOE Quantization
if hasattr(module, "sync_moe_local_experts_amax"):
module.sync_moe_local_experts_amax()


# KV Cache Quantization
if hasattr(module, "k_bmm_quantizer") and hasattr(module, "v_bmm_quantizer"):
# We only support KVCache quantization with scalar per-tensor states for now (NVFP4 & FP8 KV cache)
Expand Down Expand Up @@ -278,8 +280,6 @@ def quant_func(x, amax, quantizer=module):
# Step 4: Compute optimal amax and load it
finish_stats_collection(model, method="mse")

# TODO: Sync amax across distributed processes


def enable_stats_collection(model: nn.Module):
"""Enable stats collection for all quantizers in the model."""
Expand Down
Loading