Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion cpp/tensorrt_llm/nanobind/batch_manager/bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ void initBindings(nb::module_& m)
.def_rw("max_new_tokens", &GenLlmReq::mMaxNewTokens)
.def_rw("sampling_config", &GenLlmReq::mSamplingConfig)
.def_prop_rw("state", &GenLlmReq::getState, &GenLlmReq::setState)
.def_prop_ro("state_value", [](GenLlmReq const& self) { return static_cast<int>(self.getState()); })
.def_prop_rw("streaming", &GenLlmReq::isStreaming, &GenLlmReq::setStreaming)
.def_rw("end_id", &GenLlmReq::mEndId)
.def_rw("pad_id", &GenLlmReq::mPadId)
Expand Down Expand Up @@ -175,6 +176,7 @@ void initBindings(nb::module_& m)
.def_prop_ro("is_disagg_generation_transmission_complete", &GenLlmReq::isDisaggGenerationTransmissionComplete)
.def_prop_ro(
"is_disagg_generation_transmission_in_progress", &GenLlmReq::isDisaggGenerationTransmissionInProgress)
.def_prop_ro("is_encoder_init_state", &GenLlmReq::isEncoderInitState)
.def_prop_ro("is_context_init_state", &GenLlmReq::isContextInitState)
.def_prop_ro("is_generation_in_progress_state", &GenLlmReq::isGenerationInProgressState)
.def_prop_ro("is_disagg_context_transmission_state", &GenLlmReq::isDisaggContextTransmissionState)
Expand Down Expand Up @@ -253,7 +255,20 @@ void initBindings(nb::module_& m)
})
.def_prop_rw("is_dummy_request", &GenLlmReq::isDummyRequest, &GenLlmReq::setIsDummyRequest)
.def_prop_ro("return_perf_metrics", &GenLlmReq::getReturnPerfMetrics)
.def_prop_rw("use_draft_model", &GenLlmReq::useDraftModel, &GenLlmReq::setUseDraftModel);
.def_prop_rw("use_draft_model", &GenLlmReq::useDraftModel, &GenLlmReq::setUseDraftModel)
.def("get_unique_tokens", nb::overload_cast<GenLlmReq::SizeType32>(&GenLlmReq::getUniqueTokens, nb::const_),
nb::arg("beam"))
.def("get_unique_tokens", nb::overload_cast<>(&GenLlmReq::getUniqueTokens, nb::const_))
.def("get_encoder_unique_tokens",
[](GenLlmReq& self)
{
auto const& encoderUniqueTokens = self.getEncoderUniqueTokens();
if (encoderUniqueTokens.has_value() && encoderUniqueTokens.value())
{
return std::optional<GenLlmReq::VecUniqueTokens>(*encoderUniqueTokens.value());
}
return std::optional<GenLlmReq::VecUniqueTokens>(std::nullopt);
});

nb::class_<tb::LlmRequest, GenLlmReq>(m, "LlmRequest", nb::dynamic_attr())
.def(
Expand Down
11 changes: 10 additions & 1 deletion cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -481,6 +481,8 @@ void tb::kv_cache_manager::KVCacheManagerBindings::initBindings(nb::module_& m)
.def("store_context_blocks", &BaseKVCacheManager::storeContextBlocks, nb::call_guard<nb::gil_scoped_release>())
.def("store_blocks_for_reuse", &BaseKVCacheManager::storeBlocksForReuse,
nb::call_guard<nb::gil_scoped_release>())
.def("find_new_context_block", &BaseKVCacheManager::findNewContextBlock, nb::arg("unique_tokens"),
nb::arg("llm_request"), nb::call_guard<nb::gil_scoped_release>())
.def("get_cache_block_ids", &BaseKVCacheManager::getCacheBlockIds, nb::call_guard<nb::gil_scoped_release>())
.def("get_batch_cache_block_ids", &BaseKVCacheManager::getBatchCacheBlockIds,
nb::call_guard<nb::gil_scoped_release>())
Expand Down Expand Up @@ -524,7 +526,14 @@ void tb::kv_cache_manager::KVCacheManagerBindings::initBindings(nb::module_& m)
nb::arg("event_manager") = nullptr, nb::arg("enable_partial_reuse") = true,
nb::arg("copy_on_partial_reuse") = true, nb::arg("kv_connector_manager") = nullptr,
nb::arg("enable_indexer_k_cache") = false, nb::arg("indexer_k_cache_quant_block_size") = 128,
nb::arg("indexer_k_cache_index_head_dim") = 0, nb::call_guard<nb::gil_scoped_release>());
nb::arg("indexer_k_cache_index_head_dim") = 0, nb::call_guard<nb::gil_scoped_release>())
.def(
"scheduling_has_free_blocks",
[](tbk::KVCacheManager& self, SizeType32 numRequired, SizeType32 windowSize)
{ return self.getBlockManager().schedulingHasFreeBlocks(numRequired, windowSize); },
nb::arg("num_required"), nb::arg("window_size"), nb::call_guard<nb::gil_scoped_release>())
.def_prop_ro(
"is_variable_window", [](tbk::KVCacheManager& self) { return self.getBlockManager().isVariableWindow(); });
}

void tb::BasePeftCacheManagerBindings::initBindings(nb::module_& m)
Expand Down
17 changes: 16 additions & 1 deletion cpp/tensorrt_llm/pybind/batch_manager/bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ void initBindings(pybind11::module_& m)
.def_readwrite("max_new_tokens", &GenLlmReq::mMaxNewTokens)
.def_readwrite("sampling_config", &GenLlmReq::mSamplingConfig)
.def_property("state", &GenLlmReq::getState, &GenLlmReq::setState)
.def_property_readonly("state_value", [](GenLlmReq const& self) { return static_cast<int>(self.getState()); })
.def_property("streaming", &GenLlmReq::isStreaming, &GenLlmReq::setStreaming)
.def_readwrite("end_id", &GenLlmReq::mEndId)
.def_readwrite("pad_id", &GenLlmReq::mPadId)
Expand Down Expand Up @@ -181,6 +182,7 @@ void initBindings(pybind11::module_& m)
"is_disagg_generation_transmission_complete", &GenLlmReq::isDisaggGenerationTransmissionComplete)
.def_property_readonly(
"is_disagg_generation_transmission_in_progress", &GenLlmReq::isDisaggGenerationTransmissionInProgress)
.def_property_readonly("is_encoder_init_state", &GenLlmReq::isEncoderInitState)
.def_property_readonly("is_context_init_state", &GenLlmReq::isContextInitState)
.def_property_readonly("is_generation_in_progress_state", &GenLlmReq::isGenerationInProgressState)
.def_property_readonly("is_disagg_context_transmission_state", &GenLlmReq::isDisaggContextTransmissionState)
Expand Down Expand Up @@ -259,7 +261,20 @@ void initBindings(pybind11::module_& m)
})
.def_property("is_dummy_request", &GenLlmReq::isDummyRequest, &GenLlmReq::setIsDummyRequest)
.def_property_readonly("return_perf_metrics", &GenLlmReq::getReturnPerfMetrics)
.def_property("use_draft_model", &GenLlmReq::useDraftModel, &GenLlmReq::setUseDraftModel);
.def_property("use_draft_model", &GenLlmReq::useDraftModel, &GenLlmReq::setUseDraftModel)
.def("get_unique_tokens", py::overload_cast<GenLlmReq::SizeType32>(&GenLlmReq::getUniqueTokens, py::const_),
py::arg("beam"))
.def("get_unique_tokens", py::overload_cast<>(&GenLlmReq::getUniqueTokens, py::const_))
.def("get_encoder_unique_tokens",
[](GenLlmReq& self)
{
auto const& encoderUniqueTokens = self.getEncoderUniqueTokens();
if (encoderUniqueTokens.has_value() && encoderUniqueTokens.value())
{
return std::optional<GenLlmReq::VecUniqueTokens>(*encoderUniqueTokens.value());
}
return std::optional<GenLlmReq::VecUniqueTokens>(std::nullopt);
});

py::classh<tb::LlmRequest, GenLlmReq>(m, "LlmRequest", pybind11::dynamic_attr())
.def(py::init<>(
Expand Down
11 changes: 10 additions & 1 deletion cpp/tensorrt_llm/pybind/batch_manager/kvCacheManager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -485,6 +485,8 @@ void tb::kv_cache_manager::KVCacheManagerBindings::initBindings(py::module_& m)
.def("store_context_blocks", &BaseKVCacheManager::storeContextBlocks, py::call_guard<py::gil_scoped_release>())
.def("store_blocks_for_reuse", &BaseKVCacheManager::storeBlocksForReuse,
py::call_guard<py::gil_scoped_release>())
.def("find_new_context_block", &BaseKVCacheManager::findNewContextBlock, py::arg("unique_tokens"),
py::arg("llm_request"), py::call_guard<py::gil_scoped_release>())
.def("get_cache_block_ids", &BaseKVCacheManager::getCacheBlockIds, py::call_guard<py::gil_scoped_release>())
.def("get_batch_cache_block_ids", &BaseKVCacheManager::getBatchCacheBlockIds,
py::call_guard<py::gil_scoped_release>())
Expand Down Expand Up @@ -519,7 +521,14 @@ void tb::kv_cache_manager::KVCacheManagerBindings::initBindings(py::module_& m)
py::arg("enable_partial_reuse") = true, py::arg("copy_on_partial_reuse") = true,
py::arg("kv_connector_manager") = nullptr, py::arg("enable_indexer_k_cache") = false,
py::arg("indexer_k_cache_quant_block_size") = 128, py::arg("indexer_k_cache_index_head_dim") = 0,
py::call_guard<py::gil_scoped_release>());
py::call_guard<py::gil_scoped_release>())
.def(
"scheduling_has_free_blocks",
[](tbk::KVCacheManager& self, SizeType32 numRequired, SizeType32 windowSize)
{ return self.getBlockManager().schedulingHasFreeBlocks(numRequired, windowSize); },
py::arg("num_required"), py::arg("window_size"), py::call_guard<py::gil_scoped_release>())
.def_property_readonly(
"is_variable_window", [](tbk::KVCacheManager& self) { return self.getBlockManager().isVariableWindow(); });
}

void tb::BasePeftCacheManagerBindings::initBindings(py::module_& m)
Expand Down
59 changes: 36 additions & 23 deletions cpp/tensorrt_llm/thop/moeAlltoAllOp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ namespace torch_ext
namespace moe_comm
{

static constexpr size_t CACHELINE_ALIGNMENT = 128;

// TODO: Is Alignment necessary?
// Helper function to align offset to specified byte boundary
inline size_t alignOffset(size_t offset, size_t alignment)
Expand All @@ -46,7 +48,6 @@ MoeA2ADataOffsets calculateOffsets(int epSize, int maxNumTokens)
// TODO: Use lambdas to encapsulate offset and alignment for each entry, which is less error prone and easier to
// read.
constexpr size_t SIZEOF_INT32 = 4;
constexpr size_t CACHELINE_ALIGNMENT = 128;

MoeA2ADataOffsets offsets;
size_t offset = 0;
Expand Down Expand Up @@ -203,29 +204,43 @@ std::tuple<std::vector<torch::Tensor>, int64_t> moeA2ADispatchOp(torch::Tensor c
TORCH_CHECK(payload.is_contiguous(), "All payloads must be contiguous");
}

// Calculate buffer sizes for all payloads
// Each payload buffer needs space for data from ALL ranks: epSize * maxTokensPerRank * elementsPerToken
int64_t totalBytesNeeded = 0;
std::vector<int64_t> payloadByteSizes;
// Record the cacheline aligned start offset for each payload's recv buffer.
// 1. We assume the base workspace ptr of each rank is aligned (checked in this OP)
// 2. offsets[PAYLOAD_DATA_OFFSET_INDEX] is aligned (ensured in calculateOffsets)
// 3. We align the currentOffset during update.
// In this way, it is guaranteed that the recv buffer is (over-)aligned, sufficient for 128bit vectorized ld/st.

std::vector<int> payloadElementSizes;
std::vector<int> payloadElementsPerToken;
std::vector<size_t> payloadRecvBufferOffsets;

// Start offset for the first payload
size_t currentOffset = static_cast<size_t>(offsets[PAYLOAD_DATA_OFFSET_INDEX]);
for (auto const& payload : inputPayloads)
{
CHECK_CONTIGUOUS(payload);
CHECK_TH_CUDA(payload);
TORCH_CHECK(payload.dim() == 2, "payload must be a 2D tensor");
TORCH_CHECK(
payload.size(0) == localNumTokens, "payload must have the same first dimension as tokenSelectedExperts");
// Unlike recv buffer for payloads, payload itself is not allocated by us and we cannot control its alignment.
// We only make sure the payload start offset is 16-byte aligned, while the actual vectorized ld/st width is
// dynamically determined based on bytes per token of this payload.
TORCH_CHECK(reinterpret_cast<uintptr_t>(payload.data_ptr()) % 16 == 0, "payload must be 16-byte aligned");

int elementsPerToken = static_cast<int>(payload.size(1));
int elementSize = static_cast<int>(payload.dtype().itemsize());
// Each payload buffer stores data from ALL ranks
int64_t bytesPerPayload = epSize * runtimeMaxTokensPerRank * elementsPerToken * elementSize;

payloadByteSizes.push_back(bytesPerPayload);
payloadElementSizes.push_back(elementSize);
payloadElementsPerToken.push_back(elementsPerToken);
totalBytesNeeded += bytesPerPayload;

payloadRecvBufferOffsets.push_back(currentOffset);

// Update offset and align to cacheline boundary for the next payload recv buffer.
currentOffset += bytesPerPayload;
currentOffset = alignOffset(currentOffset, CACHELINE_ALIGNMENT);
}

CHECK_TH_CUDA(workspace);
Expand All @@ -236,16 +251,18 @@ std::tuple<std::vector<torch::Tensor>, int64_t> moeA2ADispatchOp(torch::Tensor c

// Validate workspace size - must include space for auxiliary data + payloads
int64_t sizePerRank = workspace.size(1);
int64_t requiredSize = offsets[PAYLOAD_DATA_OFFSET_INDEX] + totalBytesNeeded;
int64_t requiredSize = static_cast<int64_t>(currentOffset);
TORCH_CHECK(sizePerRank >= requiredSize,
"Workspace size per rank insufficient for dispatch. "
"Need at least ",
requiredSize, " bytes (", offsets[PAYLOAD_DATA_OFFSET_INDEX], " for auxiliary data + ", totalBytesNeeded,
" for payloads), but got ", sizePerRank);
requiredSize, " bytes (", offsets[PAYLOAD_DATA_OFFSET_INDEX], " for auxiliary data + payloads), but got ",
sizePerRank);

// Get base workspace pointer
uint8_t* workspacePtr = workspace.data_ptr<uint8_t>();
uint8_t* rankWorkSpacePtr = workspacePtr + epRank * workspace.stride(0);
TORCH_CHECK(reinterpret_cast<uintptr_t>(rankWorkSpacePtr) % CACHELINE_ALIGNMENT == 0,
"rankWorkSpacePtr must be %d-byte aligned", CACHELINE_ALIGNMENT);

// Setup payload descriptors for source data
int num_payloads = static_cast<int>(inputPayloads.size());
Expand Down Expand Up @@ -288,13 +305,10 @@ std::tuple<std::vector<torch::Tensor>, int64_t> moeA2ADispatchOp(torch::Tensor c
params.completion_flags[target_rank]
= reinterpret_cast<uint32_t*>(targetWorkSpacePtr + offsets[DISPATCH_COMPLETION_FLAGS_OFFSET_INDEX]);

size_t offset = static_cast<size_t>(offsets[PAYLOAD_DATA_OFFSET_INDEX]);
for (int payload_idx = 0; payload_idx < num_payloads; payload_idx++)
{
// Store pointer for current payload
params.recv_buffers[target_rank][payload_idx] = targetWorkSpacePtr + offset;
// Update offset for next payload
offset += payloadByteSizes[payload_idx];
// Store pointer for current payload using pre-calculated aligned offset
params.recv_buffers[target_rank][payload_idx] = targetWorkSpacePtr + payloadRecvBufferOffsets[payload_idx];
}
}

Expand All @@ -310,22 +324,17 @@ std::tuple<std::vector<torch::Tensor>, int64_t> moeA2ADispatchOp(torch::Tensor c

// Create tensor views for the current rank's receive buffers only
std::vector<torch::Tensor> recvTensors;
size_t offset = static_cast<size_t>(offsets[PAYLOAD_DATA_OFFSET_INDEX]);
for (int payload_idx = 0; payload_idx < num_payloads; payload_idx++)
{
auto const& payload = inputPayloads[payload_idx];
// Create tensor view for this payload
auto recvTensor = torch::from_blob(rankWorkSpacePtr + offset,
// Create tensor view for this payload using pre-calculated aligned offset
auto recvTensor = torch::from_blob(rankWorkSpacePtr + payloadRecvBufferOffsets[payload_idx],
{epSize, runtimeMaxTokensPerRank, payloadElementsPerToken[payload_idx]}, payload.options());
recvTensors.push_back(recvTensor);

// Update offset for next payload
offset += payloadByteSizes[payload_idx];
}

// Compute aligned offset after dispatch payloads for combine payload region
constexpr size_t CACHELINE_ALIGNMENT = 128;
int64_t combinePayloadOffset = static_cast<int64_t>(alignOffset(static_cast<size_t>(offset), CACHELINE_ALIGNMENT));
int64_t combinePayloadOffset = static_cast<int64_t>(alignOffset(currentOffset, CACHELINE_ALIGNMENT));

return std::make_tuple(std::move(recvTensors), combinePayloadOffset);
}
Expand Down Expand Up @@ -356,6 +365,9 @@ torch::Tensor moeA2ACombineOp(torch::Tensor const& payload, int64_t localNumToke
TORCH_CHECK(payload.size(0) == epSize, "payload first dimension must equal epSize");
TORCH_CHECK(
payload.size(1) == runtimeMaxTokensPerRank, "payload second dimension must equal runtimeMaxTokensPerRank");
// We only make sure the payload start offset is 16-byte aligned, while the actual vectorized ld/st width is
// dynamically determined based on bytes per token of this payload.
TORCH_CHECK(reinterpret_cast<uintptr_t>(payload.data_ptr()) % 16 == 0, "payload must be 16-byte aligned");
int64_t elementsPerToken = payload.size(2);
TORCH_CHECK(elementsPerToken > 0, "elementsPerToken must be positive");
TORCH_CHECK(epRank >= 0 && epRank < epSize, "epRank must be in the range [0, epSize)");
Expand Down Expand Up @@ -411,6 +423,7 @@ torch::Tensor moeA2ACombineOp(torch::Tensor const& payload, int64_t localNumToke
" for payload), but got ", sizePerRank);

// Create output tensor (local on current rank), no need for initialization
// Typically, newly allocated GPU torch tensors are at least 16-byte aligned.
torch::Tensor output = torch::empty({localNumTokens, elementsPerToken}, payload.options());

// Setup combine parameters
Expand Down
4 changes: 2 additions & 2 deletions security_scanning/metadata.json
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
{
"commit_hash": "df845a028b84d2d1ff02ffcaf2631d0007f4f79a",
"timestamp": "2026-01-19T05:51:38Z"
"commit_hash": "dbb858ae0cadf616d58defc1e2c4b35c97a20b63",
"timestamp": "2026-01-20T02:46:13Z"
}
32 changes: 19 additions & 13 deletions tensorrt_llm/_torch/distributed/moe_alltoall.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,25 +54,31 @@ def calculate_required_workspace_size(
dtype: torch.dtype,
extra_payload_bytes_per_token: int = 0) -> int:
element_size = dtype.itemsize

# Auxiliary data size
aux_size = MoeAlltoAll.get_aux_data_size(ep_size, max_num_tokens)
workspace_size = MoeAlltoAll.get_aux_data_size(ep_size, max_num_tokens)

# Dispatch needs workspace for [ep_size, max_tokens] tokens,
# but due to the variety of quantization recipes, we cannot know the exact size,
# so we conservatively estimate assuming no quantization.
payload_size_dispatch = ep_size * max_num_tokens * (
hidden_size * element_size # (Unquantized) token hidden states
+ top_k * 4 # token_selected_experts
+ top_k * 4 # token_final_scales
+ extra_payload_bytes_per_token # extra payload bytes per token
)
# but due to the variety of quantization recipes, we cannot know the exact size, so we conservatively estimate assuming no quantization.
# Meanwhile, we consider the alignment requirement as in moeA2ADispatchOp and moeA2ACombineOp.
# (Unquantized) token hidden states
workspace_size += ep_size * max_num_tokens * hidden_size * element_size
workspace_size = pad_up(workspace_size, 128)
# token_selected_experts
workspace_size += ep_size * max_num_tokens * top_k * 4
workspace_size = pad_up(workspace_size, 128)
# token_final_scales
workspace_size += ep_size * max_num_tokens * top_k * 4
workspace_size = pad_up(workspace_size, 128)
# extra payload bytes per token
workspace_size += ep_size * max_num_tokens * extra_payload_bytes_per_token
workspace_size = pad_up(workspace_size, 128)

# Required workspace for combine [ep_size, max_tokens] tokens
payload_size_combine = ep_size * max_num_tokens * hidden_size * element_size
workspace_size += ep_size * max_num_tokens * hidden_size * element_size
workspace_size = pad_up(workspace_size, 128)

# Pad to 128 bytes to ensure alignment. This matches the implementation of C++ torch OP code.
return pad_up(aux_size, 128) + pad_up(
payload_size_dispatch, 128) + pad_up(payload_size_combine, 128)
return workspace_size

@classmethod
def _init_constants(cls):
Expand Down
Loading
Loading