Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions src/maxtext/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -476,14 +476,13 @@ logical_axis_rules: [
['heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']],
['q_heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']],
['kv_heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']],
['embed', ['fsdp', 'fsdp_transpose', 'sequence', 'tensor_transpose', 'context', 'expert']],
['embed', ['fsdp', 'sequence', 'tensor_transpose', 'context' , 'expert']],
['embed', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert']],
['embed', ['fsdp', 'sequence', 'context', 'expert']],
['embed', ['expert']], # Instead of using pure FSDP (both FSDP and EP act like FSDP during attn), we replace the FSDP by DP to shard less, otherwise we would be sharding too much (e.g. 512+ ways, and have small shard shape). We use EP since EP is 2D in the target config
['embed_no_exp', ['fsdp', 'fsdp_transpose', 'sequence', 'tensor_transpose', 'context']],
['embed_no_exp', ['fsdp', 'sequence', 'tensor_transpose', 'context']],
['embed_no_exp', ['fsdp', 'fsdp_transpose', 'sequence', 'context']],
['embed_no_exp', ['fsdp', 'sequence', 'context']],
# For full solution should rename embed_no_exp to embed_moe
# May need remove_fsdp functionality tho
['embed_tensor_transpose', ['tensor_transpose']],
['q_lora', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'tensor_transpose', 'expert']],
['q_lora', ['fsdp', 'sequence', 'context', 'tensor_transpose', 'expert']],
Expand Down Expand Up @@ -522,6 +521,7 @@ logical_axis_rules: [
# Axes used for DCN must be earlier in this list than ICI, see (b/339009148) for details
data_sharding: [['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'context', 'context_autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive']]
input_data_sharding_logical_axes: ['activation_embed_and_logits_batch', 'activation_norm_length']
embed_shard: "expert_only" #Choose to shard embed (embed_attn) on both fsdp and expert, expert only, or fsdp only ("both", "expert_only", "fsdp_only")

# sharding tolerance: float between 0.0 and 1.0 representing the allowed percentage of non-sharded parameters.
sharding_tolerance: 0.02
Expand Down
16 changes: 16 additions & 0 deletions src/maxtext/configs/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -817,6 +817,7 @@ class LayoutAndSharding(BaseModel):
shard_optimizer_over_data: bool = Field(False, description="Enable ZeRO-1 optimizer sharding over the data axis.")
internal_compile: bool = Field(False, description="Use internal_compile to bypass open-source topology mappings.")
internal_compile_num_devices: int = Field(-1, description="Number of devices when using internal_compile.")
embed_shard: str = Field("expert_only", description="Which axes to shard embed (embed_attention) on")


class DcnParallelism(BaseModel):
Expand Down Expand Up @@ -2278,6 +2279,21 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
if self.expert_shard_attention_option == "context":
cp_size *= self.ici_expert_parallelism * self.dcn_expert_parallelism
self.context_parallel_size = cp_size

# Modify embed - this is a VERY hacky (non-mergeable) implementation, to be replaced with some cool new way to share logical axis rules soon
for rule in self.logical_axis_rules:
if rule and rule[0] == "embed":
if self.embed_shard == "expert_only":
rule[1] = ["expert"]
elif self.embed_shard == "fsdp_only":
rule[1] = ["fsdp"]
elif self.embed_shard == "both":
rule[1] = ["fsdp", "expert"]
else:
# throw value error
raise ValueError(f"Invalid embed_shard: {self.embed_shard}. Must be 'expert_only', 'fsdp_only', or 'both'.")
break

if self.pipeline_parallel_layers == -1:
if self.decoder_block == DecoderBlockType.DEEPSEEK:
moe_layers = self.num_decoder_layers - self.first_num_dense_layers
Expand Down
1 change: 1 addition & 0 deletions src/maxtext/layers/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1073,6 +1073,7 @@ def gmm(
w1_pspec = self._logical_to_mesh_axes(("embed_tensor_transpose", "mlp_no_fsdp", None))
wo_pspec = self._logical_to_mesh_axes(("embed_tensor_transpose", "mlp_no_fsdp", None))
else:
# embed_tensor_transpose here is crazy but doesn't have FSDP so we AG over FSDP....
w0_pspec = self._logical_to_mesh_axes(("exp", "embed_tensor_transpose", "mlp_no_fsdp"))
w1_pspec = self._logical_to_mesh_axes(("exp", "embed_tensor_transpose", "mlp_no_fsdp"))
wo_pspec = self._logical_to_mesh_axes(("exp", "mlp_no_fsdp", "embed_tensor_transpose"))
Expand Down
Loading