Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
138 changes: 124 additions & 14 deletions cpp/tensorrt_llm/executor/cache_transmission/agent_utils/connection.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
Expand All @@ -18,6 +18,7 @@
#include "connection.h"
#include "tensorrt_llm/common/envUtils.h"
#include "tensorrt_llm/executor/cache_transmission/cacheSplitConcat.h"
#include <limits>
#include <random>
#include <string>
#include <unistd.h>
Expand Down Expand Up @@ -54,11 +55,8 @@ std::string genUniqueAgentName()
// layer num, since the buffer size is ratio is equal to the layer num ratio
// except the VSWA case.

template <typename CacheStateT>
auto computeSendOffsetRatio(
CacheStateT const& peerCacheState, int peerIdx, CacheStateT const& selfCacheState, int connectionIdx)
auto computeSendOffsetRatio(TargetRanksInfo const& peerTargetInfo, int connectionIdx)
{
auto peerTargetInfo = targetIRanks(selfCacheState, peerCacheState, peerIdx);
size_t offsetLayer = 0;
for (int i = 0; i < connectionIdx; i++)
{
Expand All @@ -69,6 +67,46 @@ auto computeSendOffsetRatio(
return std::make_pair(offsetLayer, selfSendLayer);
}

namespace
{

bool isValidBufferKind(uint8_t kind)
{
switch (static_cast<batch_manager::BufferKind>(kind))
{
case batch_manager::BufferKind::kKV:
case batch_manager::BufferKind::kKV_INDEXER:
case batch_manager::BufferKind::kRNN: return true;
}
return false;
}

AgentState const* findAgentState(CommState const& commState, std::string const& agentName)
{
if (!commState.isAgentState())
{
return nullptr;
}
for (auto const& agentState : commState.getAgentState())
{
if (agentState.mAgentName == agentName)
{
return &agentState;
}
}
return nullptr;
}

void validateConnectionIdx(
int connectionIdx, TargetRanksInfo const& targetInfo, char const* bufferKind, std::string const& remoteAgentName)
{
TLLM_CHECK_WITH_INFO(static_cast<size_t>(connectionIdx) < targetInfo.mIRanks.size(),
"AgentConnectionManager received %s connection index %d outside target rank count %zu from agent '%s'",
bufferKind, connectionIdx, targetInfo.mIRanks.size(), remoteAgentName.c_str());
}

} // namespace

AgentConnection::AgentConnection(
std::string mAgentName, std::string mRemoteAgentName, AgentConnectionManager* mAgentConnectionManager)
: mAgentName(mAgentName)
Expand Down Expand Up @@ -132,7 +170,15 @@ void AgentConnection::send(DataContext const& ctx, void const* data, size_t size
MemoryDescs srcDescs{MemoryType::kVRAM, {srcDesc}};
auto const& dstBaseDesc = mSenderState.activeBufferDesc();
auto const& offsetRatio = mSenderState.activeOffsetRatio();
auto offset = size / offsetRatio.second * offsetRatio.first;
TLLM_CHECK_WITH_INFO(offsetRatio.second != 0, "AgentConnection::send offset ratio denominator cannot be 0");
TLLM_CHECK_WITH_INFO(size <= dstBaseDesc.getLen(), "AgentConnection::send size exceeds destination buffer");
auto const chunkSize = size / offsetRatio.second;
TLLM_CHECK_WITH_INFO(offsetRatio.first == 0 || chunkSize <= std::numeric_limits<size_t>::max() / offsetRatio.first,
"AgentConnection::send offset calculation overflow");
auto const offset = chunkSize * offsetRatio.first;
TLLM_CHECK_WITH_INFO(offset <= dstBaseDesc.getLen() - size, "AgentConnection::send destination out of bounds");
TLLM_CHECK_WITH_INFO(dstBaseDesc.getAddr() <= std::numeric_limits<uintptr_t>::max() - offset,
"AgentConnection::send destination address overflow");
MemoryDesc dstDesc{dstBaseDesc.getAddr() + offset, size, dstBaseDesc.getDeviceId()};
TLLM_LOG_DEBUG(
"send dstDesc: %p, size: %ld ,validSegmentIdx: %ld", dstDesc.getAddr(), size, mSenderState.validSegmentIdx);
Expand Down Expand Up @@ -386,13 +432,75 @@ AgentConnection const* AgentConnectionManager::recvConnectionAndRequestInfo(

erase = true;
requestInfo = requestAndBufferInfo.mRequestInfo;
auto address = requestAndBufferInfo.mAddress;
auto bufferDescs = std::move(requestAndBufferInfo.mBufferDescs);
auto const& address = requestAndBufferInfo.mAddress;
auto const& bufferDescsRef = requestAndBufferInfo.mBufferDescs;
auto metadataOpt = requestAndBufferInfo.mMetadata;
auto connectionIdx = requestAndBufferInfo.mValidConnectionIdx;
auto remoteAgentName = requestAndBufferInfo.mAgentName;
auto const connectionIdx = requestAndBufferInfo.mValidConnectionIdx;
auto const& remoteAgentName = requestAndBufferInfo.mAgentName;
auto const& bufferKindsRef = requestAndBufferInfo.mBufferKinds;

TLLM_CHECK_WITH_INFO(agent == remoteAgentName,
"AgentConnectionManager received RequestAndBufferInfo from '%s' with embedded agent '%s'",
agent.c_str(), remoteAgentName.c_str());
auto const* expectedAgentState = findAgentState(mCommState, remoteAgentName);
if (expectedAgentState != nullptr)
{
TLLM_CHECK_WITH_INFO(address == expectedAgentState->mConnectionInfo,
"AgentConnectionManager received mismatched connection info for agent '%s'",
remoteAgentName.c_str());
}
else
{
std::scoped_lock remoteInfoLock(mRemoteConnectionInfoMutex);
auto [remoteInfoIt, inserted] = mRemoteConnectionInfo.emplace(remoteAgentName, address);
TLLM_CHECK_WITH_INFO(inserted || remoteInfoIt->second == address,
"AgentConnectionManager received mismatched connection info for dynamic agent '%s'",
remoteAgentName.c_str());
}
TLLM_CHECK_WITH_INFO(connectionIdx >= 0,
"AgentConnectionManager received negative connection index for agent '%s'",
remoteAgentName.c_str());
TLLM_CHECK_WITH_INFO(!bufferDescsRef.empty(),
"AgentConnectionManager received empty destination descriptors from agent '%s'",
remoteAgentName.c_str());
TLLM_CHECK_WITH_INFO(bufferDescsRef.size() == bufferKindsRef.size(),
"AgentConnectionManager received %zu descriptors but %zu buffer kinds from agent '%s'",
bufferDescsRef.size(), bufferKindsRef.size(), remoteAgentName.c_str());
TLLM_CHECK_WITH_INFO(bufferDescsRef.size() <= mCacheTransBufferManagers.size(),
"AgentConnectionManager received too many destination descriptors from agent '%s'",
remoteAgentName.c_str());
for (auto const kind : bufferKindsRef)
{
TLLM_CHECK_WITH_INFO(isValidBufferKind(kind),
"AgentConnectionManager received invalid buffer kind %u from agent '%s'",
static_cast<unsigned>(kind), remoteAgentName.c_str());
bool isConfiguredKind = false;
for (auto const configuredKind : mBufferKinds)
{
if (configuredKind == kind)
{
isConfiguredKind = true;
break;
}
}
TLLM_CHECK_WITH_INFO(isConfiguredKind,
"AgentConnectionManager received unconfigured buffer kind %u from agent '%s'",
static_cast<unsigned>(kind), remoteAgentName.c_str());
}
TLLM_CHECK_WITH_INFO(requestInfo.getTransState().getCacheState().has_value(),
"AgentConnectionManager received request without cache state from agent '%s'",
remoteAgentName.c_str());
TLLM_CHECK_WITH_INFO(requestInfo.getTransState().getCommState().has_value(),
"AgentConnectionManager received request without comm state from agent '%s'",
remoteAgentName.c_str());

TLLM_LOG_DEBUG(" recv Address:%s", address.c_str());
auto connection = connect(remoteAgentName, address, metadataOpt, true);
TLLM_CHECK_WITH_INFO(
m_Agent->checkRemoteDescs(remoteAgentName, MemoryDescs{MemoryType::kVRAM, bufferDescsRef}),
"AgentConnectionManager received unregistered destination descriptors from agent '%s'",
remoteAgentName.c_str());
auto bufferDescs = std::move(requestAndBufferInfo.mBufferDescs);
auto bufferKinds = std::move(requestAndBufferInfo.mBufferKinds);

std::optional<std::pair<size_t, size_t>> kvOffsetRatio;
Expand All @@ -410,10 +518,11 @@ AgentConnection const* AgentConnectionManager::recvConnectionAndRequestInfo(
{
if (!kvOffsetRatio)
{
kvOffsetRatio
= computeSendOffsetRatio(requestInfo.getTransState().getCacheState().value(),
requestInfo.getTransState().getCommState()->getSelfIdx(), mCacheState,
connectionIdx);
auto kvTargetInfo
= targetIRanks(mCacheState, requestInfo.getTransState().getCacheState().value(),
requestInfo.getTransState().getCommState()->getSelfIdx());
validateConnectionIdx(connectionIdx, kvTargetInfo, "KV", remoteAgentName);
kvOffsetRatio = computeSendOffsetRatio(kvTargetInfo, connectionIdx);
}
offsetRatios.push_back(*kvOffsetRatio);
break;
Expand All @@ -425,6 +534,7 @@ AgentConnection const* AgentConnectionManager::recvConnectionAndRequestInfo(
auto rnnTargetInfo = targetIRanksForRnn(mCacheState,
requestInfo.getTransState().getCacheState().value(),
requestInfo.getTransState().getCommState()->getSelfIdx());
validateConnectionIdx(connectionIdx, rnnTargetInfo, "RNN", remoteAgentName);
size_t rnnOffsetLayer = 0;
for (int ri = 0; ri < connectionIdx; ri++)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,9 @@ class AgentConnectionManager : public ConnectionManager
private:
std::map<std::string, std::shared_ptr<AgentConnection>> mConnections;
std::mutex mConnectionsMutex;
/// Connection info for dynamically discovered agents that are not listed in mCommState.
std::map<std::string, std::string> mRemoteConnectionInfo;
std::mutex mRemoteConnectionInfoMutex;
CommState mCommState;
CacheState mCacheState;
std::optional<CacheState::RnnCacheState> mRnnCacheState;
Expand Down
29 changes: 18 additions & 11 deletions cpp/tensorrt_llm/nanobind/thop/bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,9 @@ nb::tuple trtllmGenContextPreprocessBinding(torch::Tensor qkv_input, torch::Tens
}();

return nb::make_tuple(std::get<0>(result), optionalTensorToObject(std::get<1>(result)),
optionalTensorToObject(std::get<2>(result)), std::get<3>(result), std::get<4>(result), std::get<5>(result),
std::get<6>(result), std::get<7>(result), std::get<8>(result));
optionalTensorToObject(std::get<2>(result)), optionalTensorToObject(std::get<3>(result)),
optionalTensorToObject(std::get<4>(result)), optionalTensorToObject(std::get<5>(result)), std::get<6>(result),
std::get<7>(result), std::get<8>(result), std::get<9>(result), std::get<10>(result), std::get<11>(result));
}

nb::tuple trtllmGenGenerationPreprocessBinding(torch::Tensor qkv_input, torch::Tensor workspace,
Expand Down Expand Up @@ -108,8 +109,9 @@ nb::tuple trtllmGenGenerationPreprocessBinding(torch::Tensor qkv_input, torch::T
}();

return nb::make_tuple(std::get<0>(result), optionalTensorToObject(std::get<1>(result)),
optionalTensorToObject(std::get<2>(result)), std::get<3>(result), optionalTensorToObject(std::get<4>(result)),
std::get<5>(result), std::get<6>(result), std::get<7>(result), std::get<8>(result));
optionalTensorToObject(std::get<2>(result)), optionalTensorToObject(std::get<3>(result)), std::get<4>(result),
std::get<5>(result), std::get<6>(result), optionalTensorToObject(std::get<7>(result)), std::get<8>(result),
std::get<9>(result), std::get<10>(result), std::get<11>(result));
}

} // namespace
Expand Down Expand Up @@ -292,13 +294,18 @@ void initBindings(nb::module_& m)
int64_t head_dim, int64_t kv_factor, int64_t total_num_blocks, int64_t kv_cache_quant_mode,
int64_t batch_start, int64_t batch_size, at::ScalarType dtype) -> nb::tuple
{
nb::gil_scoped_release release;
auto kvPool = torch_ext::buildFlashinferTrtllmGenPagedKvCacheBuffers(host_kv_cache_pool_pointers,
host_kv_cache_pool_mapping, layer_idx, num_kv_heads, tokens_per_block, head_dim, kv_factor,
total_num_blocks, kv_cache_quant_mode, dtype);
auto const mapping = torch_ext::readKvCachePoolMapping(host_kv_cache_pool_mapping, layer_idx);
auto blockTables = kv_cache_block_offsets.select(0, mapping.poolIndex).narrow(0, batch_start, batch_size);
return nb::make_tuple(nb::cast(kvPool), nb::cast(blockTables));
at::Tensor kvPool;
std::optional<at::Tensor> kvScalePool;
at::Tensor blockTables;
{
nb::gil_scoped_release release;
std::tie(kvPool, kvScalePool) = torch_ext::buildFlashinferTrtllmGenPagedKvCacheBuffers(
host_kv_cache_pool_pointers, host_kv_cache_pool_mapping, layer_idx, num_kv_heads, tokens_per_block,
head_dim, kv_factor, total_num_blocks, kv_cache_quant_mode, dtype);
auto const mapping = torch_ext::readKvCachePoolMapping(host_kv_cache_pool_mapping, layer_idx);
blockTables = kv_cache_block_offsets.select(0, mapping.poolIndex).narrow(0, batch_start, batch_size);
}
return nb::make_tuple(nb::cast(kvPool), nb::cast(blockTables), optionalTensorToObject(kvScalePool));
},
nb::arg("host_kv_cache_pool_pointers"), nb::arg("host_kv_cache_pool_mapping"),
nb::arg("kv_cache_block_offsets"), nb::arg("layer_idx"), nb::arg("num_kv_heads"), nb::arg("tokens_per_block"),
Expand Down
20 changes: 16 additions & 4 deletions cpp/tensorrt_llm/thop/attentionOp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1383,9 +1383,10 @@ common::op::KvCacheBuffers<kernels::KVBlockArray> buildPagedKvCacheBuffers(
quantMode.hasFp4KvCache());
}

at::Tensor buildFlashinferTrtllmGenPagedKvCacheBuffers(at::Tensor host_kv_cache_pool_pointers,
at::Tensor host_kv_cache_pool_mapping, int64_t layer_idx, int64_t num_kv_heads, int64_t tokens_per_block,
int64_t head_dim, int64_t kv_factor, int64_t total_num_blocks, int64_t kv_cache_quant_mode, at::ScalarType dtype)
std::tuple<at::Tensor, std::optional<at::Tensor>> buildFlashinferTrtllmGenPagedKvCacheBuffers(
at::Tensor host_kv_cache_pool_pointers, at::Tensor host_kv_cache_pool_mapping, int64_t layer_idx,
int64_t num_kv_heads, int64_t tokens_per_block, int64_t head_dim, int64_t kv_factor, int64_t total_num_blocks,
int64_t kv_cache_quant_mode, at::ScalarType dtype)
{
auto const mapping = readKvCachePoolMapping(host_kv_cache_pool_mapping, layer_idx);
int32_t const poolIndex = mapping.poolIndex;
Expand Down Expand Up @@ -1422,7 +1423,18 @@ at::Tensor buildFlashinferTrtllmGenPagedKvCacheBuffers(at::Tensor host_kv_cache_
auto kv_pool = torch::from_blob(
poolPointers.primaryPoolPtr, {total_num_blocks, num_kv_heads, tokens_per_block, containerDim}, options);

return kv_pool;
std::optional<at::Tensor> kvScalePool = std::nullopt;
if (isFp4 && poolPointers.primaryBlockScalePoolPtr != nullptr)
{
auto scaleOptions
= at::TensorOptions()
.dtype(at::kFloat8_e4m3fn)
.device(c10::Device(at::kCUDA, static_cast<c10::DeviceIndex>(at::cuda::current_device())));
kvScalePool = torch::from_blob(poolPointers.primaryBlockScalePoolPtr,
{total_num_blocks, num_kv_heads, tokens_per_block, head_dim / 16}, scaleOptions);
}

return {kv_pool, kvScalePool};
}

} // namespace torch_ext
Expand Down
8 changes: 5 additions & 3 deletions cpp/tensorrt_llm/thop/attentionOp.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include <climits>
#include <optional>
#include <torch/extension.h>
#include <tuple>

#include "tensorrt_llm/common/attentionOp.h"
#include "tensorrt_llm/common/config.h"
Expand Down Expand Up @@ -118,9 +119,10 @@ common::op::KvCacheBuffers<kernels::KVBlockArray> buildPagedKvCacheBuffers(
int64_t cyclic_attention_window_size, int64_t max_attention_window_size, int64_t beam_width, int64_t seq_offset,
bool is_mla_enable, size_t elem_size);

at::Tensor buildFlashinferTrtllmGenPagedKvCacheBuffers(at::Tensor host_kv_cache_pool_pointers,
at::Tensor host_kv_cache_pool_mapping, int64_t layer_idx, int64_t num_kv_heads, int64_t tokens_per_block,
int64_t head_dim, int64_t kv_factor, int64_t total_num_blocks, int64_t kv_cache_quant_mode, at::ScalarType dtype);
std::tuple<at::Tensor, std::optional<at::Tensor>> buildFlashinferTrtllmGenPagedKvCacheBuffers(
at::Tensor host_kv_cache_pool_pointers, at::Tensor host_kv_cache_pool_mapping, int64_t layer_idx,
int64_t num_kv_heads, int64_t tokens_per_block, int64_t head_dim, int64_t kv_factor, int64_t total_num_blocks,
int64_t kv_cache_quant_mode, at::ScalarType dtype);

// Layout manager for the thop attention workspace slices used by trtllm-gen.
// Context follows AttentionOp::getWorkspaceSizeForContext() ordering. Generation
Expand Down
8 changes: 4 additions & 4 deletions cpp/tensorrt_llm/thop/trtllmGenFusedOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ TRTLLM_NAMESPACE_BEGIN
namespace torch_ext
{

std::tuple<at::Tensor, std::optional<at::Tensor>, std::optional<at::Tensor>, at::Tensor, at::Tensor, at::Tensor,
int64_t, int64_t, int64_t>
std::tuple<at::Tensor, std::optional<at::Tensor>, std::optional<at::Tensor>, std::optional<at::Tensor>,
std::optional<at::Tensor>, std::optional<at::Tensor>, at::Tensor, at::Tensor, at::Tensor, int64_t, int64_t, int64_t>
trtllmGenContextPreprocess(torch::Tensor qkv_input, torch::Tensor workspace, torch::Tensor sequence_lengths,
torch::Tensor context_lengths, std::optional<torch::Tensor> kv_cache_block_offsets,
std::optional<torch::Tensor> host_kv_cache_pool_pointers, std::optional<torch::Tensor> host_kv_cache_pool_mapping,
Expand Down Expand Up @@ -57,8 +57,8 @@ void trtllmGenContextPostprocess(torch::Tensor qkv_input, torch::Tensor workspac
int64_t position_embedding_type, double bmm1_scale, bool fp8_context_fmha, bool paged_context_fmha,
bool is_mla_enable, int64_t attention_chunk_size, int64_t multi_processor_count);

std::tuple<at::Tensor, std::optional<at::Tensor>, std::optional<at::Tensor>, at::Tensor, std::optional<at::Tensor>,
int64_t, int64_t, int64_t, bool>
std::tuple<at::Tensor, std::optional<at::Tensor>, std::optional<at::Tensor>, std::optional<at::Tensor>, at::Tensor,
at::Tensor, at::Tensor, std::optional<at::Tensor>, int64_t, int64_t, int64_t, bool>
trtllmGenGenerationPreprocess(torch::Tensor qkv_input, torch::Tensor workspace, torch::Tensor sequence_lengths,
std::optional<torch::Tensor> spec_decoding_generation_lengths,
std::optional<torch::Tensor> spec_decoding_position_offsets, std::optional<torch::Tensor> kv_cache_block_offsets,
Expand Down
Loading
Loading