Skip to content

Add LoRA co-training support for HF EAGLE speculative decoding#1060

Open
yeyu-nvidia wants to merge 5 commits intomainfrom
yeyu/speculative-lora-cotrain
Open

Add LoRA co-training support for HF EAGLE speculative decoding#1060
yeyu-nvidia wants to merge 5 commits intomainfrom
yeyu/speculative-lora-cotrain

Conversation

@yeyu-nvidia
Copy link
Contributor

@yeyu-nvidia yeyu-nvidia commented Mar 17, 2026

● ### What does this PR do?

Type of change: New feature

Adds LoRA co-training support for HF EAGLE speculative decoding. When
eagle_base_lora=True, HF PEFT LoRA adapters are injected into the base
model and co-trained alongside the EAGLE draft module in a single online
training pass. A preservation loss (KL divergence between the original
frozen base model output and the LoRA-adapted output) is added to prevent
the base model from drifting during training. LoRA adapter weights are
exported separately alongside the EAGLE draft model artifacts.

Usage

import modelopt.torch.speculative as mtsp

# Convert model to EAGLE with LoRA co-training enabled
mtsp.convert(model, mode=[("eagle", {
    "eagle_architecture_config": eagle_arch_cfg,
    "eagle_base_lora": True,
    "eagle_base_lora_rank": 64,
    "eagle_base_lora_alpha": 16.0,
    "eagle_base_lora_target_modules": ["q_proj", "k_proj", "v_proj", "o_proj"],
    "eagle_base_lora_preservation_loss_weight": 0.1,
})])

# Train as usual — LoRA params and eagle_module params are trainable,
# base model weights are frozen. Total loss = eagle_loss + preservation_loss.
output = model(input_ids=input_ids, labels=labels)
output.loss.backward()

# Export: eagle draft weights + LoRA adapter weights saved separately
model.get_exporter().export("./export_dir")
# export_dir/
#   model.safetensors            <- EAGLE draft module
#   config.json                  <- EAGLE config
#   lora_adapter_model.safetensors  <- LoRA adapter weights
#   lora_adapter_config.json        <- LoRA config (rank, alpha, target_modules)

Testing

Added tests/unit/torch/speculative/plugins/test_hf_speculative_lora.py with 5 unit tests:
- test_lora_layers_injectedverifies LoRA layers are present in the base model after conversion
- test_trainable_paramsverifies only lora_* and eagle_module params have requires_grad=True
- test_forward_returns_lossverifies the forward pass returns a non-zero scalar loss
- test_eagle_offline_incompatibleverifies eagle_base_lora=True + eagle_offline=True raises an error
- test_export_lora_artifactsverifies export() produces the expected LoRA files

Before your PR is "Ready for review"

- Is this change backward compatible?: ✅ — new config fields all have defaults; existing EAGLE workflows are unaffected.
- If you copied code from any other sources or added a new PIP dependency, did you follow guidance in CONTRIBUTING.md: ✅ — uses peft>=0.17.0 which is already listed in the [hf] optional extra.
- Did you write any new necessary tests?: ✅
- Did you update Changelog?: ✅

Additional Information

This feature is intended for online HF training only (eagle_offline=True is explicitly blocked). The LoRA adapters are applied to the base model via peft.inject_adapter_in_model (in-place, no wrapper), keeping the existing HFEagleModel structure intact.



<!-- This is an auto-generated comment: release notes by coderabbit.ai -->
## Summary by CodeRabbit

* **New Features**
* LoRA support for EAGLE speculative decoding with configurable rank, alpha, target modules, and a preservation-loss to retain base-model behavior
* Export capability to save model and LoRA adapter artifacts and adapter configuration

* **Chores**
* Added peft dependency (minimum version specified) for LoRA support
* Broadened backend usage for a runtime patch

* **Tests**
* New tests covering LoRA injection, parameter freezing, forward/loss behavior, compatibility checks, and export artifacts
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

yeyu-nvidia and others added 2 commits March 17, 2026 11:43
Introduces eagle_base_lora training mode where HF PEFT LoRA adapters are
injected into the base model and co-trained with the EAGLE draft module.
A preservation loss (KL divergence between original and LoRA-adapted base
model outputs) is added to prevent the base model from drifting during
training. LoRA adapter weights are exported separately alongside the EAGLE
draft model artifacts.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: Ye Yu <yeyu@nvidia.com>
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: Ye Yu <yeyu@nvidia.com>
@yeyu-nvidia yeyu-nvidia requested review from a team as code owners March 17, 2026 18:56
@yeyu-nvidia yeyu-nvidia requested a review from ChenhanYu March 17, 2026 18:56
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Mar 17, 2026

No actionable comments were generated in the recent review. 🎉

ℹ️ Recent review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 6c7520bd-3c08-44be-9c04-1485930054f9

📥 Commits

Reviewing files that changed from the base of the PR and between 9415c07 and 7fae8b1.

📒 Files selected for processing (2)
  • modelopt/torch/export/plugins/hf_spec_export.py
  • modelopt/torch/speculative/utils.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • modelopt/torch/export/plugins/hf_spec_export.py

📝 Walkthrough

Walkthrough

Adds configurable PEFT LoRA support for the EAGLE base model: new config fields, runtime LoRA adapter injection/toggling, a KL-based preservation loss, export of LoRA artifacts, dependency addition, and unit tests covering injection, training constraints, loss, incompatibility, and export.

Changes

Cohort / File(s) Summary
Dependency
examples/speculative_decoding/requirements.txt
Adds peft>=0.17.0.
Config
modelopt/torch/speculative/config.py
Adds five EagleConfig fields: eagle_base_lora, eagle_base_lora_rank, eagle_base_lora_alpha, eagle_base_lora_target_modules, eagle_base_lora_preservation_loss_weight.
Runtime LoRA integration
modelopt/torch/speculative/plugins/transformers.py
Injects and manages PEFT LoRA adapters for the base model: adds _inject_base_lora, _set_base_lora_enabled, _preservation_loss, and updates modify() and _base_model_forward() to run reference (LoRA-disabled) and LoRA-enabled forwards and compute preservation loss.
Eagle model wiring
modelopt/torch/speculative/eagle/eagle_model.py
Assigns new LoRA-related config fields to EagleModel attributes during modify().
Export
modelopt/torch/export/plugins/hf_spec_export.py
Adds _export_lora() to export LoRA adapter weights (lora_adapter_model.safetensors) and lora_adapter_config.json when eagle_base_lora is present; integrated into export flow and adjusts state-key validation.
Tests
tests/unit/torch/speculative/plugins/test_hf_speculative_lora.py
New test suite and fixture lora_eagle_model: verifies LoRA injection, trainable-parameter freezing, forward loss, eagle_offline incompatibility, and export artifacts (model.safetensors, lora_adapter_model.safetensors, lora_adapter_config.json).
Other runtime change
modelopt/torch/speculative/utils.py
Modifies enable_cp_ttt_patch to call sdpa_kernel with backends list [CUDNN_ATTENTION, MATH] instead of a single backend.

Sequence Diagram(s)

sequenceDiagram
    participant Trainer
    participant HFEagleModel
    participant BaseModel
    participant LoRAAdapter
    participant Exporter

    Trainer->>HFEagleModel: init(config with eagle_base_lora)
    HFEagleModel->>LoRAAdapter: _inject_base_lora() (create LoraConfig, inject adapters)
    LoRAAdapter-->>HFEagleModel: adapters installed, LoRA params unfrozen

    Trainer->>HFEagleModel: training step
    HFEagleModel->>BaseModel: _set_base_lora_enabled(False)
    HFEagleModel->>BaseModel: forward -> ref_logits
    BaseModel-->>HFEagleModel: ref_logits

    HFEagleModel->>BaseModel: _set_base_lora_enabled(True)
    HFEagleModel->>BaseModel: forward -> lora_logits, hidden_states
    BaseModel-->>HFEagleModel: lora_logits, hidden_states

    HFEagleModel->>HFEagleModel: _preservation_loss(ref_logits, lora_logits)
    HFEagleModel-->>Trainer: combined loss

    Trainer->>Exporter: export request
    Exporter->>LoRAAdapter: export weights -> `lora_adapter_model.safetensors`
    Exporter->>Exporter: write `lora_adapter_config.json`
    Exporter-->>Trainer: exported artifacts
Loading

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

🚥 Pre-merge checks | ✅ 4
✅ Passed checks (4 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title accurately summarizes the main change: adding LoRA co-training support for HF EAGLE speculative decoding, which is reflected across all modified files.
Docstring Coverage ✅ Passed Docstring coverage is 82.35% which is sufficient. The required threshold is 80.00%.
Security Anti-Patterns ✅ Passed The pull request does not introduce any critical security anti-patterns. No torch.load with weights_only=False, numpy.load with allow_pickle=True, hardcoded trust_remote_code=True, eval()/exec() builtins, or nosec comments found.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
📝 Generate docstrings
  • Create stacked PR
  • Commit on current branch
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch yeyu/speculative-lora-cotrain
📝 Coding Plan
  • Generate coding plan for human review comments

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🧹 Nitpick comments (3)
modelopt/torch/speculative/config.py (1)

122-145: Strengthen LoRA config typing and value bounds.

Consider constraining invalid user input at config parse time (e.g., rank <= 0, negative preservation weight) and avoid mutable list defaults.

Proposed patch
-    eagle_base_lora_rank: int = ModeloptField(
+    eagle_base_lora_rank: int = ModeloptField(
         default=64,
+        ge=1,
         description="LoRA rank for the base model adapters.",
     )

     eagle_base_lora_alpha: float = ModeloptField(
         default=16.0,
+        gt=0.0,
         description="LoRA alpha (scaling) for the base model adapters.",
     )

-    eagle_base_lora_target_modules: list = ModeloptField(
-        default=[],
+    eagle_base_lora_target_modules: tuple[str, ...] = ModeloptField(
+        default=(),
         description=(
             "List of module name patterns to apply LoRA to in the base model "
             "(e.g. ['q_proj', 'v_proj']). Empty list uses peft defaults."
         ),
     )

     eagle_base_lora_preservation_loss_weight: float = ModeloptField(
         default=0.1,
+        ge=0.0,
         description=(
             "Weight for the preservation loss that minimizes the KL divergence between "
             "the LoRA-adapted base model output and the original base model output."
         ),
     )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/speculative/config.py` around lines 122 - 145, The config
fields eagle_base_lora_rank, eagle_base_lora_alpha,
eagle_base_lora_target_modules, and eagle_base_lora_preservation_loss_weight use
permissive types and a mutable list default; update their ModeloptField
definitions to enforce proper typing and validate bounds at parse/validation
time: require eagle_base_lora_rank to be an int > 0, eagle_base_lora_alpha to be
a float >= 0, eagle_base_lora_preservation_loss_weight to be a float >= 0
(reject negatives), and replace the mutable default for
eagle_base_lora_target_modules with an immutable default (e.g., None or tuple)
and coerce/validate it into a list of strings; implement these checks using the
config validation hook or the ModeloptField's validator callbacks so invalid
inputs raise a clear parsing error.
tests/unit/torch/speculative/plugins/test_hf_speculative_lora.py (1)

93-100: Strengthen export test by validating LoRA config contents, not just file existence.

Existence checks can pass with malformed config. Assert expected r, lora_alpha, and target_modules in lora_adapter_config.json.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/unit/torch/speculative/plugins/test_hf_speculative_lora.py` around
lines 93 - 100, In test_export_lora_artifacts, after exporting via
lora_eagle_model.get_exporter().export(export_dir), open and parse export_dir /
"lora_adapter_config.json" as JSON and assert the config contains the keys "r",
"lora_alpha", and "target_modules"; further validate that "r" and "lora_alpha"
are positive integers and that "target_modules" is a non-empty list of strings
(or matches the expected module names for this model), so the test checks
semantic correctness not just file existence.
modelopt/torch/speculative/plugins/transformers.py (1)

581-582: Prefer F.kl_div for preservation loss clarity/stability.

Current expression is a manual cross-entropy form; F.kl_div makes intent explicit and is less error-prone to maintain.

Proposed patch
+import torch.nn.functional as F
...
-        loss = nn.Softmax(dim=-1)(ref_logits.detach()) * nn.LogSoftmax(dim=-1)(lora_logits)
-        return -loss.sum(dim=-1).mean() * self.eagle_base_lora_preservation_loss_weight
+        ref_prob = F.softmax(ref_logits.detach(), dim=-1)
+        lora_log_prob = F.log_softmax(lora_logits, dim=-1)
+        kl = F.kl_div(lora_log_prob, ref_prob, reduction="batchmean")
+        return kl * self.eagle_base_lora_preservation_loss_weight
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/speculative/plugins/transformers.py` around lines 581 - 582,
Replace the manual cross-entropy-like expression with torch.nn.functional.kl_div
to make intent explicit and numerically stable: compute log-probs from
lora_logits with F.log_softmax, compute target probs from ref_logits.detach()
with F.softmax, call F.kl_div(log_probs, target_probs, reduction='none'), sum
over the last dim, take the mean, and multiply by
self.eagle_base_lora_preservation_loss_weight (no leading negative). Update the
expression that currently uses nn.Softmax/nn.LogSoftmax and returns
-loss.sum(...).mean()*self.eagle_base_lora_preservation_loss_weight to the
F.kl_div-based sequence using the same dimensions and detachment of ref_logits.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@modelopt/torch/export/plugins/hf_spec_export.py`:
- Around line 191-193: The export currently builds lora_sd = {k: v for k, v in
full_sd.items() if "lora_A" in k or "lora_B" in k} and calls save_file(...) even
if lora_sd is empty; add a guard after constructing lora_sd in the
hf_spec_export export routine to fail fast: check if lora_sd is empty and if so,
raise a clear exception (or call processLogger.error and raise RuntimeError)
indicating no LoRA tensors found instead of writing an empty file; reference the
lora_sd variable, full_sd source, save_file call and
export_dir/"lora_adapter_model.safetensors" target so the change is applied in
the right spot.

In `@modelopt/torch/speculative/plugins/transformers.py`:
- Around line 812-818: The block that disables LoRA adapters using
self._set_base_lora_enabled(False) before calling _run_forward can leave
adapters disabled if _run_forward raises; wrap the reference forward in a
try/finally so that self._set_base_lora_enabled(True) always executes, still
clearing self._aux_hidden_states when present and returning/using ref_logits
from _run_forward; specifically, call _set_base_lora_enabled(False), run
ref_logits = _run_forward(no_grad=True).logits inside try, then in finally
re-enable via _set_base_lora_enabled(True) and clear self._aux_hidden_states if
present.
- Around line 648-650: The code currently uses an assert to enforce that
eagle_base_lora and eagle_offline are not both set (in the block that calls
self._inject_base_lora()); replace the assert with an explicit runtime exception
(e.g., raise ValueError or RuntimeError) so the check always runs in production.
Locate the conditional that checks self.eagle_base_lora and the incompatible
flag self.eagle_offline, and throw a clear exception with a descriptive message
instead of using assert before calling self._inject_base_lora().

---

Nitpick comments:
In `@modelopt/torch/speculative/config.py`:
- Around line 122-145: The config fields eagle_base_lora_rank,
eagle_base_lora_alpha, eagle_base_lora_target_modules, and
eagle_base_lora_preservation_loss_weight use permissive types and a mutable list
default; update their ModeloptField definitions to enforce proper typing and
validate bounds at parse/validation time: require eagle_base_lora_rank to be an
int > 0, eagle_base_lora_alpha to be a float >= 0,
eagle_base_lora_preservation_loss_weight to be a float >= 0 (reject negatives),
and replace the mutable default for eagle_base_lora_target_modules with an
immutable default (e.g., None or tuple) and coerce/validate it into a list of
strings; implement these checks using the config validation hook or the
ModeloptField's validator callbacks so invalid inputs raise a clear parsing
error.

In `@modelopt/torch/speculative/plugins/transformers.py`:
- Around line 581-582: Replace the manual cross-entropy-like expression with
torch.nn.functional.kl_div to make intent explicit and numerically stable:
compute log-probs from lora_logits with F.log_softmax, compute target probs from
ref_logits.detach() with F.softmax, call F.kl_div(log_probs, target_probs,
reduction='none'), sum over the last dim, take the mean, and multiply by
self.eagle_base_lora_preservation_loss_weight (no leading negative). Update the
expression that currently uses nn.Softmax/nn.LogSoftmax and returns
-loss.sum(...).mean()*self.eagle_base_lora_preservation_loss_weight to the
F.kl_div-based sequence using the same dimensions and detachment of ref_logits.

In `@tests/unit/torch/speculative/plugins/test_hf_speculative_lora.py`:
- Around line 93-100: In test_export_lora_artifacts, after exporting via
lora_eagle_model.get_exporter().export(export_dir), open and parse export_dir /
"lora_adapter_config.json" as JSON and assert the config contains the keys "r",
"lora_alpha", and "target_modules"; further validate that "r" and "lora_alpha"
are positive integers and that "target_modules" is a non-empty list of strings
(or matches the expected module names for this model), so the test checks
semantic correctness not just file existence.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 66fcf495-7fbb-405d-9f5d-1206155ab766

📥 Commits

Reviewing files that changed from the base of the PR and between 00fa5bd and ebdbf65.

📒 Files selected for processing (5)
  • examples/speculative_decoding/requirements.txt
  • modelopt/torch/export/plugins/hf_spec_export.py
  • modelopt/torch/speculative/config.py
  • modelopt/torch/speculative/plugins/transformers.py
  • tests/unit/torch/speculative/plugins/test_hf_speculative_lora.py

Comment on lines +191 to +193
lora_sd = {k: v for k, v in full_sd.items() if "lora_A" in k or "lora_B" in k}
save_file(lora_sd, export_dir / "lora_adapter_model.safetensors")

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Fail fast when no LoRA tensors are found during export.

If LoRA injection/reg-key filtering regresses, this currently emits an empty adapter file and still reports success. Add an explicit guard.

Proposed patch
     def _export_lora(self, export_dir: Path, full_sd: dict):
         """Export base model LoRA adapter weights alongside the eagle module artifacts."""
         lora_sd = {k: v for k, v in full_sd.items() if "lora_A" in k or "lora_B" in k}
+        if not lora_sd:
+            raise RuntimeError("No LoRA adapter tensors found in state_dict; refusing empty export.")
         save_file(lora_sd, export_dir / "lora_adapter_model.safetensors")
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/export/plugins/hf_spec_export.py` around lines 191 - 193, The
export currently builds lora_sd = {k: v for k, v in full_sd.items() if "lora_A"
in k or "lora_B" in k} and calls save_file(...) even if lora_sd is empty; add a
guard after constructing lora_sd in the hf_spec_export export routine to fail
fast: check if lora_sd is empty and if so, raise a clear exception (or call
processLogger.error and raise RuntimeError) indicating no LoRA tensors found
instead of writing an empty file; reference the lora_sd variable, full_sd
source, save_file call and export_dir/"lora_adapter_model.safetensors" target so
the change is applied in the right spot.

Comment on lines +648 to +650
if self.eagle_base_lora:
assert not self.eagle_offline, "eagle_base_lora is incompatible with eagle_offline=True"
self._inject_base_lora()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Do not use assert for runtime config validation.

assert can be optimized out; use an explicit exception so the incompatibility check always executes.

Proposed patch
         # Inject HF PEFT LoRA adapters into the base model for co-training
         if self.eagle_base_lora:
-            assert not self.eagle_offline, "eagle_base_lora is incompatible with eagle_offline=True"
+            if self.eagle_offline:
+                raise ValueError("eagle_base_lora is incompatible with eagle_offline=True")
             self._inject_base_lora()
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/speculative/plugins/transformers.py` around lines 648 - 650,
The code currently uses an assert to enforce that eagle_base_lora and
eagle_offline are not both set (in the block that calls
self._inject_base_lora()); replace the assert with an explicit runtime exception
(e.g., raise ValueError or RuntimeError) so the check always runs in production.
Locate the conditional that checks self.eagle_base_lora and the incompatible
flag self.eagle_offline, and throw a clear exception with a descriptive message
instead of using assert before calling self._inject_base_lora().

Comment on lines +812 to 818
if self.eagle_base_lora:
self._set_base_lora_enabled(False)
ref_logits = _run_forward(no_grad=True).logits
if hasattr(self, "_aux_hidden_states"):
self._aux_hidden_states.clear()
self._set_base_lora_enabled(True)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Restore LoRA adapter state with try/finally around the reference forward.

If the reference forward raises, adapters stay disabled and subsequent training behavior becomes incorrect.

Proposed patch
         ref_logits = None
         if self.eagle_base_lora:
-            self._set_base_lora_enabled(False)
-            ref_logits = _run_forward(no_grad=True).logits
-            if hasattr(self, "_aux_hidden_states"):
-                self._aux_hidden_states.clear()
-            self._set_base_lora_enabled(True)
+            self._set_base_lora_enabled(False)
+            try:
+                ref_logits = _run_forward(no_grad=True).logits
+                if hasattr(self, "_aux_hidden_states"):
+                    self._aux_hidden_states.clear()
+            finally:
+                self._set_base_lora_enabled(True)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/speculative/plugins/transformers.py` around lines 812 - 818,
The block that disables LoRA adapters using self._set_base_lora_enabled(False)
before calling _run_forward can leave adapters disabled if _run_forward raises;
wrap the reference forward in a try/finally so that
self._set_base_lora_enabled(True) always executes, still clearing
self._aux_hidden_states when present and returning/using ref_logits from
_run_forward; specifically, call _set_base_lora_enabled(False), run ref_logits =
_run_forward(no_grad=True).logits inside try, then in finally re-enable via
_set_base_lora_enabled(True) and clear self._aux_hidden_states if present.

@h-guo18
Copy link
Contributor

h-guo18 commented Mar 17, 2026

How would the base model quality and AL looks like with this lora cotraining?

The LoRA co-training config fields (eagle_base_lora, eagle_base_lora_rank,
eagle_base_lora_alpha, eagle_base_lora_target_modules,
eagle_base_lora_preservation_loss_weight) were defined in the config but
never assigned in EagleModel.modify(), causing DynamicModule.__getattr__
to raise AttributeError when HFEagleModel.modify() accessed them.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: Ye Yu <yeyu@nvidia.com>
@yeyu-nvidia
Copy link
Contributor Author

How would the base model quality and AL looks like with this lora cotraining?

Haven't tested. will report later

yeyu-nvidia and others added 2 commits March 17, 2026 14:00
Set num_key_value_heads=16 (matching num_attention_heads) to avoid GQA,
which triggers enable_gqa=True in SDPA — unsupported on CPU backends.
Set use_last_layernorm=True so the norm layer is created and norm.weight
is present in the export state dict as required by the export validator.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: Ye Yu <yeyu@nvidia.com>
- enable_cp_ttt_patch: add SDPBackend.MATH alongside CUDNN_ATTENTION so
  the math kernel is available as fallback on CPU (fixes test_forward_returns_loss)
- _check_valid_sd: skip fc/hidden_norm from required keys when
  use_aux_hidden_state=False, as these layers only exist in EAGLE-3
  (fixes test_export_lora_artifacts)

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: Ye Yu <yeyu@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants