-
Notifications
You must be signed in to change notification settings - Fork 2k
[None][feat] Nemotron H more overlap #10938
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[None][feat] Nemotron H more overlap #10938
Conversation
Signed-off-by: qgai <qgai@nvidia.com>
Signed-off-by: qgai <qgai@nvidia.com>
Signed-off-by: qgai <qgai@nvidia.com>
Signed-off-by: qgai <qgai@nvidia.com>
Signed-off-by: qgai <qgai@nvidia.com>
Signed-off-by: qgai <qgai@nvidia.com>
Signed-off-by: qgai <qgai@nvidia.com>
Signed-off-by: qgai <qgai@nvidia.com>
Signed-off-by: qgai <qgai@nvidia.com>
Signed-off-by: Izzy Putterman <iputterman@nvidia.com>
|
Builds on #10754 |
📝 WalkthroughWalkthroughThis pull request adds Multi-Token Prediction (MTP) support to the Nemotron H model within TensorRT LLM, integrating speculative decoding capabilities with Mamba2 SSM enhancements. Changes include protective kernel checks, MTP-aware weight mapping, new MTP decoder layers with speculative metadata threading, Triton-optimized causal convolution kernels, enhanced state management for speculative execution, and corresponding test configurations. Changes
Estimated code review effort🎯 4 (Complex) | ⏱️ ~50 minutes 🚥 Pre-merge checks | ✅ 1 | ❌ 2❌ Failed checks (2 warnings)
✅ Passed checks (1 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 10
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (7)
tensorrt_llm/_torch/models/modeling_nemotron_h.py (1)
1-2: Update the SPDX copyright year to 2026.The file is modified in 2026 but the header still ends at 2024; please align it with the latest meaningful modification. As per coding guidelines, please update the header year.
🔧 Suggested update
-# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.tensorrt_llm/_torch/modules/mamba/mamba2_mixer.py (1)
1-2: Update the SPDX copyright year to 2026.The file is modified in 2026 but the header still ends at 2024; please align it with the latest meaningful modification. As per coding guidelines, please update the header year.
🔧 Suggested update
-# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.tensorrt_llm/_torch/modules/mamba/ssd_combined.py (1)
4-5: Update the SPDX copyright year to 2026.The file is modified in 2026 but the header still ends at 2024; please align it with the latest meaningful modification. As per coding guidelines, please update the header year.
🔧 Suggested update
-# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.tensorrt_llm/_torch/modules/mamba/ssd_chunk_scan.py (1)
4-5: Update the SPDX copyright year to 2026.The file is modified in 2026 but the header still ends at 2024; please align it with the latest meaningful modification. As per coding guidelines, please update the header year.
🔧 Suggested update
-# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.tensorrt_llm/_torch/pyexecutor/py_executor.py (1)
1-3: Add NVIDIA copyright SPDX header.This file was modified but lacks the required NVIDIA SPDX header with the latest year. Please add the standard header (e.g., 2026).
📄 Suggested header addition
+ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + # SPDX-License-Identifier: Apache-2.0As per coding guidelines, please add the NVIDIA header.
tensorrt_llm/_torch/modules/mamba/selective_state_update.py (1)
293-359: Guard optionaloutbefore dereferencing.
outdefaults toNone, butout.dim()is used unconditionally, which will raise at runtime if callers don’t passout. Either makeoutmandatory or allocate a default.🐛 Suggested fix
- if out.dim() == 2: + if out is None: + out = torch.empty_like(x) + if out.dim() == 2: out = out.unsqueeze(1) if out.dim() == 3: out = out.unsqueeze(1)tensorrt_llm/_torch/pyexecutor/mamba_cache_manager.py (1)
1-1: Update NVIDIA copyright year to 2026.The file was modified in 2026; update the header year accordingly.
📄 Suggested header update
-# SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.As per coding guidelines, please update the header year.
🤖 Fix all issues with AI agents
In `@tensorrt_llm/_torch/models/checkpoints/hf/nemotron_h_weight_mapper.py`:
- Around line 59-67: The MTP key matching uses re.match which only matches at
the string start, so keys with prefixes like "model.mtp.layers.0..." will not be
transformed; update the matching in the block that checks "if 'mtp.layers.' in
key:" to use a prefix-tolerant pattern (e.g., use re.search or change the regex
to allow leading text) to reliably capture sublayer_idx and rest from keys like
"...mtp.layers.<idx>.<rest>", and on failure raise an exception (or raise a
ValueError) instead of merely logging via logger.error so mapping failures
fail-fast; ensure you update the code that assigns key =
f"model.layers.{config.num_hidden_layers}.layers.{sublayer_idx}.{rest}" to use
the captured groups from the new match.
In `@tensorrt_llm/_torch/models/modeling_nemotron_h.py`:
- Around line 594-601: The MTP helper functions (starting with the forward
method in modeling_nemotron_h.py) are triggering Ruff ARG002 for unused
parameters; to silence these, rename the unused params by prefixing them with an
underscore (e.g., positions -> _positions, is_separate_draft_engine ->
_is_separate_draft_engine, prefix -> _prefix, layer_idx -> _layer_idx,
all_rank_num_tokens -> _all_rank_num_tokens) or add an inline noqa comment (#
noqa: ARG002) where renaming is undesirable for API compatibility; ensure you do
this consistently in the forward signature and the other MTP helper functions
referenced around the other ranges so spec_metadata/**kwargs remain intact if
you later propagate metadata.
- Around line 627-630: The mixer call in NemotronHMTPDecoderLayer.forward and
the calls originating from NemotronHMTP.forward do not pass mamba_metadata and
spec_metadata, which Mamba2Mixer (and speculative-cache logic) require when
mtp_hybrid_override_pattern can include "M"; update both
NemotronHMTPDecoderLayer.forward and the code paths in NemotronHMTP.forward that
invoke the layer/mixer to thread the mamba_metadata and spec_metadata arguments
through to the mixer invocation (the hidden_states = self.mixer(...) call),
ensuring the same metadata variables are accepted by the layer signatures and
forwarded unchanged to avoid runtime errors or speculative-state corruption.
In `@tensorrt_llm/_torch/models/modeling_speculative.py`:
- Around line 743-748: MTPDraftModel now constructs NemotronHMTP for model_type
"nemotron_h" but MTPDraftModelForCausalLM.load_weights lacks a corresponding
branch, causing a ValueError in two-engine mode; update load_weights to handle
"nemotron_h" by adding a dedicated weight-loading branch that mirrors how
NemotronHMTP expects weights (or call the existing NemotronHMTP weight helper),
or explicitly guard against two-engine usage for nemotron_h by raising a clear,
early error. Locate MTPDraftModelForCausalLM.load_weights and either add a case
for "nemotron_h" that delegates to NemotronHMTP weight loading logic or add an
explicit check that prevents two-engine mode when model_type == "nemotron_h".
In `@tensorrt_llm/_torch/modules/mamba/causal_conv1d_triton.py`:
- Around line 1-4: Add the NVIDIA SPDX copyright header (with the current year,
e.g., 2026) to the top of the module
tensorrt_llm._torch.modules.mamba.causal_conv1d_triton (i.e., prepend the
standard NVIDIA header block including SPDX-License-Identifier and copyright
notice) so the file begins with the required NVIDIA header instead of or in
addition to the existing Tri Dao comment; ensure the header follows project SPDX
format and is placed before any other code or comments.
In `@tensorrt_llm/_torch/modules/mamba/selective_state_update.py`:
- Around line 1-7: This file is missing the required NVIDIA SPDX copyright
header; at the top of
tensorrt_llm/_torch/modules/mamba/selective_state_update.py add the standard
NVIDIA SPDX header (including the current year, e.g., 2026) above or alongside
the existing SPDX and copyright notices so the file includes NVIDIA's
SPDX-License-Identifier and SPDX-FileCopyrightText with the NVIDIA copyright
owner and year; ensure the new header appears before any code or existing
comments to comply with licensing guidelines.
In `@tensorrt_llm/_torch/modules/mamba/ssd_combined.py`:
- Around line 235-269: The unpacked outputs from _mamba_chunk_scan_combined_fwd
(out_x, dt_out, dA_cumsum, states) are unused and should be renamed with leading
underscores to silence lint warnings (e.g., _out_x, _dt_out, _dA_cumsum,
_states) in the assignment where _mamba_chunk_scan_combined_fwd is called; also
fix the parameter name mismatch where triton_backend_mamba.py calls
mamba_chunk_scan_combined by replacing the incorrect keyword
mamba_ssm_cache_dtype with the expected state_dtype so the call uses
state_dtype=ssm_state_cache.dtype.
In `@tests/integration/defs/accuracy/references/gsm8k.yaml`:
- Around line 325-329: The YAML entry currently uses mtp_enabled and
num_nextn_predict_layers which the accuracy matching logic in accuracy_core.py
ignores; remove those two fields and instead use the standard
speculative-decoding pattern by adding spec_dec_algo: MTP (matching other MTP
entries like deepseek-ai/DeepSeek-V3-Lite) so the framework can correctly find
this reference entry.
In `@tests/integration/defs/accuracy/references/mmlu.yaml`:
- Around line 362-366: The accuracy matcher never sees mtp_enabled and
num_nextn_predict_layers so MTP-specific reference entries (mtp_enabled,
num_nextn_predict_layers) are never selected; update the callers of
get_hypothesis_testing_params() to include mtp_enabled and
num_nextn_predict_layers in the parameter dict passed to the matcher (or
alternatively merge those flags into extra_acc_spec before calling
get_hypothesis_testing_params()), ensuring the matcher receives the new keys so
MTP runs match the MTP-specific YAML entry.
In `@tests/integration/defs/accuracy/test_llm_api_pytorch.py`:
- Around line 5554-5589: The test constructs an LLM with the MTP configuration
assigned to the wrong parameter: replace the use of decoding_config=mtp_config
with speculative_config=mtp_config in the LLM(...) call so the MTPDecodingConfig
(mtp_config) is passed into the speculative decoding pipeline; update the LLM
constructor invocation that currently references decoding_config to use
speculative_config instead so the test actually exercises MTP (refer to
mtp_config and the LLM(...) call where decoding_config is set).
🧹 Nitpick comments (3)
tensorrt_llm/_torch/pyexecutor/resource_manager.py (1)
1471-1487: Preferisinstanceover string-based type checks for MambaHybridCacheManager.
type(...).__name__is brittle (renames/subclasses) and non-idiomatic. Consider a local import andisinstance, or a duck-typing check. Please also confirm no circular-import issues.♻️ Suggested refactor
- if type(resource_manager).__name__ == 'MambaHybridCacheManager': + from .mamba_cache_manager import MambaHybridCacheManager + if isinstance(resource_manager, MambaHybridCacheManager): resource_manager.update_resources( scheduled_batch, attn_metadata, kv_cache_dtype_byte_size, update_mamba_cache_manager)tensorrt_llm/_torch/modules/mamba/ssd_chunk_scan.py (1)
501-545: Add explicit guard to clarify thatoutparameter must be provided.The current code will raise an opaque
NoneTypeerror if a caller forgets to pass theoutparameter, since it defaults toNone. Add an explicit assertion before accessing its attributes.Suggested fix
- assert out.shape == x.shape + assert out is not None, "out must be provided (preallocated)" + assert out.shape == x.shapetensorrt_llm/_torch/pyexecutor/py_executor.py (1)
22-23: Prefer module-namespace import for MambaHybridCacheManager.Repo guidelines require preserving module namespaces for imports. Consider importing the module and referencing the class via its namespace.
♻️ Suggested refactor
-from tensorrt_llm._torch.pyexecutor.mamba_cache_manager import \ - MambaHybridCacheManager +import tensorrt_llm._torch.pyexecutor.mamba_cache_manager as mamba_cache_manager @@ -self.is_mamba_hybrid_cache = isinstance(self.kv_cache_manager, - MambaHybridCacheManager) +self.is_mamba_hybrid_cache = isinstance( + self.kv_cache_manager, + mamba_cache_manager.MambaHybridCacheManager, +)As per coding guidelines, keep module namespaces on imports.
Also applies to: 351-352
| # MTP layers are stored as mtp.layers.0.xxx (sublayer 0, Attention) and mtp.layers.1.xxx (sublayer 1, MoE) | ||
| if "mtp.layers." in key: | ||
| import re | ||
| match = re.match(r'mtp\.layers\.(\d+)\.(.*)', key) | ||
| if match: | ||
| sublayer_idx, rest = match.groups() | ||
| key = f"model.layers.{config.num_hidden_layers}.layers.{sublayer_idx}.{rest}" | ||
| else: | ||
| logger.error(f"Failed to match MTP pattern for: {name}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fail-fast and match prefixed MTP keys.
re.match anchors at the string start, so keys like model.mtp.layers.0.* won’t match and will silently fall through with only a log line, leaving weights unmapped. Use re.search (or a prefix-tolerant pattern) and raise on failure to avoid corrupt mappings.
🔧 Proposed fix
- match = re.match(r'mtp\.layers\.(\d+)\.(.*)', key)
+ match = re.search(r'mtp\.layers\.(\d+)\.(.*)', key)
if match:
sublayer_idx, rest = match.groups()
key = f"model.layers.{config.num_hidden_layers}.layers.{sublayer_idx}.{rest}"
else:
- logger.error(f"Failed to match MTP pattern for: {name}")
+ raise ValueError(f"Failed to match MTP pattern for: {name}")🤖 Prompt for AI Agents
In `@tensorrt_llm/_torch/models/checkpoints/hf/nemotron_h_weight_mapper.py` around
lines 59 - 67, The MTP key matching uses re.match which only matches at the
string start, so keys with prefixes like "model.mtp.layers.0..." will not be
transformed; update the matching in the block that checks "if 'mtp.layers.' in
key:" to use a prefix-tolerant pattern (e.g., use re.search or change the regex
to allow leading text) to reliably capture sublayer_idx and rest from keys like
"...mtp.layers.<idx>.<rest>", and on failure raise an exception (or raise a
ValueError) instead of merely logging via logger.error so mapping failures
fail-fast; ensure you update the code that assigns key =
f"model.layers.{config.num_hidden_layers}.layers.{sublayer_idx}.{rest}" to use
the captured groups from the new match.
| def forward( | ||
| self, | ||
| inputs_embeds: torch.Tensor, | ||
| positions: torch.Tensor, | ||
| hidden_states: torch.Tensor, | ||
| residual: torch.Tensor | None = None, | ||
| attn_metadata: Optional[AttentionMetadata] = None, | ||
| ) -> tuple[torch.Tensor, torch.Tensor | None]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Silence unused-parameter warnings in MTP helpers.
Ruff flags unused args (positions, is_separate_draft_engine, prefix, layer_idx, all_rank_num_tokens). If they’re kept for API consistency, consider prefixing with _ or adding # noqa: ARG002 to avoid lint failures. (If you adopt the metadata propagation change above, spec_metadata/**kwargs will no longer be unused.)
✅ Minimal lint-silencing example
- positions: torch.Tensor,
+ _positions: torch.Tensor,
@@
- is_separate_draft_engine: bool = False,
- prefix: str = ""):
+ _is_separate_draft_engine: bool = False,
+ _prefix: str = ""):
@@
- def _get_mtp_sublayer_quant_config(
- self, model_config: ModelConfig[NemotronHConfig], layer_idx: int):
+ def _get_mtp_sublayer_quant_config(
+ self, model_config: ModelConfig[NemotronHConfig], _layer_idx: int):
@@
- all_rank_num_tokens: Optional[List[int]] = None,
+ _all_rank_num_tokens: Optional[List[int]] = None,Also applies to: 645-650, 697-713, 715-724
🧰 Tools
🪛 Ruff (0.14.13)
597-597: Unused method argument: positions
(ARG002)
🤖 Prompt for AI Agents
In `@tensorrt_llm/_torch/models/modeling_nemotron_h.py` around lines 594 - 601,
The MTP helper functions (starting with the forward method in
modeling_nemotron_h.py) are triggering Ruff ARG002 for unused parameters; to
silence these, rename the unused params by prefixing them with an underscore
(e.g., positions -> _positions, is_separate_draft_engine ->
_is_separate_draft_engine, prefix -> _prefix, layer_idx -> _layer_idx,
all_rank_num_tokens -> _all_rank_num_tokens) or add an inline noqa comment (#
noqa: ARG002) where renaming is undesirable for API compatibility; ensure you do
this consistently in the forward signature and the other MTP helper functions
referenced around the other ranges so spec_metadata/**kwargs remain intact if
you later propagate metadata.
| hidden_states = self.mixer( | ||
| hidden_states=hidden_states, | ||
| attn_metadata=attn_metadata, | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Propagate mamba_metadata/spec_metadata into MTP mixer to avoid M-layer failures.
NemotronHMTPDecoderLayer.forward calls the mixer without mamba_metadata (required by Mamba2Mixer) and without spec_metadata (needed to avoid speculative-state corruption). If mtp_hybrid_override_pattern includes "M", this can raise a runtime error or update real cache state during verify. Please thread these through from NemotronHMTP.forward.
🐛 Proposed fix
@@
- def forward(
+ def forward(
self,
inputs_embeds: torch.Tensor,
positions: torch.Tensor,
hidden_states: torch.Tensor,
residual: torch.Tensor | None = None,
attn_metadata: Optional[AttentionMetadata] = None,
+ mamba_metadata: Optional[Mamba2Metadata] = None,
+ spec_metadata: Optional[SpecMetadata] = None,
+ **kwargs,
) -> tuple[torch.Tensor, torch.Tensor | None]:
@@
- hidden_states = self.mixer(
- hidden_states=hidden_states,
- attn_metadata=attn_metadata,
- )
+ hidden_states = self.mixer(
+ hidden_states=hidden_states,
+ attn_metadata=attn_metadata,
+ mamba_metadata=mamba_metadata,
+ spec_metadata=spec_metadata,
+ **kwargs,
+ )
@@
- def forward(
+ def forward(
self,
input_ids: torch.IntTensor,
position_ids: torch.IntTensor,
hidden_states: torch.Tensor,
embed_tokens: Embedding,
attn_metadata: AttentionMetadata,
all_rank_num_tokens: Optional[List[int]] = None,
spec_metadata: Optional[SpecMetadata] = None,
+ mamba_metadata: Optional[Mamba2Metadata] = None,
**kwargs,
) -> torch.Tensor:
@@
- hidden_states, residual = layer(
+ hidden_states, residual = layer(
inputs_embeds=inputs_embeds,
positions=position_ids,
hidden_states=hidden_states,
residual=residual,
attn_metadata=attn_metadata,
+ mamba_metadata=mamba_metadata,
+ spec_metadata=spec_metadata,
+ **kwargs,
)Also applies to: 730-736
🤖 Prompt for AI Agents
In `@tensorrt_llm/_torch/models/modeling_nemotron_h.py` around lines 627 - 630,
The mixer call in NemotronHMTPDecoderLayer.forward and the calls originating
from NemotronHMTP.forward do not pass mamba_metadata and spec_metadata, which
Mamba2Mixer (and speculative-cache logic) require when
mtp_hybrid_override_pattern can include "M"; update both
NemotronHMTPDecoderLayer.forward and the code paths in NemotronHMTP.forward that
invoke the layer/mixer to thread the mamba_metadata and spec_metadata arguments
through to the mixer invocation (the hidden_states = self.mixer(...) call),
ensuring the same metadata variables are accepted by the layer signatures and
forwarded unchanged to avoid runtime errors or speculative-state corruption.
| elif model_type == "nemotron_h": | ||
| from .modeling_nemotron_h import NemotronHMTP | ||
| mtp_layer = NemotronHMTP(model_config, | ||
| layer_idx, | ||
| aux_stream_dict, | ||
| is_separate_draft_engine=False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Avoid partial two‑engine support for nemotron_h.
MTPDraftModel now constructs NemotronHMTP, but MTPDraftModelForCausalLM.load_weights still lacks a nemotron_h case and will raise ValueError in two‑engine mode. Either add the corresponding weight loader or explicitly block/guard this path.
🤖 Prompt for AI Agents
In `@tensorrt_llm/_torch/models/modeling_speculative.py` around lines 743 - 748,
MTPDraftModel now constructs NemotronHMTP for model_type "nemotron_h" but
MTPDraftModelForCausalLM.load_weights lacks a corresponding branch, causing a
ValueError in two-engine mode; update load_weights to handle "nemotron_h" by
adding a dedicated weight-loading branch that mirrors how NemotronHMTP expects
weights (or call the existing NemotronHMTP weight helper), or explicitly guard
against two-engine usage for nemotron_h by raising a clear, early error. Locate
MTPDraftModelForCausalLM.load_weights and either add a case for "nemotron_h"
that delegates to NemotronHMTP weight loading logic or add an explicit check
that prevents two-engine mode when model_type == "nemotron_h".
| # Copyright (c) 2024, Tri Dao. | ||
| # Adapted from https://github.com/Dao-AILab/causal-conv1d/blob/main/causal_conv1d/causal_conv1d_interface.py | ||
| # and https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/layers/attention/mamba/causal_conv1d_triton.py | ||
| # -*- coding: utf-8 -*- |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add NVIDIA copyright SPDX header.
As a TensorRT‑LLM source file, this needs the NVIDIA SPDX header with the latest year (e.g., 2026).
📄 Suggested header addition
+# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
# Copyright (c) 2024, Tri Dao.As per coding guidelines, please add the NVIDIA header.
🤖 Prompt for AI Agents
In `@tensorrt_llm/_torch/modules/mamba/causal_conv1d_triton.py` around lines 1 -
4, Add the NVIDIA SPDX copyright header (with the current year, e.g., 2026) to
the top of the module tensorrt_llm._torch.modules.mamba.causal_conv1d_triton
(i.e., prepend the standard NVIDIA header block including
SPDX-License-Identifier and copyright notice) so the file begins with the
required NVIDIA header instead of or in addition to the existing Tri Dao
comment; ensure the header follows project SPDX format and is placed before any
other code or comments.
| # Adapted from: https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/layers/attention/mamba/ops/mamba_ssm.py | ||
|
|
||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # | ||
| # SPDX-FileCopyrightText: Copyright contributors to the sglang project | ||
|
|
||
| # Copyright (c) 2024, Tri Dao, Albert Gu. | ||
| # Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/ops/triton/selective_state_update.py |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add NVIDIA copyright SPDX header.
This TensorRT‑LLM source file lacks the required NVIDIA header with the latest year. Please add the standard SPDX header (e.g., 2026) alongside the existing attribution.
📄 Suggested header update
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the sglang project
+# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the sglang projectAs per coding guidelines, please add the NVIDIA header.
🤖 Prompt for AI Agents
In `@tensorrt_llm/_torch/modules/mamba/selective_state_update.py` around lines 1 -
7, This file is missing the required NVIDIA SPDX copyright header; at the top of
tensorrt_llm/_torch/modules/mamba/selective_state_update.py add the standard
NVIDIA SPDX header (including the current year, e.g., 2026) above or alongside
the existing SPDX and copyright notices so the file includes NVIDIA's
SPDX-License-Identifier and SPDX-FileCopyrightText with the NVIDIA copyright
owner and year; ensure the new header appears before any code or existing
comments to comply with licensing guidelines.
| if not return_varlen_states: | ||
| cu_seqlens = None | ||
| else: | ||
| assert cu_seqlens is not None, "cu_seqlens must be provided if return_varlen_states is True" | ||
| out, out_x, dt_out, dA_cumsum, states, final_states, *rest = _mamba_chunk_scan_combined_fwd( | ||
| x, | ||
| dt, | ||
| A, | ||
| B, | ||
| C, | ||
| chunk_size, | ||
| D=D, | ||
| z=z, | ||
| dt_bias=dt_bias, | ||
| initial_states=initial_states, | ||
| seq_idx=seq_idx, | ||
| chunk_indices=chunk_indices, | ||
| chunk_offsets=chunk_offsets, | ||
| cu_seqlens=cu_seqlens, | ||
| dt_softplus=dt_softplus, | ||
| dt_limit=dt_limit, | ||
| mamba_ssm_cache_dtype=mamba_ssm_cache_dtype) | ||
| assert (cu_seqlens is not None | ||
| ), "cu_seqlens must be provided if return_varlen_states is True" | ||
| out_x, dt_out, dA_cumsum, states, final_states, *rest = ( | ||
| _mamba_chunk_scan_combined_fwd( | ||
| x, | ||
| dt, | ||
| A, | ||
| B, | ||
| C, | ||
| chunk_size, | ||
| D=D, | ||
| z=z, | ||
| dt_bias=dt_bias, | ||
| initial_states=initial_states, | ||
| seq_idx=seq_idx, | ||
| chunk_indices=chunk_indices, | ||
| chunk_offsets=chunk_offsets, | ||
| cu_seqlens=cu_seqlens, | ||
| dt_softplus=dt_softplus, | ||
| dt_limit=dt_limit, | ||
| out=out, | ||
| state_dtype=state_dtype, | ||
| )) | ||
| if not return_varlen_states: | ||
| return out if not return_final_states else (out, final_states) | ||
| if not return_final_states: | ||
| return | ||
| else: | ||
| return final_states | ||
| else: | ||
| varlen_states = rest[0] | ||
| return (out, | ||
| varlen_states) if not return_final_states else (out, | ||
| final_states, | ||
| varlen_states) | ||
| return ((varlen_states) if not return_final_states else | ||
| (final_states, varlen_states)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
🏁 Script executed:
head -n 300 tensorrt_llm/_torch/modules/mamba/ssd_combined.py | tail -n 100Repository: NVIDIA/TensorRT-LLM
Length of output: 182
🏁 Script executed:
rg -n "def.*mamba_chunk_scan_combined\|_mamba_chunk_scan_combined_fwd" tensorrt_llm/_torch/modules/mamba/ssd_combined.py -A 5Repository: NVIDIA/TensorRT-LLM
Length of output: 45
🏁 Script executed:
rg -n "_mamba_chunk_scan_combined_fwd\(" tensorrt_llm/_torch/modules/mamba/ -A 2 -B 2Repository: NVIDIA/TensorRT-LLM
Length of output: 875
🏁 Script executed:
rg -n "mamba_chunk_scan_combined\(" tensorrt_llm/ -A 2 -B 2 --include="*.py"Repository: NVIDIA/TensorRT-LLM
Length of output: 504
🏁 Script executed:
sed -n '235,270p' tensorrt_llm/_torch/modules/mamba/ssd_combined.pyRepository: NVIDIA/TensorRT-LLM
Length of output: 1129
🏁 Script executed:
sed -n '35,100p' tensorrt_llm/_torch/modules/mamba/ssd_combined.pyRepository: NVIDIA/TensorRT-LLM
Length of output: 2409
🏁 Script executed:
rg -n "mamba_chunk_scan_combined" tensorrt_llm/ -t pyRepository: NVIDIA/TensorRT-LLM
Length of output: 849
🏁 Script executed:
sed -n '194,275p' tensorrt_llm/_torch/modules/mamba/ssd_combined.pyRepository: NVIDIA/TensorRT-LLM
Length of output: 2207
🏁 Script executed:
sed -n '160,180p' tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/triton_backend_mamba.pyRepository: NVIDIA/TensorRT-LLM
Length of output: 694
🏁 Script executed:
sed -n '258,270p' tensorrt_llm/_torch/modules/mamba/mamba2_mixer.pyRepository: NVIDIA/TensorRT-LLM
Length of output: 511
🏁 Script executed:
sed -n '155,175p' tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/triton_backend_mamba.pyRepository: NVIDIA/TensorRT-LLM
Length of output: 731
🏁 Script executed:
sed -n '250,275p' tensorrt_llm/_torch/modules/mamba/mamba2_mixer.pyRepository: NVIDIA/TensorRT-LLM
Length of output: 1108
🏁 Script executed:
sed -n '100,150p' tensorrt_llm/_torch/modules/mamba/ssd_combined.pyRepository: NVIDIA/TensorRT-LLM
Length of output: 2583
🏁 Script executed:
sed -n '165,180p' tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/triton_backend_mamba.pyRepository: NVIDIA/TensorRT-LLM
Length of output: 522
🏁 Script executed:
sed -n '263,280p' tensorrt_llm/_torch/modules/mamba/mamba2_mixer.pyRepository: NVIDIA/TensorRT-LLM
Length of output: 740
🏁 Script executed:
sed -n '35,120p' tensorrt_llm/_torch/modules/mamba/ssd_combined.pyRepository: NVIDIA/TensorRT-LLM
Length of output: 3340
🏁 Script executed:
sed -n '35,180p' tensorrt_llm/_torch/modules/mamba/ssd_combined.py | tail -50Repository: NVIDIA/TensorRT-LLM
Length of output: 1892
🏁 Script executed:
sed -n '160,190p' tensorrt_llm/_torch/modules/mamba/ssd_combined.pyRepository: NVIDIA/TensorRT-LLM
Length of output: 879
🏁 Script executed:
sed -n '263,285p' tensorrt_llm/_torch/modules/mamba/mamba2_mixer.pyRepository: NVIDIA/TensorRT-LLM
Length of output: 993
🏁 Script executed:
sed -n '178,195p' tensorrt_llm/_torch/modules/mamba/ssd_combined.pyRepository: NVIDIA/TensorRT-LLM
Length of output: 589
🏁 Script executed:
sed -n '160,190p' tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/triton_backend_mamba.pyRepository: NVIDIA/TensorRT-LLM
Length of output: 1064
🏁 Script executed:
sed -n '160,200p' tensorrt_llm/_torch/modules/mamba/mamba2_mixer.pyRepository: NVIDIA/TensorRT-LLM
Length of output: 1710
🏁 Script executed:
sed -n '164,195p' tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/triton_backend_mamba.pyRepository: NVIDIA/TensorRT-LLM
Length of output: 1087
🏁 Script executed:
grep -n "state_dtype\|mamba_ssm_cache" tensorrt_llm/_torch/modules/mamba/ssd_combined.pyRepository: NVIDIA/TensorRT-LLM
Length of output: 289
🏁 Script executed:
grep -n "_chunk_scan_fwd\|def _chunk_scan_fwd" tensorrt_llm/_torch/modules/mamba/ssd_combined.pyRepository: NVIDIA/TensorRT-LLM
Length of output: 142
🏁 Script executed:
rg "_chunk_scan_fwd" tensorrt_llm/_torch/modules/mamba/ -A 10 -B 2Repository: NVIDIA/TensorRT-LLM
Length of output: 6195
🏁 Script executed:
sed -n '1,50p' tensorrt_llm/_torch/modules/mamba/ssd_chunk_scan.pyRepository: NVIDIA/TensorRT-LLM
Length of output: 1644
🏁 Script executed:
sed -n '200,300p' tensorrt_llm/_torch/modules/mamba/ssd_chunk_scan.pyRepository: NVIDIA/TensorRT-LLM
Length of output: 4535
🏁 Script executed:
sed -n '295,330p' tensorrt_llm/_torch/modules/mamba/ssd_chunk_scan.pyRepository: NVIDIA/TensorRT-LLM
Length of output: 1533
🏁 Script executed:
grep -n "^def _chunk_scan_fwd\|return out" tensorrt_llm/_torch/modules/mamba/ssd_chunk_scan.py | head -20Repository: NVIDIA/TensorRT-LLM
Length of output: 140
🏁 Script executed:
sed -n '501,540p' tensorrt_llm/_torch/modules/mamba/ssd_chunk_scan.pyRepository: NVIDIA/TensorRT-LLM
Length of output: 1431
🏁 Script executed:
grep -n "mamba_ssm_cache_dtype" tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/triton_backend_mamba.py tensorrt_llm/_torch/modules/mamba/ssd_combined.pyRepository: NVIDIA/TensorRT-LLM
Length of output: 196
🏁 Script executed:
sed -n '161,175p' tensorrt_llm/_torch/modules/mamba/ssd_combined.pyRepository: NVIDIA/TensorRT-LLM
Length of output: 359
🏁 Script executed:
grep -n "out=" tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/triton_backend_mamba.py | head -10Repository: NVIDIA/TensorRT-LLM
Length of output: 45
🏁 Script executed:
grep -n "out=" tensorrt_llm/_torch/modules/mamba/mamba2_mixer.py | head -10Repository: NVIDIA/TensorRT-LLM
Length of output: 271
🏁 Script executed:
sed -n '560,580p' tensorrt_llm/_torch/modules/mamba/ssd_chunk_scan.pyRepository: NVIDIA/TensorRT-LLM
Length of output: 462
🏁 Script executed:
sed -n '620,634p' tensorrt_llm/_torch/modules/mamba/ssd_chunk_scan.pyRepository: NVIDIA/TensorRT-LLM
Length of output: 610
🏁 Script executed:
sed -n '175,192p' tensorrt_llm/_torch/modules/mamba/ssd_combined.pyRepository: NVIDIA/TensorRT-LLM
Length of output: 641
🏁 Script executed:
sed -n '164,195p' tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/triton_backend_mamba.pyRepository: NVIDIA/TensorRT-LLM
Length of output: 1087
🏁 Script executed:
sed -n '194,220p' tensorrt_llm/_torch/modules/mamba/ssd_combined.pyRepository: NVIDIA/TensorRT-LLM
Length of output: 580
Prefix unused unpacked variables with _ and fix parameter mismatch in triton_backend_mamba.py.
The variables out_x, dt_out, dA_cumsum, and states are unpacked but never used after line 240, triggering Ruff warnings. Prefix them with _ to silence the warnings. Additionally, triton_backend_mamba.py:168 calls mamba_chunk_scan_combined() with the parameter mamba_ssm_cache_dtype=ssm_state_cache.dtype, but the function signature expects state_dtype=None, which will cause a TypeError at runtime. Correct the parameter name in that call site.
Unused variable prefixes
- out_x, dt_out, dA_cumsum, states, final_states, *rest = (
+ _out_x, _dt_out, _dA_cumsum, _states, final_states, *rest = (
_mamba_chunk_scan_combined_fwd(
x,
dt,🧰 Tools
🪛 Ruff (0.14.13)
240-240: Unpacked variable out_x is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
240-240: Unpacked variable dt_out is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
240-240: Unpacked variable dA_cumsum is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
240-240: Unpacked variable states is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
🤖 Prompt for AI Agents
In `@tensorrt_llm/_torch/modules/mamba/ssd_combined.py` around lines 235 - 269,
The unpacked outputs from _mamba_chunk_scan_combined_fwd (out_x, dt_out,
dA_cumsum, states) are unused and should be renamed with leading underscores to
silence lint warnings (e.g., _out_x, _dt_out, _dA_cumsum, _states) in the
assignment where _mamba_chunk_scan_combined_fwd is called; also fix the
parameter name mismatch where triton_backend_mamba.py calls
mamba_chunk_scan_combined by replacing the incorrect keyword
mamba_ssm_cache_dtype with the expected state_dtype so the call uses
state_dtype=ssm_state_cache.dtype.
| - quant_algo: NVFP4 | ||
| kv_cache_quant_algo: FP8 | ||
| mtp_enabled: true | ||
| num_nextn_predict_layers: 3 | ||
| accuracy: 80.85 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
# Search for mtp_enabled usage across the codebase
echo "=== Searching for mtp_enabled usage ==="
rg -n "mtp_enabled" -C2
echo ""
echo "=== Searching for num_nextn_predict_layers usage ==="
rg -n "num_nextn_predict_layers" -C2
echo ""
echo "=== Searching for spec_dec_algo usage in YAML ==="
rg -n "spec_dec_algo.*MTP" --type yaml -C2
echo ""
echo "=== Looking at the gsm8k.yaml file around the referenced lines ==="
sed -n '315,335p' tests/integration/defs/accuracy/references/gsm8k.yamlRepository: NVIDIA/TensorRT-LLM
Length of output: 50376
🏁 Script executed:
#!/bin/bash
# Search for code that parses the accuracy reference YAML files
echo "=== Looking for accuracy YAML parsing code ==="
rg -n "accuracy.*references" --type py -A3 | head -50
echo ""
echo "=== Searching for mtp_enabled in Python code ==="
rg -n "mtp_enabled" --type py
echo ""
echo "=== Looking at how spec_config is created from YAML ==="
rg -n "spec_dec_algo\|speculative_config" tensorrt_llm/llmapi/llm_args.py -B2 -A2 | head -100Repository: NVIDIA/TensorRT-LLM
Length of output: 216
🏁 Script executed:
#!/bin/bash
# Find test files that reference the accuracy YAML files
echo "=== Finding test files that reference accuracy YAML files ==="
rg -n "gsm8k\|mmlu" --type py tests/
echo ""
echo "=== Look at the directory structure of tests/integration ==="
find tests/integration -type f -name "*.py" | head -20
echo ""
echo "=== Search for YAML loading code ==="
rg -n "yaml\|YAML" --type py tensorrt_llm/llmapi/ | grep -i "load\|parse" | head -20Repository: NVIDIA/TensorRT-LLM
Length of output: 1213
🏁 Script executed:
#!/bin/bash
# Examine the accuracy_core.py file
echo "=== Looking at accuracy_core.py ==="
cat tests/integration/defs/accuracy/accuracy_core.py | head -200
echo ""
echo "=== Search for how configs are loaded from YAML ==="
rg -n "yaml\.load\|yaml\.safe_load\|\.yaml" tests/integration/defs/accuracy/ -B2 -A2Repository: NVIDIA/TensorRT-LLM
Length of output: 7624
🏁 Script executed:
#!/bin/bash
# Let's see more of accuracy_core.py to understand how spec_dec_algo is used
echo "=== Continue reading accuracy_core.py ==="
sed -n '1,300p' tests/integration/defs/accuracy/accuracy_core.py | tail -100
echo ""
echo "=== Look at a sample of the gsm8k.yaml to understand structure ==="
sed -n '315,335p' tests/integration/defs/accuracy/references/gsm8k.yaml
echo ""
echo "=== Look at entries with spec_dec_algo to see the pattern ==="
rg -n "spec_dec_algo" tests/integration/defs/accuracy/references/gsm8k.yaml -B3 -A3 | head -40Repository: NVIDIA/TensorRT-LLM
Length of output: 5187
Remove mtp_enabled and num_nextn_predict_layers fields or use standard spec_dec_algo: MTP pattern.
The test framework's accuracy reference matching logic (in accuracy_core.py) only recognizes spec_dec_algo for speculative decoding configurations, not mtp_enabled or num_nextn_predict_layers. The fields you've added will be silently ignored during accuracy specification matching, causing the framework to fail to locate the correct reference entry. Use the standard pattern with spec_dec_algo: MTP consistent with other MTP entries in the file (e.g., deepseek-ai/DeepSeek-V3-Lite at lines 50-51, 61-62).
🤖 Prompt for AI Agents
In `@tests/integration/defs/accuracy/references/gsm8k.yaml` around lines 325 -
329, The YAML entry currently uses mtp_enabled and num_nextn_predict_layers
which the accuracy matching logic in accuracy_core.py ignores; remove those two
fields and instead use the standard speculative-decoding pattern by adding
spec_dec_algo: MTP (matching other MTP entries like
deepseek-ai/DeepSeek-V3-Lite) so the framework can correctly find this reference
entry.
| - quant_algo: NVFP4 | ||
| kv_cache_quant_algo: FP8 | ||
| mtp_enabled: true | ||
| num_nextn_predict_layers: 3 | ||
| accuracy: 77.56 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
🏁 Script executed:
rg -n "mtp_enabled|num_nextn_predict_layers" -SRepository: NVIDIA/TensorRT-LLM
Length of output: 182
🏁 Script executed:
# Also search for accuracy matching logic and how reference selection works
rg -n "accuracy.*match|reference.*select" -S -i | head -30Repository: NVIDIA/TensorRT-LLM
Length of output: 50376
🏁 Script executed:
# Find files that reference accuracy matching or mmlu.yaml
fd -t f -e py | xargs rg -l "accuracy.*match|mmlu" | head -20Repository: NVIDIA/TensorRT-LLM
Length of output: 803
🏁 Script executed:
rg -n "reference.*match|match.*reference" tests/integration/defs/accuracy/ -i -A 5 | head -50Repository: NVIDIA/TensorRT-LLM
Length of output: 45
🏁 Script executed:
# Look at the accuracy_core.py to understand how references are matched
head -100 tests/integration/defs/accuracy/accuracy_core.pyRepository: NVIDIA/TensorRT-LLM
Length of output: 3528
🏁 Script executed:
# Search for where mtp_enabled or num_nextn_predict_layers might be used
rg -n "def.*match|def.*select|def.*find" tests/integration/defs/accuracy/ -i -A 3 | head -60Repository: NVIDIA/TensorRT-LLM
Length of output: 45
🏁 Script executed:
# Check the mmlu.yaml file structure around line 362
sed -n '355,375p' tests/integration/defs/accuracy/references/mmlu.yamlRepository: NVIDIA/TensorRT-LLM
Length of output: 460
🏁 Script executed:
# Look for reference matching logic in accuracy files
rg -n "def.*find.*reference|def.*match.*reference|def.*select.*reference" tests/integration/defs/accuracy/ -A 20Repository: NVIDIA/TensorRT-LLM
Length of output: 45
🏁 Script executed:
# Search for how yaml references are loaded and matched
rg -n "yaml.load|reference.*yaml|mtp_enabled" tests/integration/defs/accuracy/ -B 2 -A 5Repository: NVIDIA/TensorRT-LLM
Length of output: 3544
🏁 Script executed:
# Look at accuracy_validator.py which might have the matching logic
cat tests/integration/defs/perf/disagg/reporting/accuracy_validator.pyRepository: NVIDIA/TensorRT-LLM
Length of output: 10512
🏁 Script executed:
# Find the complete get_hypothesis_testing_params method and how it matches references
sed -n '148,200p' tests/integration/defs/accuracy/accuracy_core.pyRepository: NVIDIA/TensorRT-LLM
Length of output: 2495
🏁 Script executed:
# Search for how acc_specs are matched or filtered against references
rg -n "acc_specs|**acc_specs" tests/integration/defs/accuracy/ -B 3 -A 10Repository: NVIDIA/TensorRT-LLM
Length of output: 169
🏁 Script executed:
# Look for any logic that filters or matches references by their keys
rg -n "for.*reference|match.*reference|find.*reference" tests/integration/defs/accuracy/accuracy_core.py -A 10Repository: NVIDIA/TensorRT-LLM
Length of output: 520
🏁 Script executed:
# Find where get_hypothesis_testing_params is called and what arguments are passed
rg -n "get_hypothesis_testing_params" tests/integration/defs/ -B 5 -A 3Repository: NVIDIA/TensorRT-LLM
Length of output: 6365
🏁 Script executed:
# Look for where mtp_enabled might be set or passed as an argument
rg -n "mtp_enabled" tests/integration/defs/ -B 3 -A 3Repository: NVIDIA/TensorRT-LLM
Length of output: 1208
🏁 Script executed:
# Check the evaluate method to see what acc_specs are constructed
sed -n '185,250p' tests/integration/defs/accuracy/accuracy_core.pyRepository: NVIDIA/TensorRT-LLM
Length of output: 3016
🏁 Script executed:
# Look for where mtp_enabled is set in the test infrastructure or llm args
rg -n "mtp_enabled\s*=" tests/integration/defs/ -B 3 -A 3Repository: NVIDIA/TensorRT-LLM
Length of output: 45
🏁 Script executed:
# Check what parameters are actually passed to get_hypothesis_testing_params
sed -n '213,220p' tests/integration/defs/accuracy/accuracy_core.pyRepository: NVIDIA/TensorRT-LLM
Length of output: 445
🏁 Script executed:
# Look for if there are other calls to get_hypothesis_testing_params with more parameters
rg -n "get_hypothesis_testing_params" tests/integration/defs/accuracy/accuracy_core.py -A 8Repository: NVIDIA/TensorRT-LLM
Length of output: 2566
The MTP reference entry will never be selected because mtp_enabled and num_nextn_predict_layers are not passed to the accuracy matcher.
The matching logic only checks parameters that are explicitly provided to get_hypothesis_testing_params(). Current call sites pass dtype, quant_algo, kv_cache_quant_algo, spec_dec_algo, and extra_acc_spec—but never mtp_enabled or num_nextn_predict_layers. This means MTP runs will always match against the non-MTP baseline (lines 362-361) instead of the MTP-specific entry. Either pass these keys to the matcher or consolidate into extra_acc_spec.
🤖 Prompt for AI Agents
In `@tests/integration/defs/accuracy/references/mmlu.yaml` around lines 362 - 366,
The accuracy matcher never sees mtp_enabled and num_nextn_predict_layers so
MTP-specific reference entries (mtp_enabled, num_nextn_predict_layers) are never
selected; update the callers of get_hypothesis_testing_params() to include
mtp_enabled and num_nextn_predict_layers in the parameter dict passed to the
matcher (or alternatively merge those flags into extra_acc_spec before calling
get_hypothesis_testing_params()), ensuring the matcher receives the new keys so
MTP runs match the MTP-specific YAML entry.
| @pytest.mark.skip(reason="Skip MTP test due to no model path file in CI") | ||
| @skip_pre_blackwell | ||
| @pytest.mark.skip_less_mpi_world_size(8) | ||
| def test_nvfp4_8gpus_mtp(self): | ||
| # Test MTP (Multi-Token Prediction) accuracy with nvfp4-fp8kv model. | ||
| # This test uses MTP with max_draft_len=3 and one_model mode. | ||
| mtp_config = MTPDecodingConfig( | ||
| num_nextn_predict_layers=3, | ||
| mtp_eagle_one_model=True, | ||
| ) | ||
| model_path = f"{llm_models_root()}/nemotron-super-sft-repeated-mtp-iter-0010600-nvfp4-fp8kv" | ||
| with LLM( | ||
| model_path, | ||
| kv_cache_config=KvCacheConfig( | ||
| enable_block_reuse=False, | ||
| mamba_ssm_cache_dtype="float16", | ||
| free_gpu_memory_fraction=0.5, | ||
| ), | ||
| max_batch_size=128, | ||
| tensor_parallel_size=8, | ||
| moe_expert_parallel_size=8, | ||
| pipeline_parallel_size=1, | ||
| enable_attention_dp=False, | ||
| cuda_graph_config=CudaGraphConfig(max_batch_size=32, | ||
| enable_padding=True), | ||
| disable_overlap_scheduler=False, | ||
| moe_config=MoeConfig(backend="CUTLASS"), | ||
| decoding_config=mtp_config, | ||
| ) as llm: | ||
| task = MMLU(self.MODEL_NAME) | ||
| task.evaluate(llm, | ||
| extra_evaluator_kwargs=self.EXTRA_EVALUATOR_KWARGS) | ||
| task = GSM8K(self.MODEL_NAME) | ||
| task.evaluate(llm, | ||
| extra_evaluator_kwargs=self.EXTRA_EVALUATOR_KWARGS) | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
🏁 Script executed:
# Search for LLM class definition and its constructor
find . -type f -name "*.py" | xargs grep -l "class LLM" | head -5Repository: NVIDIA/TensorRT-LLM
Length of output: 190
🏁 Script executed:
# Search for usage of decoding_config and speculative_config in the codebase
rg "decoding_config\s*=" --type py | head -20Repository: NVIDIA/TensorRT-LLM
Length of output: 2227
🏁 Script executed:
# Search for speculative_config usage in tests
rg "speculative_config\s*=" --type py | head -20Repository: NVIDIA/TensorRT-LLM
Length of output: 2138
🏁 Script executed:
# Look for other MTP test examples
rg "MTPDecodingConfig" --type py -B 2 -A 2 | head -50Repository: NVIDIA/TensorRT-LLM
Length of output: 4099
🏁 Script executed:
# Look at LLM constructor in llmapi/llm.py
rg -n "def __init__" tensorrt_llm/llmapi/llm.py -A 30 | head -60Repository: NVIDIA/TensorRT-LLM
Length of output: 2542
🏁 Script executed:
# Find how MTP config is passed to LLM in other tests
rg -n "speculative_config\s*=\s*mtp_config" tests/integration/defs/accuracy/test_llm_api_pytorch.py -B 5 -A 5Repository: NVIDIA/TensorRT-LLM
Length of output: 18170
🏁 Script executed:
# Get more context around MTP test usage in integration tests
rg -n "MTPDecodingConfig" tests/integration/defs/accuracy/test_llm_api_pytorch.py -B 2 -A 10 | head -100Repository: NVIDIA/TensorRT-LLM
Length of output: 4772
🏁 Script executed:
# Look at more of the LLM constructor to see how kwargs are processed
cat -n tensorrt_llm/llmapi/llm.py | head -200 | tail -100Repository: NVIDIA/TensorRT-LLM
Length of output: 4733
🏁 Script executed:
# Search for where speculative_config is processed in llmapi
rg "speculative_config" tensorrt_llm/llmapi/ -B 2 -A 2Repository: NVIDIA/TensorRT-LLM
Length of output: 10446
🏁 Script executed:
# Look for any reference to decoding_config in LLM class
rg "decoding_config" tensorrt_llm/llmapi/llm.py -B 2 -A 2Repository: NVIDIA/TensorRT-LLM
Length of output: 450
Change decoding_config to speculative_config for MTP configuration.
All MTP tests in this file use speculative_config=mtp_config. The parameter decoding_config is not a valid LLM constructor parameter for speculative decoding configuration—it's only used internally for executor setup. This test will not exercise MTP without the fix.
🔧 Suggested fix
- decoding_config=mtp_config,
+ speculative_config=mtp_config,🤖 Prompt for AI Agents
In `@tests/integration/defs/accuracy/test_llm_api_pytorch.py` around lines 5554 -
5589, The test constructs an LLM with the MTP configuration assigned to the
wrong parameter: replace the use of decoding_config=mtp_config with
speculative_config=mtp_config in the LLM(...) call so the MTPDecodingConfig
(mtp_config) is passed into the speculative decoding pipeline; update the LLM
constructor invocation that currently references decoding_config to use
speculative_config instead so the test actually exercises MTP (refer to
mtp_config and the LLM(...) call where decoding_config is set).
Signed-off-by: Izzy Putterman <iputterman@nvidia.com>
Summary by CodeRabbit
Release Notes
New Features
Performance Improvements
Tests
✏️ Tip: You can customize this high-level summary in your review settings.
Description
Test Coverage
PR Checklist
Please review the following before submitting your PR:
PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.
PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.
Test cases are provided for new code paths (see test instructions)
Any new dependencies have been scanned for license and vulnerabilities
CODEOWNERS updated if ownership changes
Documentation updated as needed
Update tava architecture diagram if there is a significant design change in PR.
The reviewers assigned automatically/manually are appropriate for the PR.
Please check this after reviewing the above items as appropriate for this PR.
GitHub Bot Help
/bot [-h] ['run', 'kill', 'skip', 'reuse-pipeline'] ...Provide a user friendly way for developers to interact with a Jenkins server.
Run
/bot [-h|--help]to print this help message.See details below for each supported subcommand.
Details
run [--reuse-test (optional)pipeline-id --disable-fail-fast --skip-test --stage-list "A10-PyTorch-1, xxx" --gpu-type "A30, H100_PCIe" --test-backend "pytorch, cpp" --add-multi-gpu-test --only-multi-gpu-test --disable-multi-gpu-test --post-merge --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" --detailed-log --debug(experimental)]Launch build/test pipelines. All previously running jobs will be killed.
--reuse-test (optional)pipeline-id(OPTIONAL) : Allow the new pipeline to reuse build artifacts and skip successful test stages from a specified pipeline or the last pipeline if no pipeline-id is indicated. If the Git commit ID has changed, this option will be always ignored. The DEFAULT behavior of the bot is to reuse build artifacts and successful test results from the last pipeline.--disable-reuse-test(OPTIONAL) : Explicitly prevent the pipeline from reusing build artifacts and skipping successful test stages from a previous pipeline. Ensure that all builds and tests are run regardless of previous successes.--disable-fail-fast(OPTIONAL) : Disable fail fast on build/tests/infra failures.--skip-test(OPTIONAL) : Skip all test stages, but still run build stages, package stages and sanity check stages. Note: Does NOT update GitHub check status.--stage-list "A10-PyTorch-1, xxx"(OPTIONAL) : Only run the specified test stages. Examples: "A10-PyTorch-1, xxx". Note: Does NOT update GitHub check status.--gpu-type "A30, H100_PCIe"(OPTIONAL) : Only run the test stages on the specified GPU types. Examples: "A30, H100_PCIe". Note: Does NOT update GitHub check status.--test-backend "pytorch, cpp"(OPTIONAL) : Skip test stages which don't match the specified backends. Only support [pytorch, cpp, tensorrt, triton]. Examples: "pytorch, cpp" (does not run test stages with tensorrt or triton backend). Note: Does NOT update GitHub pipeline status.--only-multi-gpu-test(OPTIONAL) : Only run the multi-GPU tests. Note: Does NOT update GitHub check status.--disable-multi-gpu-test(OPTIONAL) : Disable the multi-GPU tests. Note: Does NOT update GitHub check status.--add-multi-gpu-test(OPTIONAL) : Force run the multi-GPU tests in addition to running L0 pre-merge pipeline.--post-merge(OPTIONAL) : Run the L0 post-merge pipeline instead of the ordinary L0 pre-merge pipeline.--extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx"(OPTIONAL) : Run the ordinary L0 pre-merge pipeline and specified test stages. Examples: --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx".--detailed-log(OPTIONAL) : Enable flushing out all logs to the Jenkins console. This will significantly increase the log volume and may slow down the job.--debug(OPTIONAL) : Experimental feature. Enable access to the CI container for debugging purpose. Note: Specify exactly one stage in thestage-listparameter to access the appropriate container environment. Note: Does NOT update GitHub check status.For guidance on mapping tests to stage names, see
docs/source/reference/ci-overview.mdand the
scripts/test_to_stage_mapping.pyhelper.kill
killKill all running builds associated with pull request.
skip
skip --comment COMMENTSkip testing for latest commit on pull request.
--comment "Reason for skipping build/test"is required. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.reuse-pipeline
reuse-pipelineReuse a previous pipeline to validate current commit. This action will also kill all currently running builds associated with the pull request. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.