[DeepSeek v3] Add grad mask and update MLA init#3864
Conversation
Codecov Report✅ All modified and coverable lines are covered by tests. 📢 Thoughts on this report? Let us know! |
|
🤖 Hi @gagika, I've received your request, and I'm working on it now! You can track my progress in the logs for more details. |
ab75f0c to
766dc25
Compare
|
🤖 I'm sorry @gagika, but I was unable to process your request. Please see the logs for more details. |
|
🤖 Hi @gagika, I've received your request, and I'm working on it now! You can track my progress in the logs for more details. |
There was a problem hiding this comment.
This Pull Request introduces two stability-focused improvements for DeepSeek-V3 training: a per-token gradient mask and a constant-standard-deviation initialization for MLA projections. These changes are well-integrated into the existing configuration and layer structure, with comprehensive unit tests provided for the new gradient masking utility.
🔍 General Feedback
- Stability: The opt-in gradient mask is a robust defensive mechanism against gradient overflows in bf16, particularly useful for large-scale training.
- Precision: Using
float32for RMS calculation in the gradient mask is a good practice to maintain precision. - Initialization: The specialized initialization for MLA projections correctly follows the DeepSeek-V3 architecture's requirements.
- Testing: The new unit tests cover edge cases and dtype preservation effectively.
766dc25 to
24fb91c
Compare
|
🤖 Hi @gagika, I've received your request, and I'm working on it now! You can track my progress in the logs for more details. |
|
🤖 I'm sorry @gagika, but I was unable to process your request. Please see the logs for more details. |
Description
When training DeepSeek-V3 671B with
load_balance_loss_weight > 0deterministically NaNs on step 1 (per-layer MLA backward gradient overflows bf16). Adds two opt-in fixes that default to no-op:grad_mask_threshold): zeros tokens whose feature-axis backward-RMS exceeds threshold.mla_init_std): N(0, std) with output proj scaled by 1/sqrt(2*num_decoder_layers). No effect when loading a checkpoint.DeepSeek-V3 model config defaults these to 0.001 and 100.0.
Tests
tests/unit/grad_mask_utils_test.py.Checklist
Before submitting this PR, please make sure (put X in square brackets):
gemini-reviewlabel.