fix(transform): cast updates to grad dtype for mixed-precision support#1687
Open
nileshpatil6 wants to merge 1 commit into
Open
fix(transform): cast updates to grad dtype for mixed-precision support#1687nileshpatil6 wants to merge 1 commit into
nileshpatil6 wants to merge 1 commit into
Conversation
In mixed-precision training params are float32 while grads are bfloat16. Optimizer accumulator states are initialized from params (float32), so arithmetic with bfloat16 grads promotes the result to float32. The cast back to param dtype should happen at the apply_updates step, not inside the transform. This adds an explicit cast of the output updates to match the input grad dtype in scale_by_adam, scale_by_amsgrad, scale_by_adamax, scale_by_rms, scale_by_stddev, scale_by_belief, scale_by_yogi, scale_by_radam, scale_by_lion, scale_by_adan, and scale_by_novograd. Adds a parameterized test that verifies each affected transform produces bfloat16 updates when given float32 params and bfloat16 grads. Fixes google-deepmind#1098
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Fixes #1098
Problem
In mixed-precision training, params are float32 while grads are bfloat16. Optimizer
accumulator states (mu, nu) are initialized from params so they start as float32. When
update is called with bfloat16 grads, JAX type promotion causes the output updates to
be float32 instead of bfloat16. The cast back to param dtype should happen at the
apply_updatesstep, not inside the transform.This is the gap identified in #1098 after #1060 was merged: #1060 ensured updates
match params dtype, but in mixed precision grads and params have different dtypes and
updates should track the grad dtype.
Fix
Each affected
update_fnnow casts the output updates to match the dtype of theincoming gradients. Accumulator states are left in their computed dtype.
Transforms changed:
scale_by_adam,scale_by_amsgrad,scale_by_adamax,scale_by_rms,scale_by_stddev,scale_by_belief,scale_by_yogi,scale_by_radam,scale_by_lion,scale_by_adan,scale_by_novograd.Test
Added
test_mixed_precision_dtypeintransform_test.py. It gives each transformfloat32 params and bfloat16 grads and checks that updates come back bfloat16.
All existing tests pass.