Conversation
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
… calls Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughParallelism config creation in the speculative decoding example is now conditional. The Transformers plugin adds NVTX profiling decorators, lazily initializes Llama rotary embeddings in Changes
Sequence Diagram(s)mermaid Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Important Pre-merge checks failedPlease resolve all errors before merging. Addressing warnings is optional. ❌ Failed checks (1 error, 1 warning)
✅ Passed checks (2 passed)
✨ Finishing Touches
🧪 Generate unit tests (beta)
📝 Coding Plan
Comment |
| training_args.parallelism_config = ParallelismConfig( | ||
| cp_size=training_args.cp_size, dp_shard_size=training_args.dp_shard_size | ||
| ) | ||
| if training_args.cp_size > 1 or training_args.dp_shard_size > 1: |
There was a problem hiding this comment.
Note, this is an unrelated bugfix related to #1045 (does not fully solve the issue, just a single-gpu workaround)
There was a problem hiding this comment.
As discussed in slack, this issue id due to transformers version mismatch. Should be fixed after updating transformers.
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@modelopt/torch/speculative/plugins/transformers.py`:
- Around line 908-910: The code can raise a NameError when self.eagle_ttt_steps
== 0 because the loop that defines ttt_step never runs; update the logic in
modify() (or the surrounding block) to handle the zero-case explicitly: either
assert self.eagle_ttt_steps >= 1 at the start of modify() to make the invariant
explicit, or initialize ttt_step to a safe default (or skip code that uses
ttt_step) when eagle_ttt_steps == 0 and ensure train_accs =
torch.zeros(num_parallel, num_ttt, device=input_ids.device) is still valid;
reference symbols: self.eagle_ttt_steps, train_accs, ttt_step, and modify() when
applying the fix.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: de2f61b0-6ad0-43ef-b333-c5cd195b6a21
📒 Files selected for processing (1)
modelopt/torch/speculative/plugins/transformers.py
There was a problem hiding this comment.
Actionable comments posted: 1
♻️ Duplicate comments (1)
modelopt/torch/speculative/plugins/transformers.py (1)
909-912:⚠️ Potential issue | 🟠 MajorGuard zero-step TTT to avoid undefined
ttt_step.If
self.eagle_ttt_steps == 0, the loop at Line 931 never runs, and Line 989 referencesttt_stepbefore assignment.🔧 Proposed fix
- train_accs = torch.zeros(num_parallel, num_ttt, device=input_ids.device) + train_accs = torch.zeros(num_parallel, num_ttt, device=input_ids.device) + executed_ttt_steps = 0 @@ - for ttt_step in range(self.eagle_ttt_steps): + for ttt_step in range(self.eagle_ttt_steps): @@ - train_accs[i, ttt_step] = acc + train_accs[i, ttt_step] = acc + executed_ttt_steps = ttt_step + 1 if not self.training: break @@ - train_accs = train_accs[:, : ttt_step + 1].tolist() + train_accs = train_accs[:, :executed_ttt_steps].tolist()Also applies to: 988-990
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/speculative/plugins/transformers.py` around lines 909 - 912, Guard against zero TTT steps by checking self.eagle_ttt_steps before using ttt_step or running the TTT loop: if self.eagle_ttt_steps == 0 skip the entire TTT block (including the loop that populates train_accs and any later use of ttt_step) or initialize a safe default for ttt_step and related tensors so they are defined when eagle_ttt_steps is 0; update references around train_accs, the loop that iterates over range(self.eagle_ttt_steps) and the later code that uses ttt_step (the code near variables num_ttt, train_accs and where ttt_step is referenced) to either early-return/skip or handle the zero-case explicitly.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@modelopt/torch/speculative/plugins/transformers.py`:
- Around line 928-929: The RoPE initializer is only called in
HFEagleModel.forward, but other entry points like pseudo_speculative_generate()
call _eagle_forward() and can trigger EagleModule.forward before rotary_emb
exists; update the code so every EAGLE entry path invokes the initializer: call
self.eagle_module._maybe_init_rope() at the start of EagleModule.forward and
also ensure _eagle_forward() (and/or pseudo_speculative_generate()) invokes
_maybe_init_rope() before any use of rotary_emb so RoPE is always initialized
regardless of which forward path is exercised.
---
Duplicate comments:
In `@modelopt/torch/speculative/plugins/transformers.py`:
- Around line 909-912: Guard against zero TTT steps by checking
self.eagle_ttt_steps before using ttt_step or running the TTT loop: if
self.eagle_ttt_steps == 0 skip the entire TTT block (including the loop that
populates train_accs and any later use of ttt_step) or initialize a safe default
for ttt_step and related tensors so they are defined when eagle_ttt_steps is 0;
update references around train_accs, the loop that iterates over
range(self.eagle_ttt_steps) and the later code that uses ttt_step (the code near
variables num_ttt, train_accs and where ttt_step is referenced) to either
early-return/skip or handle the zero-case explicitly.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: d8dcdd21-580f-406c-9a0f-ff3ab7c80865
📒 Files selected for processing (1)
modelopt/torch/speculative/plugins/transformers.py
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## main #1044 +/- ##
==========================================
- Coverage 70.10% 70.02% -0.09%
==========================================
Files 221 221
Lines 25541 25541
==========================================
- Hits 17905 17884 -21
- Misses 7636 7657 +21 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
|
Make sense to me. Is there any perf comparison before/after the optimizations? |
There was a problem hiding this comment.
🧹 Nitpick comments (1)
modelopt/torch/speculative/plugins/transformers.py (1)
720-725: Consider a size guard before caching full teacher softmax.Caching
base_output_softmax_logitsas a full[B, S, V]tensor can materially increase peak memory (especially large-vocab runs). A guarded fallback to per-slice softmax would keep this optimization safer across wider configs.💡 Example guard pattern
- base_output_predict_tok = base_model_logits.argmax(dim=-1).detach() - base_output_softmax_logits = torch.softmax(base_model_logits, dim=2).detach() + base_output_predict_tok = base_model_logits.argmax(dim=-1).detach() + cache_softmax = base_model_logits.numel() <= getattr( + self.eagle_config, "max_cached_teacher_prob_elems", 0 + ) + base_output_softmax_logits = ( + torch.softmax(base_model_logits, dim=2).detach() if cache_softmax else None + )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/speculative/plugins/transformers.py` around lines 720 - 725, The code currently caches base_output_softmax_logits as a full [B,S,V] tensor which can blow up memory for large vocabularies; add a size guard using eagle_config.draft_vocab_size and eagle_config.vocab_size (or an explicit max_vocab_for_full_softmax threshold) and only compute/cache full softmax when vocab_size*B*S is below the threshold, otherwise avoid storing base_output_softmax_logits and compute softmax per-slice on demand (or keep only argmax via base_output_predict_tok); update the block around base_model_logits, base_output_predict_tok and base_output_softmax_logits to branch on this guard and ensure downstream users handle the per-slice-compute path.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Nitpick comments:
In `@modelopt/torch/speculative/plugins/transformers.py`:
- Around line 720-725: The code currently caches base_output_softmax_logits as a
full [B,S,V] tensor which can blow up memory for large vocabularies; add a size
guard using eagle_config.draft_vocab_size and eagle_config.vocab_size (or an
explicit max_vocab_for_full_softmax threshold) and only compute/cache full
softmax when vocab_size*B*S is below the threshold, otherwise avoid storing
base_output_softmax_logits and compute softmax per-slice on demand (or keep only
argmax via base_output_predict_tok); update the block around base_model_logits,
base_output_predict_tok and base_output_softmax_logits to branch on this guard
and ensure downstream users handle the per-slice-compute path.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: f3113c0a-eb31-4a7d-b61d-bae7d19554ae
📒 Files selected for processing (1)
modelopt/torch/speculative/plugins/transformers.py
|
@h-guo18 I isolated perf using nsys profile with online training of a config for Llama 3.2 1B with K=3 on ISL 2048. Looking at the EAGLE3 FWD+BWD and excluding the target model forward pass, I get 1.9x speed improvement (roughly 280ms per batch of 16 requests, down from 540ms on main) |
|
@h-guo18 could you advise on this test failure? It seems like the windows build doesn't have NVTX available? I'm not sure how modelopt CI works, what do you suggest to fix the CI? |
Seems like only the test on windows fails (link). Instead of installing it in the testing container, I think it's better to make it an optional dependency for minimal impact. e.g. wrap the decorator with some check: from contextlib import contextmanager
try:
from torch.cuda import nvtx as torch_nvtx
except Exception:
torch_nvtx = None
def _nvtx_available() -> bool:
if torch_nvtx is None:
return False
try:
torch_nvtx.range_push("probe")
torch_nvtx.range_pop()
return True
except Exception:
return False
_NVTX_ENABLED = _nvtx_available()
def nvtx_range(msg: str):
"""Can be used as both decorator and context manager fallback target."""
if _NVTX_ENABLED:
return torch_nvtx.range(msg)
return _null_range(msg)
@contextmanager
def _null_range(msg: str):
yield
#to use it:
@nvtx_range("eagle_loss")
def compute_loss(x):
return x.sum() |
That's huge speedup. Thanks! I don't see any torch.compile in current PR code, so I assume most speedup comes from precomputing base_model_logits related stuff? |
|
Oh oops I forgot to push the torch.compile annotations |
|
A massive chunk comes from the torch.compile |
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@modelopt/torch/speculative/plugins/transformers.py`:
- Line 41: Make the NVTX import optional: replace the hard import of nvtx with a
guarded import (try/except ImportError) that sets a module-level nvtx = None
when unavailable, and add a helper function/ decorator factory named nvtx_range
that returns a no-op decorator when nvtx is None or returns nvtx.range(...) when
present; then replace all uses of the `@nvtx.range`(...) decorator in this module
(e.g., on any functions decorated with nvtx.range) with `@nvtx_range`(...) so the
code runs on systems without CUDA without changing behavior when nvtx is
available.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 85cd1527-f41a-4ff3-ad63-5908d1cc835b
📒 Files selected for processing (1)
modelopt/torch/speculative/plugins/transformers.py
| import transformers | ||
| from packaging.version import Version | ||
| from torch import nn | ||
| from torch.cuda import nvtx |
There was a problem hiding this comment.
Make NVTX import optional to avoid CI failures on systems without CUDA.
Per the PR discussion, Windows CI is failing because NVTX is not available in the Windows build. The import should be guarded to make NVTX an optional dependency. As per coding guidelines: "Avoid hard imports of optional dependencies at module level."
🛠️ Proposed fix to make NVTX optional
-from torch.cuda import nvtx
+try:
+ from torch.cuda import nvtx
+ _NVTX_AVAILABLE = True
+except ImportError:
+ _NVTX_AVAILABLE = False
+
+
+@contextlib.contextmanager
+def _nvtx_range_fallback(name):
+ """No-op fallback when NVTX is unavailable."""
+ yield
+
+
+def nvtx_range(name):
+ """Return NVTX range decorator if available, else no-op."""
+ if _NVTX_AVAILABLE:
+ return nvtx.range(name)
+ return _nvtx_range_fallback(name)Then replace all @nvtx.range(...) decorators with @nvtx_range(...).
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@modelopt/torch/speculative/plugins/transformers.py` at line 41, Make the NVTX
import optional: replace the hard import of nvtx with a guarded import
(try/except ImportError) that sets a module-level nvtx = None when unavailable,
and add a helper function/ decorator factory named nvtx_range that returns a
no-op decorator when nvtx is None or returns nvtx.range(...) when present; then
replace all uses of the `@nvtx.range`(...) decorator in this module (e.g., on any
functions decorated with nvtx.range) with `@nvtx_range`(...) so the code runs on
systems without CUDA without changing behavior when nvtx is available.
Make sense. Besides,
|
|
LGTM. Please run a full test before merging |
What does this PR do?
Type of change: Optimization
Changes:
Usage
No changes to external interfaces
Testing
Ran training commands for benchmarking. Did not do a full training run, did not validate correctness.
Before your PR is "Ready for review"
Make sure you read and follow Contributor guidelines and your commits are signed (
git commit -s -S).Make sure you read and follow the Security Best Practices (e.g. avoiding hardcoded
trust_remote_code=True,torch.load(..., weights_only=False),pickle, etc.).CONTRIBUTING.md: N/AAdditional Information
Summary by CodeRabbit
New Features
Improvements