Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,9 @@ class KVCacheBlock
std::optional<std::chrono::steady_clock::time_point::duration> mExpirationTime;
// Hash for the event manager
size_t mHash;

// Mutex for the next blocks
mutable std::mutex mNextBlocksMutex;
};

class GenerationRequest
Expand Down Expand Up @@ -1021,7 +1024,7 @@ class WindowBlockManager
std::shared_ptr<kv_connector::KvCacheConnectorManager> mKvCacheConnectorManager;

// Mutex for the cached blocks root
std::mutex mCachedBlocksRootMutex;
mutable std::mutex mCachedBlocksRootMutex;

// Record which sequence is using the block
std::map<KVCacheBlock::IdType, LlmRequest::RequestIdType> mBlockToSequence;
Expand Down
6 changes: 6 additions & 0 deletions cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,7 @@ void KVCacheBlock::setPrevBlockInSeq(BlockPtr prevBlock)

void KVCacheBlock::addNextBlock(BlockKey const& blockKey, BlockPtr block)
{
std::lock_guard<std::mutex> lock(mNextBlocksMutex);
if (mNextBlocks.find(blockKey) == mNextBlocks.end())
{
mNextBlocks[blockKey] = std::move(block);
Expand All @@ -425,6 +426,8 @@ void KVCacheBlock::addNextBlock(BlockKey const& blockKey, BlockPtr block)
std::tuple<bool, SizeType32, BlockPtr> KVCacheBlock::findMatchingBlock(
BlockKey const& blockKey, bool enablePartialReuse, bool copyOnPartialReuse) const
{
std::lock_guard<std::mutex> lock(mNextBlocksMutex);

if (blockKey.uniqueTokens.size() == 0 || mNextBlocks.size() == 0)
{
return {false, 0, nullptr};
Expand Down Expand Up @@ -474,11 +477,13 @@ void KVCacheBlock::freeLeafBlock()

void KVCacheBlock::removeNextBlock(BlockKey const& blockKey)
{
std::lock_guard<std::mutex> lock(mNextBlocksMutex);
mNextBlocks.erase(blockKey);
}

void KVCacheBlock::freeDescendantsRecursively()
{
std::lock_guard<std::mutex> lock(mNextBlocksMutex);
bool hasChildren = !mNextBlocks.empty();
if (hasChildren)
{
Expand Down Expand Up @@ -1176,6 +1181,7 @@ std::optional<BlockKey> WindowBlockManager::findNewContextBlock(
auto blockKeys = buildBlockKeys(blockedUniqueTokens, llmRequest);
BlockKey ret;
ret.loraTaskId = llmRequest.getLoraTaskId();
std::lock_guard<std::mutex> lock(mCachedBlocksRootMutex);
auto searchRoot = mCachedBlocksRoot;
for (auto const& blockKey : blockKeys)
{
Expand Down
142 changes: 71 additions & 71 deletions cpp/tensorrt_llm/common/customAllReduceUtils.h

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -98,13 +98,14 @@ TransferState MooncakeTransferStatus::wait(int64_t timeout_ms) const
mBatchFreed = true;
TLLM_LOG_DEBUG("Batch ID %lu freed in wait()", mBatchId);
syncSegmentCache(mEngine);
std::this_thread::sleep_for(std::chrono::milliseconds(1));
return TransferState::kSUCCESS;
}

// If timeout_ms < 0, wait indefinitely
if (timeout_ms < 0)
{
std::this_thread::yield();
std::this_thread::sleep_for(std::chrono::milliseconds(1));
continue;
}

Expand All @@ -117,7 +118,7 @@ TransferState MooncakeTransferStatus::wait(int64_t timeout_ms) const
return TransferState::kIN_PROGRESS;
}

std::this_thread::yield();
std::this_thread::sleep_for(std::chrono::milliseconds(1));
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -137,17 +137,8 @@ public:
// corresponding CTA has not been launched.
for (int flag_idx = blockIdx.x; flag_idx < kBarrierFlagCount; flag_idx += gridDim.x)
{
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
asm volatile(
"st.global.relaxed.sys.b32 [%1], %0;" ::"r"(m_flag_value), "l"(m_target_flag + flag_idx * NRanks));
#else
st_flag(m_target_flag + flag_idx * NRanks, m_flag_value);
#endif
}
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
// Single release fence
asm volatile("fence.release.sys;");
#endif

while (ld_flag(m_current_flag) == prev_flag(m_flag_value))
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@ namespace kernels::moe_comm
#define ENABLE_DEBUG_PRINT 0
#define DISABLE_SYNC_FOR_PROFILING 0

#ifndef DISABLE_TIMEOUT
#define DISABLE_TIMEOUT 0
#endif

// Macros for concise launch-time specialization
#define SWITCH_BOOL(flag, NAME, ...) \
if (flag) \
Expand Down Expand Up @@ -141,6 +145,13 @@ namespace kernels::moe_comm
__VA_ARGS__ \
}

#if DISABLE_TIMEOUT
#define check_timeout(s) false
#else
// 300 * 2000 MHz - should be high enough on any GPU but will prevent a hang
#define check_timeout(s) ((clock64() - (s)) > (300ll * 2000ll * 1000ll * 1000ll))
#endif

// ============================================================================
// Helper Functions for Expert-to-Rank Mapping
// ============================================================================
Expand Down Expand Up @@ -515,6 +526,7 @@ __global__ void moeA2ADispatchKernel(int32_t const* token_selected_experts, // [
for (int peer_rank = lane_id; peer_rank < ep_size; peer_rank += warpSize)
{
bool flag_set = false;
auto s = clock64();
do
{
uint32_t* flag_ptr = &ptrs.completion_flags[rank_id][peer_rank];
Expand All @@ -528,7 +540,15 @@ __global__ void moeA2ADispatchKernel(int32_t const* token_selected_experts, // [
rank_id, peer_rank, flag_value, expected_value, flag_ptr);
#endif
flag_set = flag_value == expected_value;
} while (!flag_set);
} while (!flag_set && !check_timeout(s));

if (__builtin_expect(!flag_set, 0))
{
printf("dispatch: ---Rank %d timed out waiting for completion flag from rank %d\n", rank_id,
peer_rank);
asm volatile("trap;");
return;
}
}
#endif
}
Expand Down Expand Up @@ -1038,6 +1058,7 @@ __global__ void moeA2ACombineKernel(
for (int peer_rank = lane_id; peer_rank < ep_size; peer_rank += warpSize)
{
bool flag_set = false;
auto s = clock64();
do
{
uint32_t* flag_ptr = &ptrs.completion_flags[rank_id][peer_rank];
Expand All @@ -1046,12 +1067,20 @@ __global__ void moeA2ACombineKernel(
asm volatile("ld.relaxed.sys.u32 %0, [%1];" : "=r"(flag_value) : "l"(flag_ptr));
#if ENABLE_DEBUG_PRINT
printf(
"combine: ---Rank %d received completion flag from rank %d, flag_value: %d, expected_value: %d, "
"combine: ---Rank %d received completion flag from rank %d, flag_value: %d, expected_value: "
"%d, "
"address: %p\n",
rank_id, peer_rank, flag_value, expected_value, flag_ptr);
#endif
flag_set = flag_value == expected_value;
} while (!flag_set);
} while (!flag_set && !check_timeout(s));

if (__builtin_expect(!flag_set, 0))
{
printf("combine: ---Rank %d timed out waiting for completion flag from rank %d\n", rank_id, peer_rank);
asm volatile("trap;");
return;
}
}
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900
// .acquire and .release qualifiers for fence instruction require sm_90 or higher.
Expand Down
5 changes: 5 additions & 0 deletions cpp/tests/unit_tests/multi_gpu/cacheTransceiverTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1304,6 +1304,11 @@ TEST_P(AsymmetricalCacheTest, TestCase)
{
GTEST_SKIP() << "Temporarily skipping cache transceiver tests with NIXL and MOONCAKE backend for CP.";
}
if (isIndexerKCache && tensorrt_llm::common::getEnvUseMooncakeKvCache())
{
// https://nvbugs/5760737
GTEST_SKIP() << "Temporarily skipping cache transceiver tests with Mooncake backend for Indexer KCache.";
}
std::vector<int> lenList = {30, 10, 60, 80};
if (genCp > 1)
{
Expand Down
15 changes: 12 additions & 3 deletions examples/disaggregated/slurm/benchmark/run_benchmark.sh
Original file line number Diff line number Diff line change
Expand Up @@ -99,11 +99,20 @@ do_process_all_logs(){
fi
fi
done
if [ "${mode}" = "clean" ]; then
if [ -d "${tmp_start_logs}" ]; then
mkdir -p ${log_path}/start_logs
cp ${tmp_start_logs}/3_output_CTX_*.log ${log_path}/start_logs/ 2>/dev/null || true
cp ${tmp_start_logs}/3_output_GEN_*.log ${log_path}/start_logs/ 2>/dev/null || true
rm -rf ${tmp_start_logs}
fi
fi
}

mkdir -p ${log_path}/start_logs
cp ${log_path}/3_output_CTX_*.log ${log_path}/start_logs/ 2>/dev/null || true
cp ${log_path}/3_output_GEN_*.log ${log_path}/start_logs/ 2>/dev/null || true
tmp_start_logs=/tmp/${SLURM_JOB_ID}/start_logs
mkdir -p ${tmp_start_logs}
cp ${log_path}/3_output_CTX_*.log ${tmp_start_logs}/ 2>/dev/null || true
cp ${log_path}/3_output_GEN_*.log ${tmp_start_logs}/ 2>/dev/null || true

# warmup requests for ucx connections
if [ "${ucx_warmup_requests}" -gt 0 ]; then
Expand Down
25 changes: 17 additions & 8 deletions examples/disaggregated/slurm/benchmark/run_benchmark_nv_sa.sh
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,9 @@ do_process_all_logs(){
local gen_num
local line_count
local start_line
for ctx_log in ${input_folder}/output_ctx_*.log; do
for ctx_log in ${input_folder}/3_output_CTX_*.log; do
if [ -f "${ctx_log}" ]; then
ctx_num=$(basename "${ctx_log}" | sed 's/output_ctx_\([0-9]*\)\.log/\1/')
ctx_num=$(basename "${ctx_log}" | sed 's/3_output_CTX_\([0-9]*\)\.log/\1/')
if [ "${mode}" = "line" ]; then
line_count=$(wc -l < ${ctx_log})
echo ${line_count} > ${output_folder}/ctx_only_line_${ctx_num}.txt
Expand All @@ -111,9 +111,9 @@ do_process_all_logs(){
fi
done
# process all the gen log files in the input folder
for gen_log in ${input_folder}/output_gen_*.log; do
for gen_log in ${input_folder}/3_output_GEN_*.log; do
if [ -f "${gen_log}" ]; then
gen_num=$(basename "${gen_log}" | sed 's/output_gen_\([0-9]*\)\.log/\1/')
gen_num=$(basename "${gen_log}" | sed 's/3_output_GEN_\([0-9]*\)\.log/\1/')
if [ "${mode}" = "line" ]; then
line_count=$(wc -l < ${gen_log})
echo ${line_count} > ${output_folder}/gen_only_line_${gen_num}.txt
Expand All @@ -130,11 +130,20 @@ do_process_all_logs(){
fi
fi
done
if [ "${mode}" = "clean" ]; then
if [ -d "${tmp_start_logs}" ]; then
mkdir -p ${log_path}/start_logs
cp ${tmp_start_logs}/3_output_CTX_*.log ${log_path}/start_logs/ 2>/dev/null || true
cp ${tmp_start_logs}/3_output_GEN_*.log ${log_path}/start_logs/ 2>/dev/null || true
rm -rf ${tmp_start_logs}
fi
fi
}

mkdir -p ${log_path}/start_logs
cp ${log_path}/output_ctx_*.log ${log_path}/start_logs/ 2>/dev/null || true
cp ${log_path}/output_gen_*.log ${log_path}/start_logs/ 2>/dev/null || true
tmp_start_logs=/tmp/${SLURM_JOB_ID}/start_logs
mkdir -p ${tmp_start_logs}
cp ${log_path}/3_output_CTX_*.log ${tmp_start_logs}/ 2>/dev/null || true
cp ${log_path}/3_output_GEN_*.log ${tmp_start_logs}/ 2>/dev/null || true

# warmup requests for ucx connections
if [ "${ucx_warmup_requests}" -gt 0 ]; then
Expand All @@ -160,8 +169,8 @@ for concurrency in ${concurrency_list}; do
num_prompts=$((concurrency * multi_round))
output_dir="${log_path}/concurrency_${concurrency}"
echo "Benchmarking with concurrency ${concurrency} ... ${num_prompts} prompts"
do_process_all_logs ${log_path}/ ${log_path}/concurrency_${concurrency} "line"
mkdir -p "${output_dir}"
do_process_all_logs ${log_path}/ ${log_path}/concurrency_${concurrency} "line"

python "${BENCH_SCRIPT}" \
--model "${model_name}" \
Expand Down
8 changes: 8 additions & 0 deletions examples/disaggregated/slurm/benchmark/submit.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,6 +484,14 @@ def submit_job(config, log_dir, dry_run):
f"&> {log_dir}/7_accuracy_eval_{task}.log"
]
client_cmds.append(" ".join(accuracy_prefix + accuracy_cmd))

# record ${SLURM_JOB_NODELIST} to ${log_dir}/8_done_job_id.txt
done_cmd = [
"echo", "${SLURM_JOB_NODELIST}", ">",
f"{log_dir}/8_done_${{SLURM_JOB_ID}}.txt"
]
client_cmds.append(" ".join(done_cmd))

with open(os.path.join(log_dir, "client_cmds.sh"), "w") as f:
f.write("\n".join(client_cmds) + "\n")

Expand Down
4 changes: 2 additions & 2 deletions security_scanning/metadata.json
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
{
"commit_hash": "864b61cadd1b112ed3e28391f39def529d7788f0",
"timestamp": "2026-01-20T18:11:54Z"
"commit_hash": "c381790d15585e8f9e014e72218d1fef6945ed5f",
"timestamp": "2026-01-21T02:50:03Z"
}
8 changes: 5 additions & 3 deletions tensorrt_llm/_torch/autotuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -505,9 +505,8 @@ def load_cache(self, file_path: Union[str, Path], rank: int) -> None:
fcntl.flock(f, fcntl.LOCK_SH)
current_cache_contents = json.load(f)
self._deserialize_metadata(current_cache_contents["metadata"])
assert f"rank_{rank}" in current_cache_contents, f"Rank {rank} cache not found in {file_path}"
self.cache = self._deserialize_cache_data(
current_cache_contents[f'rank_{rank}'])
current_cache_contents.get(f'rank_{rank}', {}))
logger.info(
f"[AutoTuner] Successfully loaded cache from {file_path} using JSON format"
)
Expand Down Expand Up @@ -1105,7 +1104,10 @@ def pure_profile(stream: torch.cuda.Stream, repeat: int):

disable_short_profile = os.environ.get(
"TLLM_AUTOTUNER_DISABLE_SHORT_PROFILE", "0") == "1"
if fewer_repeat_avg_time > short_profile_threshold_ms and not disable_short_profile:

# Disable this feature for merged tuning strategy to avoid potential hang due to asymmetric tuning.
if fewer_repeat_avg_time > short_profile_threshold_ms and not disable_short_profile \
and tuning_config.distributed_tuning_strategy != DistributedTuningStrategy.MERGE:
# directly use the few repeat estimated time to avoid redundant profiling
avg_time = fewer_repeat_avg_time
else:
Expand Down
4 changes: 2 additions & 2 deletions tensorrt_llm/_torch/modules/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,11 +334,11 @@ def __init__(
key="sparse_attention_config")

if config.sparse_attention_config.algorithm == "rocket":
logger.info_once("disable rope_fusion for RocketKV.")
logger.warning("disable rope_fusion for RocketKV.")
self.rope_fusion = False

if self.rope_fusion and not attn_cls.support_fused_rope():
logger.info_once(
logger.warning(
"rope_fusion is true but the attention backend does not support it. Will disable rope_fusion."
)
self.rope_fusion = False
Expand Down
1 change: 1 addition & 0 deletions tensorrt_llm/_torch/modules/fused_moe/configurable_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ def __init__(
swiglu_limit=kwargs.get("swiglu_limit"),
init_load_balancer=False,
without_comm=True,
activation_type=self.activation_type,
)

self.validate_backend(backend)
Expand Down
13 changes: 8 additions & 5 deletions tensorrt_llm/_torch/modules/fused_moe/create_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,7 @@ def create_moe(
moe_cls = get_moe_cls(model_config, override_quant_config)

enable_configurable_moe = os.environ.get("ENABLE_CONFIGURABLE_MOE",
"0") == "1"
"1") == "1"
if enable_configurable_moe or moe_cls == CuteDslFusedMoE:
if moe_cls in (DeepGemmFusedMoE, TRTLLMGenFusedMoE, CuteDslFusedMoE,
CutlassFusedMoE):
Expand All @@ -365,6 +365,7 @@ def create_moe(
swiglu_alpha=swiglu_alpha,
swiglu_beta=swiglu_beta,
swiglu_limit=swiglu_limit,
activation_type=activation_type,
)
else:
# Check if this is a TRTLLM backend request that fallback to CutlassFusedMoE
Expand All @@ -378,10 +379,12 @@ def create_moe(
f"ConfigurableMoE only supports TRTLLMGenFusedMoE and CuteDslFusedMoE backends. "
f"Continuing with legacy MoE backend {moe_cls.__name__}.")
else:
# For other incompatible backends, raise error
raise ValueError(
f"ENABLE_CONFIGURABLE_MOE is set but backend {moe_cls.__name__} is not supported. "
f"ConfigurableMoE only supports TRTLLMGenFusedMoE backend.")
# Other backends are not supported by ConfigurableMoE, fallback to legacy backend
# This is a WAR to make sure all the CI test cases pass.
# TODO: Remove this workaround when ConfigurableMoE is supported by all backends.
logger.warning(
f"ENABLE_CONFIGURABLE_MOE is set but {moe_cls.__name__} is not supported by ConfigurableMoE. "
f"Continuing with legacy MoE backend {moe_cls.__name__}.")

# Use legacy create_moe_backend for other backends or when ConfigurableMoE is disabled
return create_moe_backend(
Expand Down
Loading