-
Notifications
You must be signed in to change notification settings - Fork 307
Merge mbridge distillation for any_model #1036
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
danielkorzekwa
wants to merge
97
commits into
feature/puzzletron
Choose a base branch
from
dkorzekwa/anymodel_mbridgedist
base: feature/puzzletron
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+690
−0
Open
Changes from all commits
Commits
Show all changes
97 commits
Select commit
Hold shift + click to select a range
e82164f
Add anymodel directories to feature/puzzletron
danielkorzekwa 2099df3
Make any_model conversion working.
danielkorzekwa eb5cf8a
Update child_init.py with anymodel version
danielkorzekwa c9de41c
fix attention pruning
danielkorzekwa 3c1bc1f
Add trust_remote_code to load_model_config (default to false)
danielkorzekwa 8357136
Make activation scoring working
danielkorzekwa 6cc2194
Comment all tested models aside of llama_3_1_8b_instruct
danielkorzekwa ee4e1e3
Delete not needed decilm test
danielkorzekwa 449b523
Fix broken tests
danielkorzekwa fb27bba
Update puzzletron_nas_pluging to any_model version
danielkorzekwa b350f82
Correct test resources used by tests.
danielkorzekwa fafe5a3
Disable puzzletron tests (will be enabled after all any_model logic i…
danielkorzekwa e988248
Merge branch 'dkorzekwa/anymodel_core' into dkorzekwa/anymodel_activa…
danielkorzekwa c717852
Comment out not implemented models.
danielkorzekwa 030f126
format python docs
danielkorzekwa 8dcdfbf
Merge branch 'dkorzekwa/anymodel_core' into dkorzekwa/anymodel_activa…
danielkorzekwa 70df0df
Use trust_remote_code in force_cache_dynamic_modules()
danielkorzekwa bb56662
Merge branch 'dkorzekwa/anymodel_core' into dkorzekwa/anymodel_activa…
danielkorzekwa ecd953e
Fix anymodel pruning
danielkorzekwa ee8f538
Fix buid docs issue.
danielkorzekwa c9b76a1
Merge branch 'dkorzekwa/anymodel_core' into dkorzekwa/anymodel_activa…
danielkorzekwa 6e3af61
Merge branch 'dkorzekwa/anymodel_activation_scoring' into dkorzekwa/a…
danielkorzekwa 0ad6d92
Merging build_library_and_stats
danielkorzekwa 995eb1a
Merging anymodel: calc_one_block_scores
danielkorzekwa 34081c9
Mering any_model: calc_one_block_scores
danielkorzekwa ed5c00f
merge any_model: mip_and_realize_models
danielkorzekwa 993b5ec
Add all anymodel models but gptoss
danielkorzekwa 6e9f03b
Make nemotron-nano-12b-v2 to work (set trust_remote_code=true)
danielkorzekwa e8b7a7d
merge anymodel for nemotron-3-nano-30b-a3b-base-bf16
danielkorzekwa 47414d5
Clarify readme and avoid reusing the same reference in llama_converter.
danielkorzekwa a8305d8
Fix tied-embedding handling before writing the safetensors index.
danielkorzekwa 68421a5
Fix NaN ranking currently selects NaNs as “best” experts by default.
danielkorzekwa d6b8028
Code clean up.
danielkorzekwa ecd2341
Code clean up.
danielkorzekwa f9d845d
code clean up
danielkorzekwa d171b01
Merge branch 'dkorzekwa/anymodel_core' into dkorzekwa/anymodel_activa…
danielkorzekwa 722da90
Merge branch 'dkorzekwa/anymodel_activation_scoring' into dkorzekwa/a…
danielkorzekwa 934ab2f
code clean up
danielkorzekwa 0f14ec3
Merge branch 'dkorzekwa/anymodel_pruning' into dkorzekwa/anymodel_bui…
danielkorzekwa dcb9e02
remove not needed comment
danielkorzekwa 0c9ea5d
Merge branch 'dkorzekwa/anymodel_build_library_and_stats' into dkorze…
danielkorzekwa 5b310e2
Merge branch 'dkorzekwa/any_model_calc_one_block_scores' into dkorzek…
danielkorzekwa 4f82b1c
Merge branch 'dkorzekwa/mip_and_realize_models' into dkorzekwa/any_mo…
danielkorzekwa 176a435
Fix a broken test_puzzletron test on 2 gpus.
danielkorzekwa 02e2c9b
Merge branch 'dkorzekwa/anymodel_activation_scoring' into dkorzekwa/a…
danielkorzekwa 92c4419
Merge branch 'dkorzekwa/anymodel_pruning' into dkorzekwa/anymodel_bui…
danielkorzekwa aa1eb3e
Merge branch 'dkorzekwa/anymodel_build_library_and_stats' into dkorze…
danielkorzekwa 2b84a96
Merge branch 'dkorzekwa/any_model_calc_one_block_scores' into dkorzek…
danielkorzekwa fb838c0
Merge branch 'dkorzekwa/mip_and_realize_models' into dkorzekwa/any_mo…
danielkorzekwa 13378ff
Add gpt-oss model
danielkorzekwa 47ca0e3
Add comments about a broken test
danielkorzekwa 96112f7
Fix a broken gptoss test
danielkorzekwa cb6b182
Add mamba to puzzletron dependencies.
danielkorzekwa 670bb34
Update mamba-ssm and casual-conv1d dependences (remove pinpoint versi…
danielkorzekwa 0e1b591
Install mamba-ssm and causal-conv1d in testenv:cuda13-gpu-puzzletron
danielkorzekwa ca845ec
Fix installing dependencies in testenv:cuda13-gpu-puzzletron
danielkorzekwa be825bc
Fix anymodel for qwen3 8B in 2 gpus
danielkorzekwa 7fd1afa
Fix pipeline parallelism issue for wen3-vl-30b-a3b-instruct-qwen3_vl-…
danielkorzekwa 7d7b609
Fix multi-gpu issue for nemotron-nano-12b-v2
danielkorzekwa 249af9d
Fix no_op in any_model
danielkorzekwa b80583c
Merge branch 'feature/puzzletron' into dkorzekwa/any_model_other_models
danielkorzekwa 88b1b13
Merge any_model tutorial
danielkorzekwa c0da9c0
Merge mbridge distillation for any_model
danielkorzekwa 1dd742e
Fix nemotron_h_model_descriptor.
danielkorzekwa 4a6ebbe
Fix tox -e build-docs
danielkorzekwa 585f0ed
pin mamba/casual-conv1d versions to fix failing assertion for test_pu…
danielkorzekwa 7fb5d9a
Fix for installing mamba-ssm
danielkorzekwa 75d3d69
Fix broken test for nemotron-3-nano-30b-a3b-base-bf16
danielkorzekwa 0e5722d
code clean up
danielkorzekwa 2dd9735
Make test_puzzletron test deterministic
danielkorzekwa 3561de5
Comment out all models but nemotron-3-nano-30b-a3b-base-bf16 to check…
danielkorzekwa 27866de
Implement Qwen3VLRemoveExpertsIndependentHook
danielkorzekwa a012fe6
Remove not needed nvidia licence header
danielkorzekwa 52922a4
# Initialize weights to ensure all parameters are properly initialized
danielkorzekwa c234fb4
Fix non-deterministic test_puzzletron test
danielkorzekwa 53dcd10
Fix for unsetting CUDA_VISIBLE_DEVICES
danielkorzekwa 69d9648
increase numeric tolerance for test_puzzletron.py
danielkorzekwa 4a692dc
Disable lm_loss assertion for nemotron-3-nano-30b-a3b-base-bf16 (not …
danielkorzekwa e795f0c
Removing incorrect licence file. gpt_oss_pruned_to_mxfp4.py was not a…
danielkorzekwa 631306c
Fix hardcoded trust_remote_code
danielkorzekwa dc77be2
Merge branch 'dkorzekwa/any_model_other_models' into dkorzekwa/anymod…
danielkorzekwa b76e0ef
Merge branch 'dkorzekwa/anymodel_gptoss' into dkorzekwa/anymodel_tuto…
danielkorzekwa 109b185
Merge branch 'dkorzekwa/anymodel_tutorial' into dkorzekwa/anymodel_mb…
danielkorzekwa 5cadc65
Merge branch 'feature/puzzletron' into dkorzekwa/anymodel_gptoss
danielkorzekwa 151081c
Delete not needed yaml files for test_puzzletron.
danielkorzekwa 36daa6d
Delete not needed mypy exclusion for removed hf_configs files.
danielkorzekwa 960b8ce
Merge branch 'dkorzekwa/anymodel_gptoss' into dkorzekwa/anymodel_tuto…
danielkorzekwa 854d96b
Merge branch 'dkorzekwa/anymodel_tutorial' into dkorzekwa/anymodel_mb…
danielkorzekwa b47f846
Merge branch 'feature/puzzletron' into dkorzekwa/anymodel_tutorial
danielkorzekwa 13f5edc
Merge branch 'dkorzekwa/anymodel_tutorial' into dkorzekwa/anymodel_mb…
danielkorzekwa f2c1578
Fix a broken mbridge distillation test for anymodel
danielkorzekwa 3592eec
Code clean up.
danielkorzekwa f06cb20
Use all available GPUs for test_distill_hf
danielkorzekwa ad31b09
use extend_cmd_parts
danielkorzekwa 0505916
code clean up.
danielkorzekwa 7016857
Improve naming of --hf_export_path and --hf_export_path
danielkorzekwa 7ede076
Merge branch 'feature/puzzletron' into dkorzekwa/anymodel_mbridgedist
danielkorzekwa File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,35 @@ | ||
| # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| """Megatron-Bridge adapters for Puzzletron AnyModel checkpoints. | ||
|
|
||
| This module provides bridges for converting Puzzletron AnyModel checkpoints | ||
| (heterogeneous layer architectures) to Megatron-Core format via Megatron-Bridge. | ||
| """ | ||
|
|
||
| # Import to register bridges (side effect) | ||
| from modelopt.torch.puzzletron.export.mbridge.base import HeterogeneousBridgeMixin | ||
| from modelopt.torch.puzzletron.export.mbridge.llama import ( # noqa: F401 | ||
| PuzzletronLlamaAnyModelBridge, | ||
| ) | ||
| from modelopt.torch.puzzletron.export.mbridge.qwen3 import ( # noqa: F401 | ||
| PuzzletronQwen3AnyModelBridge, | ||
| ) | ||
|
|
||
| __all__ = [ | ||
| "HeterogeneousBridgeMixin", | ||
| "PuzzletronLlamaAnyModelBridge", | ||
| "PuzzletronQwen3AnyModelBridge", | ||
| ] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,142 @@ | ||
| #!/usr/bin/env python3 | ||
| # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| """ | ||
| Mixin class for bridges that support heterogeneous layer architectures. | ||
|
|
||
| This module provides a mixin class for converting models with block_configs | ||
| (heterogeneous layer configurations) to Megatron-Core format via Megatron-Bridge. | ||
| """ | ||
|
|
||
| import dataclasses | ||
| import json | ||
| from collections.abc import Callable | ||
| from dataclasses import dataclass, fields | ||
|
|
||
| from megatron.bridge.models.gpt_provider import GPTModelProvider | ||
| from megatron.bridge.models.hf_pretrained.causal_lm import PreTrainedCausalLM | ||
| from megatron.bridge.models.transformer_config import HeterogeneousTransformerConfig | ||
| from megatron.core.models.gpt.heterogeneous.heterogeneous_layer_specs import ( | ||
| get_gpt_heterogeneous_layer_spec, | ||
| ) | ||
| from megatron.core.transformer.spec_utils import ModuleSpec | ||
|
|
||
|
|
||
| def heterogeneous_layer_spec(config) -> ModuleSpec: | ||
| """Get GPT heterogeneous layer spec using Transformer Engine.""" | ||
| return get_gpt_heterogeneous_layer_spec(config, use_te=True) | ||
|
|
||
|
|
||
| @dataclass | ||
| class GenericHeterogeneousProvider(GPTModelProvider, HeterogeneousTransformerConfig): | ||
| """Generic provider for AnyModel checkpoints with block_configs.""" | ||
|
|
||
| # Heterogeneous configuration fields | ||
| heterogeneous_layers_config_path: str | None = None | ||
| heterogeneous_layers_config_encoded_json: str = "" | ||
| transformer_layer_spec: ModuleSpec | Callable = heterogeneous_layer_spec | ||
|
|
||
| def __getattr__(self, name: str): | ||
| """Handle missing attributes for OmegaConf compatibility. | ||
|
|
||
| Returns empty list for per_block_parameters if not yet initialized (before finalize()). | ||
| This allows OmegaConf to serialize/deserialize configs without errors. Actual usage | ||
| should call finalize() first to set per_block_parameters as a real attribute. | ||
| """ | ||
| if name == "per_block_parameters": | ||
| # Return existing attribute if set, otherwise [] for OmegaConf compatibility | ||
| try: | ||
| return object.__getattribute__(self, name) | ||
| except AttributeError: | ||
| return [] | ||
| raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'") | ||
|
|
||
|
|
||
| class HeterogeneousBridgeMixin: | ||
| """Mixin for bridges supporting heterogeneous layer architectures (block_configs). | ||
|
|
||
| Must be used with multiple inheritance alongside a model-specific bridge. | ||
| Example: class PuzzletronLlamaAnyModelBridge(HeterogeneousBridgeMixin, LlamaBridge) | ||
| """ | ||
|
|
||
| def provider_bridge(self, hf_pretrained: PreTrainedCausalLM) -> GPTModelProvider: | ||
| """Convert HF AnyModel config to Megatron GPTModelProvider. | ||
|
|
||
| This method: | ||
| 1. Calls the parent bridge's provider_bridge() to get a GPTModelProvider with all | ||
| model-specific settings (e.g., LlamaBridge sets normalization="RMSNorm", etc.) | ||
| 2. Converts the provider to a dict and filters to only fields accepted by | ||
| GenericHeterogeneousProvider (which inherits from GPTModelProvider, so all valid | ||
| GPTModelProvider fields are preserved) | ||
| 3. Adds heterogeneous configuration and returns GenericHeterogeneousProvider | ||
|
|
||
| All parameters from the parent bridge (e.g., LlamaBridge) are maintained because | ||
| GenericHeterogeneousProvider inherits from GPTModelProvider, which includes all | ||
| the fields that the parent bridge sets. | ||
| """ | ||
|
|
||
| parent_provider = super().provider_bridge(hf_pretrained) # type: ignore[misc] | ||
|
|
||
| provider_kwargs = dataclasses.asdict(parent_provider) | ||
|
|
||
| # Filter to only fields that GenericHeterogeneousProvider accepts. | ||
| # GenericHeterogeneousProvider inherits from GPTModelProvider, so it includes all | ||
| # GPTModelProvider fields. Model-specific fields from subclasses (e.g., MistralModelProvider, | ||
| # GPTOSSModelProvider) are filtered out because GenericHeterogeneousProvider only inherits | ||
| # from GPTModelProvider, not from model-specific subclasses. | ||
| # | ||
| # Note: This logic may not work for bridges like MistralBridge or GPTOSSBridge if they | ||
| # use model-specific parameters not supported by GenericHeterogeneousProvider (e.g., | ||
| # scale_factor, yarn_rotary_scaling_factor, moe_* parameters). In such cases, create a | ||
| # model-specific heterogeneous provider that inherits from the model-specific provider. | ||
| valid_fields = {f.name for f in fields(GenericHeterogeneousProvider)} | ||
|
|
||
| # Only keep kwargs that are valid fields | ||
| provider_kwargs = {k: v for k, v in provider_kwargs.items() if k in valid_fields} | ||
|
|
||
| provider_kwargs["heterogeneous_layers_config_encoded_json"] = ( | ||
| self._build_heterogeneous_config_json(hf_pretrained.config) | ||
| ) | ||
| return GenericHeterogeneousProvider(**provider_kwargs) | ||
|
|
||
| def _build_heterogeneous_config_json(self, hf_config) -> str: | ||
| """Build heterogeneous layers config JSON from HF config.""" | ||
|
|
||
| hf_config_dict = json.loads(hf_config.to_json_string()) | ||
|
|
||
| mcore_block_configs = [ | ||
| self._convert_block_config(block) for block in hf_config_dict["block_configs"] | ||
| ] | ||
| return json.dumps({"block_configs": mcore_block_configs}, ensure_ascii=False) | ||
|
|
||
| def _convert_block_config(self, block: dict) -> dict: | ||
| """Convert a single block config from HF format to MCore format.""" | ||
| return { | ||
| "attention": self._convert_attention_config(block["attention"]), | ||
| "ffn": self._convert_ffn_config(block["ffn"]), | ||
| } | ||
|
|
||
| def _convert_attention_config(self, attention_config: dict) -> dict: | ||
| """Convert attention config from HF format to MCore format.""" | ||
| attention_config = attention_config.copy() | ||
| attention_config["num_query_groups"] = attention_config.pop("num_key_value_heads") | ||
| return attention_config | ||
|
|
||
| def _convert_ffn_config(self, ffn_config: dict) -> dict: | ||
| """Convert FFN/MLP config from HF format to MCore format.""" | ||
| ffn_config = ffn_config.copy() | ||
| ffn_config["ffn_hidden_size"] = ffn_config.pop("intermediate_size") | ||
| return ffn_config |
190 changes: 190 additions & 0 deletions
190
modelopt/torch/puzzletron/export/mbridge/distillation_provider.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,190 @@ | ||
| # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| # TODO: Upstream this fix to Megatron-Bridge and remove this local copy. | ||
|
|
||
| import logging | ||
| from dataclasses import dataclass, fields | ||
| from typing import TYPE_CHECKING, Any, Optional | ||
|
|
||
| from megatron.bridge.models.gpt_provider import GPTModelProvider | ||
| from megatron.bridge.models.mamba.mamba_provider import MambaModelProvider | ||
| from megatron.bridge.models.transformer_config import TransformerConfig | ||
| from megatron.core.models.gpt import GPTModel as MCoreGPTModel | ||
|
|
||
| import modelopt.torch.distill as mtd | ||
| import modelopt.torch.distill.plugins.megatron as mtd_mcore | ||
|
|
||
| if TYPE_CHECKING: | ||
| from megatron.bridge.training.post_training.distillation import ModelOptDistillConfig | ||
|
|
||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
|
|
||
| @dataclass | ||
| class DistillationProvider(TransformerConfig): | ||
| """Provider for Megatron Core GPT models in distillation mode. | ||
|
|
||
| Please use `convert_to_distillation_provider()` to create an instance of this class. | ||
| """ | ||
|
|
||
| teacher: Optional[GPTModelProvider | MambaModelProvider] = None | ||
| kd_config: Optional["ModelOptDistillConfig"] = None | ||
|
|
||
| def __init__(self, *args, **kwargs): | ||
| raise NotImplementedError( | ||
| "Use `convert_to_distillation_provider()` to create an instance of this class." | ||
| ) | ||
|
|
||
| def __post_init__(self): | ||
| assert getattr(self, "teacher", None) is not None, "Teacher model must be provided." | ||
|
|
||
| shared_attrs = [ | ||
| "tensor_model_parallel_size", | ||
| "pipeline_model_parallel_size", | ||
| "context_parallel_size", | ||
| "seq_length", | ||
| "pipeline_dtype", | ||
| ] | ||
| for attr in shared_attrs: | ||
| if getattr(self, attr) != getattr(self.teacher, attr): | ||
| raise ValueError(f"Student and teacher providers must have the same {attr}.") | ||
|
|
||
| # Logits are overwritten in-place when TE cross-entropy loss is enabled, so switch it back to native version. | ||
| self.cross_entropy_fusion_impl = "native" | ||
|
|
||
| # Hack to dynamically subclass other providers and still use their methods | ||
| self._super_class = self.__class__.__bases__[0] | ||
|
|
||
| def provide(self, pre_process=None, post_process=None, vp_stage=None) -> MCoreGPTModel: | ||
| """Configure and instantiate a ModelOpt DistillationModel based on this configuration. | ||
|
|
||
| Args: | ||
| pre_process: Whether to include pre-processing in the model, defaults to first pipeline stage | ||
| post_process: Whether to include post-processing in the model, defaults to last pipeline stage | ||
| vp_stage: Virtual pipeline stage | ||
|
|
||
| Returns: | ||
| MCoreGPTModel: Configured ModelOpt DistillationModel instance | ||
| """ | ||
| if vp_stage is not None: | ||
| raise ValueError("ModelOpt KD currently does not support virtual-pipeline parallel.") | ||
|
|
||
| assert self.teacher is not None, "Teacher model must be provided." | ||
| student_model = self._super_class.provide(self, pre_process, post_process, vp_stage) # type: ignore[attr-defined] | ||
|
|
||
| # Finalize teacher provider before creating model (required for heterogeneous models). | ||
| # | ||
| # per_block_parameters is an attribute of HeterogeneousTransformerConfig (defined in | ||
| # MCoreHeterogeneousTransformerConfig, heterogeneous_config.py:197). It's created during | ||
| # provider creation (bridge.to_megatron_provider()), but finalize() ensures they're consistent | ||
| # with current parallelism settings and distributed context. Student model creation (above) | ||
| # initializes parallel_state (process groups, TP/PP config), which weight loading/scatter | ||
| # requires. During teacher model creation, get_config_for_layer() is called (transformer_block.py:341) | ||
| # for each layer, which uses per_block_parameters and current tensor_model_parallel_size to | ||
| # determine layer architecture. Without finalize() in this context, architecture expectations | ||
| # don't match checkpoint weights, causing: | ||
| # ValueError: ProcessGroupNCCL::scatter: invalid tensor size at index 0 | ||
| # (expected (2880, 4096), got (3584, 4096)) | ||
| # | ||
| # Note: This explanation needs to be confirmed yet. | ||
| self.teacher.finalize() | ||
|
|
||
| # Hack to get teacher's pre-wrap hooks called to potentially load HF weights | ||
| teacher_model = self.teacher.provide_distributed_model( | ||
| wrap_with_ddp=False, mixed_precision_wrapper=None | ||
| )[0] | ||
|
|
||
| kd_cfg = mtd_mcore.setup_distillation_config( | ||
| self.kd_config, student_model.config, teacher_model.config | ||
| ) | ||
| modelopt_cfg = { | ||
| "teacher_model": teacher_model, | ||
| "criterion": kd_cfg.criterion, | ||
| "loss_balancer": kd_cfg.loss_balancer, | ||
| } | ||
| kd_model = mtd.convert(student_model, mode=[("kd_loss", modelopt_cfg)]) | ||
| mtd_mcore.adjust_distillation_model_for_mcore(kd_model, kd_cfg) | ||
|
|
||
| return kd_model | ||
|
|
||
| def to_cfg_dict(self) -> dict[str, Any]: | ||
| """Custom method to save equivalent to the original provider class. | ||
|
|
||
| Used by `_ConfigContainerBase` to serialize the main `ConfigContainer` to YAML. | ||
| There is no need to restore a `DistillationProvider` from the run config file, as | ||
| it can always be re-converted using the original student provider. | ||
|
|
||
| Returns: | ||
| Dictionary representation of this provider class | ||
| """ | ||
| from megatron.bridge.training.utils.config_utils import _ConfigContainerBase | ||
|
|
||
| result = {"_target_": f"{self._super_class.__module__}.{self._super_class.__qualname__}"} | ||
|
|
||
| # Include all fields from the original provider class (self._super_class), not just DistillationProvider | ||
| # This ensures fields like heterogeneous_layers_config_encoded_json are preserved | ||
| excluded_fields = {"teacher", "kd_config"} | ||
| for field in fields(self._super_class): | ||
| if field.name.startswith("_") or field.name in excluded_fields: | ||
| continue | ||
| # Only include if the field exists on this instance (it should, since we converted from the original provider) | ||
| if hasattr(self, field.name): | ||
| result[field.name] = _ConfigContainerBase._convert_value_to_dict( | ||
| getattr(self, field.name) | ||
| ) | ||
|
|
||
| # Also include any additional fields from DistillationProvider itself (if any) | ||
| for field in fields(self): | ||
| if field.name.startswith("_") or field.name in excluded_fields: | ||
| continue | ||
| # Skip if already included from _super_class | ||
| if field.name not in result: | ||
| result[field.name] = _ConfigContainerBase._convert_value_to_dict( | ||
| getattr(self, field.name) | ||
| ) | ||
|
|
||
| return result | ||
|
|
||
| def __setattr__(self, name, value): | ||
| super().__setattr__(name, value) | ||
| # Mirror to teacher if it has that attribute | ||
| if hasattr(self.teacher, name): | ||
| setattr(self.teacher, name, value) | ||
|
|
||
|
|
||
| def convert_to_distillation_provider( | ||
| student_provider: GPTModelProvider | MambaModelProvider, | ||
| teacher_provider: GPTModelProvider | MambaModelProvider, | ||
| kd_config: Optional["ModelOptDistillConfig"] = None, | ||
| ) -> "DistillationProvider": | ||
| """Convert a given model provider to a DistillationProvider.""" | ||
|
|
||
| assert isinstance(student_provider, (GPTModelProvider, MambaModelProvider)), ( | ||
| "Student provider must be a subclass of GPTModelProvider or MambaModelProvider." | ||
| ) | ||
| assert isinstance(teacher_provider, (GPTModelProvider, MambaModelProvider)), ( | ||
| "Teacher provider must be a subclass of GPTModelProvider or MambaModelProvider." | ||
| ) | ||
|
|
||
| DistillationProvider.__bases__ = (type(student_provider),) | ||
| student_provider.__class__ = DistillationProvider | ||
|
|
||
| student_provider.teacher = teacher_provider | ||
| student_provider.kd_config = kd_config | ||
| student_provider.__post_init__() | ||
|
|
||
| return student_provider |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pretty much everything in this PR seems like we should instead merge to M-Bridge. Are we confident enough to upstream these changes?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We are not confident, e.g., we would need to talk to mbrdige/megatron-lm people on that first, align with their plans for heterogenous support. Let's think about it once puzzletron is in main.
We also have to do support for gpt-oss and mamba, so it is not the best time to merge it to mcore
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nemo:26.04 container code freeze is in 2 weeks. Lets make sure we raise a PR for required changes to M-Bridge before that so we can see what can and cannot be upstreamed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
unlikely have time for it in next 2 weeks