-
Notifications
You must be signed in to change notification settings - Fork 718
[Common][PyTorch] Fix int32 overflow and -1 sentinel handling in moe_permute #2907
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -51,13 +51,22 @@ std::tuple<at::Tensor, at::Tensor, std::vector<at::Tensor>> moe_permute_fwd( | |||||||||||||||||||
| reinterpret_cast<int *>(sorted_indices_ptr), reinterpret_cast<int *>(row_id_ptr), | ||||||||||||||||||||
| reinterpret_cast<int *>(sorted_row_id_ptr), num_tokens * topK); | ||||||||||||||||||||
|
|
||||||||||||||||||||
| // Output buffer alloc | ||||||||||||||||||||
| // Signed radix sort places -1 sentinel entries (e.g. expert-parallel rank mask) | ||||||||||||||||||||
| // at the HEAD of sorted_row_id. Skip that prefix so the kernel sees only the | ||||||||||||||||||||
| // valid suffix, and pre-fill row_id_map with -1 so the dropped slots are marked | ||||||||||||||||||||
| // without the kernel ever dereferencing a sentinel. | ||||||||||||||||||||
| num_out_tokens = (num_out_tokens > 0) ? num_out_tokens : num_tokens * topK; | ||||||||||||||||||||
| NVTE_CHECK(num_out_tokens <= num_tokens * topK, "num_out_tokens (", num_out_tokens, | ||||||||||||||||||||
| ") must not exceed num_tokens*topK (", num_tokens * topK, ")"); | ||||||||||||||||||||
| const int num_minus_ones = num_tokens * topK - num_out_tokens; | ||||||||||||||||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is probably going to introduce a regression for the capacity-drop path. This shift assumes the dropped routes are -1 sentinels at the head of sorted_row_id (cub's signed radix sort), which is true for the EP-mask case this PR targets. But the pre-existing capacity-drop path encodes drops as a large positive expert id that sorts to the tail. For that case, the head is valid low-expert-id rows, and shifting past them drops the wrong tokens.(just fyi, capacity-dropping case means no -1 in indices, num_out_tokens < num_tokens * topK because some expert exceeded capacity)) See in this file
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think another solution to this without doing num_tokens * topk - num_out_tokens (or counting the number of -1 on host side) is to sort the keys as
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for the careful review. Acknowledging the capacity-drop regression concern and the unsigned-sort suggestion below — both make sense. Waiting on the te_ci result you triggered before I push any code change, so we have a concrete signal on what needs to move.
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here is the CI pipeline: https://gitlab-master.nvidia.com/dl/transformerengine/transformerengine/-/pipelines/50478896 It failed in the expected tests |
||||||||||||||||||||
| sorted_row_id_ptr = reinterpret_cast<char *>(sorted_row_id_ptr) + | ||||||||||||||||||||
| static_cast<size_t>(num_minus_ones) * sizeof(int); | ||||||||||||||||||||
|
Comment on lines
+61
to
+63
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This (the greptile command) looks correct. Can you please help cast num_tpkens to int64 before multiplication and - num_out_tokens?
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is not a huge deal because even with topK=128, youwould need > 16M tokens per rank for the int product to overflow. But better to be consistent, and also, this casting of 1 value on the CPU side probably would not slow dow much |
||||||||||||||||||||
| at::Tensor permuted_output = | ||||||||||||||||||||
| torch::empty({num_out_tokens, num_cols}, | ||||||||||||||||||||
| torch::dtype(input.scalar_type()).device(torch::kCUDA).requires_grad(false)); | ||||||||||||||||||||
| at::Tensor row_id_map = torch::empty( | ||||||||||||||||||||
| {num_tokens * topK}, torch::dtype(torch::kInt32).device(torch::kCUDA).requires_grad(false)); | ||||||||||||||||||||
| at::Tensor row_id_map = | ||||||||||||||||||||
| torch::full({num_tokens * topK}, -1, | ||||||||||||||||||||
| torch::dtype(torch::kInt32).device(torch::kCUDA).requires_grad(false)); | ||||||||||||||||||||
|
|
||||||||||||||||||||
| auto stream = at::cuda::getCurrentCUDAStream().stream(); | ||||||||||||||||||||
|
|
||||||||||||||||||||
|
|
@@ -71,8 +80,7 @@ std::tuple<at::Tensor, at::Tensor, std::vector<at::Tensor>> moe_permute_fwd( | |||||||||||||||||||
| static_cast<size_t>(num_cols)}, | ||||||||||||||||||||
| dtype); | ||||||||||||||||||||
| auto sorted_row_id_cu = makeTransformerEngineTensor( | ||||||||||||||||||||
| sorted_row_id_ptr, std::vector<size_t>{static_cast<size_t>(num_tokens * topK)}, | ||||||||||||||||||||
| DType::kInt32); | ||||||||||||||||||||
| sorted_row_id_ptr, std::vector<size_t>{static_cast<size_t>(num_out_tokens)}, DType::kInt32); | ||||||||||||||||||||
| auto row_id_map_cu = makeTransformerEngineTensor(row_id_map); | ||||||||||||||||||||
|
|
||||||||||||||||||||
| nvte_permute(input_cu.data(), permuted_output_cu.data(), sorted_row_id_cu.data(), | ||||||||||||||||||||
|
|
||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is correct here but has an implied prerequisite that host prefills the buffer with -1 and shift the ptr by num_minus_ones (what you did in the other file). Better make it more explicit with a comment so no regression will happen by someone accidentally changing this behavior and mess up the number of blocks here. Something like:
"// row_id_map MUST be pre-initialized to -1; sorted_row_id MUST point past the sentinel prefix"