Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/speculative_decoding/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
accelerate==1.12.0
peft>=0.17.0
transformers==5.0.0rc1
40 changes: 38 additions & 2 deletions modelopt/torch/export/plugins/hf_spec_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,11 +112,17 @@ def _check_valid_sd(self, export_sd: dict):
"llama": LLAMA_EAGLE_SINGLE_LAYER,
"kimik2": KIMIK2_EAGLE_SINGLE_LAYER,
}[self.eagle_decoder_type]
# fc and hidden_norm are only present when use_aux_hidden_state=True
use_aux = getattr(self.model.eagle_config, "use_aux_hidden_state", False)
aux_only_keys = {"fc", "layers.0.hidden_norm"}
required_keys = set(expected_keys_single_layer["required"])
if not use_aux:
required_keys -= aux_only_keys
# Check that export sd has required keys
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LoRA key filtering ("lora_A" in k or "lora_B" in k) is fragile. PEFT may use other key patterns (lora_embedding_A, lora_magnitude_vector, etc.). Consider using peft's own utilities to identify adapter parameters, or add a warning if zero LoRA tensors are found.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in 5632aba — tightened the filter to .lora_A. / .lora_B. (dot-bounded) and added a RuntimeError if no LoRA tensors are found.

for key in expected_keys_single_layer["required"]:
for key in required_keys:
assert f"{key}.weight" in export_sd, f"Missing required key: {key}.weight"
for i in range(1, self.num_hidden_layers):
for key in expected_keys_single_layer["required"] - {
for key in required_keys - {
"layers.0.hidden_norm",
"layers.0.input_layernorm",
"norm",
Expand Down Expand Up @@ -185,6 +191,32 @@ def _get_config_from_draft_or_base(key: str, model: nn.Module):

return template_config

def _export_lora(self, export_dir: Path, full_sd: dict):
"""Export base model LoRA adapter weights alongside the eagle module artifacts."""
from peft import LoraConfig

lora_sd = {k: v for k, v in full_sd.items() if ".lora_A." in k or ".lora_B." in k}
if not lora_sd:
raise RuntimeError(
"No LoRA adapter tensors found in the model state dict. "
"Ensure eagle_base_lora=True and the model was converted with LoRA adapters."
)
save_file(lora_sd, export_dir / "lora_adapter_model.safetensors")

lora_config = LoraConfig(
r=self.model.eagle_base_lora_rank,
lora_alpha=self.model.eagle_base_lora_alpha,
target_modules=self.model.eagle_base_lora_target_modules or None,
bias="none",
)
with open(export_dir / "lora_adapter_config.json", "w") as f:
json.dump(
lora_config.to_dict(),
f,
indent=4,
default=lambda o: sorted(o) if isinstance(o, set) else o,
)

def export(self, export_dir: Path | str, dtype: torch.dtype | None = None):
"""Export the model to the deployment format."""
# Make export dir
Expand Down Expand Up @@ -215,6 +247,10 @@ def export(self, export_dir: Path | str, dtype: torch.dtype | None = None):
with open(f"{export_dir}/hf_quant_config.json", "w") as file:
json.dump(hf_quant_config, file, indent=4)

# Export LoRA adapter weights separately
if getattr(self.model, "eagle_base_lora", False):
self._export_lora(export_dir, full_sd)


class EagleMedusaExporter(EagleExporter):
"""Draft model exporter for EagleMedusa."""
Expand Down
34 changes: 34 additions & 0 deletions modelopt/torch/speculative/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,3 +110,37 @@ class EagleConfig(ModeloptBaseConfig):
"Whether to mix hidden states of multiple TTT steps. It is a technique to reduce training cost."
),
)

eagle_base_lora: bool = ModeloptField(
default=False,
description=(
"Whether to add LoRA adapters to the base model for co-training with the EAGLE module. "
"Requires the `peft` library. Incompatible with eagle_offline=True."
),
)

eagle_base_lora_rank: int = ModeloptField(
default=64,
description="LoRA rank for the base model adapters.",
)

eagle_base_lora_alpha: float = ModeloptField(
default=16.0,
description="LoRA alpha (scaling) for the base model adapters.",
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mutable default: default=[] means all config instances share the same list object. Use default_factory=list or default=None with a note that None uses peft defaults.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in 5632aba — changed to list | None = ModeloptField(default=None). All existing usages already used or None so nothing breaks.


eagle_base_lora_target_modules: list | None = ModeloptField(
default=None,
description=(
"List of module name patterns to apply LoRA to in the base model "
"(e.g. ['q_proj', 'v_proj']). None uses peft defaults."
),
)

eagle_base_lora_preservation_loss_weight: float = ModeloptField(
default=0.1,
description=(
"Weight for the preservation loss that minimizes the KL divergence between "
"the LoRA-adapted base model output and the original base model output."
),
)
7 changes: 7 additions & 0 deletions modelopt/torch/speculative/eagle/eagle_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,10 @@ def modify(
self.eagle_decoder_type = config.eagle_decoder_type
self.eagle_ttt_steps = config.eagle_ttt_steps
self.eagle_mix_hidden_states = config.eagle_mix_hidden_states
self.eagle_base_lora = config.eagle_base_lora
self.eagle_base_lora_rank = config.eagle_base_lora_rank
self.eagle_base_lora_alpha = config.eagle_base_lora_alpha
self.eagle_base_lora_target_modules = config.eagle_base_lora_target_modules
self.eagle_base_lora_preservation_loss_weight = (
config.eagle_base_lora_preservation_loss_weight
)
109 changes: 86 additions & 23 deletions modelopt/torch/speculative/plugins/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,6 +547,42 @@ def _get_eagle_device(self):
base_model_last_layer = self._base_model.layers[-1]
return next(base_model_last_layer.parameters()).device

def _inject_base_lora(self):
"""Inject HF PEFT LoRA adapters into the base model in-place and unfreeze them."""
from peft import LoraConfig
from peft.mapping import inject_adapter_in_model

target_modules = self.eagle_base_lora_target_modules or None
lora_config = LoraConfig(
r=self.eagle_base_lora_rank,
lora_alpha=self.eagle_base_lora_alpha,
target_modules=target_modules,
bias="none",
)
inject_adapter_in_model(lora_config, self._base_model, adapter_name="default")
# Unfreeze only the LoRA parameters
for name, param in self._base_model.named_parameters():
if "lora_" in name:
param.requires_grad = True

def _set_base_lora_enabled(self, enabled: bool) -> None:
"""Enable or disable LoRA adapters in the base model."""
from peft.tuners.lora import LoraLayer

for module in self._base_model.modules():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The docstring says "KL divergence" but this computes cross-entropy: -softmax(ref) * log_softmax(lora). The missing entropy term is constant w.r.t. LoRA params so gradients are correct, but the naming is misleading. Either rename or add a comment clarifying this is KL up to a constant.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is logit KL divergence

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To expand on this: the implementation computes cross-entropy H(ref, lora) = -softmax(ref) · log_softmax(lora), which equals KL(ref ∥ lora) + H(ref). Since H(ref) is constant w.r.t. LoRA parameters, the gradients are identical to true KL divergence — so the optimization objective is equivalent. The docstring has been updated to clarify this: it now reads "KL(softmax(ref) || log_softmax(lora))" and notes that the entropy term of the reference is constant and dropped.

if isinstance(module, LoraLayer):
module.enable_adapters(enabled)

def _preservation_loss(
self, ref_logits: torch.Tensor, lora_logits: torch.Tensor
) -> torch.Tensor:
"""KL divergence encouraging LoRA output to stay close to the original base model.

KL(softmax(ref) || log_softmax(lora)) weighted by eagle_base_lora_preservation_loss_weight.
"""
loss = nn.Softmax(dim=-1)(ref_logits.detach()) * nn.LogSoftmax(dim=-1)(lora_logits)
return -loss.sum(dim=-1).mean() * self.eagle_base_lora_preservation_loss_weight

def modify(
self,
config,
Expand Down Expand Up @@ -610,6 +646,12 @@ def modify(
if layer_idx in self.eagle_config.eagle_aux_hidden_state_layer_ids:
layer.register_forward_hook(self._collect_aux_hidden_states_forward_hook)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

assert not self.eagle_offline can be optimized out with python -O. Use if self.eagle_offline: raise ValueError(...) for runtime config validation.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in 6e709b5 — replaced with if self.eagle_offline: raise ValueError(...). Also updated the test to expect ValueError instead of AssertionError.

# Inject HF PEFT LoRA adapters into the base model for co-training
if self.eagle_base_lora:
if self.eagle_offline:
raise ValueError("eagle_base_lora is incompatible with eagle_offline=True")
self._inject_base_lora()

# delete base model layers for offline training
if self.eagle_offline:
self._base_model._modules.pop("layers")
Expand Down Expand Up @@ -723,7 +765,9 @@ def _compute_ttt_attention_mask(
) -> BlockMask | torch.Tensor:
"""Return TTT attention_mask tensor of type BlockMask or Tensor depends on eagle attn impl."""
msk_func = get_ttt_msk_func(seq_length, ttt_step)
dtypemin = torch.finfo(self._base_llm_config.dtype).min
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reference forward with LoRA disabled has no try/finally. If the forward throws, LoRA stays disabled for all subsequent calls:

self._set_base_lora_enabled(False)
try:
    ref_logits = _run_forward(no_grad=True).logits
    if hasattr(self, "_aux_hidden_states"):
        self._aux_hidden_states.clear()
finally:
    self._set_base_lora_enabled(True)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in 6e709b5 — wrapped the reference forward in try/finally so _set_base_lora_enabled(True) is guaranteed to run even if the forward throws.

dtypemin = torch.finfo(
getattr(self._base_llm_config, "dtype", None) or torch.get_default_dtype()
).min
q_len = seq_length
kv_len = seq_length * (1 + ttt_step)
if self.eagle_config._attn_implementation == "flex_attention":
Expand All @@ -739,7 +783,10 @@ def _compute_ttt_attention_mask(
torch.arange(kv_len).view(1, 1, 1, kv_len),
).to(self.device)
tensor_mask = torch.full_like(
tensor_mask, 0, dtype=self._base_llm_config.dtype, device=self.device
tensor_mask,
0,
dtype=getattr(self._base_llm_config, "dtype", None) or torch.get_default_dtype(),
device=self.device,
).masked_fill(~tensor_mask, dtypemin)

# Note: (hg) repeat mask for kimi-k2 compatibility
Expand All @@ -756,32 +803,48 @@ def _base_model_forward(
labels,
**kwargs,
):
with torch.no_grad() if freeze_base_model else contextlib.nullcontext():
outputs = super().forward(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
output_hidden_states=True,
**kwargs,
)
past_key_values = getattr(outputs, "past_key_values", None)
base_input_embeds = outputs.hidden_states[0]
base_model_hidden_states = outputs.hidden_states[-1]
base_model_logits = outputs.logits
def _run_forward(no_grad):
with torch.no_grad() if no_grad else contextlib.nullcontext():
return super(HFEagleModel, self).forward(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
output_hidden_states=True,
**kwargs,
)

# Optionally, compute base model loss when we want to tune the base model.
# When using LoRA, run a reference forward with LoRA disabled to get the original
# model distribution for the preservation loss.
ref_logits = None
if self.eagle_base_lora:
self._set_base_lora_enabled(False)
try:
ref_logits = _run_forward(no_grad=True).logits
finally:
if hasattr(self, "_aux_hidden_states"):
self._aux_hidden_states.clear()
self._set_base_lora_enabled(True)

# Main forward — LoRA params receive gradients when eagle_base_lora is True.
outputs = _run_forward(no_grad=freeze_base_model and not self.eagle_base_lora)
past_key_values = getattr(outputs, "past_key_values", None)
base_model_logits = outputs.logits

if ref_logits is not None:
base_model_loss = self._preservation_loss(ref_logits, base_model_logits)
elif not freeze_base_model and labels is not None:
loss_fct = CrossEntropyLoss()
base_model_loss = loss_fct(
base_model_logits.view(-1, base_model_logits.shape[-1]), labels.view(-1)
)
else:
base_model_loss = None
if not freeze_base_model and labels is not None: # Base model loss
loss_fct = CrossEntropyLoss()
loss_logits = base_model_logits.view(-1, base_model_logits.shape[-1])
labels = labels.view(-1)
base_model_loss = loss_fct(loss_logits, labels)

return EagleBaseModelOutput(
input_embeds=base_input_embeds,
input_embeds=outputs.hidden_states[0],
aux_hiddens=self.pop_and_gather_aux_hiddens(),
out_hiddens=base_model_hidden_states,
out_hiddens=outputs.hidden_states[-1],
logits=base_model_logits,
loss=base_model_loss,
), past_key_values
Expand Down
2 changes: 1 addition & 1 deletion modelopt/torch/speculative/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,7 +467,7 @@ def enable_cp_ttt_patch():
import modelopt.torch.speculative.plugins.transformers

modelopt.torch.speculative.plugins.transformers.ENABLE_CP_TTT_PATCH = True
with sdpa_kernel(SDPBackend.CUDNN_ATTENTION):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding SDPBackend.MATH fallback and the dtype getattr changes in transformers.py look unrelated to LoRA co-training. Consider splitting into a separate PR or calling them out in the description.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is needed for tests

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To expand: the SDPBackend.MATH fallback in enable_cp_ttt_patch() is required for CPU unit tests. CUDNN_ATTENTION is only available on GPU, so without MATH as a fallback the test environment raises an error when no supported SDPA backend is found. This is directly exercised by the LoRA co-training forward pass test (test_forward_returns_loss), which runs the TTT attention path on CPU. The change is scoped to the enable_cp_ttt_patch() context manager and doesn't affect production GPU paths.

with sdpa_kernel([SDPBackend.CUDNN_ATTENTION, SDPBackend.MATH]):
try:
yield
finally:
Expand Down
100 changes: 100 additions & 0 deletions tests/unit/torch/speculative/plugins/test_hf_speculative_lora.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Unit tests for EAGLE + LoRA co-training (eagle_base_lora feature)."""

from copy import deepcopy

import pytest
import torch
from _test_utils.torch.transformers_models import get_tiny_llama
from peft.tuners.lora import LoraLayer

import modelopt.torch.speculative as mtsp
from modelopt.torch.speculative.eagle.default_config import default_eagle_config

TINY_EAGLE_CFG = {
"num_hidden_layers": 1,
"intermediate_size": 32,
"num_attention_heads": 16,
"num_key_value_heads": 16,
"head_dim": 2,
"use_last_layernorm": True,
"use_aux_hidden_state": False,
"eagle_aux_hidden_state_layer_ids": [],
}

EAGLE_LORA_CONFIG = {
"eagle_architecture_config": {**default_eagle_config, **TINY_EAGLE_CFG},
"eagle_base_lora": True,
"eagle_base_lora_rank": 4,
"eagle_base_lora_alpha": 8.0,
"eagle_base_lora_target_modules": ["q_proj", "v_proj"],
"eagle_base_lora_preservation_loss_weight": 0.1,
}


@pytest.fixture
def lora_eagle_model():
model = get_tiny_llama(num_hidden_layers=4)
mtsp.convert(model, mode=[("eagle", deepcopy(EAGLE_LORA_CONFIG))])
return model


def test_lora_layers_injected(lora_eagle_model):
"""LoRA adapters should be present in the base model after conversion."""
lora_layers = [m for m in lora_eagle_model._base_model.modules() if isinstance(m, LoraLayer)]
assert len(lora_layers) > 0, "No LoRA layers found in base model"


def test_trainable_params(lora_eagle_model):
"""Only LoRA and eagle_module params should be trainable; base model weights frozen."""
for name, param in lora_eagle_model.named_parameters():
is_lora = "lora_" in name
is_eagle = "eagle_module" in name
if is_lora or is_eagle:
assert param.requires_grad, f"Expected {name} to be trainable"
else:
assert not param.requires_grad, f"Expected {name} to be frozen"


def test_forward_returns_loss(lora_eagle_model):
"""Forward pass should return a scalar loss containing preservation + eagle components."""
lora_eagle_model.train()
seq_len = 8
input_ids = torch.randint(0, lora_eagle_model.config.vocab_size, (1, seq_len))
output = lora_eagle_model(input_ids=input_ids, labels=input_ids)
assert output.loss is not None
assert output.loss.ndim == 0, "Loss should be a scalar"
assert output.loss.item() > 0


def test_eagle_offline_incompatible():
"""eagle_base_lora=True should raise when combined with eagle_offline=True."""
model = get_tiny_llama(num_hidden_layers=4)
config = deepcopy(EAGLE_LORA_CONFIG)
config["eagle_offline"] = True
with pytest.raises(ValueError, match="eagle_base_lora is incompatible with eagle_offline"):
mtsp.convert(model, mode=[("eagle", config)])


def test_export_lora_artifacts(lora_eagle_model, tmp_path):
"""export() should produce lora_adapter_model.safetensors and lora_adapter_config.json."""
export_dir = tmp_path / "eagle_export"
lora_eagle_model.get_exporter().export(export_dir)

assert (export_dir / "model.safetensors").exists(), "Eagle model weights missing"
assert (export_dir / "lora_adapter_model.safetensors").exists(), "LoRA weights missing"
assert (export_dir / "lora_adapter_config.json").exists(), "LoRA config missing"
Loading