Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions src/maxtext/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,16 @@ qk_clip_threshold: 100.0 # Threshold for clipping (tau in the paper)
fused_qkv: False
fused_mlp: False

# DeepSeek-V4 Compressed Attention parameters
compress_rope_theta: 160000.0
compress_ratios: []
index_head_dim: 128
index_n_heads: 64
index_topk: 512
o_groups: 8
o_lora_rank: 1024
sliding_window: 128

record_internal_nn_metrics: 0

# Output directory
Expand Down
17 changes: 17 additions & 0 deletions src/maxtext/configs/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -605,6 +605,22 @@ class AttentionIndexer(BaseModel):
indexer_loss_scaling_factor: float = Field(0.0, description="Multiplier for the indexer KL divergence loss.")


class DeepSeekV4AttentionConfig(BaseModel):
"""Configuration specific to DeepSeek-V4 stateless compressed attention layers."""

compress_rope_theta: float = Field(160000.0, description="Theta base frequency for long-range compressor layers.")
compress_ratios: list[int] = Field(
default_factory=list,
description="Layer-by-layer compressor rates (0: standard, 4: CSA, 128: HCA).",
)
index_head_dim: int = Field(128, description="Head dim for indexer query and key.")
index_n_heads: int = Field(64, description="Number of query heads in indexer.")
index_topk: int = Field(512, description="Number of tokens selected by indexer.")
o_groups: int = Field(8, description="Number of group partitions for grouped linear output projection.")
o_lora_rank: int = Field(1024, description="Low-rank output dimension prior to grouped mix projection.")
sliding_window: int = Field(128, description="Sliding window size for attention.")


class Llama4Attention(BaseModel):
"""Configuration specific to Llama4-style models."""

Expand Down Expand Up @@ -2159,6 +2175,7 @@ class MaxTextConfig(
MlaAttention,
MoBa,
AttentionIndexer,
DeepSeekV4AttentionConfig,
Llama4Attention,
SplashAttention,
PagedAttention,
Expand Down
Loading
Loading