-
Notifications
You must be signed in to change notification settings - Fork 2k
[None][feat] Cache Transceivers for Mamba States #10934
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Iman Tabrizian <10105175+tabrizian@users.noreply.github.com>
Signed-off-by: Shreyas Misra <shreyasm@nvidia.com>
Signed-off-by: Shreyas Misra <shreyasm@nvidia.com>
Signed-off-by: Shreyas Misra <shreyasm@nvidia.com>
Signed-off-by: Shreyas Misra <shreyasm@nvidia.com>
Signed-off-by: Shreyas Misra <shreyasm@nvidia.com>
Signed-off-by: Shreyas Misra <shreyasm@nvidia.com>
Signed-off-by: Shreyas Misra <shreyasm@nvidia.com>
|
/bot run |
|
PR_Github #33217 [ run ] triggered by Bot. Commit: |
📝 WalkthroughWalkthroughThis pull request introduces comprehensive RNN/Mamba cache transfer infrastructure for disaggregated state management. New components include RnnCacheFormatter, RnnCacheTransceiver with sender/receiver pairs, RnnStateManager cache block allocation, buffer management framework, and RnnCacheState serialization support. Python bindings and attention backends are updated to integrate Mamba metadata handling. Changes
Sequence Diagram(s)sequenceDiagram
actor Req as Request
participant RnnCacheTx as RnnCacheTransceiver
participant RnnSend as RnnCacheSender
participant RnnRecv as RnnCacheReceiver
participant CommGrp as CacheTransceiverComm<br/>(MPI/UCX)
participant RnnMgr as RnnStateManager
participant Fmt as RnnCacheFormatter
Req->>RnnCacheTx: respondAndSendAsync(llmRequest)
activate RnnCacheTx
RnnCacheTx->>RnnCacheTx: setContextState(llmRequest)
RnnCacheTx->>RnnSend: sendAsync(llmRequest)
activate RnnSend
RnnSend->>Fmt: inquireSupport(selfState, peerState)
RnnSend->>Fmt: getCounterparts(selfState, idx, peerState)
RnnSend->>CommGrp: send RNN state data
RnnSend-->>RnnCacheTx: return future
deactivate RnnSend
deactivate RnnCacheTx
Req->>RnnCacheTx: requestAndReceiveAsync(llmRequest)
activate RnnCacheTx
RnnCacheTx->>RnnRecv: receiveAsync(llmRequest)
activate RnnRecv
RnnRecv->>Fmt: inquireSupport(selfState, peerState)
RnnRecv->>CommGrp: wait for RNN state data
RnnRecv->>RnnMgr: allocateCacheBlocks(requestIds)
RnnRecv-->>RnnCacheTx: return future
deactivate RnnRecv
deactivate RnnCacheTx
Req->>RnnCacheTx: checkContextTransferStatus()
activate RnnCacheTx
RnnCacheTx->>CommGrp: aggregate completion across ranks
RnnCacheTx->>Req: update request state
deactivate RnnCacheTx
Req->>RnnCacheTx: checkGenTransferStatus()
activate RnnCacheTx
RnnCacheTx->>CommGrp: synchronize generation transfer completion
RnnCacheTx->>RnnMgr: query state indices
RnnCacheTx->>Req: update generation progress
deactivate RnnCacheTx
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes 🚥 Pre-merge checks | ✅ 1 | ❌ 2❌ Failed checks (1 warning, 1 inconclusive)
✅ Passed checks (1 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 17
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (10)
tensorrt_llm/_torch/modules/mamba/mamba2_mixer.py (1)
1-1: Update header year to include 2026.
This file was modified in 2026; please bump the copyright range.Proposed fix
-# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.tensorrt_llm/_torch/modules/mamba/mamba2_metadata.py (1)
1-2: Update SPDX year to 2026.
This file now includes 2026 modifications but the header still ends at 2024.As per coding guidelines, please update the year.📝 Proposed update
-# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.cpp/include/tensorrt_llm/batch_manager/cacheTransceiver.h (1)
1-2: Update copyright year to 2026.
The file now has 2026 modifications but the header still ends at 2024.As per coding guidelines, please update the year.📝 Proposed update
- * Copyright (c) 2023-2024, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2023-2026, NVIDIA CORPORATION. All rights reserved.tensorrt_llm/_torch/pyexecutor/mamba_cache_manager.py (1)
1-2: Update SPDX year to 2026.
This file now includes 2026 modifications but the header still ends at 2025.As per coding guidelines, please update the year.📝 Proposed update
-# SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.tensorrt_llm/_torch/attention_backend/interface.py (1)
1-7: Add the NVIDIA SPDX header.
This source file is missing the required SPDX/license header.As per coding guidelines, please add the SPDX header.📝 Proposed header
+# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import copytensorrt_llm/_torch/attention_backend/flashinfer.py (1)
1-3: Add the NVIDIA SPDX header.
This source file is missing the required SPDX/license header.As per coding guidelines, please add the SPDX header.📝 Proposed header
+# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import mathtensorrt_llm/_torch/attention_backend/vanilla.py (1)
1-3: Add the NVIDIA SPDX header.
This source file is missing the required SPDX/license header.As per coding guidelines, please add the SPDX header.📝 Proposed header
+# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import mathcpp/tensorrt_llm/batch_manager/cacheTransBuffer.cpp (1)
131-131: Typo in function name:supportFbaricMemoryshould besupportFabricMemory.The function name has a typo - "Fbaric" instead of "Fabric".
🔤 Proposed fix
-bool FabricMemory::supportFbaricMemory() +bool FabricMemory::supportFabricMemory()Note: This will also require updating the header file declaration and any call sites.
cpp/tensorrt_llm/batch_manager/cacheTransBuffer.h (1)
53-53: Same typo in header declaration:supportFbaricMemory.This matches the typo flagged in the .cpp file. Both should be fixed together.
🔤 Proposed fix
- static bool supportFbaricMemory(); + static bool supportFabricMemory();cpp/include/tensorrt_llm/batch_manager/rnnStateManager.h (1)
1-15: Update copyright year to include 2025.The copyright header shows "2023-2024" but this file has been modified with significant new functionality in 2025. Per coding guidelines, the year should reflect the latest meaningful modification.
📅 Proposed fix
/* - * Copyright (c) 2023-2024, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2023-2025, NVIDIA CORPORATION. All rights reserved. *
🤖 Fix all issues with AI agents
In `@cpp/include/tensorrt_llm/batch_manager/rnnCacheTransceiver.h`:
- Around line 8-44: Replace the non-compliant header protection and comment
style: remove "#pragma once" and add the explicit include guard
"TRTLLM_RNNCACHETRANSCEIVER_H" (i.e., `#ifndef` TRTLLM_RNNCACHETRANSCEIVER_H /
`#define` ... / `#endif`) at the top/bottom of the header, and convert C++03-style
triple-slash Doxygen comments (e.g., "/// `@brief`" and other "///" comments
associated with RnnCacheSender, its constructor and members) to the project
standard "//!"/"//!<" style so class-level docs and member descriptions (for
RnnCacheSender, RnnCacheFormatter references, etc.) use //! and //!< where
appropriate. Ensure the guard macro name matches the file identifier exactly
(TRTLLM_RNNCACHETRANSCEIVER_H) and update any file-local comment markers
accordingly.
- Around line 1-6: Update the file header comment's copyright year from 2025 to
2026 in the top-of-file comment block of
tensorrt_llm/batch_manager/rnnCacheTransceiver.h; locate the copyright line in
the file's leading comment (the block beginning "/*" at the top) and change
"2025" to "2026" so the header reflects the latest modification year.
In `@cpp/include/tensorrt_llm/executor/dataTransceiverState.h`:
- Around line 610-615: Remove the inline comment " // is this needed?" next to
the mRnnCacheState declaration in the class (friend class Serialization;
std::optional<rnn_cache::RnnCacheState> mRnnCacheState), and if this represents
a tracked TODO open an issue or add a documented TODO elsewhere rather than
leaving an ambiguous inline note; simply delete the comment or replace it with a
proper tracked reference.
In `@cpp/tensorrt_llm/batch_manager/baseTransBuffer.cpp`:
- Around line 1-16: Update the copyright year in the file header comment: locate
the SPDX header line that reads "Copyright (c) 2025 NVIDIA CORPORATION &
AFFILIATES" (the SPDX comment block near the top of baseTransBuffer.cpp) and
change the year from 2025 to 2026 so the header reflects the current
modification year; ensure the SPDX-License-Identifier and surrounding comment
formatting remain unchanged.
In `@cpp/tensorrt_llm/batch_manager/baseTransBuffer.h`:
- Around line 1-16: The file header's copyright year is outdated (shows 2025);
update the year to 2026 in the top comment — specifically modify the
SPDX-FileCopyrightText line in baseTransBuffer.h to read 2026 so the header
matches the latest changes.
- Around line 18-44: The header currently uses `#pragma` once and ‘///’ comments;
replace `#pragma` once with a traditional include guard macro named
TRTLLM_BASETRANSBUFFER_H (wrap the entire header between `#ifndef`
TRTLLM_BASETRANSBUFFER_H / `#define` TRTLLM_BASETRANSBUFFER_H and end with `#endif`)
and convert the public API doc comments (e.g., the comment above the
BaseTransBufferManager class and any member docs) from triple-slash “///” to the
required Doxygen style using “//!” or “//!<” as appropriate; keep all existing
includes, namespaces and declarations (class BaseTransBufferManager, forward
declaration FabricMemory, etc.) unchanged other than the guard and comment style
swap.
In `@cpp/tensorrt_llm/batch_manager/dataTransceiver.cpp`:
- Around line 458-630: The response loop condition is wrong and mReadyResponses
access is not fully protected causing UB; change the loop in
BaseCacheSenderImpl::response to continue while (!mTerminate || mAnyReady) so we
still drain ready responses after terminate, and ensure all reads/iterators of
mReadyResponses are done under mSenderMutex: inside response() acquire
mSenderMutex, locate the iterator via
mReadyResponses.find/getCurrentRequestId(), move the Response out (std::move)
and erase the map entry while still holding mSenderMutex, then release the lock
before calling sendResponse/sendAndRemoveResponse/asyncSendAndRemoveResponse;
remove or adjust getCurrentResponse() so it does not return a dangling iterator
(or make it only be used while holding mSenderMutex) and ensure removeResponse()
still updates mAnyReady under mCondMutex as before.
In `@cpp/tensorrt_llm/batch_manager/rnnCacheFormatter.cpp`:
- Around line 123-140: Validate that selfRank yields a valid PP rank before
indexing selfNumLayerPerPP: ensure selfRank is non-negative and that computed
selfPPRank (selfRank / selfTPNum) is less than selfNumLayerPerPP.size() (or
equivalently selfRank < selfTPNum * selfPPNum). If the check fails, fail fast
(e.g., use TLLM_CHECK or return an error) so that the later use of
selfNumLayerPerPP[selfPPRank] in the function (around variables selfTPRank,
selfPPRank, selfStartLayerId, selfEndLayerId) cannot index out of bounds.
- Around line 18-26: Add a direct include for the header that declares std::iota
by adding `#include` <numeric> at the top of rnnCacheFormatter.cpp (near the other
includes) so that uses of std::iota (seen in this file) do not rely on
transitive includes; update the include section in the file containing
rnnCacheFormatter.cpp to reference <numeric>.
- Around line 36-46: RnnCacheFormatter::format and ::unformat currently always
TLLM_THROW and will crash runtime; replace the throws with real implementations
or safe no-ops: implement format(TransferSession& session) to check for RNN
state on the session (e.g., session.hasRnnState()/session.getRnnState()),
serialize the state into the session transfer buffer using the TransferSession
write/append APIs and set any flags/lengths expected by the receiver; implement
unformat(TransferSession& session) to parse the incoming buffer, deserialize RNN
state and restore it into the session (e.g.,
session.setRnnState()/session.applyRnnState()); if you cannot implement
serialization yet, remove the TLLM_THROW and instead perform a guarded
early-return that logs a warning and clears any RNN flags on TransferSession so
callers on sendAsync()/receiveAsync() paths do not hard-fail.
In `@cpp/tensorrt_llm/batch_manager/rnnCacheTransceiver.cpp`:
- Around line 138-142: The code calls mCacheTransceiverConfig.value() before
verifying the optional is set, which can throw; fix by guarding the optional
first (e.g., check mCacheTransceiverConfig.has_value() before calling value() or
use mCacheTransceiverConfig->getBackendType()) and update the
TLLM_CHECK_WITH_INFO to validate that mCacheTransceiverConfig has a value and
that its getBackendType() is present and not DEFAULT; reference
mCacheTransceiverConfig, CacheTransceiverConfig::BackendType, getBackendType,
and TLLM_CHECK_WITH_INFO to locate where to change the check.
- Around line 61-95: The RNN-specific request/response hooks are unimplemented:
implement RnnCacheSender::recvRequestInfo to deserialize a RequestInfo into the
sender's state and return a valid RequestInfo (mirroring other cache types), and
implement RnnCacheReceiver::sendRequestInfo to build and return a
TransferSession populated with RNN-specific request metadata (including
RnnCacheState and CommState via mManager->getCommState()) so
BaseCacheSenderImpl::response() and BaseCacheReceiverImpl::requestSync() can
operate; if you cannot implement transfer logic right now, explicitly gate the
feature by having these functions return a clear no-op/default (and not throw)
or return an error TransferSession/RequestInfo that upstream handles,
referencing the methods RnnCacheSender::recvRequestInfo and
RnnCacheReceiver::sendRequestInfo and the member mSelfState and mManager when
making changes.
- Around line 152-156: The call to dlerror() in the TLLM_CHECK_WITH_INFO around
mWrapperLibHandle (after dllOpen) is not portable and breaks Windows builds;
change the error text construction to use platform-specific error retrieval: on
POSIX keep dlerror(), on Windows use GetLastError() + FormatMessage to produce a
printable message, then pass that message into TLLM_CHECK_WITH_INFO. Update the
block around mWrapperLibHandle (and the same pattern wherever dlerror() is used)
so the error string is created conditionally (e.g., using `#ifdef` _WIN32) before
calling TLLM_CHECK_WITH_INFO.
In `@cpp/tensorrt_llm/batch_manager/rnnStateManager.cpp`:
- Around line 225-288: The shared containers mFreeBlocks, mCacheIndex, and
mTempBlocks are accessed concurrently; add a mutex member (e.g., mMutex) and use
a lock (std::lock_guard or std::scoped_lock) at the start of
allocateCacheBlocks, freeCacheBlock, getCacheIndex, and getStateIndices to
protect all reads/writes to those members (follow the same locking pattern used
in kvCacheManager/peftCacheManager); ensure locks cover early-return checks and
modifications (e.g., the TLLM_CHECK_WITH_INFO checks and
pop_back/push_back/erase operations) so no unsynchronized access occurs while
the GIL is released.
In `@cpp/tensorrt_llm/nanobind/batch_manager/bindings.cpp`:
- Around line 398-432: The exposed RnnStateManager methods allocateCacheBlocks,
freeCacheBlock, getStateIndices and the reader getCacheIndex access and mutate
shared members (mCacheIndex, mFreeBlocks, mTempBlocks) without synchronization;
remove nb::call_guard<nb::gil_scoped_release>() from those .def(...) bindings so
the GIL is held during calls, or alternatively add a std::mutex member to
tb::rnn_state_manager::RnnStateManager and lock it (e.g., std::scoped_lock)
inside the implementations of allocateCacheBlocks, freeCacheBlock,
getStateIndices and getCacheIndex to make container access thread-safe, then
keep or remove the gil_scoped_release accordingly.
In `@tensorrt_llm/_torch/attention_backend/interface.py`:
- Around line 295-312: The _prepare_mamba_metadata method can leave an existing
mamba_metadata instance sized for a previous max_num_requests, causing
buffer/share-size mismatches; update it to detect when an existing
self.mamba_metadata does not match current sizing (e.g., compare its configured
max_num_requests or mamba_chunk_size) and reinitialize it by creating a fresh
Mamba2Metadata(self.max_num_requests, self.mamba_chunk_size) (or set
self.mamba_metadata = False and recreate) before calling prepare; ensure this
logic lives in _prepare_mamba_metadata and references self.mamba_metadata,
self.max_num_requests, self.mamba_chunk_size and Mamba2Metadata so the metadata
is recreated whenever sizing changes.
In `@tensorrt_llm/_torch/pyexecutor/mamba_cache_manager.py`:
- Around line 94-111: The add_dummy_requests method currently has an unused
kwargs parameter causing lint ARG002; update its signature to rename kwargs to
_kwargs (def add_dummy_requests(self, request_ids: List[int], **_kwargs):) or
explicitly del kwargs at the start of the method, leaving the body calling
self.mamba_impl.allocate_cache_blocks(request_ids) unchanged; this silences the
linter while preserving behavior in add_dummy_requests and keeps references to
free_resources, get_cache_index, get_state_indices, get_conv_states, and
get_ssm_states intact.
🧹 Nitpick comments (12)
tensorrt_llm/_torch/modules/mamba/mamba2_metadata.py (1)
21-23: Keep CUDA graph runner namespace in imports.
Prefer module-qualified access rather than pulling the constant directly into the local namespace.As per coding guidelines, keep module namespaces for imports.♻️ Proposed refactor
-from tensorrt_llm._torch.pyexecutor.cuda_graph_runner import \ - CUDA_GRAPH_DUMMY_REQUEST_ID +import tensorrt_llm._torch.pyexecutor.cuda_graph_runner as cuda_graph_runner @@ - req_id == CUDA_GRAPH_DUMMY_REQUEST_ID + req_id == cuda_graph_runner.CUDA_GRAPH_DUMMY_REQUEST_IDcpp/include/tensorrt_llm/batch_manager/cacheTransceiver.h (1)
280-285: Align Doxygen style and parameter naming.
Use//!comments in headers and avoid them-prefix for non-member parameters. Also update the definition incpp/tensorrt_llm/batch_manager/cacheTransceiver.cppto match.As per coding guidelines, please align comment style and naming.♻️ Proposed refactor
-/// `@brief` Gather request IDs across all ranks in the communicator. -/// `@param` comm The communicator to use for gathering. -/// `@param` requestIds The local request IDs to gather. -/// `@return` All request IDs from all ranks. +//! `@brief` Gather request IDs across all ranks in the communicator. +//! `@param` comm The communicator to use for gathering. +//! `@param` requestIds The local request IDs to gather. +//! `@return` All request IDs from all ranks. std::vector<LlmRequest::RequestIdType> gatherRequestIds( - std::shared_ptr<CacheTransceiverComm> const& mComm, std::vector<LlmRequest::RequestIdType> const& requestIds); + std::shared_ptr<CacheTransceiverComm> const& comm, std::vector<LlmRequest::RequestIdType> const& requestIds);tensorrt_llm/_torch/pyexecutor/mamba_cache_manager.py (1)
20-26: Preserve namespace fortorch_dtype_to_bindingimport.
Keep module namespace to align with the import guidelines.As per coding guidelines, please keep module namespaces for imports.♻️ Proposed refactor
-import tensorrt_llm.bindings +import tensorrt_llm.bindings +import tensorrt_llm._utils as tllm_utils @@ -from tensorrt_llm._utils import torch_dtype_to_binding @@ - dtype_binding = torch_dtype_to_binding(dtype) - ssm_cache_dtype_binding = torch_dtype_to_binding( + dtype_binding = tllm_utils.torch_dtype_to_binding(dtype) + ssm_cache_dtype_binding = tllm_utils.torch_dtype_to_binding( ssm_cache_dtype if ssm_cache_dtype is not None else dtype)tensorrt_llm/_torch/attention_backend/interface.py (1)
24-25: Keepmamba_cache_managernamespace in imports.
Prefer module-qualified access instead of importing the class directly.As per coding guidelines, please keep module namespaces for imports.♻️ Proposed refactor
-from ..pyexecutor.mamba_cache_manager import MambaCacheManager +from ..pyexecutor import mamba_cache_manager @@ - if (self.kv_cache_manager is not None - and isinstance(self.kv_cache_manager, MambaCacheManager)): + if (self.kv_cache_manager is not None + and isinstance(self.kv_cache_manager, + mamba_cache_manager.MambaCacheManager)):tensorrt_llm/_torch/models/modeling_nemotron_h.py (1)
416-416: Consider removing unused attribute.The
self.mamba_metadataattribute is declared but no longer used in theforward()method sincemamba_metadatais now derived fromattn_metadata.mamba_metadata. This attribute appears to be dead code after this refactoring.tensorrt_llm/_torch/models/modeling_qwen3_next.py (1)
1200-1200: Consider removing unused attribute.Similar to
modeling_nemotron_h.py, theself.mamba_metadataattribute is declared but no longer used inforward()after this refactoring.cpp/include/tensorrt_llm/executor/serialization.h (1)
114-117: Add Doxygen comments for the new RnnCacheState serialization APIs.
These are new public header declarations; please add//!docs for deserialize/serialize/serializedSize to keep the header compliant.As per coding guidelines, please document new public APIs.
cpp/tests/unit_tests/batch_manager/rnnCacheFormatterTest.cpp (1)
19-35: Make the fixed model-config valuesconst/constexpr.
These locals never change; marking them immutable reduces accidental edits and aligns with style.As per coding guidelines, please prefer `const` for immutable locals.♻️ Proposed tweak
- SizeType32 dState = 16; - SizeType32 dConv = 4; - SizeType32 hiddenSize = 256; - SizeType32 headDim = 64; - SizeType32 convDimSize = 128; - SizeType32 nGroups = 1; - SizeType32 numHeads = 4; - auto convDtype = nvinfer1::DataType::kFLOAT; - auto ssmDtype = nvinfer1::DataType::kFLOAT; + SizeType32 const dState = 16; + SizeType32 const dConv = 4; + SizeType32 const hiddenSize = 256; + SizeType32 const headDim = 64; + SizeType32 const convDimSize = 128; + SizeType32 const nGroups = 1; + SizeType32 const numHeads = 4; + auto const convDtype = nvinfer1::DataType::kFLOAT; + auto const ssmDtype = nvinfer1::DataType::kFLOAT;cpp/include/tensorrt_llm/executor/dataTransceiverState.h (1)
419-584: Add Doxygen//!docs for the new RnnCacheState and accessors.
These are new public header APIs; please document them with//!(and//!<for members where applicable).As per coding guidelines, new public interfaces should be Doxygen-documented.
cpp/tensorrt_llm/batch_manager/rnnStateManager.cpp (1)
86-133: Consider simplifyinggetPpLayers- the ppSize==1 branch duplicates logic.The two branches (ppSize==1 and ppSize>1) have similar logic that could potentially be unified. When ppSize==1, the formula
startLayer = ppRank * layersPerRank + std::min(ppRank, remainder)would yieldstartLayer = 0andendLayer = totalLayers, making the multi-rank branch work correctly for the single-rank case.♻️ Suggested simplification
std::vector<RnnStateManager::SizeType32> RnnStateManager::getPpLayers( SizeType32 numLayers, WorldConfig const& worldConfig, std::optional<std::vector<bool>> const& layerMask) { auto const ppSize = worldConfig.getPipelineParallelism(); auto const ppRank = worldConfig.getPipelineParallelRank(); std::vector<SizeType32> ppLayers; - if (ppSize == 1) - { - SizeType32 totalLayers = numLayers; - if (layerMask.has_value()) - { - totalLayers = static_cast<SizeType32>(layerMask->size()); - } - for (SizeType32 i = 0; i < totalLayers; ++i) - { - if (!layerMask.has_value() || (*layerMask)[i]) - { - ppLayers.push_back(i); - } - } - } - else - { - SizeType32 totalLayers = numLayers; - if (layerMask.has_value()) - { - totalLayers = static_cast<SizeType32>(layerMask->size()); - } + SizeType32 totalLayers = numLayers; + if (layerMask.has_value()) + { + totalLayers = static_cast<SizeType32>(layerMask->size()); + } - SizeType32 layersPerRank = totalLayers / ppSize; - SizeType32 remainder = totalLayers % ppSize; + SizeType32 layersPerRank = totalLayers / ppSize; + SizeType32 remainder = totalLayers % ppSize; - SizeType32 startLayer = ppRank * layersPerRank + std::min(ppRank, remainder); - SizeType32 endLayer = startLayer + layersPerRank + (ppRank < remainder ? 1 : 0); + SizeType32 startLayer = ppRank * layersPerRank + std::min(ppRank, remainder); + SizeType32 endLayer = startLayer + layersPerRank + (ppRank < remainder ? 1 : 0); - for (SizeType32 i = startLayer; i < endLayer; ++i) + for (SizeType32 i = startLayer; i < endLayer; ++i) + { + if (!layerMask.has_value() || (*layerMask)[i]) { - if (!layerMask.has_value() || (*layerMask)[i]) - { - ppLayers.push_back(i); - } + ppLayers.push_back(i); } } return ppLayers; }cpp/tensorrt_llm/batch_manager/rnnCacheTransceiver.cpp (2)
162-164: Avoid C‑style cast fordlsymresult.Line 163 uses a C‑style cast, which violates the C++ guidelines and is non‑idiomatic for function pointers.
♻️ Suggested approach
- std::unique_ptr<executor::kv_cache::ConnectionManager> (*makeUcxConnectionManager)(); - *(void**) (&makeUcxConnectionManager) = load_sym(mWrapperLibHandle, "makeUcxConnectionManager"); + using MakeUcxFn = std::unique_ptr<executor::kv_cache::ConnectionManager>(*)(); + auto makeUcxConnectionManager = + reinterpret_cast<MakeUcxFn>(load_sym(mWrapperLibHandle, "makeUcxConnectionManager"));
235-239: Layer‑wise RNN transfer currently throws.Line 238 throws unconditionally. If callers can reach this override, it will abort the request flow. Consider implementing it or explicitly documenting/guarding this path.
| /* | ||
| * Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. | ||
| * | ||
| * Licensed under the Apache License, Version 2.0 (the "License"); | ||
| * ... | ||
| */ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Update the copyright year to 2026.
Line 2 still lists 2025. Please update to the year of latest meaningful modification.
As per coding guidelines, this should be 2026.
🤖 Prompt for AI Agents
In `@cpp/include/tensorrt_llm/batch_manager/rnnCacheTransceiver.h` around lines 1
- 6, Update the file header comment's copyright year from 2025 to 2026 in the
top-of-file comment block of tensorrt_llm/batch_manager/rnnCacheTransceiver.h;
locate the copyright line in the file's leading comment (the block beginning
"/*" at the top) and change "2025" to "2026" so the header reflects the latest
modification year.
| #pragma once | ||
|
|
||
| #include "tensorrt_llm/batch_manager/cacheTransceiver.h" | ||
| #include "tensorrt_llm/batch_manager/common.h" | ||
| #include "tensorrt_llm/batch_manager/dataTransceiver.h" | ||
| #include "tensorrt_llm/batch_manager/llmRequest.h" | ||
| #include "tensorrt_llm/batch_manager/rnnCacheFormatter.h" | ||
| #include "tensorrt_llm/batch_manager/rnnStateManager.h" | ||
| #include "tensorrt_llm/executor/cacheCommunicator.h" | ||
| #include "tensorrt_llm/executor/dataTransceiverState.h" | ||
| #include "tensorrt_llm/runtime/worldConfig.h" | ||
|
|
||
| #include <future> | ||
| #include <memory> | ||
| #include <mutex> | ||
| #include <optional> | ||
| #include <vector> | ||
|
|
||
| using SizeType32 = tensorrt_llm::runtime::SizeType32; | ||
|
|
||
| namespace tensorrt_llm::batch_manager | ||
| { | ||
|
|
||
| namespace rnn_state_manager | ||
| { | ||
| class RnnStateManager; | ||
| } | ||
|
|
||
| /// @brief RNN cache sender - inherits from BaseCacheSenderImpl with RNN-specific logic. | ||
| class RnnCacheSender : public BaseCacheSenderImpl | ||
| { | ||
| public: | ||
| /// @brief Constructor. | ||
| RnnCacheSender(executor::kv_cache::ConnectionManager* manager, executor::rnn_cache::RnnCacheState selfCacheState, | ||
| SizeType32 selfIndex, std::unique_ptr<BaseCacheFormatter> formatter); | ||
|
|
||
| /// @brief Receive request information - RNN-specific implementation. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use TRTLLM_ include guards and //! Doxygen comments.*
Line 8 uses #pragma once, and the public API docs use ///. The guidelines require an explicit TRTLLM_<FILENAME>_H guard and //!///!< Doxygen style.
As per coding guidelines, please switch to the mandated guard and comment format.
🧰 Tools
🪛 Clang (14.0.6)
[error] 10-10: 'tensorrt_llm/batch_manager/cacheTransceiver.h' file not found
(clang-diagnostic-error)
🤖 Prompt for AI Agents
In `@cpp/include/tensorrt_llm/batch_manager/rnnCacheTransceiver.h` around lines 8
- 44, Replace the non-compliant header protection and comment style: remove
"#pragma once" and add the explicit include guard "TRTLLM_RNNCACHETRANSCEIVER_H"
(i.e., `#ifndef` TRTLLM_RNNCACHETRANSCEIVER_H / `#define` ... / `#endif`) at the
top/bottom of the header, and convert C++03-style triple-slash Doxygen comments
(e.g., "/// `@brief`" and other "///" comments associated with RnnCacheSender, its
constructor and members) to the project standard "//!"/"//!<" style so
class-level docs and member descriptions (for RnnCacheSender, RnnCacheFormatter
references, etc.) use //! and //!< where appropriate. Ensure the guard macro
name matches the file identifier exactly (TRTLLM_RNNCACHETRANSCEIVER_H) and
update any file-local comment markers accordingly.
| private: | ||
| friend class Serialization; | ||
| std::optional<kv_cache::CacheState> mCacheState; | ||
| std::optional<kv_cache::CommState> mCommState; | ||
| std::optional<rnn_cache::RnnCacheState> mRnnCacheState; // is this needed? | ||
| }; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove the inline “is this needed?” comment.
If this is a TODO, please track it explicitly; otherwise remove the ambiguity.
🧹 Suggested cleanup
- std::optional<rnn_cache::RnnCacheState> mRnnCacheState; // is this needed?
+ std::optional<rnn_cache::RnnCacheState> mRnnCacheState;📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| private: | |
| friend class Serialization; | |
| std::optional<kv_cache::CacheState> mCacheState; | |
| std::optional<kv_cache::CommState> mCommState; | |
| std::optional<rnn_cache::RnnCacheState> mRnnCacheState; // is this needed? | |
| }; | |
| private: | |
| friend class Serialization; | |
| std::optional<kv_cache::CacheState> mCacheState; | |
| std::optional<kv_cache::CommState> mCommState; | |
| std::optional<rnn_cache::RnnCacheState> mRnnCacheState; | |
| }; |
🤖 Prompt for AI Agents
In `@cpp/include/tensorrt_llm/executor/dataTransceiverState.h` around lines 610 -
615, Remove the inline comment " // is this needed?" next to the mRnnCacheState
declaration in the class (friend class Serialization;
std::optional<rnn_cache::RnnCacheState> mRnnCacheState), and if this represents
a tracked TODO open an issue or add a documented TODO elsewhere rather than
leaving an ambiguous inline note; simply delete the comment or replace it with a
proper tracked reference.
| /* | ||
| * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| * SPDX-License-Identifier: Apache-2.0 | ||
| * | ||
| * Licensed under the Apache License, Version 2.0 (the "License"); | ||
| * you may not use this file except in compliance with the License. | ||
| * You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Update the copyright year to 2026.
Line 2 still lists 2025; this file is new in a 2026 PR.
As per coding guidelines, the header should reflect the latest meaningful modification year.
🤖 Prompt for AI Agents
In `@cpp/tensorrt_llm/batch_manager/baseTransBuffer.cpp` around lines 1 - 16,
Update the copyright year in the file header comment: locate the SPDX header
line that reads "Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES" (the SPDX
comment block near the top of baseTransBuffer.cpp) and change the year from 2025
to 2026 so the header reflects the current modification year; ensure the
SPDX-License-Identifier and surrounding comment formatting remain unchanged.
| /* | ||
| * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| * SPDX-License-Identifier: Apache-2.0 | ||
| * | ||
| * Licensed under the Apache License, Version 2.0 (the "License"); | ||
| * you may not use this file except in compliance with the License. | ||
| * You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Update the copyright year to 2026.
Line 2 still lists 2025; please align with the 2026 modification.
As per coding guidelines, the header year should match the latest meaningful change.
🤖 Prompt for AI Agents
In `@cpp/tensorrt_llm/batch_manager/baseTransBuffer.h` around lines 1 - 16, The
file header's copyright year is outdated (shows 2025); update the year to 2026
in the top comment — specifically modify the SPDX-FileCopyrightText line in
baseTransBuffer.h to read 2026 so the header matches the latest changes.
| std::lock_guard<std::mutex> lock(mDllMutex); | ||
| mWrapperLibHandle = dllOpen(UCX_WRAPPER_LIB_NAME); | ||
| TLLM_CHECK_WITH_INFO( | ||
| mWrapperLibHandle != nullptr, "UCX wrapper library is not open correctly. error : %s", dlerror()); | ||
| auto load_sym = [](void* handle, char const* name) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
🏁 Script executed:
cat -n cpp/tensorrt_llm/batch_manager/rnnCacheTransceiver.cpp | sed -n '140,170p'Repository: NVIDIA/TensorRT-LLM
Length of output: 1938
🏁 Script executed:
head -80 cpp/tensorrt_llm/batch_manager/rnnCacheTransceiver.cppRepository: NVIDIA/TensorRT-LLM
Length of output: 2911
🏁 Script executed:
cat -n cpp/tensorrt_llm/batch_manager/rnnCacheTransceiver.cpp | sed -n '130,180p'Repository: NVIDIA/TensorRT-LLM
Length of output: 2937
Windows build breaks: dlerror() is undefined on Windows.
Line 155 calls dlerror() without platform guards. On Windows, dlerror() doesn't exist in the C standard library. Use GetLastError() / FormatMessage() for Windows and dlerror() for non-Windows platforms.
Example fix
- TLLM_CHECK_WITH_INFO(
- mWrapperLibHandle != nullptr, "UCX wrapper library is not open correctly. error : %s", dlerror());
+#if defined(_WIN32)
+ TLLM_CHECK_WITH_INFO(mWrapperLibHandle != nullptr,
+ "UCX wrapper library is not open correctly. error: %lu", GetLastError());
+#else
+ TLLM_CHECK_WITH_INFO(
+ mWrapperLibHandle != nullptr, "UCX wrapper library is not open correctly. error : %s", dlerror());
+#endif🤖 Prompt for AI Agents
In `@cpp/tensorrt_llm/batch_manager/rnnCacheTransceiver.cpp` around lines 152 -
156, The call to dlerror() in the TLLM_CHECK_WITH_INFO around mWrapperLibHandle
(after dllOpen) is not portable and breaks Windows builds; change the error text
construction to use platform-specific error retrieval: on POSIX keep dlerror(),
on Windows use GetLastError() + FormatMessage to produce a printable message,
then pass that message into TLLM_CHECK_WITH_INFO. Update the block around
mWrapperLibHandle (and the same pattern wherever dlerror() is used) so the error
string is created conditionally (e.g., using `#ifdef` _WIN32) before calling
TLLM_CHECK_WITH_INFO.
| void RnnStateManager::allocateCacheBlocks(std::vector<RequestIdType> const& requestIds) | ||
| { | ||
| for (auto const& requestId : requestIds) | ||
| { | ||
| auto it = mCacheIndex.find(requestId); | ||
| if (it == mCacheIndex.end()) | ||
| { | ||
| TLLM_CHECK_WITH_INFO(!mFreeBlocks.empty(), "Run out of RNN state cache blocks"); | ||
| SizeType32 const block = mFreeBlocks.back(); | ||
| mFreeBlocks.pop_back(); | ||
| mCacheIndex[requestId] = block; | ||
| } | ||
| } | ||
| } | ||
|
|
||
| void RnnStateManager::freeCacheBlock(RequestIdType requestId) | ||
| { | ||
| auto it = mCacheIndex.find(requestId); | ||
| if (it != mCacheIndex.end()) | ||
| { | ||
| mFreeBlocks.push_back(it->second); | ||
| mCacheIndex.erase(it); | ||
| } | ||
| } | ||
|
|
||
| RnnStateManager::SizeType32 RnnStateManager::getCacheIndex(RequestIdType requestId) const | ||
| { | ||
| auto it = mCacheIndex.find(requestId); | ||
| TLLM_CHECK_WITH_INFO(it != mCacheIndex.end(), "Request ID not found in cache index"); | ||
| return it->second; | ||
| } | ||
|
|
||
| std::vector<RnnStateManager::SizeType32> RnnStateManager::getStateIndices( | ||
| std::vector<RequestIdType> const& requestIds, std::vector<bool> const& isPadding) | ||
| { | ||
| TLLM_CHECK_WITH_INFO(requestIds.size() == isPadding.size(), "requestIds and isPadding must have the same size"); | ||
|
|
||
| for (auto const& block : mTempBlocks) | ||
| { | ||
| mFreeBlocks.push_back(block); | ||
| } | ||
| mTempBlocks.clear(); | ||
|
|
||
| std::vector<SizeType32> result; | ||
| result.reserve(requestIds.size()); | ||
|
|
||
| for (size_t i = 0; i < requestIds.size(); ++i) | ||
| { | ||
| if (isPadding[i]) | ||
| { | ||
| TLLM_CHECK_WITH_INFO(!mFreeBlocks.empty(), "Run out of RNN state cache blocks for padding"); | ||
| SizeType32 const block = mFreeBlocks.back(); | ||
| mFreeBlocks.pop_back(); | ||
| mTempBlocks.push_back(block); | ||
| result.push_back(block); | ||
| } | ||
| else | ||
| { | ||
| result.push_back(getCacheIndex(requestIds[i])); | ||
| } | ||
| } | ||
|
|
||
| return result; | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
# Find RnnStateManager class definition and declaration
fd -e h -e cpp | xargs grep -l "class RnnStateManager" | head -5Repository: NVIDIA/TensorRT-LLM
Length of output: 182
🏁 Script executed:
#!/bin/bash
# Get the RnnStateManager class definition
cat -n cpp/tensorrt_llm/batch_manager/rnnStateManager.h | head -100Repository: NVIDIA/TensorRT-LLM
Length of output: 143
🏁 Script executed:
#!/bin/bash
# Check for mutexes and synchronization in RnnStateManager
rg -n "mutex|lock|Lock|Mutex" cpp/tensorrt_llm/batch_manager/rnnStateManager.h cpp/tensorrt_llm/batch_manager/rnnStateManager.cppRepository: NVIDIA/TensorRT-LLM
Length of output: 1981
🏁 Script executed:
#!/bin/bash
# Find all call sites of these RnnStateManager methods
rg -n "allocateCacheBlocks|freeCacheBlock|getStateIndices|getCacheIndex" --type cpp | grep -v "rnnStateManager.cpp:" | head -20Repository: NVIDIA/TensorRT-LLM
Length of output: 1890
🏁 Script executed:
#!/bin/bash
# Read the RnnStateManager header file
cat -n cpp/include/tensorrt_llm/batch_manager/rnnStateManager.hRepository: NVIDIA/TensorRT-LLM
Length of output: 4688
🏁 Script executed:
#!/bin/bash
# Check how RnnStateManager is used in the batch manager
rg -n "RnnStateManager" cpp/tensorrt_llm/batch_manager/ --type cpp | grep -E "(\.allocateCacheBlocks|\.freeCacheBlock|\.getStateIndices)" | head -20Repository: NVIDIA/TensorRT-LLM
Length of output: 45
🏁 Script executed:
#!/bin/bash
# Check batch manager threading model and synchronization
rg -n "std::mutex|std::lock|mutable.*mutex" cpp/tensorrt_llm/batch_manager/*.cpp | head -20Repository: NVIDIA/TensorRT-LLM
Length of output: 2369
🏁 Script executed:
#!/bin/bash
# Find where RnnStateManager methods are called in the batch manager
rg -n "allocateCacheBlocks|freeCacheBlock|getStateIndices" cpp/tensorrt_llm/batch_manager/ --type cpp -B3 -A3 | head -80Repository: NVIDIA/TensorRT-LLM
Length of output: 1828
🏁 Script executed:
#!/bin/bash
# Check if RnnStateManager is stored as a member in batch manager
rg -n "RnnStateManager" cpp/tensorrt_llm/batch_manager/batch_manager.h cpp/tensorrt_llm/batch_manager/batch_manager.cpp | head -30Repository: NVIDIA/TensorRT-LLM
Length of output: 238
🏁 Script executed:
#!/bin/bash
# Check the batch manager's threading model and synchronization
rg -n "mRequestMutex|mGenerationMutex|mQueueMutex" cpp/tensorrt_llm/batch_manager/batch_manager.h -A2 -B2 | head -50Repository: NVIDIA/TensorRT-LLM
Length of output: 149
🏁 Script executed:
#!/bin/bash
# Find batch manager files
find . -name "*batch_manager*" -type f | grep -E "\.(h|cpp)$" | head -20Repository: NVIDIA/TensorRT-LLM
Length of output: 45
🏁 Script executed:
#!/bin/bash
# Search for the class that owns RnnStateManager
rg -n "RnnStateManager\s+mRnnStateManager|std::.*RnnStateManager" --type cpp --type h | head -20Repository: NVIDIA/TensorRT-LLM
Length of output: 722
🏁 Script executed:
#!/bin/bash
# Check Python binding patterns to understand threading context
rg -n "\.allocate_cache_blocks|\.free_cache_block|\.get_state_indices" cpp/pybind/batch_manager/bindings.cpp -B5 -A5 | head -100Repository: NVIDIA/TensorRT-LLM
Length of output: 140
🏁 Script executed:
#!/bin/bash
# Find nanobind bindings file
find . -name "*bindings*" -type f | grep -E "\.(h|cpp)$" | head -20Repository: NVIDIA/TensorRT-LLM
Length of output: 1048
🏁 Script executed:
#!/bin/bash
# Check trtGptModelInflightBatching to understand threading model
rg -n "mRnnStateManager" cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.cpp -B3 -A3 | head -60Repository: NVIDIA/TensorRT-LLM
Length of output: 1633
🏁 Script executed:
#!/bin/bash
# Look for mutex in trtGptModelInflightBatching.h
rg -n "mutex|mMutex|mLock" cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.h | head -20Repository: NVIDIA/TensorRT-LLM
Length of output: 45
🏁 Script executed:
#!/antml:bash
# Check for thread safety patterns around allocateCacheBlocks and freeCacheBlock calls
rg -n "allocateCacheBlocks|freeCacheBlock" cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.cpp -B10 -A5 | head -100Repository: NVIDIA/TensorRT-LLM
Length of output: 45
🏁 Script executed:
#!/bin/bash
# Find where allocateCacheBlocks is actually called
rg -n "allocateCacheBlocks" --type cpp -B5 -A5 | head -80Repository: NVIDIA/TensorRT-LLM
Length of output: 4868
🏁 Script executed:
#!/bin/bash
# Check the runtime buffers to see if RnnStateManager is accessed there
rg -n "mRnnStateManager" cpp/tensorrt_llm/batch_manager/ --type cpp -B3 -A3 | head -80Repository: NVIDIA/TensorRT-LLM
Length of output: 6714
🏁 Script executed:
#!/bin/bash
# Look for synchronization patterns in trtGptModelInflightBatching
rg -n "mMutex|mLock|lock_guard|unique_lock" cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.h -B2 -A2 | head -50Repository: NVIDIA/TensorRT-LLM
Length of output: 45
🏁 Script executed:
#!/bin/bash
# Check if there's any documentation about thread safety in RnnStateManager
rg -n "thread|Thread|concurrent|Concurrent|safe|Safe" cpp/tensorrt_llm/batch_manager/rnnStateManager.cpp -B2 -A2Repository: NVIDIA/TensorRT-LLM
Length of output: 45
🏁 Script executed:
#!/bin/bash
# Check the py::call_guard documentation pattern - see if it's used elsewhere and how
rg -n "py::call_guard" cpp/tensorrt_llm/pybind/batch_manager/bindings.cpp | head -30Repository: NVIDIA/TensorRT-LLM
Length of output: 618
🏁 Script executed:
#!/bin/bash
# Look for RuntimeBuffers to understand if prepareStep also accesses RnnStateManager
rg -n "class RuntimeBuffers" --type cpp -A30 | head -60Repository: NVIDIA/TensorRT-LLM
Length of output: 5838
🏁 Script executed:
#!/bin/bash
# Verify the GIL release pattern is intentional by checking if other batch manager methods also use it
rg -n "py::call_guard<py::gil_scoped_release>" cpp/tensorrt_llm/pybind/batch_manager/bindings.cpp | wc -lRepository: NVIDIA/TensorRT-LLM
Length of output: 64
🏁 Script executed:
#!/bin/bash
# Check if there's any documentation or TODO about thread safety in RnnStateManager
rg -n "TODO|FIXME|thread|Thread|concurrent|safe" cpp/include/tensorrt_llm/batch_manager/rnnStateManager.h -iRepository: NVIDIA/TensorRT-LLM
Length of output: 45
🏁 Script executed:
#!/antml:bash
# Look for any synchronization at the batch manager level that might protect RnnStateManager
grep -n "mRnnStateManager" cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.h -B5 -A5 | head -40Repository: NVIDIA/TensorRT-LLM
Length of output: 746
Add thread synchronization to protect concurrent access to cache block state.
The Python bindings explicitly release the GIL with py::call_guard<py::gil_scoped_release>(), allowing concurrent calls from multiple Python threads. However, allocateCacheBlocks, freeCacheBlock, and getStateIndices modify shared state (mFreeBlocks, mCacheIndex, mTempBlocks) without any synchronization. Standard library containers like std::vector and std::unordered_map are not thread-safe for concurrent access.
Add a mutex (e.g., mMutex) to protect these operations, following the pattern used in kvCacheManager and peftCacheManager. Alternatively, document that single-threaded access is guaranteed by design and remove the GIL release guards.
🤖 Prompt for AI Agents
In `@cpp/tensorrt_llm/batch_manager/rnnStateManager.cpp` around lines 225 - 288,
The shared containers mFreeBlocks, mCacheIndex, and mTempBlocks are accessed
concurrently; add a mutex member (e.g., mMutex) and use a lock (std::lock_guard
or std::scoped_lock) at the start of allocateCacheBlocks, freeCacheBlock,
getCacheIndex, and getStateIndices to protect all reads/writes to those members
(follow the same locking pattern used in kvCacheManager/peftCacheManager);
ensure locks cover early-return checks and modifications (e.g., the
TLLM_CHECK_WITH_INFO checks and pop_back/push_back/erase operations) so no
unsynchronized access occurs while the GIL is released.
| nb::class_<tb::rnn_state_manager::RnnStateManager>(m, "RnnStateManager") | ||
| .def(nb::init<tr::SizeType32, tr::ModelConfig, tr::WorldConfig, tr::BufferManager>(), | ||
| nb::arg("max_num_sequences"), nb::arg("model_config"), nb::arg("world_config"), nb::arg("buffer_manager")); | ||
| nb::arg("max_num_sequences"), nb::arg("model_config"), nb::arg("world_config"), nb::arg("buffer_manager"), | ||
| nb::call_guard<nb::gil_scoped_release>()) | ||
| .def(nb::init<tr::SizeType32, tr::SizeType32, tr::SizeType32, tr::SizeType32, tr::SizeType32, tr::SizeType32, | ||
| tr::SizeType32, tr::WorldConfig const&, int64_t, nvinfer1::DataType, nvinfer1::DataType, | ||
| std::optional<std::vector<bool>> const&>(), | ||
| nb::arg("d_state"), nb::arg("d_conv"), nb::arg("num_heads"), nb::arg("n_groups"), nb::arg("head_dim"), | ||
| nb::arg("num_layers"), nb::arg("max_batch_size"), nb::arg("world_config"), nb::arg("stream"), | ||
| nb::arg("dtype"), nb::arg("ssm_cache_dtype"), nb::arg("layer_mask") = std::nullopt, | ||
| nb::call_guard<nb::gil_scoped_release>()) | ||
| .def("get_cache_index", &tb::rnn_state_manager::RnnStateManager::getCacheIndex, nb::arg("request_id"), | ||
| nb::call_guard<nb::gil_scoped_release>()) | ||
| .def( | ||
| "get_conv_states", | ||
| [](tb::rnn_state_manager::RnnStateManager& self, tr::SizeType32 layerIdx) -> at::Tensor | ||
| { | ||
| auto tensor = self.getConvStates(layerIdx); | ||
| return tr::Torch::tensor(tensor); | ||
| }, | ||
| nb::arg("layer_idx"), nb::call_guard<nb::gil_scoped_release>()) | ||
| .def( | ||
| "get_ssm_states", | ||
| [](tb::rnn_state_manager::RnnStateManager& self, tr::SizeType32 layerIdx) -> at::Tensor | ||
| { | ||
| auto tensor = self.getSsmStates(layerIdx); | ||
| return tr::Torch::tensor(tensor); | ||
| }, | ||
| nb::arg("layer_idx"), nb::call_guard<nb::gil_scoped_release>()) | ||
| .def("allocate_cache_blocks", &tb::rnn_state_manager::RnnStateManager::allocateCacheBlocks, | ||
| nb::arg("request_ids"), nb::call_guard<nb::gil_scoped_release>()) | ||
| .def("free_cache_block", &tb::rnn_state_manager::RnnStateManager::freeCacheBlock, nb::arg("request_id"), | ||
| nb::call_guard<nb::gil_scoped_release>()) | ||
| .def("get_state_indices", &tb::rnn_state_manager::RnnStateManager::getStateIndices, nb::arg("request_ids"), | ||
| nb::arg("is_padding"), nb::call_guard<nb::gil_scoped_release>()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
🏁 Script executed:
# Find RnnStateManager class definition
find . -type f \( -name "*.h" -o -name "*.hpp" -o -name "*.cpp" \) | xargs grep -l "class RnnStateManager" | head -20Repository: NVIDIA/TensorRT-LLM
Length of output: 182
🏁 Script executed:
# Search for the rnn_state_manager namespace/directory
find . -type d -name "*rnn*" | grep -E "state|manager"Repository: NVIDIA/TensorRT-LLM
Length of output: 45
🏁 Script executed:
# Look for files in batch_manager or related directories that contain RnnStateManager
find cpp -type f -name "*.h" -o -name "*.hpp" | xargs grep -l "RnnStateManager" | head -20Repository: NVIDIA/TensorRT-LLM
Length of output: 405
🏁 Script executed:
# Read the RnnStateManager header file
cat -n cpp/include/tensorrt_llm/batch_manager/rnnStateManager.hRepository: NVIDIA/TensorRT-LLM
Length of output: 4688
🏁 Script executed:
# Find RnnStateManager implementation file
find cpp -type f -name "*.cpp" | xargs grep -l "RnnStateManager::" | head -10Repository: NVIDIA/TensorRT-LLM
Length of output: 217
🏁 Script executed:
# Read the RnnStateManager implementation file
cat -n cpp/tensorrt_llm/batch_manager/rnnStateManager.cppRepository: NVIDIA/TensorRT-LLM
Length of output: 14506
Keep GIL or add synchronization for RnnStateManager cache operations.
allocateCacheBlocks, freeCacheBlock, and getStateIndices mutate shared containers (mCacheIndex, mFreeBlocks, mTempBlocks) without internal synchronization. Releasing the GIL with nb::call_guard<nb::gil_scoped_release>() allows concurrent Python threads to trigger data races on these containers. Similarly, getCacheIndex reads mCacheIndex which can be concurrently modified.
Either retain the GIL for these methods or add a mutex to RnnStateManager for thread-safe container access.
🤖 Prompt for AI Agents
In `@cpp/tensorrt_llm/nanobind/batch_manager/bindings.cpp` around lines 398 - 432,
The exposed RnnStateManager methods allocateCacheBlocks, freeCacheBlock,
getStateIndices and the reader getCacheIndex access and mutate shared members
(mCacheIndex, mFreeBlocks, mTempBlocks) without synchronization; remove
nb::call_guard<nb::gil_scoped_release>() from those .def(...) bindings so the
GIL is held during calls, or alternatively add a std::mutex member to
tb::rnn_state_manager::RnnStateManager and lock it (e.g., std::scoped_lock)
inside the implementations of allocateCacheBlocks, freeCacheBlock,
getStateIndices and getCacheIndex to make container access thread-safe, then
keep or remove the gil_scoped_release accordingly.
| self._prepare_mamba_metadata() | ||
|
|
||
| def _prepare_mamba_metadata(self): | ||
| if self.mamba_metadata is False: | ||
| return | ||
|
|
||
| if self.mamba_metadata is None: | ||
| if (self.kv_cache_manager is not None | ||
| and isinstance(self.kv_cache_manager, MambaCacheManager)): | ||
| from ..modules.mamba.mamba2_metadata import Mamba2Metadata | ||
| self.mamba_metadata = Mamba2Metadata(self.max_num_requests, | ||
| self.mamba_chunk_size) | ||
| else: | ||
| self.mamba_metadata = False | ||
| return | ||
|
|
||
| self.mamba_metadata.prepare(self) | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reinitialize Mamba metadata when sizing changes.
create_cuda_graph_metadata() uses a shallow copy; if mamba_metadata was already initialized, the new metadata instance can share buffers sized for the old max_num_requests, risking out-of-bounds or stale state. Reinitialize when the requested max changes.
🛠️ Suggested fix
def _prepare_mamba_metadata(self):
if self.mamba_metadata is False:
return
- if self.mamba_metadata is None:
- if (self.kv_cache_manager is not None
- and isinstance(self.kv_cache_manager, MambaCacheManager)):
- from ..modules.mamba.mamba2_metadata import Mamba2Metadata
- self.mamba_metadata = Mamba2Metadata(self.max_num_requests,
- self.mamba_chunk_size)
- else:
- self.mamba_metadata = False
- return
+ if self.kv_cache_manager is None or not isinstance(
+ self.kv_cache_manager, MambaCacheManager):
+ self.mamba_metadata = False
+ return
+
+ from ..modules.mamba.mamba2_metadata import Mamba2Metadata
+ if (self.mamba_metadata is None
+ or self.mamba_metadata.max_batch_size != self.max_num_requests):
+ self.mamba_metadata = Mamba2Metadata(self.max_num_requests,
+ self.mamba_chunk_size)
self.mamba_metadata.prepare(self)🤖 Prompt for AI Agents
In `@tensorrt_llm/_torch/attention_backend/interface.py` around lines 295 - 312,
The _prepare_mamba_metadata method can leave an existing mamba_metadata instance
sized for a previous max_num_requests, causing buffer/share-size mismatches;
update it to detect when an existing self.mamba_metadata does not match current
sizing (e.g., compare its configured max_num_requests or mamba_chunk_size) and
reinitialize it by creating a fresh Mamba2Metadata(self.max_num_requests,
self.mamba_chunk_size) (or set self.mamba_metadata = False and recreate) before
calling prepare; ensure this logic lives in _prepare_mamba_metadata and
references self.mamba_metadata, self.max_num_requests, self.mamba_chunk_size and
Mamba2Metadata so the metadata is recreated whenever sizing changes.
| def free_resources(self, request: LlmRequest): | ||
| request_id = request.py_request_id | ||
| if request_id in self.mamba_cache_index: | ||
| block = self.mamba_cache_index.pop(request_id) | ||
| self.mamba_cache_free_blocks.append(block) | ||
| self.mamba_impl.free_cache_block(request.py_request_id) | ||
|
|
||
| def add_dummy_requests(self, request_ids: List[int], **kwargs): | ||
| self.mamba_impl.allocate_cache_blocks(request_ids) | ||
|
|
||
| def get_state_indices(self) -> torch.Tensor: | ||
| return self.state_indices | ||
| def get_cache_index(self, request_id: int) -> int: | ||
| return self.mamba_impl.get_cache_index(request_id) | ||
|
|
||
| def get_state_indices(self, request_ids: List[int], | ||
| is_padding: List[bool]) -> List[int]: | ||
| return self.mamba_impl.get_state_indices(request_ids, is_padding) | ||
|
|
||
| def get_conv_states(self, layer_idx: int) -> torch.Tensor: | ||
| layer_offset = self.mamba_layer_offsets[layer_idx] | ||
| return self.conv_states[layer_offset] | ||
| return self.mamba_impl.get_conv_states(layer_idx) | ||
|
|
||
| def get_ssm_states(self, layer_idx: int) -> torch.Tensor: | ||
| layer_offset = self.mamba_layer_offsets[layer_idx] | ||
| return self.ssm_states[layer_offset] | ||
| return self.mamba_impl.get_ssm_states(layer_idx) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Silence unused kwargs in add_dummy_requests.
Ruff flags the unused kwargs (ARG002). If the signature is intentional, rename to _kwargs (or del kwargs) to keep lint clean.
🧹 Proposed fix
- def add_dummy_requests(self, request_ids: List[int], **kwargs):
+ def add_dummy_requests(self, request_ids: List[int], **_kwargs):
self.mamba_impl.allocate_cache_blocks(request_ids)🧰 Tools
🪛 Ruff (0.14.13)
97-97: Unused method argument: kwargs
(ARG002)
🤖 Prompt for AI Agents
In `@tensorrt_llm/_torch/pyexecutor/mamba_cache_manager.py` around lines 94 - 111,
The add_dummy_requests method currently has an unused kwargs parameter causing
lint ARG002; update its signature to rename kwargs to _kwargs (def
add_dummy_requests(self, request_ids: List[int], **_kwargs):) or explicitly del
kwargs at the start of the method, leaving the body calling
self.mamba_impl.allocate_cache_blocks(request_ids) unchanged; this silences the
linter while preserving behavior in add_dummy_requests and keeps references to
free_resources, get_cache_index, get_state_indices, get_conv_states, and
get_ssm_states intact.
|
PR_Github #33217 [ run ] completed with state
|
|
Will break this up into smaller PRs after cherry picking refactor only commits |
|
first PR #10957 |
Summary by CodeRabbit
New Features
Tests
✏️ Tip: You can customize this high-level summary in your review settings.
Description
Test Coverage
PR Checklist
Please review the following before submitting your PR:
PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.
PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.
Test cases are provided for new code paths (see test instructions)
Any new dependencies have been scanned for license and vulnerabilities
CODEOWNERS updated if ownership changes
Documentation updated as needed
Update tava architecture diagram if there is a significant design change in PR.
The reviewers assigned automatically/manually are appropriate for the PR.
Please check this after reviewing the above items as appropriate for this PR.
GitHub Bot Help
/bot [-h] ['run', 'kill', 'skip', 'reuse-pipeline'] ...Provide a user friendly way for developers to interact with a Jenkins server.
Run
/bot [-h|--help]to print this help message.See details below for each supported subcommand.
Details
run [--reuse-test (optional)pipeline-id --disable-fail-fast --skip-test --stage-list "A10-PyTorch-1, xxx" --gpu-type "A30, H100_PCIe" --test-backend "pytorch, cpp" --add-multi-gpu-test --only-multi-gpu-test --disable-multi-gpu-test --post-merge --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" --detailed-log --debug(experimental)]Launch build/test pipelines. All previously running jobs will be killed.
--reuse-test (optional)pipeline-id(OPTIONAL) : Allow the new pipeline to reuse build artifacts and skip successful test stages from a specified pipeline or the last pipeline if no pipeline-id is indicated. If the Git commit ID has changed, this option will be always ignored. The DEFAULT behavior of the bot is to reuse build artifacts and successful test results from the last pipeline.--disable-reuse-test(OPTIONAL) : Explicitly prevent the pipeline from reusing build artifacts and skipping successful test stages from a previous pipeline. Ensure that all builds and tests are run regardless of previous successes.--disable-fail-fast(OPTIONAL) : Disable fail fast on build/tests/infra failures.--skip-test(OPTIONAL) : Skip all test stages, but still run build stages, package stages and sanity check stages. Note: Does NOT update GitHub check status.--stage-list "A10-PyTorch-1, xxx"(OPTIONAL) : Only run the specified test stages. Examples: "A10-PyTorch-1, xxx". Note: Does NOT update GitHub check status.--gpu-type "A30, H100_PCIe"(OPTIONAL) : Only run the test stages on the specified GPU types. Examples: "A30, H100_PCIe". Note: Does NOT update GitHub check status.--test-backend "pytorch, cpp"(OPTIONAL) : Skip test stages which don't match the specified backends. Only support [pytorch, cpp, tensorrt, triton]. Examples: "pytorch, cpp" (does not run test stages with tensorrt or triton backend). Note: Does NOT update GitHub pipeline status.--only-multi-gpu-test(OPTIONAL) : Only run the multi-GPU tests. Note: Does NOT update GitHub check status.--disable-multi-gpu-test(OPTIONAL) : Disable the multi-GPU tests. Note: Does NOT update GitHub check status.--add-multi-gpu-test(OPTIONAL) : Force run the multi-GPU tests in addition to running L0 pre-merge pipeline.--post-merge(OPTIONAL) : Run the L0 post-merge pipeline instead of the ordinary L0 pre-merge pipeline.--extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx"(OPTIONAL) : Run the ordinary L0 pre-merge pipeline and specified test stages. Examples: --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx".--detailed-log(OPTIONAL) : Enable flushing out all logs to the Jenkins console. This will significantly increase the log volume and may slow down the job.--debug(OPTIONAL) : Experimental feature. Enable access to the CI container for debugging purpose. Note: Specify exactly one stage in thestage-listparameter to access the appropriate container environment. Note: Does NOT update GitHub check status.For guidance on mapping tests to stage names, see
docs/source/reference/ci-overview.mdand the
scripts/test_to_stage_mapping.pyhelper.kill
killKill all running builds associated with pull request.
skip
skip --comment COMMENTSkip testing for latest commit on pull request.
--comment "Reason for skipping build/test"is required. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.reuse-pipeline
reuse-pipelineReuse a previous pipeline to validate current commit. This action will also kill all currently running builds associated with the pull request. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.