Skip to content

Conversation

@shubhamchandak94
Copy link
Contributor

Issue #, if available:

None, although closely related to aws-neuron/aws-neuron-sdk#1156.

Description of changes:

Existing dropout implementation in flash attention backward kernel had a couple issues:

  1. Using softmax_y post-dropout for computing softmax_dx_local.
  2. Not applying dropout to softmax_dy before using it to compute softmax_dx_local (subsequently used to compute dq and dk).

The CR updates the implementation to correctly comply with reference pseudocode as provided in https://arxiv.org/pdf/2205.14135 (Section B.4, algorithm 4).

Testing:

Please see detailed unit test requirements in the CONTRIBUTING.md

  • The change is covered by numeric check using nki.baremetal
  • The change is covered by performance benchmark test using nki.benchmark
  • The change is covered by end-to-end integration test

I tested locally with a golden function to make sure output is accurate and performance is as expected with and without dropout.

Pull Request Checklist

  • I have filled in all the required field in the template
  • I have tested locally that all the tests pass
  • By submitting this pull request, I confirm that my contribution is made under the terms of the MIT-0 license.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants