Skip to content

[EAGLE] Optimize EAGLE Training#1044

Open
benchislett wants to merge 9 commits intomainfrom
bchislett/eagle-speedups-torch-compile
Open

[EAGLE] Optimize EAGLE Training#1044
benchislett wants to merge 9 commits intomainfrom
bchislett/eagle-speedups-torch-compile

Conversation

@benchislett
Copy link
Contributor

@benchislett benchislett commented Mar 16, 2026

What does this PR do?

Type of change: Optimization

Changes:

  • Precompute base_model_logits.argmax() and base_model_logits.softmax() instead of recomputing in every call to _eagle_loss
  • Calculate per-prediction accuracy on the GPU and synchronize it to the host after running all TTT steps, to avoid cpu/gpu synchronization inside the TTT step loop.
  • Apply torch.compile to performance-critical training functions: prepare inputs, eagle forward, and eagle loss calculation. Omitted from target model in online case for now, as it may not be natively compatible with all architectures.

Usage

No changes to external interfaces

Testing

Ran training commands for benchmarking. Did not do a full training run, did not validate correctness.

Before your PR is "Ready for review"

Make sure you read and follow Contributor guidelines and your commits are signed (git commit -s -S).

Make sure you read and follow the Security Best Practices (e.g. avoiding hardcoded trust_remote_code=True, torch.load(..., weights_only=False), pickle, etc.).

  • Is this change backward compatible?: ✅
  • If you copied code from any other sources or added a new PIP dependency, did you follow guidance in CONTRIBUTING.md: N/A
  • Did you write any new necessary tests?: N/A
  • Did you update Changelog?: ✅ / ❌ / N/A

Additional Information

Summary by CodeRabbit

  • New Features

    • Added runtime profiling ranges for speculative decoding to enable finer performance tracing.
  • Improvements

    • Lazy initialization for rotary embeddings in llama-style decoders for more reliable startup.
    • Speculative decoding now uses base-model predicted tokens and softmax probabilities for sampling and loss, improving stability and accuracy reporting.
    • Parallelism configuration is now conditional, avoiding unnecessary setup unless required.

Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
… calls

Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
@benchislett benchislett requested a review from a team as a code owner March 16, 2026 05:18
@benchislett benchislett requested a review from ChenhanYu March 16, 2026 05:18
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Mar 16, 2026

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

Parallelism config creation in the speculative decoding example is now conditional. The Transformers plugin adds NVTX profiling decorators, lazily initializes Llama rotary embeddings in EagleModule, and threads base-model predicted tokens and softmax logits through input preparation into the eagle forward/loss/accuracy flow.

Changes

Cohort / File(s) Summary
Conditional Parallelism Configuration
examples/speculative_decoding/main.py
Make parallelism_config only when cp_size > 1 or dp_shard_size > 1; keep ring-attention patching and set parallelism_config.sp_backend = None when cp_size > 1.
Profiling & EagleModule changes
modelopt/torch/speculative/plugins/transformers.py
Add @nvtx.range decorators to _prepare_eagle_inputs, _base_model_forward, _eagle_forward, _eagle_loss; add _maybe_init_rope() for lazy Llama rotary init; _prepare_eagle_inputs now returns base_output_predict_tok and base_output_softmax_logits; propagate these into _eagle_forward/_eagle_loss and update loss/accuracy computations.
Manifest
pyproject.toml
Minor manifest edits (small line changes).

Sequence Diagram(s)

mermaid
sequenceDiagram
participant Trainer
participant BaseModel
participant EagleModule
participant Loss
Trainer->>BaseModel: run base forward(inputs)
BaseModel-->>Trainer: base_logits
Trainer->>Trainer: softmax(base_logits) -> base_output_softmax_logits\nargmax(base_logits) -> base_output_predict_tok
Trainer->>EagleModule: _prepare_eagle_inputs(..., base_output_predict_tok, base_output_softmax_logits)
EagleModule->>EagleModule: _maybe_init_rope()\ncompute eagle_logits
EagleModule-->>Trainer: eagle_logits
Trainer->>Loss: _eagle_loss(eagle_logits, base_output_softmax_logits, base_output_predict_tok, loss_mask)
Loss-->>Trainer: loss, accuracy

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes


Important

Pre-merge checks failed

Please resolve all errors before merging. Addressing warnings is optional.

❌ Failed checks (1 error, 1 warning)

Check name Status Explanation Resolution
Security Anti-Patterns ❌ Error Critical NVTX import is not wrapped in try/except, breaking Windows CI where NVTX is unavailable; review feedback explicitly requests optional dependency handling. Wrap NVTX import in try/except block with _NVTX_AVAILABLE flag and fallback decorator; replace all @nvtx.range() with conditional @nvtx_range() function.
Docstring Coverage ⚠️ Warning Docstring coverage is 64.29% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title '[EAGLE] Optimize EAGLE Training' directly relates to the PR's main objective of optimizing EAGLE training through precomputing base model outputs, adding torch.compile annotations, and improving GPU/CPU synchronization.
✨ Finishing Touches
  • 📝 Generate docstrings (stacked PR)
  • 📝 Generate docstrings (commit on current branch)
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch bchislett/eagle-speedups-torch-compile
📝 Coding Plan
  • Generate coding plan for human review comments

Comment @coderabbitai help to get the list of available commands and usage tips.

training_args.parallelism_config = ParallelismConfig(
cp_size=training_args.cp_size, dp_shard_size=training_args.dp_shard_size
)
if training_args.cp_size > 1 or training_args.dp_shard_size > 1:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note, this is an unrelated bugfix related to #1045 (does not fully solve the issue, just a single-gpu workaround)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As discussed in slack, this issue id due to transformers version mismatch. Should be fixed after updating transformers.

Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@modelopt/torch/speculative/plugins/transformers.py`:
- Around line 908-910: The code can raise a NameError when self.eagle_ttt_steps
== 0 because the loop that defines ttt_step never runs; update the logic in
modify() (or the surrounding block) to handle the zero-case explicitly: either
assert self.eagle_ttt_steps >= 1 at the start of modify() to make the invariant
explicit, or initialize ttt_step to a safe default (or skip code that uses
ttt_step) when eagle_ttt_steps == 0 and ensure train_accs =
torch.zeros(num_parallel, num_ttt, device=input_ids.device) is still valid;
reference symbols: self.eagle_ttt_steps, train_accs, ttt_step, and modify() when
applying the fix.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: de2f61b0-6ad0-43ef-b333-c5cd195b6a21

📥 Commits

Reviewing files that changed from the base of the PR and between 3d40373 and 0309c19.

📒 Files selected for processing (1)
  • modelopt/torch/speculative/plugins/transformers.py

Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (1)
modelopt/torch/speculative/plugins/transformers.py (1)

909-912: ⚠️ Potential issue | 🟠 Major

Guard zero-step TTT to avoid undefined ttt_step.

If self.eagle_ttt_steps == 0, the loop at Line 931 never runs, and Line 989 references ttt_step before assignment.

🔧 Proposed fix
-        train_accs = torch.zeros(num_parallel, num_ttt, device=input_ids.device)
+        train_accs = torch.zeros(num_parallel, num_ttt, device=input_ids.device)
+        executed_ttt_steps = 0
@@
-        for ttt_step in range(self.eagle_ttt_steps):
+        for ttt_step in range(self.eagle_ttt_steps):
@@
-                train_accs[i, ttt_step] = acc
+                train_accs[i, ttt_step] = acc
+            executed_ttt_steps = ttt_step + 1
             if not self.training:
                 break
@@
-        train_accs = train_accs[:, : ttt_step + 1].tolist()
+        train_accs = train_accs[:, :executed_ttt_steps].tolist()

Also applies to: 988-990

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/speculative/plugins/transformers.py` around lines 909 - 912,
Guard against zero TTT steps by checking self.eagle_ttt_steps before using
ttt_step or running the TTT loop: if self.eagle_ttt_steps == 0 skip the entire
TTT block (including the loop that populates train_accs and any later use of
ttt_step) or initialize a safe default for ttt_step and related tensors so they
are defined when eagle_ttt_steps is 0; update references around train_accs, the
loop that iterates over range(self.eagle_ttt_steps) and the later code that uses
ttt_step (the code near variables num_ttt, train_accs and where ttt_step is
referenced) to either early-return/skip or handle the zero-case explicitly.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@modelopt/torch/speculative/plugins/transformers.py`:
- Around line 928-929: The RoPE initializer is only called in
HFEagleModel.forward, but other entry points like pseudo_speculative_generate()
call _eagle_forward() and can trigger EagleModule.forward before rotary_emb
exists; update the code so every EAGLE entry path invokes the initializer: call
self.eagle_module._maybe_init_rope() at the start of EagleModule.forward and
also ensure _eagle_forward() (and/or pseudo_speculative_generate()) invokes
_maybe_init_rope() before any use of rotary_emb so RoPE is always initialized
regardless of which forward path is exercised.

---

Duplicate comments:
In `@modelopt/torch/speculative/plugins/transformers.py`:
- Around line 909-912: Guard against zero TTT steps by checking
self.eagle_ttt_steps before using ttt_step or running the TTT loop: if
self.eagle_ttt_steps == 0 skip the entire TTT block (including the loop that
populates train_accs and any later use of ttt_step) or initialize a safe default
for ttt_step and related tensors so they are defined when eagle_ttt_steps is 0;
update references around train_accs, the loop that iterates over
range(self.eagle_ttt_steps) and the later code that uses ttt_step (the code near
variables num_ttt, train_accs and where ttt_step is referenced) to either
early-return/skip or handle the zero-case explicitly.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: d8dcdd21-580f-406c-9a0f-ff3ab7c80865

📥 Commits

Reviewing files that changed from the base of the PR and between 0309c19 and 44c00b8.

📒 Files selected for processing (1)
  • modelopt/torch/speculative/plugins/transformers.py

@codecov
Copy link

codecov bot commented Mar 16, 2026

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 70.02%. Comparing base (1070d89) to head (67e4bee).
⚠️ Report is 1 commits behind head on main.

Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1044      +/-   ##
==========================================
- Coverage   70.10%   70.02%   -0.09%     
==========================================
  Files         221      221              
  Lines       25541    25541              
==========================================
- Hits        17905    17884      -21     
- Misses       7636     7657      +21     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
@h-guo18
Copy link
Contributor

h-guo18 commented Mar 16, 2026

Make sense to me. Is there any perf comparison before/after the optimizations?

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (1)
modelopt/torch/speculative/plugins/transformers.py (1)

720-725: Consider a size guard before caching full teacher softmax.

Caching base_output_softmax_logits as a full [B, S, V] tensor can materially increase peak memory (especially large-vocab runs). A guarded fallback to per-slice softmax would keep this optimization safer across wider configs.

💡 Example guard pattern
-        base_output_predict_tok = base_model_logits.argmax(dim=-1).detach()
-        base_output_softmax_logits = torch.softmax(base_model_logits, dim=2).detach()
+        base_output_predict_tok = base_model_logits.argmax(dim=-1).detach()
+        cache_softmax = base_model_logits.numel() <= getattr(
+            self.eagle_config, "max_cached_teacher_prob_elems", 0
+        )
+        base_output_softmax_logits = (
+            torch.softmax(base_model_logits, dim=2).detach() if cache_softmax else None
+        )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/speculative/plugins/transformers.py` around lines 720 - 725,
The code currently caches base_output_softmax_logits as a full [B,S,V] tensor
which can blow up memory for large vocabularies; add a size guard using
eagle_config.draft_vocab_size and eagle_config.vocab_size (or an explicit
max_vocab_for_full_softmax threshold) and only compute/cache full softmax when
vocab_size*B*S is below the threshold, otherwise avoid storing
base_output_softmax_logits and compute softmax per-slice on demand (or keep only
argmax via base_output_predict_tok); update the block around base_model_logits,
base_output_predict_tok and base_output_softmax_logits to branch on this guard
and ensure downstream users handle the per-slice-compute path.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Nitpick comments:
In `@modelopt/torch/speculative/plugins/transformers.py`:
- Around line 720-725: The code currently caches base_output_softmax_logits as a
full [B,S,V] tensor which can blow up memory for large vocabularies; add a size
guard using eagle_config.draft_vocab_size and eagle_config.vocab_size (or an
explicit max_vocab_for_full_softmax threshold) and only compute/cache full
softmax when vocab_size*B*S is below the threshold, otherwise avoid storing
base_output_softmax_logits and compute softmax per-slice on demand (or keep only
argmax via base_output_predict_tok); update the block around base_model_logits,
base_output_predict_tok and base_output_softmax_logits to branch on this guard
and ensure downstream users handle the per-slice-compute path.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: f3113c0a-eb31-4a7d-b61d-bae7d19554ae

📥 Commits

Reviewing files that changed from the base of the PR and between 44c00b8 and b30d95c.

📒 Files selected for processing (1)
  • modelopt/torch/speculative/plugins/transformers.py

@benchislett
Copy link
Contributor Author

@h-guo18 I isolated perf using nsys profile with online training of a config for Llama 3.2 1B with K=3 on ISL 2048. Looking at the EAGLE3 FWD+BWD and excluding the target model forward pass, I get 1.9x speed improvement (roughly 280ms per batch of 16 requests, down from 540ms on main)

@benchislett
Copy link
Contributor Author

@h-guo18 could you advise on this test failure? It seems like the windows build doesn't have NVTX available? I'm not sure how modelopt CI works, what do you suggest to fix the CI?

@h-guo18
Copy link
Contributor

h-guo18 commented Mar 16, 2026

@h-guo18 could you advise on this test failure? It seems like the windows build doesn't have NVTX available? I'm not sure how modelopt CI works, what do you suggest to fix the CI?

Seems like only the test on windows fails (link). Instead of installing it in the testing container, I think it's better to make it an optional dependency for minimal impact.

e.g. wrap the decorator with some check:

from contextlib import contextmanager

try:
    from torch.cuda import nvtx as torch_nvtx
except Exception:
    torch_nvtx = None


def _nvtx_available() -> bool:
    if torch_nvtx is None:
        return False
    try:
        torch_nvtx.range_push("probe")
        torch_nvtx.range_pop()
        return True
    except Exception:
        return False


_NVTX_ENABLED = _nvtx_available()


def nvtx_range(msg: str):
    """Can be used as both decorator and context manager fallback target."""
    if _NVTX_ENABLED:
        return torch_nvtx.range(msg)
    return _null_range(msg)


@contextmanager
def _null_range(msg: str):
    yield



#to use it:       
@nvtx_range("eagle_loss")
def compute_loss(x):
    return x.sum()

@h-guo18
Copy link
Contributor

h-guo18 commented Mar 16, 2026

@h-guo18 I isolated perf using nsys profile with online training of a config for Llama 3.2 1B with K=3 on ISL 2048. Looking at the EAGLE3 FWD+BWD and excluding the target model forward pass, I get 1.9x speed improvement (roughly 280ms per batch of 16 requests, down from 540ms on main)

That's huge speedup. Thanks! I don't see any torch.compile in current PR code, so I assume most speedup comes from precomputing base_model_logits related stuff?

@benchislett
Copy link
Contributor Author

Oh oops I forgot to push the torch.compile annotations

@benchislett
Copy link
Contributor Author

A massive chunk comes from the torch.compile

Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@modelopt/torch/speculative/plugins/transformers.py`:
- Line 41: Make the NVTX import optional: replace the hard import of nvtx with a
guarded import (try/except ImportError) that sets a module-level nvtx = None
when unavailable, and add a helper function/ decorator factory named nvtx_range
that returns a no-op decorator when nvtx is None or returns nvtx.range(...) when
present; then replace all uses of the `@nvtx.range`(...) decorator in this module
(e.g., on any functions decorated with nvtx.range) with `@nvtx_range`(...) so the
code runs on systems without CUDA without changing behavior when nvtx is
available.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 85cd1527-f41a-4ff3-ad63-5908d1cc835b

📥 Commits

Reviewing files that changed from the base of the PR and between b30d95c and 67e4bee.

📒 Files selected for processing (1)
  • modelopt/torch/speculative/plugins/transformers.py

import transformers
from packaging.version import Version
from torch import nn
from torch.cuda import nvtx
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Make NVTX import optional to avoid CI failures on systems without CUDA.

Per the PR discussion, Windows CI is failing because NVTX is not available in the Windows build. The import should be guarded to make NVTX an optional dependency. As per coding guidelines: "Avoid hard imports of optional dependencies at module level."

🛠️ Proposed fix to make NVTX optional
-from torch.cuda import nvtx
+try:
+    from torch.cuda import nvtx
+    _NVTX_AVAILABLE = True
+except ImportError:
+    _NVTX_AVAILABLE = False
+
+
+@contextlib.contextmanager
+def _nvtx_range_fallback(name):
+    """No-op fallback when NVTX is unavailable."""
+    yield
+
+
+def nvtx_range(name):
+    """Return NVTX range decorator if available, else no-op."""
+    if _NVTX_AVAILABLE:
+        return nvtx.range(name)
+    return _nvtx_range_fallback(name)

Then replace all @nvtx.range(...) decorators with @nvtx_range(...).

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/speculative/plugins/transformers.py` at line 41, Make the NVTX
import optional: replace the hard import of nvtx with a guarded import
(try/except ImportError) that sets a module-level nvtx = None when unavailable,
and add a helper function/ decorator factory named nvtx_range that returns a
no-op decorator when nvtx is None or returns nvtx.range(...) when present; then
replace all uses of the `@nvtx.range`(...) decorator in this module (e.g., on any
functions decorated with nvtx.range) with `@nvtx_range`(...) so the code runs on
systems without CUDA without changing behavior when nvtx is available.

@h-guo18
Copy link
Contributor

h-guo18 commented Mar 16, 2026

A massive chunk comes from the torch.compile

Make sense. Besides,

  1. Shall we make torch.compile an optional feature? e.g. add an argument in training script. I remember torch compile fails in some case (fsdp or cp>1, forgot the exact setting).
  2. Could you run a quick correctness test and check the training acc curves before/after change? A few k steps on llama 1b would be sufficient I think.

@yeyu-nvidia
Copy link
Contributor

LGTM. Please run a full test before merging

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants