Skip to content

Support option to skip the optimizer for training step#3490

Open
RissyRan wants to merge 1 commit intomainfrom
skip_optimizer
Open

Support option to skip the optimizer for training step#3490
RissyRan wants to merge 1 commit intomainfrom
skip_optimizer

Conversation

@RissyRan
Copy link
Collaborator

@RissyRan RissyRan commented Mar 24, 2026

Description

This PR introduces a mechanism to skip training steps during severe loss or gradient anomalies (b/489540436). Reference implementation at OLMo-core.

  • Add configs in base.yml & types.py
  • Implemented skip_step_on_spikes as an optax.GradientTransformationExtraArgs wrapper
  • Training loop integration for optimizer updates

Tests

  • Add a unit test in tests/unit/optimizers_test.py
  • End-to-end training functional comparison: link

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@RissyRan RissyRan force-pushed the skip_optimizer branch 3 times, most recently from 54b4b54 to 8682ecf Compare March 24, 2026 04:25
@codecov
Copy link

codecov bot commented Mar 24, 2026

Codecov Report

❌ Patch coverage is 87.03704% with 7 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/maxtext/trainers/pre_train/train.py 28.57% 4 Missing and 1 partial ⚠️
src/maxtext/optimizers/optimizers.py 95.74% 1 Missing and 1 partial ⚠️

📢 Thoughts on this report? Let us know!

@github-actions
Copy link

🤖 Hi @RissyRan, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

## 📋 Review Summary

This PR successfully implements an optimizer wrapper to skip training steps during severe loss or gradient anomalies, effectively porting the OLMo-core logic to MaxText using JAX. The core logic elegantly computes rolling statistics and appropriately bypasses the inner optimizer during a spike to prevent momentum poisoning.

🔍 General Feedback

  • JAX Idioms: The usage of jax.lax.cond to defer and conditionalize the inner optimizer step is very cleanly implemented.
  • Resilience: Added a few critical suggestions to explicitly handle NaN or Inf loss cases. Preventing buffer poisoning and explicitly skipping on non-finite metrics will make this logic foolproof against catastrophic anomalies.
  • Kwargs Forwarding: Recommended using .pop() on **extra_args to ensure consumed arguments like loss aren't passed downstream, guaranteeing better compatibility with inner optimizers.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants