Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions src/maxtext/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -396,6 +396,10 @@ qk_nope_head_dim: 128
qk_rope_head_dim: 64
v_head_dim: 128

# Constant-std init for MLA proj; output proj scaled by 1/sqrt(2*num_decoder_layers).
# 0 keeps fan_in scaling.
mla_init_std: 0.0

# QK-Clip (Muon Clip) Configuration
use_qk_clip: False # Enable QK-Clip (supported in MLA with DotProduct or Tokamax Splash)
qk_clip_threshold: 100.0 # Threshold for clipping (tau in the paper)
Expand Down Expand Up @@ -847,6 +851,11 @@ diloco_outer_momentum: 0.9
# You may disable clipping by setting gradient_clipping_threshold to zero.
gradient_clipping_threshold: 1.0

# Per-token gradient mask at the decoder-layer boundary (DeepSeek-V3 only).
# Tokens whose feature-axis RMS exceeds threshold are zeroed in backward;
# healthy tokens pass through unchanged. 0 disables.
grad_mask_threshold: 0.0

# Instead of updating the weights every step, you may effectively use a larger
# batch by accumulating the gradient over a set of steps.
gradient_accumulation_steps: 1
Expand Down
6 changes: 6 additions & 0 deletions src/maxtext/configs/models/deepseek3-671b.yml
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,12 @@ qk_nope_head_dim: 128
qk_rope_head_dim: 64
v_head_dim: 128
mscale: 1.0
# Initialize MLA projections with N(0, std); output proj further scaled
# by 1/sqrt(2*num_decoder_layers). No effect when loading a checkpoint.
mla_init_std: 0.001
# Mask tokens whose backward-gradient RMS exceeds threshold at each
# decoder-layer boundary. Defensive against bf16 overflow; rarely fires.
grad_mask_threshold: 100.0
# RoPE
rope_type: "yarn"
rope_max_timescale: 10_000 # DeepSeek uses "rope_theta": 10000
Expand Down
15 changes: 15 additions & 0 deletions src/maxtext/configs/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -589,6 +589,13 @@ class MlaAttention(BaseModel):
qk_nope_head_dim: NonNegativeInt = Field(128, description="Dimension for non-RoPE part of QK heads in MLA.")
qk_rope_head_dim: NonNegativeInt = Field(64, description="Dimension for RoPE part of QK heads in MLA.")
v_head_dim: NonNegativeInt = Field(128, description="Dimension of V heads in MLA.")
mla_init_std: NonNegativeFloat = Field(
0.0,
description=(
"Constant-std init for MLA projections; output proj scaled by "
"1/sqrt(2*num_decoder_layers). 0 keeps fan_in scaling."
),
)


class AttentionIndexer(BaseModel):
Expand Down Expand Up @@ -1347,6 +1354,14 @@ class Optimizer(BaseModel):
gradient_clipping_threshold: NonNegativeFloat = Field(
1.0, description="The threshold for gradient clipping. 0 disables clipping."
)
grad_mask_threshold: NonNegativeFloat = Field(
0.0,
description=(
"Per-token gradient mask at the decoder-layer boundary "
"(DeepSeek-V3 only). Forward identity; backward zeros tokens "
"whose feature-axis RMS exceeds threshold. 0 disables."
),
)
learning_rate: NonNegativeFloat = Field(3.0e-5, description="The peak learning rate.")
lr_schedule_type: LearningRateScheduleType = Field(
LearningRateScheduleType.COSINE,
Expand Down
11 changes: 10 additions & 1 deletion src/maxtext/layers/attention_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@

from maxtext.layers import nnx_wrappers
from maxtext.layers.attentions import Attention
from maxtext.layers.initializers import nd_dense_init, NdInitializer, variable_to_logically_partitioned
from maxtext.layers.initializers import nd_dense_init, nd_normal_const_std, NdInitializer, variable_to_logically_partitioned
from maxtext.layers.linears import DenseGeneral
from maxtext.layers.normalizations import RMSNorm
from maxtext.layers.quantizations import AqtQuantization as Quant
Expand Down Expand Up @@ -726,6 +726,9 @@ def _init_projections(self, inputs_q_shape: Tuple, inputs_kv_shape: Tuple) -> No
assert self.num_query_heads == self.num_kv_heads, "MLA requires equal number of query and kv heads"
assert not self.config.fused_qkv, "Fused QKV is not supported for MLA"

# Constant-std init for MLA projections; output proj rescaled below.
if self.config.mla_init_std > 0.0:
self.kernel_init = nd_normal_const_std(self.config.mla_init_std)
if self.q_lora_rank == 0:
# Standard Q projection (without LoRA).
self.query = DenseGeneral(
Expand Down Expand Up @@ -823,6 +826,12 @@ def _init_projections(self, inputs_q_shape: Tuple, inputs_kv_shape: Tuple) -> No
mscale = 0.1 * self.mscale * math.log(self.rope_factor) + 1.0
self.softmax_scale = self.softmax_scale * mscale * mscale

# Output-proj residual scaling: std / sqrt(2 * num_decoder_layers).
if self.config.mla_init_std > 0.0:
self.kernel_init = nd_normal_const_std(
self.config.mla_init_std / math.sqrt(2.0 * max(1, self.config.num_decoder_layers))
)

self.out = self.init_out_w(output_dim=inputs_q_shape[-1])

# Setup paged attention op
Expand Down
15 changes: 15 additions & 0 deletions src/maxtext/layers/initializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,21 @@
default_scalar_init = jax.nn.initializers.constant(0.01)


def nd_normal_const_std(std: float):
"""Creates a constant-std normal initializer with the NdInitializer signature.

Returns an initializer that produces N(0, std) regardless of fan_in/fan_out;
useful when a layer needs a fixed-stddev init independent of input shape
(e.g. scaled init for residual output projections).
"""

def init_fn(key, shape, dtype, in_axis, out_axis):
del in_axis, out_axis
return jax.random.normal(key, shape, dtype=dtype) * std

return init_fn


def nd_dense_init(scale, mode, distribution):
"""Creates a variance-scaling initializer with dynamic in/out axes.

Expand Down
3 changes: 3 additions & 0 deletions src/maxtext/models/deepseek.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from maxtext.layers.normalizations import RMSNorm
from maxtext.models import deepseek_batchsplit
from maxtext.models import deepseek_batchsplit_fp8
from maxtext.utils import grad_mask_utils
from maxtext.utils import max_utils
from maxtext.utils.sharding import create_sharding
from maxtext.utils.sharding import maybe_shard_with_logical
Expand Down Expand Up @@ -260,6 +261,8 @@ def post_process(self, layer_output, load_balance_loss, moe_bias_updates, kv_cac
jnp.sum(layer_output == 0) / jnp.size(layer_output),
)

layer_output = grad_mask_utils.maybe_grad_mask(layer_output, self.config)

if self.config.scan_layers:
return layer_output, None
return layer_output, kv_cache
Expand Down
48 changes: 48 additions & 0 deletions src/maxtext/utils/grad_mask_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# Copyright 2023–2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Per-token gradient mask applied at a layer boundary.

Forward is identity; backward zeros tokens whose feature-axis RMS exceeds
the configured threshold. Healthy tokens pass through unchanged. Used at
decoder-layer boundaries to bound per-layer cotangent magnitudes
(see deepseek model usage)."""

import jax
import jax.numpy as jnp


@jax.custom_vjp
def _grad_mask(x: jax.Array, threshold: jax.Array) -> jax.Array:
return x


def _grad_mask_fwd(x: jax.Array, threshold: jax.Array):
return x, threshold


def _grad_mask_bwd(threshold: jax.Array, g: jax.Array):
rms = jnp.sqrt(jnp.mean(jnp.square(g.astype(jnp.float32)), axis=-1, keepdims=True))
mask = rms <= threshold
return (jnp.where(mask, g, jnp.zeros_like(g)), jnp.zeros_like(threshold))


_grad_mask.defvjp(_grad_mask_fwd, _grad_mask_bwd)


def maybe_grad_mask(x: jax.Array, cfg) -> jax.Array:
"""Per-token gradient mask if cfg.grad_mask_threshold > 0; else identity."""
if cfg.grad_mask_threshold > 0.0:
return _grad_mask(x, jnp.asarray(cfg.grad_mask_threshold, jnp.float32))
return x
124 changes: 124 additions & 0 deletions tests/unit/grad_mask_utils_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
# Copyright 2023–2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tests for per-layer per-token gradient mask (grad_mask_utils)."""

import unittest
from collections import namedtuple

import jax
import jax.numpy as jnp
import numpy as np

from maxtext.utils.grad_mask_utils import _grad_mask, maybe_grad_mask


class GradMaskTest(unittest.TestCase):

def setUp(self):
self.rng = jax.random.PRNGKey(0)

def test_forward_is_identity(self):
"""Forward pass must return input unchanged regardless of threshold."""
x = jax.random.normal(self.rng, (2, 8, 16))
for thr in [0.5, 1.0, 100.0]:
y = _grad_mask(x, jnp.float32(thr))
np.testing.assert_array_equal(np.asarray(y), np.asarray(x))

def test_backward_below_threshold_passthrough(self):
"""When per-token RMS <= threshold, backward must return g unchanged."""
x = jnp.ones((2, 4, 8), dtype=jnp.bfloat16)
threshold = jnp.float32(1e6) # huge → never clip

def loss_fn(x):
return jnp.sum(_grad_mask(x, threshold) * 0.5)

g = jax.grad(loss_fn)(x)
expected = jnp.full_like(x, 0.5)
np.testing.assert_allclose(np.asarray(g), np.asarray(expected), atol=1e-3)

def test_backward_outlier_tokens_are_masked(self):
"""Tokens with RMS > threshold get zero gradient; healthy tokens unchanged."""
x = jnp.zeros((2, 3, 8), dtype=jnp.float32)
threshold = jnp.float32(1.0)
upstream = jnp.ones_like(x)
# Make token (0, 0) an outlier (RMS = 10) and token (1, 2) an outlier (RMS = 100).
upstream = upstream.at[0, 0].set(10.0)
upstream = upstream.at[1, 2].set(100.0)

def fn(x):
return _grad_mask(x, threshold)

_, vjp = jax.vjp(fn, x)
(g_masked,) = vjp(upstream)
g_masked = np.asarray(g_masked)
# Outlier tokens zeroed.
np.testing.assert_array_equal(g_masked[0, 0], np.zeros(8, dtype=np.float32))
np.testing.assert_array_equal(g_masked[1, 2], np.zeros(8, dtype=np.float32))
# Healthy tokens (RMS = 1.0 == threshold, passes through).
np.testing.assert_array_equal(g_masked[0, 1], np.ones(8, dtype=np.float32))
np.testing.assert_array_equal(g_masked[1, 0], np.ones(8, dtype=np.float32))

def test_backward_threshold_grad_is_zero(self):
"""Threshold arg must receive a zero gradient (it's not differentiable)."""
x = jnp.ones((2, 4, 8), dtype=jnp.float32)

def fn(x, threshold):
return _grad_mask(x, threshold)

threshold = jnp.float32(1.0)
_, vjp = jax.vjp(fn, x, threshold)
upstream = jnp.ones_like(x)
_, g_threshold = vjp(upstream)
self.assertEqual(float(g_threshold), 0.0)

def test_maybe_grad_mask_threshold_zero_is_noop(self):
"""maybe_grad_mask with threshold=0 returns input unchanged and inserts no boundary."""
Cfg = namedtuple("Cfg", ["grad_mask_threshold"])
cfg = Cfg(grad_mask_threshold=0.0)
x = jax.random.normal(self.rng, (2, 4, 8))
y = maybe_grad_mask(x, cfg)
self.assertIs(y, x) # exact identity, no jnp.array wrapping

def test_maybe_grad_mask_threshold_positive_applies_mask(self):
"""maybe_grad_mask with threshold > 0 zeros tokens whose RMS exceeds threshold."""
Cfg = namedtuple("Cfg", ["grad_mask_threshold"])
cfg = Cfg(grad_mask_threshold=0.5)
x = jnp.zeros((2, 4, 8), dtype=jnp.float32)

def fn(x):
return maybe_grad_mask(x, cfg)

# All tokens have RMS = 10.0 (every element = 10.0); threshold = 0.5 → all masked.
upstream = jnp.full_like(x, 10.0)
_, vjp = jax.vjp(fn, x)
(g,) = vjp(upstream)
np.testing.assert_array_equal(np.asarray(g), np.zeros_like(np.asarray(x)))

def test_dtype_preserved_in_backward(self):
"""Backward must preserve the gradient's dtype (bf16 in, bf16 out)."""
x = jnp.zeros((2, 4, 8), dtype=jnp.bfloat16)
threshold = jnp.float32(0.1)

def fn(x):
return _grad_mask(x, threshold)

upstream = (jax.random.normal(self.rng, x.shape, dtype=jnp.float32) * 10.0).astype(jnp.bfloat16)
_, vjp = jax.vjp(fn, x)
(g,) = vjp(upstream)
self.assertEqual(g.dtype, jnp.bfloat16)


if __name__ == "__main__":
unittest.main()
Loading