-
Notifications
You must be signed in to change notification settings - Fork 243
Latent MOE & Repeated MTP support; fix KV cache quant export #768
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Draft
jenchen13
wants to merge
7
commits into
main
Choose a base branch
from
jennifchen/superv3
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Draft
Changes from all commits
Commits
Show all changes
7 commits
Select commit
Hold shift + click to select a range
9bda954
support latent moe import and fix local experts sync
jenchen13 5da17f4
patch TransformerLayer forward
jenchen13 4471c03
fix bug of duplicate forward
jenchen13 db1892f
fix kv bmm export
jenchen13 f26bf3c
small fixes
jenchen13 2cc4acc
mtp import fixes
jenchen13 3d0a31a
enable TELinear quant
jenchen13 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -51,7 +51,9 @@ | |
| from .plugins.megatron_importer import GPTModelImporter | ||
| from .quant_utils import ( | ||
| get_activation_scaling_factor, | ||
| get_kv_cache_scaling_factor, | ||
| get_kv_cache_dtype, | ||
| get_quant_config, | ||
| get_quantization_format, | ||
| get_scaling_factor, | ||
| get_weight_block_size, | ||
|
|
@@ -86,33 +88,6 @@ | |
| ] | ||
|
|
||
|
|
||
| # This path uses output_quantizer for KV cache quantization. | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I hope to see more proof before we remove this. Also TRTLLM right now by default still uses 1 as the kv cache scale [ignoring the values we set here.] |
||
| # The function below is the old version of get_kv_cache_scaling_factor which is now refactored to handle bmm_quantizer. | ||
| def get_kv_cache_scaling_factor(kv_module: nn.Module) -> torch.Tensor: | ||
| """Returns the kv_cache scaling factor if output quantizer is set. Else returns None by default.""" | ||
| scaling_factor = ( | ||
| get_scaling_factor(kv_module.output_quantizer) | ||
| if hasattr(kv_module, "output_quantizer") | ||
| else None | ||
| ) | ||
|
|
||
| if not scaling_factor: | ||
| return None | ||
|
|
||
| # For FP8, we recommend default kv cache scaling factor to be 1. | ||
| if get_kv_cache_dtype(kv_module) == KV_CACHE_FP8: | ||
| if scaling_factor.item() > 0.5: | ||
| warn( | ||
| f"!!!!Large KV activations detected: {scaling_factor.item()}, " | ||
| "Quantized KV cache may lead to higher accuracy drop.\n!!!!" | ||
| ) | ||
| scaling_factor = torch.max( | ||
| scaling_factor, | ||
| torch.tensor([1.0], dtype=torch.float, device=scaling_factor.device), | ||
| ) | ||
| return scaling_factor | ||
|
|
||
|
|
||
| class GPTModelExporter: | ||
| """Megatron Core GPTModel Exporter. | ||
|
|
||
|
|
@@ -281,11 +256,6 @@ def save_pretrained( | |
| elif quantization_format == QUANTIZATION_NVFP4: | ||
| quantization = "NVFP4" | ||
|
|
||
| kv_cache_quantization = None | ||
| kv_cache_dtype = get_kv_cache_dtype(self.model) | ||
| if kv_cache_dtype in (KV_CACHE_FP8, KV_CACHE_NVFP4): | ||
| # FP8 KV Cache is supported in VLLM; NVFP4 supported in TRTLLM | ||
| kv_cache_quantization = kv_cache_dtype | ||
| # We use the last PP rank and the 1st EP rank to write the config because | ||
| # medusa_heads and eagle_module only exist in the last stage. | ||
| if is_last_stage_main_rank: | ||
|
|
@@ -320,17 +290,22 @@ def save_pretrained( | |
| pass | ||
|
|
||
| if is_last_stage_main_rank and quantization is not None: | ||
| # TODO refactor to use mte.quant_utils.get_quant_config | ||
| # except layer names are different in MCore and HF | ||
| hf_quant_config = { | ||
| "producer": { | ||
| "name": "modelopt", | ||
| "version": __version__, | ||
| }, | ||
| "quantization": { | ||
| "quant_algo": quantization, | ||
| "kv_cache_quant_algo": kv_cache_quantization, | ||
| "exclude_modules": ["lm_head"], | ||
| "exclude_modules": ["lm_head"], # TODO update this dynamically | ||
| }, | ||
| } | ||
| if quantization == "NVFP4": # update block size | ||
| hf_quant_config["quantization"]["group_size"] = 16 | ||
| if hasattr(self, "kv_cache_dtype"): | ||
| hf_quant_config["quantization"]["kv_cache_quant_algo"] = self.kv_cache_dtype | ||
| with open(save_directory + "/hf_quant_config.json", "w") as f: | ||
| json.dump(hf_quant_config, f, indent=4) | ||
|
|
||
|
|
@@ -473,6 +448,7 @@ def _custom_mapping_to_lambda(mapping): | |
| method_map = { | ||
| "name_remapping": self._name_remapping, | ||
| "qkv_slicing": self._qkv_slicing, | ||
| "self_attention_scaling": self._self_attention_scaling, | ||
| "gated_mlp_slicing": self._gated_mlp_slicing, | ||
| "pack_name_remapping": self._pack_name_remapping, | ||
| "pack_name_remapping_gpt_oss": self._pack_name_remapping_gpt_oss, | ||
|
|
@@ -541,12 +517,8 @@ def _get_quantized_state( | |
| # TODO (chenhany): support AWQ with pre_quant_scale | ||
| if hasattr(module.input_quantizer, "_pre_quant_scale"): | ||
| raise ValueError("Detect pre_quant_scale! SmoothQuant/AWQ are not yet supported!") | ||
|
|
||
| if hasattr(module, "output_quantizer"): | ||
| output_scale = get_kv_cache_scaling_factor(module) | ||
| if output_scale is not None: | ||
| name_to_value["output_scale"] = output_scale | ||
|
|
||
|
|
||
|
|
||
| return name_to_value, qformat, block_size | ||
|
|
||
| def _get_quantization_format(self, module: torch.nn.Module): | ||
|
|
@@ -674,9 +646,7 @@ def _qkv_slicing( | |
| q_proj_name="q_proj", | ||
| k_proj_name="k_proj", | ||
| v_proj_name="v_proj", | ||
| k_scale_name="k_scale", | ||
| v_scale_name="v_scale", | ||
| ): | ||
| ): | ||
| name_to_value, qformat, block_size = self._get_quantized_state(module, self.dtype) | ||
|
|
||
| q_proj_prefix = prefix + q_proj_name + "." | ||
|
|
@@ -756,7 +726,7 @@ def _qkv_slicing( | |
| quantized_weight = to_quantized_weight( | ||
| weight, | ||
| scale, | ||
| qformat, | ||
| qformat, | ||
| weight_scale_2, | ||
| block_size, | ||
| ) | ||
|
|
@@ -774,10 +744,7 @@ def _qkv_slicing( | |
| q_proj_key = q_proj_prefix + key | ||
| k_proj_key = k_proj_prefix + key | ||
| v_proj_key = v_proj_prefix + key | ||
| if key == "output_scale": | ||
| self._state_dict[prefix + k_scale_name] = val.detach().clone() | ||
| self._state_dict[prefix + v_scale_name] = val.detach().clone() | ||
| elif key == "bias": | ||
| if key == "bias": | ||
| # Slice bias similar to weight | ||
| bias = val.detach().clone() | ||
| bias = bias.reshape([qkv_total_dim, head_size]) | ||
|
|
@@ -790,6 +757,21 @@ def _qkv_slicing( | |
| self._state_dict[k_proj_key] = val.detach().clone() | ||
| self._state_dict[v_proj_key] = val.detach().clone() | ||
|
|
||
| def _self_attention_scaling(self, module, prefix, k_scale_name="k_scale", v_scale_name="v_scale"): | ||
| """KV cache scaling for CoreAttention module.""" | ||
| k_scale_key = prefix + k_scale_name | ||
| v_scale_key = prefix + v_scale_name | ||
| if hasattr(module, "k_bmm_quantizer") and hasattr(module, "v_bmm_quantizer"): | ||
| kv_scales = get_kv_cache_scaling_factor(module) | ||
| if all(s is not None for s in kv_scales): | ||
| self._state_dict[k_scale_key] = kv_scales[0] | ||
| self._state_dict[v_scale_key] = kv_scales[1] | ||
|
|
||
| kv_cache_dtype = get_kv_cache_dtype(module) | ||
| if kv_cache_dtype in (KV_CACHE_FP8, KV_CACHE_NVFP4): | ||
| # FP8 KV Cache is supported in VLLM; NVFP4 supported in TRTLLM | ||
| self.kv_cache_dtype = kv_cache_dtype | ||
|
|
||
| def _pack_name_remapping(self, module, prefix, layer_type=None): | ||
| """Pack name remapping into one tensor.""" | ||
| weight_list = [] | ||
|
|
@@ -1149,6 +1131,8 @@ def _get_state_dict(self): | |
| self.rules["q_layernorm"](layer.self_attention.q_layernorm, layer_id) | ||
| self.rules["k_layernorm"](layer.self_attention.k_layernorm, layer_id) | ||
| self.rules["linear_qkv"](layer.self_attention.linear_qkv, layer_id) | ||
| if hasattr(layer.self_attention, "core_attention"): | ||
| self.rules["core_attention"](layer.self_attention.core_attention, layer_id) | ||
| self.rules["linear_proj"](layer.self_attention.linear_proj, layer_id) | ||
| if ( | ||
| getattr(layer.self_attention.core_attention, "softmax_offset", None) | ||
|
|
@@ -1166,6 +1150,10 @@ def _get_state_dict(self): | |
| self.rules["router"]( | ||
| layer.mlp.router, layer_id, dtype=self.moe_router_dtype | ||
| ) | ||
| if hasattr(layer.mlp, "fc1_latent_proj") and layer.mlp.fc1_latent_proj is not None: | ||
| self.rules["fc1_latent_proj"](layer.mlp.fc1_latent_proj, layer_id) | ||
| if hasattr(layer.mlp, "fc2_latent_proj") and layer.mlp.fc2_latent_proj is not None: | ||
| self.rules["fc2_latent_proj"](layer.mlp.fc2_latent_proj, layer_id) | ||
| if ( | ||
| hasattr(layer.mlp, "shared_experts") | ||
| and layer.mlp.shared_experts is not None | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
doublecheck that this is only needed for export