Skip to content
1 change: 1 addition & 0 deletions src/maxtext/common/common_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ class DecoderBlockType(enum.Enum):
SIMPLE_MLP = "simple_mlp"
LLAMA4 = "llama4"
OLMO3 = "olmo3"
DEEPSEEK_V4 = "deepseek_v4"

LLAMA2LTI = "llama2_learn_to_init"

Expand Down
11 changes: 11 additions & 0 deletions src/maxtext/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,16 @@ qk_clip_threshold: 100.0 # Threshold for clipping (tau in the paper)
fused_qkv: False
fused_mlp: False

# DeepSeek-V4 Compressed Attention parameters
compress_rope_theta: 160000.0
compress_ratios: []
index_head_dim: 128
index_n_heads: 64
index_topk: 512
o_groups: 8
o_lora_rank: 1024
sliding_window: 128

record_internal_nn_metrics: 0

# Output directory
Expand Down Expand Up @@ -1216,6 +1226,7 @@ force_q_layout: false
mhc_expansion_rate: 1
# The number of iterations for the Sinkhorn-Knopp algorithm.
sinkhorn_iterations: 20
hc_eps: 1.0e-6

################################## DeepSeek Engram ##################################
# Indices of transformer layers where Engram are integrated; leave empty [] to disable.
Expand Down
68 changes: 68 additions & 0 deletions src/maxtext/configs/models/deepseek_v4-flash.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# Copyright 2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Default model configs for DeepSeek-V4-Flash (43 Layers)

base_emb_dim: 4096
base_num_query_heads: 64
base_num_kv_heads: 64
base_mlp_dim: 2048
base_moe_mlp_dim: 2048
base_num_decoder_layers: 43
first_num_dense_layers: 2
mlp_activations: ["silu","linear"]
vocab_size: 129280
enable_dropout: False
logits_via_embedding: False
normalization_layer_epsilon: 1.0e-6
num_experts: 256
num_experts_per_tok: 6
shared_experts: 1
routed_scaling_factor: 1.5
routed_score_func: "sqrtsoftplus"
routed_bias: True
decoder_block: "deepseek_v4"
pure_nnx_decoder: True
enable_nnx: True

# Manifold-Constrained Hyper-Connection configurations
mhc_expansion_rate: 4
sinkhorn_iterations: 20
compress_rope_theta: 160000.0
index_head_dim: 128
index_n_heads: 64
index_topk: 512
o_groups: 8
o_lora_rank: 1024
sliding_window: 128
num_hash_layers: 3
mlp_activations_limit: 10.0
compress_ratios: [0, 0, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 0]

# Compressed Sparse Attention
attention_type: "compressed_sparse_attention"
q_lora_rank: 1024
kv_lora_rank: 512
qk_nope_head_dim: 128
qk_rope_head_dim: 64
v_head_dim: 128
mscale: 1.0

# RoPE
rope_type: "yarn"
rope_max_timescale: 10_000
max_position_embeddings: 1048576
original_max_position_embeddings: 65536
rope_factor: 16
beta_fast: 32
68 changes: 68 additions & 0 deletions src/maxtext/configs/models/deepseek_v4-tiny.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# Copyright 2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Tiny version of DeepSeek-V4 (4 Layers) for local sharding and compilation checks.

base_emb_dim: 128
base_num_query_heads: 16
base_num_kv_heads: 16
base_mlp_dim: 128
base_moe_mlp_dim: 128
base_num_decoder_layers: 6
first_num_dense_layers: 1
mlp_activations: ["silu","linear"]
vocab_size: 129280
enable_dropout: False
logits_via_embedding: False
normalization_layer_epsilon: 1.0e-6
num_experts: 8
num_experts_per_tok: 4
shared_experts: 1
routed_scaling_factor: 1.5
routed_score_func: "sigmoid"
routed_bias: True
decoder_block: "deepseek_v4"
pure_nnx_decoder: True
enable_nnx: True

# Manifold-Constrained Hyper-Connection configurations
mhc_expansion_rate: 4
sinkhorn_iterations: 20
compress_rope_theta: 160000.0
index_head_dim: 32
index_n_heads: 16
index_topk: 64
o_groups: 2
o_lora_rank: 64
sliding_window: 32
num_hash_layers: 3
mlp_activations_limit: 10.0
compress_ratios: [0, 4, 128, 4, 128, 0]

# Compressed Attention
attention_type: "global"
q_lora_rank: 64
kv_lora_rank: 32
qk_nope_head_dim: 32
qk_rope_head_dim: 16
v_head_dim: 128
mscale: 1.0

# RoPE
rope_type: "yarn"
rope_max_timescale: 10_000
max_position_embeddings: 163840
original_max_position_embeddings: 4096
rope_factor: 40
beta_fast: 32
19 changes: 19 additions & 0 deletions src/maxtext/configs/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,7 @@ class ProfilerType(str, Enum):
"deepseek3-test",
"deepseek3-tiny",
"deepseek3.2-671b",
"deepseek_v4-tiny",
"deepseek-custom",
"kimi-k2-1t",
"gemma-7b",
Expand Down Expand Up @@ -605,6 +606,22 @@ class AttentionIndexer(BaseModel):
indexer_loss_scaling_factor: float = Field(0.0, description="Multiplier for the indexer KL divergence loss.")


class DeepSeekV4AttentionConfig(BaseModel):
"""Configuration specific to DeepSeek-V4 stateless compressed attention layers."""

compress_rope_theta: float = Field(160000.0, description="Theta base frequency for long-range compressor layers.")
compress_ratios: list[int] = Field(
default_factory=list,
description="Layer-by-layer compressor rates (0: standard, 4: CSA, 128: HCA).",
)
index_head_dim: int = Field(128, description="Head dim for indexer query and key.")
index_n_heads: int = Field(64, description="Number of query heads in indexer.")
index_topk: int = Field(512, description="Number of tokens selected by indexer.")
o_groups: int = Field(8, description="Number of group partitions for grouped linear output projection.")
o_lora_rank: int = Field(1024, description="Low-rank output dimension prior to grouped mix projection.")
sliding_window: int = Field(128, description="Sliding window size for attention.")


class Llama4Attention(BaseModel):
"""Configuration specific to Llama4-style models."""

Expand Down Expand Up @@ -1315,6 +1332,7 @@ class ManifoldConstrainedHyperConnections(BaseModel):

mhc_expansion_rate: PositiveInt = Field(1, description="The number of parallel streams in Hyper Connection.")
sinkhorn_iterations: PositiveInt = Field(20, description="The number of iterations for the Sinkhorn-Knopp algorithm.")
hc_eps: float = Field(1e-6, description="The epsilon fallback value for numerical stability in mHC.")


class DilocoParams(BaseModel):
Expand Down Expand Up @@ -2159,6 +2177,7 @@ class MaxTextConfig(
MlaAttention,
MoBa,
AttentionIndexer,
DeepSeekV4AttentionConfig,
Llama4Attention,
SplashAttention,
PagedAttention,
Expand Down
Loading