Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ tensorrt_llm/pg_utils_bindings.*.so
tensorrt_llm/flash_mla/
tensorrt_llm/flash_mla_cpp_tllm.*.so
tensorrt_llm/flash_mla_cpp_tllm.pyi
tensorrt_llm/runtime/kv_cache_manager_v2/**/*.so
**/*__mypyc*.so
tensorrt_llm/scripts
*docs/cpp_docs*
*docs/source/_cpp_gen*
Expand Down
2 changes: 2 additions & 0 deletions cpp/tensorrt_llm/batch_manager/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ set(SRCS
kvCacheManager.cpp
kvCacheEventManager.cpp
kvCacheTransferManager.cpp
kvCacheManagerV2Utils.cpp
kvCacheManagerV2Utils.cu
llmRequest.cpp
logitsPostProcessor.cpp
loraBuffers.cpp
Expand Down
12 changes: 12 additions & 0 deletions cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2915,6 +2915,18 @@ void KVCacheManager::removeToken(RequestIdType requestId)

void KVCacheManager::rewindKVCache(RequestIdType requestId, SizeType32 rewindLengths)
{
// Check if the sequence still exists before rewinding
// In overlap mode with MTP, the request may have been terminated and removed
// from mSequences before rewindKVCache is called
{
std::scoped_lock lck(mSequencesMtx);
if (mSequences.find(requestId) == mSequences.end())
{
TLLM_LOG_DEBUG("Request %lu has already been removed from KV cache manager, skipping rewind", requestId);
return;
}
}

for (SizeType32 si = 0; si < rewindLengths; ++si)
{
removeToken(requestId);
Expand Down
163 changes: 163 additions & 0 deletions cpp/tensorrt_llm/batch_manager/kvCacheManagerV2Utils.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "tensorrt_llm/batch_manager/kvCacheManagerV2Utils.h"
#include "tensorrt_llm/common/logger.h"
#include <cassert>
#include <cstdio>
#include <cuda.h>
#include <fcntl.h>
#include <memory>
#include <unistd.h>
#include <vector>

namespace tensorrt_llm::batch_manager::kv_cache_manager_v2
{

template <typename Func>
bool loopedReadWrite(Func&& func, ssize_t size) noexcept
{
ssize_t count = 0;
while (count < size)
{
ssize_t bytes = func(count);
if (bytes <= 0)
{
if (errno == EINTR)
{
continue; // Retry on interrupt
}
TLLM_LOG_ERROR("Disk read/write failed: %s\n", strerror(errno));
return false;
}
count += bytes;
}
assert(count == size);
return true;
}

bool writeAll(int fd, ssize_t pos, void const* data, ssize_t size) noexcept
{
return loopedReadWrite([=](ssize_t finished)
{ return pwrite(fd, static_cast<std::byte const*>(data) + finished, size - finished, pos + finished); },
size);
}

bool readAll(int fd, ssize_t pos, void* data, ssize_t size) noexcept
{
return loopedReadWrite([=](ssize_t finished)
{ return pread(fd, static_cast<std::byte*>(data) + finished, size - finished, pos + finished); },
size);
}

template <typename DstAddr, typename SrcAddr>
struct UserData
{
std::vector<Task<DstAddr, SrcAddr>> tasks;
ssize_t numBytes;
};

CUDA_CB void hostFnDiskToDiskCopy(void* userData) noexcept
{
// @TODO: enable multi-threading with a thread pool
using Data = UserData<DiskAddress, DiskAddress>;
auto const data = std::unique_ptr<Data>(static_cast<Data*>(userData));
std::vector<std::byte> buffer(data->numBytes);
bool success = true;
for (auto const& t : data->tasks)
{
success = success && readAll(t.src.fd, t.src.pos, buffer.data(), data->numBytes);
success = success && writeAll(t.dst.fd, t.dst.pos, buffer.data(), data->numBytes);
}
if (!success)
{
TLLM_LOG_ERROR("[kvCacheManagerV2Utils] hostFnDiskToDiskCopy failed.\n");
}
}

CUDA_CB void hostFnDiskToHostCopy(void* userData) noexcept
{
// @TODO: enable multi-threading with a thread pool
using Data = UserData<MemAddress, DiskAddress>;
auto const data = std::unique_ptr<Data>(static_cast<Data*>(userData));
bool success = true;
for (auto const& t : data->tasks)
{
success = success && readAll(t.src.fd, t.src.pos, reinterpret_cast<void*>(t.dst), data->numBytes);
}
if (!success)
{
TLLM_LOG_ERROR("[kvCacheManagerV2Utils] hostFnDiskToHostCopy failed.\n");
}
}

CUDA_CB void hostFnHostToDiskCopy(void* userData) noexcept
{
// @TODO: enable multi-threading with a thread pool
using Data = UserData<DiskAddress, MemAddress>;
auto const data = std::unique_ptr<Data>(static_cast<Data*>(userData));
bool success = true;
for (auto const& t : data->tasks)
{
success = success && writeAll(t.dst.fd, t.dst.pos, reinterpret_cast<void const*>(t.src), data->numBytes);
}
if (!success)
{
TLLM_LOG_ERROR("[kvCacheManagerV2Utils] hostFnHostToDiskCopy failed.\n");
}
}

CUDA_CB void hostFnHostToHostCopy(void* userData) noexcept
{
// @TODO: enable multi-threading with a thread pool
using Data = UserData<MemAddress, MemAddress>;
auto const data = std::unique_ptr<Data>(static_cast<Data*>(userData));
for (auto const& t : data->tasks)
{
memcpy(reinterpret_cast<void*>(t.dst), reinterpret_cast<void const*>(t.src), data->numBytes);
}
}

CUresult copyDiskToDisk(std::vector<Task<DiskAddress, DiskAddress>> tasks, ssize_t numBytes, CUstream stream) noexcept
{
using Data = UserData<DiskAddress, DiskAddress>;
auto data = std::make_unique<Data>(Data{std::move(tasks), numBytes});
return cuLaunchHostFunc(stream, hostFnDiskToDiskCopy, data.release());
}

CUresult copyDiskToHost(std::vector<Task<MemAddress, DiskAddress>> tasks, ssize_t numBytes, CUstream stream) noexcept
{
using Data = UserData<MemAddress, DiskAddress>;
auto data = std::make_unique<Data>(Data{std::move(tasks), numBytes});
return cuLaunchHostFunc(stream, hostFnDiskToHostCopy, data.release());
}

CUresult copyHostToDisk(std::vector<Task<DiskAddress, MemAddress>> tasks, ssize_t numBytes, CUstream stream) noexcept
{
using Data = UserData<DiskAddress, MemAddress>;
auto data = std::make_unique<Data>(Data{std::move(tasks), numBytes});
return cuLaunchHostFunc(stream, hostFnHostToDiskCopy, data.release());
}

CUresult copyHostToHost(std::vector<Task<MemAddress, MemAddress>> tasks, ssize_t numBytes, CUstream stream) noexcept
{
using Data = UserData<MemAddress, MemAddress>;
auto data = std::make_unique<Data>(Data{std::move(tasks), numBytes});
return cuLaunchHostFunc(stream, hostFnHostToHostCopy, data.release());
}

} // namespace tensorrt_llm::batch_manager::kv_cache_manager_v2
182 changes: 182 additions & 0 deletions cpp/tensorrt_llm/batch_manager/kvCacheManagerV2Utils.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "kvCacheManagerV2Utils.h"
#include "tensorrt_llm/common/assert.h"
#include "tensorrt_llm/common/cudaUtils.h"
#include <algorithm>
#include <array>
#include <cassert>
#include <cuda_runtime.h>

namespace tensorrt_llm::batch_manager::kv_cache_manager_v2
{
using Grain = uint4;
constexpr uint32_t ctaSize = 128;
constexpr uint32_t nbBufs = 4;
constexpr uint32_t grainBytes = sizeof(Grain);

using MMTask = Task<MemAddress, MemAddress>;

__device__ __host__ inline uint32_t divUp(uint32_t a, uint32_t b)
{
return (a + b - 1) / b;
}

template <uint32_t N>
__global__ void batchedCopy(std::array<MMTask, N> const __grid_constant__ tasks, uint32_t nbBytes)
{
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
asm volatile("griddepcontrol.launch_dependents;\n");
#endif
assert(nbBytes % sizeof(Grain) == 0);
__shared__ Grain data[nbBufs][ctaSize];

uint32_t const nbTasks = gridDim.y;
assert(nbTasks <= N);
auto const& task = tasks[blockIdx.y];
uint32_t const nbSplits = gridDim.x;
uint32_t const idxSplit = blockIdx.x;
uint32_t const tid = threadIdx.x;

constexpr uint32_t bytesPerIter = grainBytes * ctaSize;

uint32_t const totalIters = divUp(nbBytes, bytesPerIter);
uint32_t const maxItersPerCta = divUp(totalIters, nbSplits);
uint32_t const idxGrainBeg = ctaSize * maxItersPerCta * idxSplit + tid;
uint32_t const idxGrainEnd = std::min(idxGrainBeg + ctaSize * maxItersPerCta, nbBytes / grainBytes);

#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
asm volatile("griddepcontrol.wait;\n");
#endif
for (uint32_t i = 0; i < maxItersPerCta + nbBufs; i++)
{
uint32_t const idxBuf = i % nbBufs;
if (i >= nbBufs)
{
uint32_t const stIter = i - nbBufs;
assert(idxBuf == (stIter % nbBufs));
Grain const& src = data[idxBuf][tid];
uint32_t const idxGrain = idxGrainBeg + ctaSize * stIter;
Grain& dst = reinterpret_cast<Grain*>(task.dst)[idxGrain];
asm volatile("cp.async.wait_group %0;\n" ::"n"(nbBufs - 1) : "memory");
if (idxGrain < idxGrainEnd)
{
dst = src;
}
}
uint32_t const ldIter = i;
Grain* const dst = &data[idxBuf][tid];
uint32_t const idxGrain = idxGrainBeg + ctaSize * ldIter;
Grain const* const src = &reinterpret_cast<Grain const*>(task.src)[idxGrain];
if (idxGrain < idxGrainEnd)
{
uint32_t const size = grainBytes;
asm volatile("cp.async.cg.shared.global [%0], [%1], %2, %3;\n" ::"l"(__cvta_generic_to_shared(dst)),
"l"(src), "n"(grainBytes), "r"(size)
: "memory");
}
asm volatile("cp.async.commit_group;\n" : : : "memory");
}
}

template <uint32_t N>
CUresult launchBatchedCopyImpl(
bool lowBandwidth, MMTask const* tasks, uint32_t nbTasks, uint32_t nbBytes, cudaStream_t stream)
{
TLLM_CHECK(nbTasks <= N);
TLLM_CHECK_WITH_INFO(
nbBytes % sizeof(Grain) == 0, "Not implemented case: nbBytes = %d must be a multiple of 16.", nbBytes);
std::array<MMTask, N> const* pTasks;
std::array<MMTask, N> tmp;
if (nbTasks < N)
{
std::copy_n(tasks, nbTasks, tmp.begin());
pTasks = &tmp;
}
else
{
pTasks = reinterpret_cast<std::array<MMTask, N> const*>(tasks);
}
uint32_t const nbSplits = lowBandwidth ? 1 : divUp(nbBytes, grainBytes * ctaSize * 2);
void* args[] = {(void*) pTasks, (void*) &nbBytes};
static CUkernel const kernel = [] -> CUkernel
{
cudaKernel_t kernel = nullptr;
TLLM_CUDA_CHECK(cudaGetKernel(&kernel, reinterpret_cast<void const*>(&batchedCopy<N>)));
return kernel;
}();
return common::CUDADriverWrapper::getInstance()->cuLaunchKernel(reinterpret_cast<CUfunction>(kernel), nbSplits,
nbTasks, 1, // gridDimX, gridDimY, gridDimZ
ctaSize, 1, 1, // blockDimX, blockDimY, blockDimZ
0, // sharedMemBytes
stream, args, nullptr);
}

// When bandwidth is low, e.g. when host memory is involved, we avoid splitting as fewer CTAs should be enough to
// saturate the bandwidth.
CUresult launchBatchedCopy(bool lowBandwidth, std::vector<MMTask> const& tasks, uint32_t nbBytes, cudaStream_t stream)
{
constexpr uint32_t maxN = 256;
uint32_t const nbWholeBatches = tasks.size() / maxN;
for (uint32_t i = 0; i < nbWholeBatches; i++)
{
CUresult const err = launchBatchedCopyImpl<maxN>(lowBandwidth, tasks.data() + maxN * i, maxN, nbBytes, stream);
if (err != CUDA_SUCCESS)
{
return err;
}
}
{
auto const* const pTasks = tasks.data() + maxN * nbWholeBatches;
auto const batchSize = tasks.size() % maxN;
if (batchSize == 0)
{
return CUDA_SUCCESS;
}
if (batchSize > maxN / 2)
{
return launchBatchedCopyImpl<maxN>(lowBandwidth, pTasks, batchSize, nbBytes, stream);
}
if (batchSize > maxN / 4)
{
return launchBatchedCopyImpl<maxN / 2>(lowBandwidth, pTasks, batchSize, nbBytes, stream);
}
if (batchSize > maxN / 8)
{
return launchBatchedCopyImpl<maxN / 4>(lowBandwidth, pTasks, batchSize, nbBytes, stream);
}
return launchBatchedCopyImpl<maxN / 8>(lowBandwidth, pTasks, batchSize, nbBytes, stream);
}
}

CUresult copyHostToDevice(std::vector<MMTask> const& tasks, ssize_t numBytes, CUstream stream) noexcept
{
return launchBatchedCopy(true, tasks, numBytes, stream);
}

CUresult copyDeviceToHost(std::vector<MMTask> const& tasks, ssize_t numBytes, CUstream stream) noexcept
{
return launchBatchedCopy(true, tasks, numBytes, stream);
}

CUresult copyDeviceToDevice(std::vector<MMTask> const& tasks, ssize_t numBytes, CUstream stream) noexcept
{
return launchBatchedCopy(false, tasks, numBytes, stream);
}

} // namespace tensorrt_llm::batch_manager::kv_cache_manager_v2
Loading
Loading