Fused Cross Entropy Triton - Loss Scaling and Vanishing Grads Bugfix#336
Fused Cross Entropy Triton - Loss Scaling and Vanishing Grads Bugfix#336sarthak-amd wants to merge 4 commits intodevfrom
Conversation
|
@sarthak-amd Could you post the PR for the upstream fix? |
| @@ -1,3 +1,5 @@ | |||
| # This file was modified for portability to AMDGPU | |||
There was a problem hiding this comment.
There is no real change in this file. Let's keep this file intact and then we don't need to add the AMD copyright statement.
|
Another fix came from the upstream PR NVIDIA/TransformerEngine#1879. Is the change of test in that PR also reflected? |
|
For the fix for |
wenchenvincent
left a comment
There was a problem hiding this comment.
@sarthak-amd Could you refactor the PR as 3 commits:
- 2 commits would be cherrypicking from the upstream PRs.
- 1 commit for the
ignore_idxwith a test to cover it.
This way the PR would be very clear and easy to understand.
|
@sarthak-amd Could you address the comments? Also, please rebase upon latest dev so that hot fixes for sgpu tests could pass. |
|
@sarthak-amd Could you rebase the PR and update it per reviewer comments? |
Description
The Fused Cross Entropy Triton Kernel currently has 2 bugs
ignore_idxis not None`, the loss should be computed only over valid tokens and not all tokens (new fix)reduce_loss=False. (This is already fixed in upstream)reduced loss=False, we should compute per token loss and not reduce it else it would shrink the gradients by 1/N giving wrong (higher) loss.reduce_loss=False,grad_outputis a tensor, not a scalar. We need to load 1 value per row instead of just a scalar.This fix is validated on Llama3.1 8B model for Pre-training.
Type of change