Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion src/maxtext/configs/base.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023–2025 Google LLC
# Copyright 2023–2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -790,6 +790,15 @@ gradient_clipping_threshold: 1.0
gradient_accumulation_steps: 1

opt_type: "adamw" # one of "adamw", "adam_pax", "sgd", or "muon"

# If True, skip the training step when loss or gradient spike is detected
# No updates for both weights and momentums (if applies)
skip_step_on_spikes: False
# The rolling interval to calculate the mean and standard deviation
skip_step_interval: 128
# The scaling factor to determine if a spike occurred
skip_step_scaling_factor: 6.0

# List of parameter names/patterns to train.
# If non-empty, all other parameters will be frozen. Example: ['.*indexer.*'].
# If empty (default), all parameters are trained.
Expand Down
9 changes: 8 additions & 1 deletion src/maxtext/configs/types.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023–2025 Google LLC
# Copyright 2023–2026 Google LLC

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -1150,6 +1150,13 @@ class Optimizer(BaseModel):
"""Configuration for the optimizer and learning rate schedule."""

opt_type: OptimizerType = Field(OptimizerType.ADAMW, description="The type of optimizer to use.")
skip_step_on_spikes: bool = Field(
False, description="If True, skip the training step when loss or gradient spike is detected."
)
skip_step_interval: PositiveInt = Field(
128, description="The rolling interval to calculate the mean and standard deviation."
)
skip_step_scaling_factor: float = Field(6.0, description="The scaling factor to determine if a spike occurred.")
gradient_accumulation_steps: PositiveInt = Field(
1, description="Number of steps to accumulate gradients before updating."
)
Expand Down
120 changes: 120 additions & 0 deletions src/maxtext/optimizers/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,119 @@ def get_adamw_mask(config):
return _get_path_mask_fn(getattr(config, "adamw_mask", None), match_returns_true=False)


def _compute_rolling_stats(arr: jax.Array, count: jax.Array, interval: int):
"""Computes mean and unbiased std (Bessel's correction) over a rolling window."""
valid_elements = jnp.minimum(count, interval)
safe_elements = jnp.maximum(1, valid_elements)
mask = jnp.arange(interval) < valid_elements

mean = jnp.sum(jnp.where(mask, arr, 0.0)) / safe_elements
sq_diff = jnp.where(mask, (arr - mean) ** 2, 0.0)

# Use Bessel's correction (N - 1) for unbiased variance to align with torch.std
variance = jnp.sum(sq_diff) / jnp.maximum(1, valid_elements - 1)
std = jnp.sqrt(variance)
return mean, std


def skip_step_on_spikes(
inner_opt: optax.GradientTransformation, interval: int, scaling_factor: float
) -> optax.GradientTransformationExtraArgs:
"""Wrapper that skips updates when loss or grad_norm spike.

This wrapper calculates a rolling mean and standard deviation (using
Bessel's correction) over the last `interval` steps for both the loss
and the gradient norm. If the current step's loss or gradient norm
exceeds `mean + scaling_factor * std`, the update is zeroed and the
optimizer state is not advanced, effectively skipping the step.

Reference implementation:
https://github.com/allenai/OLMo-core/blob/c757b7c3c15197154c753d883330afbfa4869dcc/src/olmo_core/optim/skip_step_optimizer.py#L12

Args:
inner_opt: The inner Optax gradient transformation to wrap.
interval: The number of recent steps to use for calculating mean and std.
scaling_factor: The multiplier for standard deviation to set the spike threshold.

Returns:
An optax.GradientTransformationExtraArgs that skips spikes.
"""

def init_fn(params):
return {
"inner_state": inner_opt.init(params),
"losses": jnp.zeros(interval, dtype=jnp.float32),
"grad_norms": jnp.zeros(interval, dtype=jnp.float32),
"count": jnp.zeros((), dtype=jnp.int32),
}

def update_fn(updates, state, params=None, **extra_args):
# Using `pop()` removes `loss` and `grad_norm` from `extra_args` before they are
# passed downstream to `inner_opt.update()`. This prevents `TypeError` if the
# inner optimizer doesn't explicitly accept these as `kwargs`.
loss = extra_args.pop("loss", None)
grad_norm = extra_args.pop("grad_norm", None)

# Fallback to standard update if loss is not provided
if loss is None:
inner_updates, new_inner_state = inner_opt.update(updates, state["inner_state"], params, **extra_args)
return inner_updates, {
"inner_state": new_inner_state,
"losses": state["losses"],
"grad_norms": state["grad_norms"],
"count": state["count"],
}

count = state["count"]
losses = state["losses"]
grad_norms = state["grad_norms"]

# Compute rolling stats
loss_mean, loss_std = _compute_rolling_stats(losses, count, interval)
grad_norm_mean, grad_norm_std = _compute_rolling_stats(grad_norms, count, interval)

# Check if the current metrics are within the allowed thresholds
is_loss_ok = (loss - loss_mean) <= scaling_factor * loss_std
if grad_norm is not None:
is_grad_norm_ok = (grad_norm - grad_norm_mean) <= scaling_factor * grad_norm_std
is_ok = jnp.logical_and(is_loss_ok, is_grad_norm_ok)
else:
is_ok = is_loss_ok

# Only enforce skip if we have at least half the interval filled (or 2 elements minimum)
min_history = max(2, interval // 2)
is_warmup = (count + 1) < min_history
is_ok = jnp.logical_or(is_warmup, is_ok)

# Conditionally execute the inner optimizer to prevent momentum poisoning
def do_update():
return inner_opt.update(updates, state["inner_state"], params, **extra_args)

def skip_update():
inner_updates = jax.tree_util.tree_map(jnp.zeros_like, updates)
return inner_updates, state["inner_state"]

inner_updates, new_inner_state = jax.lax.cond(is_ok, do_update, skip_update)

# Update rolling buffers (append even if skipped so spikes can become the new baseline)
idx = count % interval
new_losses = losses.at[idx].set(loss)

new_grad_norms = grad_norms
if grad_norm is not None:
new_grad_norms = grad_norms.at[idx].set(grad_norm)

new_state = {
"inner_state": new_inner_state,
"losses": new_losses,
"grad_norms": new_grad_norms,
"count": count + 1,
}
return inner_updates, new_state

return optax.GradientTransformationExtraArgs(init_fn, update_fn)


def get_optimizer(config, learning_rate_schedule, model=None):
"""Create optimizer."""
if config.opt_type == "adamw":
Expand Down Expand Up @@ -100,6 +213,13 @@ def get_optimizer(config, learning_rate_schedule, model=None):
else:
raise ValueError(f"{config.opt_type=} is not a supported.")

if getattr(config, "skip_step_on_spikes", False):
base_opt = skip_step_on_spikes(
base_opt,
interval=config.skip_step_interval,
scaling_factor=config.skip_step_scaling_factor,
)

# If a whitelist of trainable parameters is provided, freeze everything else.
# When trainable_parameters_mask is empty, freeze_mask_fn is None and all parameters are trained.
trainable_patterns = getattr(config, "trainable_parameters_mask", None)
Expand Down
16 changes: 15 additions & 1 deletion src/maxtext/trainers/pre_train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from absl import app

import numpy as np
import optax

import pathwaysutils # pylint: disable=unused-import

Expand Down Expand Up @@ -391,7 +392,20 @@ def move(path, value):
jax.tree_util.tree_map_with_path(move, state_mesh_shardings.params),
)
)
new_state = state.apply_gradients(grads=grads)

if getattr(config, "skip_step_on_spikes", False):
grad_norm = max_utils.l2norm_pytree(grads)
# TrainState.apply_gradients doesn't pass **kwargs to tx.update, so we unpack it manually.
updates, new_opt_state = state.tx.update(grads, state.opt_state, state.params, loss=loss, grad_norm=grad_norm)
new_params = optax.apply_updates(state.params, updates)

new_state = state.replace(
step=state.step + 1,
params=new_params,
opt_state=new_opt_state,
)
else:
new_state = state.apply_gradients(grads=grads)

# Apply updates for Auxiliary-Loss-Free load balancing for DeepSeek family
if config.routed_bias and config.routed_bias_update_rate > 0.0 and moe_bias_updates is not None:
Expand Down
51 changes: 51 additions & 0 deletions tests/unit/optimizers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
import unittest
from unittest.mock import patch
import jax
import optax
import jax.numpy as jnp

import pytest
from absl.testing import parameterized
Expand Down Expand Up @@ -428,5 +430,54 @@ def learning_rate_schedule(step):
self.assertFalse(jax.numpy.all(updates["layer1"]["kernel"] == 0))


class SkipStepOnSpikesTest(parameterized.TestCase):
"""Tests for the skip_step_on_spikes optimizer wrapper."""

def _run_spike_test(self, spike_kwargs):
inner_opt = optax.sgd(0.1)
opt = optimizers.skip_step_on_spikes(inner_opt, interval=4, scaling_factor=1.0)

params = {"x": jnp.array([1.0])}
opt_state = opt.init(params)

# Base kwargs for warmup
base_kwargs = {k: jnp.array(1.0) for k in spike_kwargs.keys()}

# Step 0: count = 0 < 2, will not skip (count should be >= interval / 2)
updates, opt_state = opt.update({"x": jnp.array([1.0])}, opt_state, params, **base_kwargs)
self.assertFalse(jnp.all(updates["x"] == 0.0))

# Step 1: count = 1 < 2, will not skip. mean=1.0, std=0.0 (count should be >= interval / 2)
updates, opt_state = opt.update({"x": jnp.array([1.0])}, opt_state, params, **base_kwargs)
self.assertFalse(jnp.all(updates["x"] == 0.0))

# Step 2: count = 2. Spike!
spike_kwargs_jnp = {k: jnp.array(v) for k, v in spike_kwargs.items()}
updates, opt_state = opt.update({"x": jnp.array([1.0])}, opt_state, params, **spike_kwargs_jnp)
self.assertTrue(jnp.all(updates["x"] == 0.0))

def test_skip_step_on_loss_spike(self):
self._run_spike_test({"loss": 100.0})

def test_skip_step_on_grad_norm_spike(self):
self._run_spike_test({"loss": 1.0, "grad_norm": 100.0})

def test_skip_step_on_both_spike(self):
self._run_spike_test({"loss": 100.0, "grad_norm": 100.0})

def test_no_skip_without_kwargs(self):
inner_opt = optax.sgd(0.1)
opt = optimizers.skip_step_on_spikes(inner_opt, interval=4, scaling_factor=1.0)

params = {"x": jnp.array([1.0])}
opt_state = opt.init(params)

# Missing kwargs should act normally
updates, opt_state = opt.update({"x": jnp.array([1.0])}, opt_state, params)
self.assertFalse(jnp.all(updates["x"] == 0.0))
# Count shouldn't have incremented
self.assertEqual(opt_state["count"], 0)


if __name__ == "__main__":
unittest.main()
Loading