Skip to content

Split logical names in moe module#3473

Open
NuojCheng wants to merge 1 commit intomainfrom
chengnuojin-separate-moe
Open

Split logical names in moe module#3473
NuojCheng wants to merge 1 commit intomainfrom
chengnuojin-separate-moe

Conversation

@NuojCheng
Copy link
Collaborator

@NuojCheng NuojCheng commented Mar 20, 2026

Description

Split following logical names:

  • embed_no_exp into: embed_no_exp, embed_moe
  • activation_embed into: activation_embed, activation_embed_moe
  • activation_norm_length into: activation_norm_length, activation_norm_length_moe
  • activation_length_no_exp into: activation_length_no_exp, activation_length_no_exp_moe
  • activation_batch into activation_batch, activation_batch_moe
  • activation_batch_no_exp into activation_batch_no_exp, activation_batch_no_exp_moe

Tests

  1. Losses and performance match using model_name=deepseek3-test on v5p-8 VM.
  1. V-llm test
    Vllm test: https://diff.googleplex.com/#key=q4eBkLazREaW
    Script: https://paste.googleplex.com/6329507825975296

  2. 2d-fsdp test:
    2d-fsdp: https://diff.googleplex.com/#key=sZNNYqZFXNEu

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@NuojCheng NuojCheng force-pushed the chengnuojin-separate-moe branch 2 times, most recently from e047c31 to 1259ecc Compare March 20, 2026 20:15
@codecov
Copy link

codecov bot commented Mar 20, 2026

Codecov Report

❌ Patch coverage is 73.91304% with 6 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/maxtext/layers/moe.py 73.91% 6 Missing ⚠️

📢 Thoughts on this report? Let us know!

@NuojCheng NuojCheng force-pushed the chengnuojin-separate-moe branch from 4793567 to 64cd454 Compare March 23, 2026 18:57
@NuojCheng NuojCheng force-pushed the chengnuojin-separate-moe branch from 64cd454 to 935ffb4 Compare March 23, 2026 20:20
elif self.config.use_2d_fsdp_sharding:
self.wi_kernel_axes = ("embed_no_exp", "mlp", None)
self.wo_kernel_axes = ("embed_no_exp", "mlp", None)
self.wi_kernel_axes = ("embed_no_exp_moe", "mlp", None)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we don't need to update for use_2d_fsdp_sharding? embed_no_exp_moe was not defined in 2dfsdp yml.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"embed_no_exp" shares same logical rule contents with "embed_no_exp_moe", see

['embed_no_exp', ['fsdp', 'fsdp_transpose', 'sequence', 'tensor_transpose', 'context']],
['embed_no_exp', ['fsdp', 'sequence', 'tensor_transpose', 'context']],
['embed_no_exp', ['fsdp', 'fsdp_transpose', 'sequence', 'context']],
['embed_no_exp', ['fsdp', 'sequence', 'context']],
['embed_moe', ['fsdp', 'fsdp_transpose', 'sequence', 'tensor_transpose', 'context', 'expert']],
['embed_moe', ['fsdp', 'sequence', 'tensor_transpose', 'context' , 'expert']],
['embed_moe', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert']],
['embed_moe', ['fsdp', 'sequence', 'context', 'expert']],
['embed_no_exp_moe', ['fsdp', 'fsdp_transpose', 'sequence', 'tensor_transpose', 'context']],
['embed_no_exp_moe', ['fsdp', 'sequence', 'tensor_transpose', 'context']],
['embed_no_exp_moe', ['fsdp', 'fsdp_transpose', 'sequence', 'context']],
['embed_no_exp_moe', ['fsdp', 'sequence', 'context']],
. It won't affect performance.

Copy link
Collaborator

@suexu1025 suexu1025 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

need some tests before change for all cases.

@NuojCheng
Copy link
Collaborator Author

need some tests before change for all cases.

I added tests in the PR description. Would you like to elaborate what tests you want to see?

@NuojCheng NuojCheng force-pushed the chengnuojin-separate-moe branch from 935ffb4 to 79c778a Compare March 23, 2026 21:41
@NuojCheng
Copy link
Collaborator Author

need some tests before change for all cases.

I added tests in the PR description. Would you like to elaborate what tests you want to see?

I added a 2d-fsdp test in PR description.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants