-
Notifications
You must be signed in to change notification settings - Fork 668
Optimize naive top-k masking in fused router #2783
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
8ad26ca
d6dfdcf
28844c1
1958bb4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -7,6 +7,8 @@ | |
| #ifndef TRANSFORMER_ENGINE_FUSED_ROUTER_UTILS_H_ | ||
| #define TRANSFORMER_ENGINE_FUSED_ROUTER_UTILS_H_ | ||
|
|
||
| #include <cassert> | ||
|
|
||
| #include "transformer_engine/transformer_engine.h" | ||
|
|
||
| namespace transformer_engine { | ||
|
|
@@ -203,50 +205,62 @@ __device__ inline void apply_softmax_on_float(float *scores, int data_size, int | |
| __syncwarp(); | ||
| } | ||
|
|
||
| __device__ inline void naive_topk_and_mask(CompType *scores, int data_size, int topk, | ||
| int *topk_indices, CompType *topk_scores, int lane_id) { | ||
| // Check if the index is masked by the later iteration | ||
| auto is_masked = [&topk_indices](int k, int index) { | ||
| if (k == 0) return false; | ||
| for (int i = 0; i < k; i++) { | ||
| if (topk_indices[i] == index) return true; | ||
| } | ||
| return false; | ||
| }; | ||
| // Topk Times: Find the max value and its index | ||
| // Then mask it, and record the index in the topk_indices | ||
| // After looping topk times, the topk_indices will be the topk indices | ||
| template <typename T> | ||
| __device__ inline void naive_topk_and_mask(T *scores, int data_size, int topk, int *topk_indices, | ||
| T *topk_scores, int lane_id) { | ||
| // Bit i indicates whether the i-th local element (lane_id + i * warp_size) was selected. | ||
| uint32_t local_mask = 0; | ||
| assert(data_size <= static_cast<int>(sizeof(local_mask) * 8 * kThreadsPerWarp) && | ||
| "local_mask too small for data_size > 1024"); | ||
|
|
||
| for (int k = 0; k < topk; k++) { | ||
| // Find the max value and its index | ||
| CompType val = (lane_id < data_size && !is_masked(k, lane_id)) | ||
| ? scores[lane_id] | ||
| : -std::numeric_limits<CompType>::infinity(); | ||
| int index = (lane_id < data_size) ? lane_id : 0; | ||
| // Some value is hanlded in local thread | ||
| // Thread 0 is responsible for the: 0-th, 32-th, 64-th, 96-th ... | ||
| // Reduce the value in local thread | ||
| for (int i = lane_id + kThreadsPerWarp; i < data_size; i += kThreadsPerWarp) { | ||
| CompType cur_val = (is_masked(k, i)) ? -std::numeric_limits<CompType>::infinity() : scores[i]; | ||
| if (cur_val > val) { | ||
| val = cur_val; | ||
| index = i; | ||
| CompType local_max_val = -std::numeric_limits<CompType>::infinity(); | ||
| int local_max_idx = -1; | ||
|
|
||
| // 1) Per-lane local max on unmasked elements. | ||
| int bit_idx = 0; | ||
| for (int i = lane_id; i < data_size; i += kThreadsPerWarp) { | ||
| CompType cur_val = 0.0f; | ||
| uint32_t full_mask = -(uint32_t)((local_mask >> bit_idx) & 1u); | ||
| uint32_t x_bits = __float_as_uint(static_cast<CompType>(scores[i])); | ||
| uint32_t result_bits = (~full_mask & x_bits) | (full_mask & 0xFF800000u); | ||
| cur_val = __uint_as_float(result_bits); | ||
| if (cur_val > local_max_val) { | ||
| local_max_val = cur_val; | ||
| local_max_idx = i; | ||
| } | ||
| bit_idx++; | ||
| } | ||
| // Warp shuffle between threads | ||
| for (int s = 16; s > 0; s /= 2) { | ||
| auto shuffled_val = __shfl_xor_sync(0xffffffff, val, s); | ||
| auto shuffled_index = __shfl_xor_sync(0xffffffff, index, s); | ||
| if (shuffled_val > val) { | ||
| val = shuffled_val; | ||
| index = shuffled_index; | ||
|
|
||
| // 2) Warp reduction to find global max and index. | ||
| CompType global_max_val = local_max_val; | ||
| int global_max_idx = local_max_idx; | ||
| for (int s = kThreadsPerWarp / 2; s > 0; s /= 2) { | ||
| CompType shuffled_val = __shfl_down_sync(0xffffffff, global_max_val, s); | ||
| int shuffled_idx = __shfl_down_sync(0xffffffff, global_max_idx, s); | ||
| if (shuffled_val > global_max_val) { | ||
| global_max_val = shuffled_val; | ||
| global_max_idx = shuffled_idx; | ||
| } | ||
| } | ||
| global_max_idx = __shfl_sync(0xffffffff, global_max_idx, 0); | ||
| global_max_val = __shfl_sync(0xffffffff, global_max_val, 0); | ||
|
|
||
| // 3) Write top-k result. | ||
| if (lane_id == 0) { | ||
| topk_indices[k] = index; | ||
| topk_scores[k] = val; | ||
| topk_indices[k] = global_max_idx; | ||
| topk_scores[k] = static_cast<T>(global_max_val); | ||
| } | ||
|
|
||
| // 4) Mark selected element in owning lane's local mask. | ||
| if (global_max_idx >= 0 && (global_max_idx % kThreadsPerWarp) == lane_id) { | ||
| int local_bit_pos = global_max_idx / kThreadsPerWarp; | ||
| if (local_bit_pos < 32) { | ||
| local_mask |= (1u << local_bit_pos); | ||
| } | ||
|
Comment on lines
+222
to
+260
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
When 1. Undefined-behavior shift in the inner loop (line 226 / 221): uint32_t full_mask = -(uint32_t)((local_mask >> bit_idx) & 1u);
2. Silent double-selection via the guard on line 261: if (local_bit_pos < 32) {
local_mask |= (1u << local_bit_pos);
}When The original // At function entry or as a static_assert at the call site:
assert(data_size <= static_cast<int>(sizeof(local_mask) * 8 * kThreadsPerWarp) &&
"local_mask too small for data_size > 1024");Or switch to a |
||
| } | ||
| __syncwarp(); | ||
| } | ||
| __syncwarp(); | ||
| } | ||
|
|
||
| // Current TE only support float32/bf16/fp16, float64 probs should be considered in the future | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
__shfl_down_syncreduction does not guarantee a stable winner on tiesThe original code used a butterfly (
__shfl_xor_sync) reduction pattern where all threads naturally converge to the same result in 5 rounds. The new__shfl_down_synctree reduction still correctly delivers the global maximum to lane 0 (then broadcast via__shfl_sync), but the tie-breaking behaviour when two lanes hold the same maximum value is different from the original: withshuffled_val > global_max_val(strict>), the lower-indexed lane's index is preserved on a tie, which subtly differs from the XOR-butterfly order.This is unlikely to matter for a floating-point router where exact ties are rare, but it is a behavioural change not called out in the PR description. If determinism across refactors matters for reproducibility testing, this should be documented or a consistent tie-breaking rule (e.g., prefer smaller index) should be explicitly enforced in both the value comparison and the index comparison:
if (shuffled_val > global_max_val || (shuffled_val == global_max_val && shuffled_idx < global_max_idx)) { global_max_val = shuffled_val; global_max_idx = shuffled_idx; }