Skip to content

[DeepSeek v3] Add grad mask and update MLA init#3864

Open
gagika wants to merge 1 commit into
mainfrom
agagik-deepseek-conv
Open

[DeepSeek v3] Add grad mask and update MLA init#3864
gagika wants to merge 1 commit into
mainfrom
agagik-deepseek-conv

Conversation

@gagika
Copy link
Copy Markdown
Collaborator

@gagika gagika commented May 10, 2026

Description

When training DeepSeek-V3 671B with load_balance_loss_weight > 0 deterministically NaNs on step 1 (per-layer MLA backward gradient overflows bf16). Adds two opt-in fixes that default to no-op:

  1. Per-token gradient mask (grad_mask_threshold): zeros tokens whose feature-axis backward-RMS exceeds threshold.
  2. Constant-std MLA init (mla_init_std): N(0, std) with output proj scaled by 1/sqrt(2*num_decoder_layers). No effect when loading a checkpoint.

DeepSeek-V3 model config defaults these to 0.001 and 100.0.

Tests

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@codecov
Copy link
Copy Markdown

codecov Bot commented May 10, 2026

Codecov Report

✅ All modified and coverable lines are covered by tests.

📢 Thoughts on this report? Let us know!

@github-actions
Copy link
Copy Markdown

🤖 Hi @gagika, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

@gagika gagika force-pushed the agagik-deepseek-conv branch from ab75f0c to 766dc25 Compare May 10, 2026 22:02
@github-actions
Copy link
Copy Markdown

🤖 I'm sorry @gagika, but I was unable to process your request. Please see the logs for more details.

@github-actions
Copy link
Copy Markdown

🤖 Hi @gagika, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

Copy link
Copy Markdown

@github-actions github-actions Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

## 📋 Review Summary

This Pull Request introduces two stability-focused improvements for DeepSeek-V3 training: a per-token gradient mask and a constant-standard-deviation initialization for MLA projections. These changes are well-integrated into the existing configuration and layer structure, with comprehensive unit tests provided for the new gradient masking utility.

🔍 General Feedback

  • Stability: The opt-in gradient mask is a robust defensive mechanism against gradient overflows in bf16, particularly useful for large-scale training.
  • Precision: Using float32 for RMS calculation in the gradient mask is a good practice to maintain precision.
  • Initialization: The specialized initialization for MLA projections correctly follows the DeepSeek-V3 architecture's requirements.
  • Testing: The new unit tests cover edge cases and dtype preservation effectively.

Comment thread src/maxtext/utils/grad_mask_utils.py Outdated
@github-actions
Copy link
Copy Markdown

🤖 Hi @gagika, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

@github-actions
Copy link
Copy Markdown

🤖 I'm sorry @gagika, but I was unable to process your request. Please see the logs for more details.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant