Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 50 additions & 36 deletions transformer_engine/common/fused_router/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
#ifndef TRANSFORMER_ENGINE_FUSED_ROUTER_UTILS_H_
#define TRANSFORMER_ENGINE_FUSED_ROUTER_UTILS_H_

#include <cassert>

#include "transformer_engine/transformer_engine.h"

namespace transformer_engine {
Expand Down Expand Up @@ -203,50 +205,62 @@ __device__ inline void apply_softmax_on_float(float *scores, int data_size, int
__syncwarp();
}

__device__ inline void naive_topk_and_mask(CompType *scores, int data_size, int topk,
int *topk_indices, CompType *topk_scores, int lane_id) {
// Check if the index is masked by the later iteration
auto is_masked = [&topk_indices](int k, int index) {
if (k == 0) return false;
for (int i = 0; i < k; i++) {
if (topk_indices[i] == index) return true;
}
return false;
};
// Topk Times: Find the max value and its index
// Then mask it, and record the index in the topk_indices
// After looping topk times, the topk_indices will be the topk indices
template <typename T>
__device__ inline void naive_topk_and_mask(T *scores, int data_size, int topk, int *topk_indices,
T *topk_scores, int lane_id) {
// Bit i indicates whether the i-th local element (lane_id + i * warp_size) was selected.
uint32_t local_mask = 0;
assert(data_size <= static_cast<int>(sizeof(local_mask) * 8 * kThreadsPerWarp) &&
"local_mask too small for data_size > 1024");

for (int k = 0; k < topk; k++) {
// Find the max value and its index
CompType val = (lane_id < data_size && !is_masked(k, lane_id))
? scores[lane_id]
: -std::numeric_limits<CompType>::infinity();
int index = (lane_id < data_size) ? lane_id : 0;
// Some value is hanlded in local thread
// Thread 0 is responsible for the: 0-th, 32-th, 64-th, 96-th ...
// Reduce the value in local thread
for (int i = lane_id + kThreadsPerWarp; i < data_size; i += kThreadsPerWarp) {
CompType cur_val = (is_masked(k, i)) ? -std::numeric_limits<CompType>::infinity() : scores[i];
if (cur_val > val) {
val = cur_val;
index = i;
CompType local_max_val = -std::numeric_limits<CompType>::infinity();
int local_max_idx = -1;

// 1) Per-lane local max on unmasked elements.
int bit_idx = 0;
for (int i = lane_id; i < data_size; i += kThreadsPerWarp) {
CompType cur_val = 0.0f;
uint32_t full_mask = -(uint32_t)((local_mask >> bit_idx) & 1u);
uint32_t x_bits = __float_as_uint(static_cast<CompType>(scores[i]));
uint32_t result_bits = (~full_mask & x_bits) | (full_mask & 0xFF800000u);
cur_val = __uint_as_float(result_bits);
if (cur_val > local_max_val) {
local_max_val = cur_val;
local_max_idx = i;
}
bit_idx++;
}
// Warp shuffle between threads
for (int s = 16; s > 0; s /= 2) {
auto shuffled_val = __shfl_xor_sync(0xffffffff, val, s);
auto shuffled_index = __shfl_xor_sync(0xffffffff, index, s);
if (shuffled_val > val) {
val = shuffled_val;
index = shuffled_index;

// 2) Warp reduction to find global max and index.
CompType global_max_val = local_max_val;
int global_max_idx = local_max_idx;
for (int s = kThreadsPerWarp / 2; s > 0; s /= 2) {
CompType shuffled_val = __shfl_down_sync(0xffffffff, global_max_val, s);
int shuffled_idx = __shfl_down_sync(0xffffffff, global_max_idx, s);
if (shuffled_val > global_max_val) {
global_max_val = shuffled_val;
global_max_idx = shuffled_idx;
}
}
global_max_idx = __shfl_sync(0xffffffff, global_max_idx, 0);
global_max_val = __shfl_sync(0xffffffff, global_max_val, 0);
Comment on lines +236 to +247
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 __shfl_down_sync reduction does not guarantee a stable winner on ties

The original code used a butterfly (__shfl_xor_sync) reduction pattern where all threads naturally converge to the same result in 5 rounds. The new __shfl_down_sync tree reduction still correctly delivers the global maximum to lane 0 (then broadcast via __shfl_sync), but the tie-breaking behaviour when two lanes hold the same maximum value is different from the original: with shuffled_val > global_max_val (strict >), the lower-indexed lane's index is preserved on a tie, which subtly differs from the XOR-butterfly order.

This is unlikely to matter for a floating-point router where exact ties are rare, but it is a behavioural change not called out in the PR description. If determinism across refactors matters for reproducibility testing, this should be documented or a consistent tie-breaking rule (e.g., prefer smaller index) should be explicitly enforced in both the value comparison and the index comparison:

if (shuffled_val > global_max_val ||
    (shuffled_val == global_max_val && shuffled_idx < global_max_idx)) {
    global_max_val = shuffled_val;
    global_max_idx = shuffled_idx;
}


// 3) Write top-k result.
if (lane_id == 0) {
topk_indices[k] = index;
topk_scores[k] = val;
topk_indices[k] = global_max_idx;
topk_scores[k] = static_cast<T>(global_max_val);
}

// 4) Mark selected element in owning lane's local mask.
if (global_max_idx >= 0 && (global_max_idx % kThreadsPerWarp) == lane_id) {
int local_bit_pos = global_max_idx / kThreadsPerWarp;
if (local_bit_pos < 32) {
local_mask |= (1u << local_bit_pos);
}
Comment on lines +222 to +260
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 32-bit local_mask silently breaks correctness for data_size > 1024

local_mask is a uint32_t, giving each lane exactly 32 bits to track up to 32 local elements. Each lane owns element indices at lane_id, lane_id + 32, lane_id + 64, …, so the maximum trackable data_size is 32 * kThreadsPerWarp = 1024.

When data_size > 1024, two related problems arise:

1. Undefined-behavior shift in the inner loop (line 226 / 221):

uint32_t full_mask = -(uint32_t)((local_mask >> bit_idx) & 1u);

bit_idx increments to 32+ when data_size > 1024. Right-shifting a uint32_t by ≥ 32 is undefined behaviour in C++. On CUDA PTX the hardware clamps the result to 0, so the element is effectively never masked — but that itself causes the second problem.

2. Silent double-selection via the guard on line 261:

if (local_bit_pos < 32) {
    local_mask |= (1u << local_bit_pos);
}

When global_max_idx / kThreadsPerWarp ≥ 32, the guard silently skips the bit-set, leaving the already-selected element unmasked. In the next top-k iteration the same element is eligible for selection again, producing duplicate indices in topk_indices.

The original is_masked lambda had no such limit because it scanned the full topk_indices array. While current usage (num_experts ≤ a few hundred) keeps data_size well under 1024, the silent failure mode is dangerous and should at minimum be guarded with a compile-time or runtime assertion:

// At function entry or as a static_assert at the call site:
assert(data_size <= static_cast<int>(sizeof(local_mask) * 8 * kThreadsPerWarp) &&
       "local_mask too small for data_size > 1024");

Or switch to a uint64_t mask (doubling the safe range to 2048) and add the assertion for anything larger.

}
__syncwarp();
}
__syncwarp();
}

// Current TE only support float32/bf16/fp16, float64 probs should be considered in the future
Expand Down