Skip to content

[#13909][fix] Reuse hidden_states buffer across CUDA graph captures in Eagle3#13920

Open
ml-inference wants to merge 1 commit intoNVIDIA:mainfrom
ml-inference:fix/eagle3-cuda-graph-memory
Open

[#13909][fix] Reuse hidden_states buffer across CUDA graph captures in Eagle3#13920
ml-inference wants to merge 1 commit intoNVIDIA:mainfrom
ml-inference:fix/eagle3-cuda-graph-memory

Conversation

@ml-inference
Copy link
Copy Markdown

@ml-inference ml-inference commented May 8, 2026

Description

During CUDA graph capture in the Eagle3 one-model flow, Eagle3OneModelSpecMetadata allocated a new hidden_states buffer
(max_num_tokens × hidden_size × num_capture_layers) per graph capture. This caused memory to grow linearly with the number
of graphs captured, wasting ~2.75 GiB for 16 captures on Qwen3-235B.

Fix: Always create Eagle3ResourceManager in the one-model flow (previously only created when sa_manager was not None)
and reuse its pre-allocated hidden_states buffer instead of allocating a new one per capture.

Fixes: #13909

Test Coverage

Manually validated on Qwen3-235B-A22B with FP8 KV cache, 16 CUDA graph captures:

  • Without fix: memory grows from 30.28 GiB → 33.03 GiB (+2.75 GiB, ~183 MiB per capture)
  • With fix: memory stays flat at ~30.37 GiB across all 16 captures

PR Checklist

Please review the following before submitting your PR:

  • PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.

  • PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.

  • Test cases are provided for new code paths (see test instructions)

  • Any new dependencies have been scanned for license and vulnerabilities

  • CODEOWNERS updated if ownership changes

  • Documentation updated as needed

  • Update tava architecture diagram if there is a significant design change in PR.

  • The reviewers assigned automatically/manually are appropriate for the PR.

  • [x ] Please check this after reviewing the above items as appropriate for this PR.

GitHub Bot Help

To see a list of available CI bot commands, please comment /bot help.

Summary by CodeRabbit

Release Notes

  • Refactor

    • Optimized hidden state memory allocation in speculative decoding through unified resource buffer usage
    • Enhanced Eagle3ResourceManager initialization to consistently construct the resource manager
  • Chores

    • Added debug logging for CUDA graph capture operations with runner identification and memory metrics

…Eagle3

Previously, Eagle3OneModelSpecMetadata allocated a new hidden_states
buffer (max_num_tokens × hidden_size × num_capture_layers) per CUDA
graph capture. This caused memory to grow linearly with the number of
graphs captured, wasting ~2.75 GiB for 16 captures on Qwen3-235B.

Fix: Always create Eagle3ResourceManager in the one-model flow and
reuse its pre-allocated hidden_states buffer instead of allocating a
new one per capture.

Fixes: NVIDIA#13909

Signed-off-by: Spurthi Sandiri <spurthi@amazon.com>
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented May 8, 2026

Review Change Stack

📝 Walkthrough

Walkthrough

This PR fixes excessive GPU memory allocation during CUDA graph capture in Eagle3 speculative decoding. The resource manager now unconditionally returns an Eagle3ResourceManager with an optional suffix automaton manager. Metadata instances reuse the manager's pre-allocated hidden-states buffer instead of allocating new tensors per capture. Debug logging tracks memory behavior across graph captures.

Changes

Eagle3 Memory Reuse During Graph Capture

Layer / File(s) Summary
Resource Manager Setup
tensorrt_llm/_torch/speculative/utils.py
get_spec_resource_manager now unconditionally constructs and returns Eagle3ResourceManager for Eagle3 one-model mode, accepting SuffixAutomatonManager as an optional nullable argument instead of only returning when manager is non-None.
Metadata Buffer Reuse
tensorrt_llm/_torch/speculative/eagle3.py
Eagle3OneModelSpecMetadata.__post_init__ assigns self.hidden_states from the pre-allocated self.spec_resource_manager.hidden_states instead of allocating a new CUDA tensor per metadata instance.
Graph Capture Observability
tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py
Logger import added; debug log emitted after pool finalization with capture key, draft model status, graph count, and current CUDA memory allocated to verify flat memory usage across captures.

Estimated code review effort

🎯 2 (Simple) | ⏱️ ~10 minutes

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 50.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly and specifically describes the main fix: reusing the hidden_states buffer across CUDA graph captures in Eagle3, which directly addresses the core issue.
Description check ✅ Passed The description clearly explains the problem (per-capture buffer allocation causing memory waste), the fix (reusing pre-allocated buffer), and provides specific validation results showing memory staying flat after the fix.
Linked Issues check ✅ Passed The PR fully addresses issue #13909: always creates Eagle3ResourceManager and reuses its pre-allocated hidden_states buffer across captures, eliminating the per-capture allocation overhead.
Out of Scope Changes check ✅ Passed All changes are scoped to the memory optimization fix: debug logging in cuda_graph_runner.py, hidden_states buffer reuse in eagle3.py, and ensuring Eagle3ResourceManager creation in utils.py.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Tip

💬 Introducing Slack Agent: The best way for teams to turn conversations into code.

Slack Agent is built on CodeRabbit's deep understanding of your code, so your team can collaborate across the entire SDLC without losing context.

  • Generate code and open pull requests
  • Plan features and break down work
  • Investigate incidents and troubleshoot customer tickets together
  • Automate recurring tasks and respond to alerts with triggers
  • Summarize progress and report instantly

Built for teams:

  • Shared memory across your entire org—no repeating context
  • Per-thread sandboxes to safely plan and execute work
  • Governance built-in—scoped access, auditability, and budget controls

One agent for your entire SDLC. Right inside Slack.

👉 Get started


Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py (1)

1-1: ⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Add/update NVIDIA copyright header in this modified source file

This file is modified but does not show the required NVIDIA copyright header/current modification year.

As per coding guidelines, "All C++, Python, and other source files must contain NVIDIA copyright header with current modification year."

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py` at line 1, This file
(starts with the import bisect statement) is missing the required NVIDIA
copyright header; add the standard NVIDIA copyright/header block at the very top
of tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py (above the existing
import bisect), update the modification year to the current year, and ensure the
header formatting matches other project source files so linters and legal checks
recognize it.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py`:
- Around line 394-399: The log line in capture uses self.config.mapping.rank
unguarded which can raise AttributeError when mapping is None; update the
logger.debug call in the cuda_graph_runner capture path to safely read rank
(e.g., use a conditional/getattr to produce a safe default like "None" or -1
when self.config.mapping is absent) while keeping the rest of the fields
(self.config.is_draft_model, len(self.graphs), torch.cuda.memory_allocated())
unchanged; locate the logger.debug invocation in the class/method that performs
cuda graph capture and replace the direct .rank access with the guarded
expression.

In `@tensorrt_llm/_torch/speculative/eagle3.py`:
- Line 372: The assignment self.hidden_states =
self.spec_resource_manager.hidden_states fails in dynamic-tree one-model mode
because Eagle3OneModelDynamicTreeResourceManager lacks hidden_states; change the
initialization to safely handle both managers by checking for the attribute
(e.g., use getattr(self.spec_resource_manager, "hidden_states", None)) or by
conditional logic that only sets self.hidden_states when spec_resource_manager
exposes it so metadata init doesn't raise AttributeError.

---

Outside diff comments:
In `@tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py`:
- Line 1: This file (starts with the import bisect statement) is missing the
required NVIDIA copyright header; add the standard NVIDIA copyright/header block
at the very top of tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py (above
the existing import bisect), update the modification year to the current year,
and ensure the header formatting matches other project source files so linters
and legal checks recognize it.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: d7902bb3-cc50-4f06-a8ca-42f9b27649fc

📥 Commits

Reviewing files that changed from the base of the PR and between f8572ab and 99db30c.

📒 Files selected for processing (3)
  • tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py
  • tensorrt_llm/_torch/speculative/eagle3.py
  • tensorrt_llm/_torch/speculative/utils.py

Comment on lines +394 to +399
logger.debug(
f"[Memory-Debug][rank={self.config.mapping.rank}][cuda_graph:capture] "
f"key={key} is_draft_model={self.config.is_draft_model} "
f"graphs_so_far={len(self.graphs)} "
f"torch_memory_allocated={torch.cuda.memory_allocated()/1024**3:.2f}GiB"
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Guard optional mapping before logging rank

self.config.mapping is optional, but Line 395 dereferences .rank unconditionally. That can crash capture with AttributeError when mapping is absent.

Proposed fix
-        logger.debug(
-            f"[Memory-Debug][rank={self.config.mapping.rank}][cuda_graph:capture] "
+        rank = self.config.mapping.rank if self.config.mapping is not None else -1
+        logger.debug(
+            f"[Memory-Debug][rank={rank}][cuda_graph:capture] "
             f"key={key} is_draft_model={self.config.is_draft_model} "
             f"graphs_so_far={len(self.graphs)} "
             f"torch_memory_allocated={torch.cuda.memory_allocated()/1024**3:.2f}GiB"
         )
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py` around lines 394 - 399,
The log line in capture uses self.config.mapping.rank unguarded which can raise
AttributeError when mapping is None; update the logger.debug call in the
cuda_graph_runner capture path to safely read rank (e.g., use a
conditional/getattr to produce a safe default like "None" or -1 when
self.config.mapping is absent) while keeping the rest of the fields
(self.config.is_draft_model, len(self.graphs), torch.cuda.memory_allocated())
unchanged; locate the logger.debug invocation in the class/method that performs
cuda graph capture and replace the direct .rank access with the guarded
expression.

self.hidden_size * len(self.layers_to_capture)),
dtype=self.dtype,
device='cuda')
self.hidden_states = self.spec_resource_manager.hidden_states
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical | ⚡ Quick win

Critical: Line 372 breaks dynamic-tree one-model mode (AttributeError).

self.spec_resource_manager.hidden_states is accessed unconditionally, but dynamic-tree mode uses Eagle3OneModelDynamicTreeResourceManager (no hidden_states field). This can fail during metadata init.

Proposed fix
-        self.hidden_states = self.spec_resource_manager.hidden_states
+        if isinstance(self.spec_resource_manager, Eagle3ResourceManager):
+            self.hidden_states = self.spec_resource_manager.hidden_states
+        else:
+            # Dynamic-tree manager does not provide hidden_states.
+            # Keep per-metadata allocation for that mode.
+            self.hidden_states = torch.empty(
+                (self.max_num_tokens, self.hidden_size * self.num_capture_layers),
+                dtype=self.dtype,
+                device='cuda',
+            )
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tensorrt_llm/_torch/speculative/eagle3.py` at line 372, The assignment
self.hidden_states = self.spec_resource_manager.hidden_states fails in
dynamic-tree one-model mode because Eagle3OneModelDynamicTreeResourceManager
lacks hidden_states; change the initialization to safely handle both managers by
checking for the attribute (e.g., use getattr(self.spec_resource_manager,
"hidden_states", None)) or by conditional logic that only sets
self.hidden_states when spec_resource_manager exposes it so metadata init
doesn't raise AttributeError.

@svc-trtllm-gh-bot svc-trtllm-gh-bot added the Community want to contribute PRs initiated from Community label May 8, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Community want to contribute PRs initiated from Community

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Bug]: Excessive memory allocation during CUDA graph capture in Eagle3 flow

2 participants