Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 57 additions & 10 deletions cpp/include/tensorrt_llm/batch_manager/cacheTransceiver.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
#include <torch/custom_class.h>
#include <torch/python.h>
#include <type_traits>
#include <unordered_set>
#include <vector>

using SizeType32 = tensorrt_llm::runtime::SizeType32;
Expand Down Expand Up @@ -204,13 +205,17 @@ class BaseCacheTransceiver
{
public:
virtual ~BaseCacheTransceiver() = default;
virtual void respondAndSendAsync(LlmRequest* llmRequest) = 0;
// These methods take std::shared_ptr<LlmRequest> so the transceiver and
// its async workers can hold a strong reference for the duration of the
// transfer. See the comment on CacheTransceiver::mSenderFutures for the
// lifetime invariant (kept in one place to avoid drift).
virtual void respondAndSendAsync(std::shared_ptr<LlmRequest> llmRequest) = 0;
virtual void respondAndSendLayerWise(
RequestVector const& requests, std::shared_ptr<ContextProgress> const& progress)
= 0;

virtual void requestAndReceiveSync(LlmRequest* llmRequest) = 0;
virtual void requestAndReceiveAsync(LlmRequest* llmRequest) = 0;
virtual void requestAndReceiveSync(std::shared_ptr<LlmRequest> llmRequest) = 0;
virtual void requestAndReceiveAsync(std::shared_ptr<LlmRequest> llmRequest) = 0;

/// Check all requests transferring context, and return the requests that have completed or encountered an error.
virtual RequestStatuses checkContextTransferStatus(
Expand All @@ -221,7 +226,27 @@ class BaseCacheTransceiver

[[nodiscard]] virtual bool checkGenTransferComplete() const = 0;

virtual bool cancelRequest(LlmRequest* llmRequest) = 0;
virtual bool cancelRequest(std::shared_ptr<LlmRequest> llmRequest) = 0;

/// @brief Returns true if any underlying receive buffer pool has been
/// poisoned and can no longer serve allocations until process restart.
/// Higher layers (e.g. dynamo's request handler / readiness probe) use
/// this to escalate the worker to permanently-unhealthy without parsing
/// exception text from the request-error path.
/// Default returns false so non-disagg subclasses do not need to
/// override.
[[nodiscard]] virtual bool isRecvPoolPoisoned() const
{
return false;
}

/// @brief Symmetric counterpart of isRecvPoolPoisoned for the send-side
/// pool. Provided for completeness; the historically-observed wedge is
/// receive-side only.
[[nodiscard]] virtual bool isSendPoolPoisoned() const
{
return false;
}
};

class CacheTransceiver : public BaseCacheTransceiver
Expand Down Expand Up @@ -252,13 +277,13 @@ class CacheTransceiver : public BaseCacheTransceiver

virtual ~CacheTransceiver();

void respondAndSendAsync(LlmRequest* llmRequest) override;
void respondAndSendAsync(std::shared_ptr<LlmRequest> llmRequest) override;

void respondAndSendLayerWise(
RequestVector const& requests, std::shared_ptr<ContextProgress> const& progress) override;

void requestAndReceiveSync(LlmRequest* llmRequest) override;
void requestAndReceiveAsync(LlmRequest* llmRequest) override;
void requestAndReceiveSync(std::shared_ptr<LlmRequest> llmRequest) override;
void requestAndReceiveAsync(std::shared_ptr<LlmRequest> llmRequest) override;

RequestStatuses checkContextTransferStatus(
std::optional<int> const& atLeastRequestNum = std::nullopt, bool markComplete = false) override;
Expand All @@ -267,7 +292,10 @@ class CacheTransceiver : public BaseCacheTransceiver

[[nodiscard]] bool checkGenTransferComplete() const override;

virtual bool cancelRequest(LlmRequest* llmRequest) override;
virtual bool cancelRequest(std::shared_ptr<LlmRequest> llmRequest) override;

[[nodiscard]] bool isRecvPoolPoisoned() const override;
[[nodiscard]] bool isSendPoolPoisoned() const override;

private:
void initializeCommState();
Expand All @@ -276,8 +304,27 @@ class CacheTransceiver : public BaseCacheTransceiver

std::unique_ptr<CacheSender> mCacheSender;
std::unique_ptr<CacheReceiver> mCacheReceiver;
std::vector<std::pair<LlmRequest*, std::future<void>>> mSenderFutures;
std::vector<std::pair<LlmRequest*, std::future<void>>> mRequesterFutures;
// Store shared_ptr rather than raw LlmRequest* so the futures map holds a
// strong reference for the duration of the transfer. Otherwise Python's
// _terminate_request can drop its pybind shared_ptr while the C++ side's
// raw pointer is still dereferenced by checkGenTransferStatus /
// checkContextTransferStatus (the UAF forensically confirmed via
// MALLOC_PERTURB_=85 producing mRequestId=0x5555555555555555).
//
// Eviction policy is asymmetric:
// - mRequesterFutures (gen side): on timeout, keep the entry tracked
// via mTimedOutRequesterIds until the worker future resolves. A
// timeout/cancel is not a quiescence proof on the recv side, so the
// advertised receive buffers may still be written to until the worker
// unwinds. See checkGenTransferStatus.
// - mSenderFutures (ctx side): erased immediately on completion,
// exception, or timeout. Sender zombies empirically unwind on peer
// teardown (decode-pod restart), and CacheSender::cancelRequest is
// only required to clear bookkeeping for telemetry / re-enqueue
// paths. See checkContextTransferStatus.
std::vector<std::pair<std::shared_ptr<LlmRequest>, std::future<void>>> mSenderFutures;
std::vector<std::pair<std::shared_ptr<LlmRequest>, std::future<void>>> mRequesterFutures;
std::unordered_set<LlmRequest::RequestIdType> mTimedOutRequesterIds;
mpi::MpiComm const* mMpiWorldComm{nullptr};

std::shared_ptr<CacheTransceiverComm> mGroupComm;
Expand Down
9 changes: 9 additions & 0 deletions cpp/include/tensorrt_llm/executor/transferAgent.h
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,15 @@ class TransferStatus
virtual ~TransferStatus() = default;
[[nodiscard]] virtual bool isCompleted() const = 0;
virtual TransferState wait(int64_t timeout_ms = -1) const = 0;

/// Release the backend transfer request. If the request is still active,
/// backends may attempt to cancel it. A true return only means the backend
/// accepted release of the transfer handle; callers must still treat remote
/// memory quiescence as backend-specific.
[[nodiscard]] virtual bool release()
{
return false;
}
};

struct BaseAgentConfig
Expand Down
189 changes: 181 additions & 8 deletions cpp/tensorrt_llm/batch_manager/baseTransBuffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,83 @@
#include "tensorrt_llm/common/logger.h"
#include "tensorrt_llm/common/opUtils.h"

#include <exception>
#include <mutex>

namespace tensorrt_llm::batch_manager
{

namespace
{

char const* bufferKindName(BufferKind kind)
{
switch (kind)
{
case BufferKind::kKV: return "kv";
case BufferKind::kKV_INDEXER: return "kv_indexer";
case BufferKind::kRNN: return "rnn";
}
return "unknown";
}

} // namespace

void BufferIndexHolder::release() noexcept
{
// Happy-path release: frees the slot and disarms the holder in one
// noexcept call. Used in place of an older detach() + explicit
// freeBufferIndex*() sequence so a throw between the two calls cannot
// leave the holder in a partially-released state.
if (!mHeld || mMgr == nullptr)
{
return;
}
try
{
if (mIsRecv)
{
mMgr->freeBufferIndexForRecv(mIndex);
}
else
{
mMgr->freeBufferIndexForSend(mIndex);
}
}
catch (...)
{
// Swallow; the destructor must be noexcept and any exit path that
// failed to release explicitly relies on this fallback to free the
// slot.
}
mHeld = false;
}

void BufferIndexHolder::poison() noexcept
{
if (!mHeld || mMgr == nullptr)
{
return;
}
try
{
if (mIsRecv)
{
mMgr->poisonBufferIndexForRecv(mIndex);
}
else
{
mMgr->poisonBufferIndexForSend(mIndex);
}
}
catch (...)
{
// poisonBufferIndex is noexcept; keep this as belt-and-suspenders so
// fail-closed cleanup cannot throw from an exception path.
}
mHeld = false;
}

BaseTransBufferManager::BaseTransBufferManager(
size_t transferBufferSize, nvinfer1::DataType dataType, std::optional<size_t> maxNumTokens)
: mDataType{dataType}
Expand Down Expand Up @@ -54,26 +126,40 @@ BaseTransBufferManager::BaseTransBufferManager(
allocateBuffer();
}

std::optional<int> BaseTransBufferManager::assignBufferIndexForSend()
std::optional<int> BaseTransBufferManager::assignBufferIndexForSend(
std::atomic<bool> const* perRequestCancel, int64_t waitSliceMs, std::optional<uint64_t> requestIdForLog)
{
return assignBufferIndex(mConcurrenceSendResource, mSendBufferCount, mOnlyUseDynamicBuffer);
return assignBufferIndex(mConcurrenceSendResource, mSendBufferCount, mOnlyUseDynamicBuffer, perRequestCancel,
waitSliceMs, requestIdForLog);
}

void BaseTransBufferManager::freeBufferIndexForSend(std::optional<int> bufferId)
{
freeBufferIndex(mConcurrenceSendResource, bufferId, mSendBufferCount, mOnlyUseDynamicBuffer);
}

std::optional<int> BaseTransBufferManager::assignBufferIndexForRecv()
void BaseTransBufferManager::poisonBufferIndexForSend(std::optional<int> bufferId) noexcept
{
poisonBufferIndex(mConcurrenceSendResource, bufferId, mSendBufferCount, mOnlyUseDynamicBuffer, "send");
}

std::optional<int> BaseTransBufferManager::assignBufferIndexForRecv(
std::atomic<bool> const* perRequestCancel, int64_t waitSliceMs, std::optional<uint64_t> requestIdForLog)
{
return assignBufferIndex(mConcurrenceRecvResource, mRecvBufferCount, mOnlyUseDynamicBuffer);
return assignBufferIndex(mConcurrenceRecvResource, mRecvBufferCount, mOnlyUseDynamicBuffer, perRequestCancel,
waitSliceMs, requestIdForLog);
}

void BaseTransBufferManager::freeBufferIndexForRecv(std::optional<int> bufferId)
{
freeBufferIndex(mConcurrenceRecvResource, bufferId, mRecvBufferCount, mOnlyUseDynamicBuffer);
}

void BaseTransBufferManager::poisonBufferIndexForRecv(std::optional<int> bufferId) noexcept
{
poisonBufferIndex(mConcurrenceRecvResource, bufferId, mRecvBufferCount, mOnlyUseDynamicBuffer, "recv");
}

std::tuple<std::vector<runtime::ITensor::SharedPtr>, size_t, bool> BaseTransBufferManager::getOrAllocateSendBuffers(
std::optional<int> bufferId, int targetNum, std::vector<size_t> const& requestedNumberOfElements,
runtime::BufferManager const& bufferManagerToUse)
Expand Down Expand Up @@ -225,16 +311,46 @@ void BaseTransBufferManager::allocateBuffer()
}
}

std::optional<int> BaseTransBufferManager::assignBufferIndex(
ConcurrenceResource& resource, size_t bufferCount, bool onlyUseDynamicBuffer)
std::optional<int> BaseTransBufferManager::assignBufferIndex(ConcurrenceResource& resource, size_t bufferCount,
bool onlyUseDynamicBuffer, std::atomic<bool> const* perRequestCancel, int64_t waitSliceMs,
std::optional<uint64_t> requestIdForLog)
{
if (onlyUseDynamicBuffer)
{
TLLM_CHECK_WITH_INFO(!resource.mPoisoned.load(std::memory_order_relaxed),
"Cannot assign dynamic cache transfer buffer kind=%s because a previous transfer left dynamic transfer "
"memory poisoned. The process must restart before these memory ranges can be safely reused.",
bufferKindName(getBufferKind()));
return std::nullopt;
}
// Bounded wait_for loop so a cancel fired on this request while parked
// here can interrupt the wait via the per-request cancel atomic, and so
// mTerminate (flipped between slices) keeps the drain worker responsive
// to shutdown.
std::unique_lock lk(resource.mBuffersMutex);
resource.mBuffersCV.wait(
lk, [&resource, bufferCount]() { return static_cast<size_t>(resource.mConcurrence) < bufferCount; });
auto const predicate = [&resource, bufferCount]()
{
return resource.mPoisoned.load(std::memory_order_relaxed)
|| static_cast<size_t>(resource.mConcurrence) < bufferCount;
};
if (!predicate())
{
auto const slice = std::chrono::milliseconds{waitSliceMs};
while (!predicate())
{
resource.mBuffersCV.wait_for(lk, slice);
if (perRequestCancel != nullptr && perRequestCancel->load(std::memory_order_relaxed))
{
auto const reqIdStr
= requestIdForLog.has_value() ? std::to_string(requestIdForLog.value()) : std::string{"?"};
TLLM_THROW("assignBufferIndex cancelled via perRequestCancel (reqId=%s)", reqIdStr.c_str());
}
}
}
TLLM_CHECK_WITH_INFO(!resource.mPoisoned.load(std::memory_order_relaxed),
"Cannot assign cache transfer buffer kind=%s because a previous transfer left the buffer pool poisoned. "
"The process must restart before these memory ranges can be safely reused.",
bufferKindName(getBufferKind()));
int bufferId = -1;
for (size_t i = 0; i < bufferCount; i++)
{
Expand Down Expand Up @@ -264,13 +380,70 @@ void BaseTransBufferManager::freeBufferIndex(
TLLM_CHECK(static_cast<size_t>(bufferId.value()) < bufferCount);
{
std::scoped_lock lk(resource.mBuffersMutex);
if (resource.mBufferIndexFlag[bufferId.value()] == 2)
{
TLLM_LOG_ERROR("Refusing to free poisoned cache transfer buffer kind=%s index=%d",
bufferKindName(getBufferKind()), bufferId.value());
return;
}
resource.mBufferIndexFlag[bufferId.value()] = 0;
}
resource.mConcurrence--;
resource.mBuffersCV.notify_one();
}
}

void BaseTransBufferManager::poisonBufferIndex(ConcurrenceResource& resource, std::optional<int> bufferId,
size_t bufferCount, bool onlyUseDynamicBuffer, char const* direction) noexcept
{
resource.mPoisoned.store(true, std::memory_order_relaxed);

if (onlyUseDynamicBuffer)
{
TLLM_LOG_ERROR(
"Poisoned dynamic %s cache transfer buffer kind=%s. Dynamic transfer memory cannot be safely reused; "
"the process must restart.",
direction, bufferKindName(getBufferKind()));
resource.mBuffersCV.notify_all();
return;
}

if (!bufferId.has_value())
{
TLLM_LOG_ERROR("Poisoned unknown %s cache transfer buffer kind=%s. The process must restart.", direction,
bufferKindName(getBufferKind()));
resource.mBuffersCV.notify_all();
return;
}

try
{
TLLM_CHECK(static_cast<size_t>(bufferId.value()) < bufferCount);
{
std::scoped_lock lk(resource.mBuffersMutex);
if (resource.mBufferIndexFlag[bufferId.value()] == 1)
{
resource.mBufferIndexFlag[bufferId.value()] = 2;
}
}
TLLM_LOG_ERROR(
"Poisoned %s cache transfer buffer kind=%s index=%d. The slot will not be returned to the pool because "
"transport quiescence is unknown; restart the process before serving more KV transfers.",
direction, bufferKindName(getBufferKind()), bufferId.value());
}
catch (std::exception const& e)
{
TLLM_LOG_ERROR("Exception while poisoning %s cache transfer buffer kind=%s index=%d: %s", direction,
bufferKindName(getBufferKind()), bufferId.value_or(-1), e.what());
}
catch (...)
{
TLLM_LOG_ERROR("Unknown exception while poisoning %s cache transfer buffer kind=%s index=%d", direction,
bufferKindName(getBufferKind()), bufferId.value_or(-1));
}
resource.mBuffersCV.notify_all();
}

size_t BaseTransBufferManager::getRecvBufferCount()
{
return mRecvBufferCount;
Expand Down
Loading