Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 7 additions & 14 deletions transformer_engine/common/permutation/permutation.cu
Original file line number Diff line number Diff line change
Expand Up @@ -19,19 +19,12 @@ static __global__ void moe_permute_row_map(const int *sorted_row_id, int *row_id
const int tid = threadIdx.x;
const int idx = bid * blockDim.x + tid;

if (idx >= num_rows * topK) return;
if (idx >= num_out_tokens) return;

int source_row = sorted_row_id[idx];
int source_token_id = source_row / topK;
int source_topK_id = source_row % topK;

if (idx >= num_out_tokens) {
// Set the indices of dropped tokens to -1
row_id_map[source_topK_id * num_rows + source_token_id] = -1;
} else {
// Create a row id map for subsequent unpermute operation
row_id_map[source_topK_id * num_rows + source_token_id] = idx;
}
row_id_map[source_topK_id * num_rows + source_token_id] = idx;
}

template <typename T, typename TCompute, bool hasProb>
Expand All @@ -42,7 +35,7 @@ __global__ void moe_unpermute_kernel(const T *input, T *unpermuted_output, const
TCompute *s_prob = reinterpret_cast<TCompute *>(s_mem);

// Each block corresponds to one dest token
const int source_token = blockIdx.x;
const int64_t source_token = blockIdx.x;
const int tid = threadIdx.x;

if (hasProb) {
Expand All @@ -65,7 +58,7 @@ __global__ void moe_unpermute_kernel(const T *input, T *unpermuted_output, const
TCompute frag_elem[kElementsPerAccess];
TCompute frag_sum[kElementsPerAccess];

int source_row = row_id_map[source_token];
int64_t source_row = row_id_map[source_token];

// source_row == -1 represents a dropped token
if (source_row != -1) {
Expand Down Expand Up @@ -134,7 +127,7 @@ __global__ void moe_permute_kernel(const T *input_bwd, const T *input_fwd, T *ac
TCompute *s_prob = reinterpret_cast<TCompute *>(s_mem);

// Each block corresponds to one source token
const int source_token = blockIdx.x;
const int64_t source_token = blockIdx.x;
const int tid = threadIdx.x;

if (hasProb) {
Expand Down Expand Up @@ -172,7 +165,7 @@ __global__ void moe_permute_kernel(const T *input_bwd, const T *input_fwd, T *ac
for (int k = 0; k < topKTile; k++) {
if (k == topK) break;

int dest_row = row_id_map[index];
int64_t dest_row = row_id_map[index];
index += num_rows;

if (dest_row != -1) {
Expand Down Expand Up @@ -239,7 +232,7 @@ void nvte_permute_launcher(const T *input, T *output, const int *sorted_row_id,
// moe_permute_fwd

int threads = 64;
int blocks = (num_rows * topK + threads - 1) / threads;
int blocks = (num_out_tokens + threads - 1) / threads;
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is correct here but has an implied prerequisite that host prefills the buffer with -1 and shift the ptr by num_minus_ones (what you did in the other file). Better make it more explicit with a comment so no regression will happen by someone accidentally changing this behavior and mess up the number of blocks here. Something like:

"// row_id_map MUST be pre-initialized to -1; sorted_row_id MUST point past the sentinel prefix"


moe_permute_row_map<<<blocks, threads, 0, stream>>>(sorted_row_id, row_id_map, num_rows, topK,
num_out_tokens);
Expand Down
18 changes: 13 additions & 5 deletions transformer_engine/pytorch/csrc/extensions/permutation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,22 @@ std::tuple<at::Tensor, at::Tensor, std::vector<at::Tensor>> moe_permute_fwd(
reinterpret_cast<int *>(sorted_indices_ptr), reinterpret_cast<int *>(row_id_ptr),
reinterpret_cast<int *>(sorted_row_id_ptr), num_tokens * topK);

// Output buffer alloc
// Signed radix sort places -1 sentinel entries (e.g. expert-parallel rank mask)
// at the HEAD of sorted_row_id. Skip that prefix so the kernel sees only the
// valid suffix, and pre-fill row_id_map with -1 so the dropped slots are marked
// without the kernel ever dereferencing a sentinel.
num_out_tokens = (num_out_tokens > 0) ? num_out_tokens : num_tokens * topK;
NVTE_CHECK(num_out_tokens <= num_tokens * topK, "num_out_tokens (", num_out_tokens,
") must not exceed num_tokens*topK (", num_tokens * topK, ")");
const int num_minus_ones = num_tokens * topK - num_out_tokens;
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is probably going to introduce a regression for the capacity-drop path. This shift assumes the dropped routes are -1 sentinels at the head of sorted_row_id (cub's signed radix sort), which is true for the EP-mask case this PR targets. But the pre-existing capacity-drop path encodes drops as a large positive expert id that sorts to the tail. For that case, the head is valid low-expert-id rows, and shifting past them drops the wrong tokens.(just fyi, capacity-dropping case means no -1 in indices, num_out_tokens < num_tokens * topK because some expert exceeded capacity))

See in this file tests/pytorch/test_permutation.py, in pytorch_permute_index_map, we have:

sorted_indices[:num_out_tokens] (keeps the head),
so I'd expect test_permutation_index_map[..., num_out_tokens=2039, ...] to fail. We can run the te_ci to confirm it.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think another solution to this without doing num_tokens * topk - num_out_tokens (or counting the number of -1 on host side) is to sort the keys as uint32_t instead of int32_t. So, -1 becomes UINT_MAX and sorts to the tail, unifying both capacity-dropping and dropless under the original idx >= num_out_tokens --> drop logic. That removes the need for the prefix shift you did, and the row_id_map pre-fill. This just needs expert_id to be <= UINT_MAX, which I do not think we are reaching there anytime soon

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the careful review. Acknowledging the capacity-drop regression concern and the unsigned-sort suggestion below — both make sense. Waiting on the te_ci result you triggered before I push any code change, so we have a concrete signal on what needs to move.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorted_row_id_ptr = reinterpret_cast<char *>(sorted_row_id_ptr) +
static_cast<size_t>(num_minus_ones) * sizeof(int);
Comment on lines +61 to +63
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 num_tokens * topK still computed as int * int

num_tokens and topK are both int, so num_tokens * topK on line 61 is evaluated in 32-bit arithmetic before the result feeds the int64_t subtraction. The same expression appears twice in the NVTE_CHECK on lines 59–60. If num_tokens * topK wraps to a negative int (possible when, e.g., num_tokens ≥ 2^31 / topK), the NVTE_CHECK would either spuriously reject a valid num_out_tokens, or the error-message value would be wrong. Casting to int64_t before the multiplication closes this gap:

Suggested change
const int num_minus_ones = num_tokens * topK - num_out_tokens;
sorted_row_id_ptr = reinterpret_cast<char *>(sorted_row_id_ptr) +
static_cast<size_t>(num_minus_ones) * sizeof(int);
const int64_t total_tokens = static_cast<int64_t>(num_tokens) * topK;
NVTE_CHECK(num_out_tokens <= total_tokens, "num_out_tokens (", num_out_tokens,
") must not exceed num_tokens*topK (", total_tokens, ")");
const int num_minus_ones = static_cast<int>(total_tokens - num_out_tokens);
sorted_row_id_ptr = reinterpret_cast<char *>(sorted_row_id_ptr) +
static_cast<size_t>(num_minus_ones) * sizeof(int);

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This (the greptile command) looks correct. Can you please help cast num_tpkens to int64 before multiplication and - num_out_tokens?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not a huge deal because even with topK=128, youwould need > 16M tokens per rank for the int product to overflow. But better to be consistent, and also, this casting of 1 value on the CPU side probably would not slow dow much

at::Tensor permuted_output =
torch::empty({num_out_tokens, num_cols},
torch::dtype(input.scalar_type()).device(torch::kCUDA).requires_grad(false));
at::Tensor row_id_map = torch::empty(
{num_tokens * topK}, torch::dtype(torch::kInt32).device(torch::kCUDA).requires_grad(false));
at::Tensor row_id_map =
torch::full({num_tokens * topK}, -1,
torch::dtype(torch::kInt32).device(torch::kCUDA).requires_grad(false));

auto stream = at::cuda::getCurrentCUDAStream().stream();

Expand All @@ -71,8 +80,7 @@ std::tuple<at::Tensor, at::Tensor, std::vector<at::Tensor>> moe_permute_fwd(
static_cast<size_t>(num_cols)},
dtype);
auto sorted_row_id_cu = makeTransformerEngineTensor(
sorted_row_id_ptr, std::vector<size_t>{static_cast<size_t>(num_tokens * topK)},
DType::kInt32);
sorted_row_id_ptr, std::vector<size_t>{static_cast<size_t>(num_out_tokens)}, DType::kInt32);
auto row_id_map_cu = makeTransformerEngineTensor(row_id_map);

nvte_permute(input_cu.data(), permuted_output_cu.data(), sorted_row_id_cu.data(),
Expand Down
Loading