Support option to skip the optimizer for training step#3490
Open
Support option to skip the optimizer for training step#3490
Conversation
54b4b54 to
8682ecf
Compare
Codecov Report❌ Patch coverage is
📢 Thoughts on this report? Let us know! |
d2f5c75 to
a51468c
Compare
|
🤖 Hi @RissyRan, I've received your request, and I'm working on it now! You can track my progress in the logs for more details. |
There was a problem hiding this comment.
This PR successfully implements an optimizer wrapper to skip training steps during severe loss or gradient anomalies, effectively porting the OLMo-core logic to MaxText using JAX. The core logic elegantly computes rolling statistics and appropriately bypasses the inner optimizer during a spike to prevent momentum poisoning.
🔍 General Feedback
- JAX Idioms: The usage of
jax.lax.condto defer and conditionalize the inner optimizer step is very cleanly implemented. - Resilience: Added a few critical suggestions to explicitly handle
NaNorInfloss cases. Preventing buffer poisoning and explicitly skipping on non-finite metrics will make this logic foolproof against catastrophic anomalies. - Kwargs Forwarding: Recommended using
.pop()on**extra_argsto ensure consumed arguments likelossaren't passed downstream, guaranteeing better compatibility with inner optimizers.
a51468c to
5b8835d
Compare
5b8835d to
960398a
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Description
This PR introduces a mechanism to skip training steps during severe loss or gradient anomalies (b/489540436). Reference implementation at OLMo-core.
base.yml&types.pyskip_step_on_spikesas anoptax.GradientTransformationExtraArgswrapperTests
tests/unit/optimizers_test.pyChecklist
Before submitting this PR, please make sure (put X in square brackets):
gemini-reviewlabel.