-
Notifications
You must be signed in to change notification settings - Fork 307
Add LoRA co-training support for HF EAGLE speculative decoding #1060
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
0305087
86e7015
adf4cbc
281d092
4960ad0
e64d154
fbc0dc0
d132fea
e05fc20
3cba87e
e056392
4087c80
56f459f
696d251
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,2 +1,3 @@ | ||
| accelerate==1.12.0 | ||
| peft>=0.17.0 | ||
| transformers==5.0.0rc1 |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -110,3 +110,37 @@ class EagleConfig(ModeloptBaseConfig): | |
| "Whether to mix hidden states of multiple TTT steps. It is a technique to reduce training cost." | ||
| ), | ||
| ) | ||
|
|
||
| eagle_base_lora: bool = ModeloptField( | ||
| default=False, | ||
| description=( | ||
| "Whether to add LoRA adapters to the base model for co-training with the EAGLE module. " | ||
| "Requires the `peft` library. Incompatible with eagle_offline=True." | ||
| ), | ||
| ) | ||
|
|
||
| eagle_base_lora_rank: int = ModeloptField( | ||
| default=64, | ||
| description="LoRA rank for the base model adapters.", | ||
| ) | ||
|
|
||
| eagle_base_lora_alpha: float = ModeloptField( | ||
| default=16.0, | ||
| description="LoRA alpha (scaling) for the base model adapters.", | ||
| ) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Mutable default:
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fixed in 5632aba — changed to |
||
|
|
||
| eagle_base_lora_target_modules: list | None = ModeloptField( | ||
| default=None, | ||
| description=( | ||
| "List of module name patterns to apply LoRA to in the base model " | ||
| "(e.g. ['q_proj', 'v_proj']). None uses peft defaults." | ||
| ), | ||
| ) | ||
|
|
||
| eagle_base_lora_preservation_loss_weight: float = ModeloptField( | ||
| default=0.1, | ||
| description=( | ||
| "Weight for the preservation loss that minimizes the KL divergence between " | ||
| "the LoRA-adapted base model output and the original base model output." | ||
| ), | ||
| ) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -547,6 +547,42 @@ def _get_eagle_device(self): | |
| base_model_last_layer = self._base_model.layers[-1] | ||
| return next(base_model_last_layer.parameters()).device | ||
|
|
||
| def _inject_base_lora(self): | ||
| """Inject HF PEFT LoRA adapters into the base model in-place and unfreeze them.""" | ||
| from peft import LoraConfig | ||
| from peft.mapping import inject_adapter_in_model | ||
|
|
||
| target_modules = self.eagle_base_lora_target_modules or None | ||
| lora_config = LoraConfig( | ||
| r=self.eagle_base_lora_rank, | ||
| lora_alpha=self.eagle_base_lora_alpha, | ||
| target_modules=target_modules, | ||
| bias="none", | ||
| ) | ||
| inject_adapter_in_model(lora_config, self._base_model, adapter_name="default") | ||
| # Unfreeze only the LoRA parameters | ||
| for name, param in self._base_model.named_parameters(): | ||
| if "lora_" in name: | ||
| param.requires_grad = True | ||
|
|
||
| def _set_base_lora_enabled(self, enabled: bool) -> None: | ||
| """Enable or disable LoRA adapters in the base model.""" | ||
| from peft.tuners.lora import LoraLayer | ||
|
|
||
| for module in self._base_model.modules(): | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The docstring says "KL divergence" but this computes cross-entropy:
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is logit KL divergence
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. To expand on this: the implementation computes cross-entropy |
||
| if isinstance(module, LoraLayer): | ||
| module.enable_adapters(enabled) | ||
|
|
||
| def _preservation_loss( | ||
| self, ref_logits: torch.Tensor, lora_logits: torch.Tensor | ||
| ) -> torch.Tensor: | ||
| """KL divergence encouraging LoRA output to stay close to the original base model. | ||
|
|
||
| KL(softmax(ref) || log_softmax(lora)) weighted by eagle_base_lora_preservation_loss_weight. | ||
| """ | ||
| loss = nn.Softmax(dim=-1)(ref_logits.detach()) * nn.LogSoftmax(dim=-1)(lora_logits) | ||
| return -loss.sum(dim=-1).mean() * self.eagle_base_lora_preservation_loss_weight | ||
|
|
||
| def modify( | ||
| self, | ||
| config, | ||
|
|
@@ -610,6 +646,12 @@ def modify( | |
| if layer_idx in self.eagle_config.eagle_aux_hidden_state_layer_ids: | ||
| layer.register_forward_hook(self._collect_aux_hidden_states_forward_hook) | ||
|
|
||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fixed in 6e709b5 — replaced with |
||
| # Inject HF PEFT LoRA adapters into the base model for co-training | ||
| if self.eagle_base_lora: | ||
| if self.eagle_offline: | ||
| raise ValueError("eagle_base_lora is incompatible with eagle_offline=True") | ||
| self._inject_base_lora() | ||
coderabbitai[bot] marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| # delete base model layers for offline training | ||
| if self.eagle_offline: | ||
| self._base_model._modules.pop("layers") | ||
|
|
@@ -723,7 +765,9 @@ def _compute_ttt_attention_mask( | |
| ) -> BlockMask | torch.Tensor: | ||
| """Return TTT attention_mask tensor of type BlockMask or Tensor depends on eagle attn impl.""" | ||
| msk_func = get_ttt_msk_func(seq_length, ttt_step) | ||
| dtypemin = torch.finfo(self._base_llm_config.dtype).min | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The reference forward with LoRA disabled has no self._set_base_lora_enabled(False)
try:
ref_logits = _run_forward(no_grad=True).logits
if hasattr(self, "_aux_hidden_states"):
self._aux_hidden_states.clear()
finally:
self._set_base_lora_enabled(True)
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fixed in 6e709b5 — wrapped the reference forward in |
||
| dtypemin = torch.finfo( | ||
| getattr(self._base_llm_config, "dtype", None) or torch.get_default_dtype() | ||
| ).min | ||
| q_len = seq_length | ||
| kv_len = seq_length * (1 + ttt_step) | ||
| if self.eagle_config._attn_implementation == "flex_attention": | ||
|
|
@@ -739,7 +783,10 @@ def _compute_ttt_attention_mask( | |
| torch.arange(kv_len).view(1, 1, 1, kv_len), | ||
| ).to(self.device) | ||
| tensor_mask = torch.full_like( | ||
| tensor_mask, 0, dtype=self._base_llm_config.dtype, device=self.device | ||
| tensor_mask, | ||
| 0, | ||
| dtype=getattr(self._base_llm_config, "dtype", None) or torch.get_default_dtype(), | ||
| device=self.device, | ||
| ).masked_fill(~tensor_mask, dtypemin) | ||
|
|
||
| # Note: (hg) repeat mask for kimi-k2 compatibility | ||
|
|
@@ -756,32 +803,48 @@ def _base_model_forward( | |
| labels, | ||
| **kwargs, | ||
| ): | ||
| with torch.no_grad() if freeze_base_model else contextlib.nullcontext(): | ||
| outputs = super().forward( | ||
| input_ids=input_ids, | ||
| attention_mask=attention_mask, | ||
| position_ids=position_ids, | ||
| past_key_values=past_key_values, | ||
| output_hidden_states=True, | ||
| **kwargs, | ||
| ) | ||
| past_key_values = getattr(outputs, "past_key_values", None) | ||
| base_input_embeds = outputs.hidden_states[0] | ||
| base_model_hidden_states = outputs.hidden_states[-1] | ||
| base_model_logits = outputs.logits | ||
| def _run_forward(no_grad): | ||
| with torch.no_grad() if no_grad else contextlib.nullcontext(): | ||
| return super(HFEagleModel, self).forward( | ||
| input_ids=input_ids, | ||
| attention_mask=attention_mask, | ||
| position_ids=position_ids, | ||
| past_key_values=past_key_values, | ||
| output_hidden_states=True, | ||
| **kwargs, | ||
| ) | ||
|
|
||
| # Optionally, compute base model loss when we want to tune the base model. | ||
| # When using LoRA, run a reference forward with LoRA disabled to get the original | ||
| # model distribution for the preservation loss. | ||
| ref_logits = None | ||
| if self.eagle_base_lora: | ||
| self._set_base_lora_enabled(False) | ||
| try: | ||
| ref_logits = _run_forward(no_grad=True).logits | ||
| finally: | ||
| if hasattr(self, "_aux_hidden_states"): | ||
| self._aux_hidden_states.clear() | ||
| self._set_base_lora_enabled(True) | ||
|
|
||
coderabbitai[bot] marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| # Main forward — LoRA params receive gradients when eagle_base_lora is True. | ||
| outputs = _run_forward(no_grad=freeze_base_model and not self.eagle_base_lora) | ||
| past_key_values = getattr(outputs, "past_key_values", None) | ||
| base_model_logits = outputs.logits | ||
|
|
||
| if ref_logits is not None: | ||
| base_model_loss = self._preservation_loss(ref_logits, base_model_logits) | ||
| elif not freeze_base_model and labels is not None: | ||
| loss_fct = CrossEntropyLoss() | ||
| base_model_loss = loss_fct( | ||
| base_model_logits.view(-1, base_model_logits.shape[-1]), labels.view(-1) | ||
| ) | ||
| else: | ||
| base_model_loss = None | ||
| if not freeze_base_model and labels is not None: # Base model loss | ||
| loss_fct = CrossEntropyLoss() | ||
| loss_logits = base_model_logits.view(-1, base_model_logits.shape[-1]) | ||
| labels = labels.view(-1) | ||
| base_model_loss = loss_fct(loss_logits, labels) | ||
|
|
||
| return EagleBaseModelOutput( | ||
| input_embeds=base_input_embeds, | ||
| input_embeds=outputs.hidden_states[0], | ||
| aux_hiddens=self.pop_and_gather_aux_hiddens(), | ||
| out_hiddens=base_model_hidden_states, | ||
| out_hiddens=outputs.hidden_states[-1], | ||
| logits=base_model_logits, | ||
| loss=base_model_loss, | ||
| ), past_key_values | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -467,7 +467,7 @@ def enable_cp_ttt_patch(): | |
| import modelopt.torch.speculative.plugins.transformers | ||
|
|
||
| modelopt.torch.speculative.plugins.transformers.ENABLE_CP_TTT_PATCH = True | ||
| with sdpa_kernel(SDPBackend.CUDNN_ATTENTION): | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Adding
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is needed for tests
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. To expand: the |
||
| with sdpa_kernel([SDPBackend.CUDNN_ATTENTION, SDPBackend.MATH]): | ||
| try: | ||
| yield | ||
| finally: | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,100 @@ | ||
| # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| """Unit tests for EAGLE + LoRA co-training (eagle_base_lora feature).""" | ||
|
|
||
| from copy import deepcopy | ||
|
|
||
| import pytest | ||
| import torch | ||
| from _test_utils.torch.transformers_models import get_tiny_llama | ||
| from peft.tuners.lora import LoraLayer | ||
|
|
||
| import modelopt.torch.speculative as mtsp | ||
| from modelopt.torch.speculative.eagle.default_config import default_eagle_config | ||
|
|
||
| TINY_EAGLE_CFG = { | ||
| "num_hidden_layers": 1, | ||
| "intermediate_size": 32, | ||
| "num_attention_heads": 16, | ||
| "num_key_value_heads": 16, | ||
| "head_dim": 2, | ||
| "use_last_layernorm": True, | ||
| "use_aux_hidden_state": False, | ||
| "eagle_aux_hidden_state_layer_ids": [], | ||
| } | ||
|
|
||
| EAGLE_LORA_CONFIG = { | ||
| "eagle_architecture_config": {**default_eagle_config, **TINY_EAGLE_CFG}, | ||
| "eagle_base_lora": True, | ||
| "eagle_base_lora_rank": 4, | ||
| "eagle_base_lora_alpha": 8.0, | ||
| "eagle_base_lora_target_modules": ["q_proj", "v_proj"], | ||
| "eagle_base_lora_preservation_loss_weight": 0.1, | ||
| } | ||
|
|
||
|
|
||
| @pytest.fixture | ||
| def lora_eagle_model(): | ||
| model = get_tiny_llama(num_hidden_layers=4) | ||
| mtsp.convert(model, mode=[("eagle", deepcopy(EAGLE_LORA_CONFIG))]) | ||
| return model | ||
|
|
||
|
|
||
| def test_lora_layers_injected(lora_eagle_model): | ||
| """LoRA adapters should be present in the base model after conversion.""" | ||
| lora_layers = [m for m in lora_eagle_model._base_model.modules() if isinstance(m, LoraLayer)] | ||
| assert len(lora_layers) > 0, "No LoRA layers found in base model" | ||
|
|
||
|
|
||
| def test_trainable_params(lora_eagle_model): | ||
| """Only LoRA and eagle_module params should be trainable; base model weights frozen.""" | ||
| for name, param in lora_eagle_model.named_parameters(): | ||
| is_lora = "lora_" in name | ||
| is_eagle = "eagle_module" in name | ||
| if is_lora or is_eagle: | ||
| assert param.requires_grad, f"Expected {name} to be trainable" | ||
| else: | ||
| assert not param.requires_grad, f"Expected {name} to be frozen" | ||
|
|
||
|
|
||
| def test_forward_returns_loss(lora_eagle_model): | ||
| """Forward pass should return a scalar loss containing preservation + eagle components.""" | ||
| lora_eagle_model.train() | ||
| seq_len = 8 | ||
| input_ids = torch.randint(0, lora_eagle_model.config.vocab_size, (1, seq_len)) | ||
| output = lora_eagle_model(input_ids=input_ids, labels=input_ids) | ||
| assert output.loss is not None | ||
| assert output.loss.ndim == 0, "Loss should be a scalar" | ||
| assert output.loss.item() > 0 | ||
|
|
||
|
|
||
| def test_eagle_offline_incompatible(): | ||
| """eagle_base_lora=True should raise when combined with eagle_offline=True.""" | ||
| model = get_tiny_llama(num_hidden_layers=4) | ||
| config = deepcopy(EAGLE_LORA_CONFIG) | ||
| config["eagle_offline"] = True | ||
| with pytest.raises(ValueError, match="eagle_base_lora is incompatible with eagle_offline"): | ||
| mtsp.convert(model, mode=[("eagle", config)]) | ||
|
|
||
|
|
||
| def test_export_lora_artifacts(lora_eagle_model, tmp_path): | ||
| """export() should produce lora_adapter_model.safetensors and lora_adapter_config.json.""" | ||
| export_dir = tmp_path / "eagle_export" | ||
| lora_eagle_model.get_exporter().export(export_dir) | ||
|
|
||
| assert (export_dir / "model.safetensors").exists(), "Eagle model weights missing" | ||
| assert (export_dir / "lora_adapter_model.safetensors").exists(), "LoRA weights missing" | ||
| assert (export_dir / "lora_adapter_config.json").exists(), "LoRA config missing" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LoRA key filtering (
"lora_A" in k or "lora_B" in k) is fragile. PEFT may use other key patterns (lora_embedding_A,lora_magnitude_vector, etc.). Consider using peft's own utilities to identify adapter parameters, or add a warning if zero LoRA tensors are found.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed in 5632aba — tightened the filter to
.lora_A./.lora_B.(dot-bounded) and added aRuntimeErrorif no LoRA tensors are found.