Skip to content

fix(transform): cast updates to grad dtype for mixed-precision support#1687

Open
nileshpatil6 wants to merge 1 commit into
google-deepmind:mainfrom
nileshpatil6:fix/optimizer-dtype-tracks-grads
Open

fix(transform): cast updates to grad dtype for mixed-precision support#1687
nileshpatil6 wants to merge 1 commit into
google-deepmind:mainfrom
nileshpatil6:fix/optimizer-dtype-tracks-grads

Conversation

@nileshpatil6
Copy link
Copy Markdown

Fixes #1098

Problem

In mixed-precision training, params are float32 while grads are bfloat16. Optimizer
accumulator states (mu, nu) are initialized from params so they start as float32. When
update is called with bfloat16 grads, JAX type promotion causes the output updates to
be float32 instead of bfloat16. The cast back to param dtype should happen at the
apply_updates step, not inside the transform.

This is the gap identified in #1098 after #1060 was merged: #1060 ensured updates
match params dtype, but in mixed precision grads and params have different dtypes and
updates should track the grad dtype.

Fix

Each affected update_fn now casts the output updates to match the dtype of the
incoming gradients. Accumulator states are left in their computed dtype.

Transforms changed: scale_by_adam, scale_by_amsgrad, scale_by_adamax,
scale_by_rms, scale_by_stddev, scale_by_belief, scale_by_yogi,
scale_by_radam, scale_by_lion, scale_by_adan, scale_by_novograd.

Test

Added test_mixed_precision_dtype in transform_test.py. It gives each transform
float32 params and bfloat16 grads and checks that updates come back bfloat16.

All existing tests pass.

In mixed-precision training params are float32 while grads are bfloat16.
Optimizer accumulator states are initialized from params (float32), so
arithmetic with bfloat16 grads promotes the result to float32. The
cast back to param dtype should happen at the apply_updates step, not
inside the transform.

This adds an explicit cast of the output updates to match the input
grad dtype in scale_by_adam, scale_by_amsgrad, scale_by_adamax,
scale_by_rms, scale_by_stddev, scale_by_belief, scale_by_yogi,
scale_by_radam, scale_by_lion, scale_by_adan, and scale_by_novograd.

Adds a parameterized test that verifies each affected transform
produces bfloat16 updates when given float32 params and bfloat16 grads.

Fixes google-deepmind#1098
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Updates dtype do not need to match params dtype but only grads dtype a priori

1 participant