Skip to content

[Fix] Muon optimizer per-expert orthogonalization for MoE models#1582

Open
CyCle1024 wants to merge 4 commits intoInternLM:mainfrom
CyCle1024:muon_per_expert_optimizer
Open

[Fix] Muon optimizer per-expert orthogonalization for MoE models#1582
CyCle1024 wants to merge 4 commits intoInternLM:mainfrom
CyCle1024:muon_per_expert_optimizer

Conversation

@CyCle1024
Copy link
Collaborator

@CyCle1024 CyCle1024 commented Mar 13, 2026

Summary

Fix the Muon optimizer to apply Newton-Schulz orthogonalization per expert matrix instead of on the concatenated large matrix for MoE models. Previously, all expert weights were concatenated into a single matrix before orthogonalization, which is mathematically incorrect — each expert should be orthogonalized independently with its own spectral norm.

Changes

  • Add num_experts parameter to zeropower_via_newtonschulz5 and newton_schulz_triton to support per-expert orthogonalization via batch matrix operations.
  • Separate Muon parameter groups in optim.py into regular and MoE expert params, passing num_experts through the optimizer pipeline.
  • Add tests verifying Newton-Schulz consistency with/without torch.compile and between PyTorch/Triton implementations.

Root Cause

The Muon optimizer treated MoE expert weight matrices as a single concatenated matrix during Newton-Schulz orthogonalization. This meant the spectral norm was computed over all experts jointly, and the iterative orthogonalization mixed information across experts — producing incorrect gradient updates for MoE models.

Fix

Reshape the concatenated expert matrix from (num_experts * M, N) to (num_experts, M, N) and use batch matrix multiplication to orthogonalize each expert independently. For non-MoE models (num_experts=1), this is mathematically equivalent to the original code path.

Test Plan

  • Added tests/optim/test_muon_compile.py with tests for:
    • zeropower_via_newtonschulz5 with and without torch.compile for both regular and MoE matrices.
    • newton_schulz_triton basic correctness (shape, no NaN/Inf).
    • Cross-function consistency between PyTorch and Triton implementations.
    • Transpose case (rows > cols) handling.

Fix Muon optimizer to apply orthogonalization per expert matrix instead of
to the concatenated large matrix for MoE models.
Add guidelines for git commit messages and PR descriptions.
@CyCle1024
Copy link
Collaborator Author

@claude plz edit title and PR message.

@claude claude bot changed the title Muon per expert optimizer [Fix] Muon optimizer per-expert orthogonalization for MoE models Mar 13, 2026
num_adamw = 0

# Get MoE config if available
num_experts = getattr(model, "n_routed_experts", 1) or 1
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe it should be model.config? n_routed_experts is no attribute of model per se.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
num_experts = getattr(model, "n_routed_experts", 1) or 1
num_experts = getattr(model.config, "n_routed_experts", 1) or 1

if is_moe_model:
logger.info(
f"Muon params: {(num_muon + num_muon_moe) // 1e6}M "
f"(regular: {num_muon // 1e6}M, MoE expert: {num_muon_moe // 1e6}M), "
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

int // 1e6 might introduce precision loss and hence confusion here. For example,

  • 9e5 // 1e6 would result in 0.0.
  • 5.99e6 // 1e6 would be 5.0.

f"{xxx / 1e6:.2f}" might be better off here.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
f"(regular: {num_muon // 1e6}M, MoE expert: {num_muon_moe // 1e6}M), "
f"(regular: {num_muon // 1e6:.2f}M, MoE expert: {num_muon_moe // 1e6:.2f}M), "

# For regular case (num_experts=1), this adds a batch dimension of size 1
M = X.size(-2) // num_experts
N = X.size(-1)
X = X.view(num_experts, M, N)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

X.view(num_experts, -1, N) would suffice?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
X = X.view(num_experts, M, N)
X = X.view(num_experts, -1, N)


# Transpose if needed (when rows > cols) for numerical stability in NS iteration
# This ensures X @ X.mT produces a smaller square matrix
need_transpose = G.size(-2) > G.size(-1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now that X has been reshaped, need_transpose should be conditioned on shape of X instead of that of G?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
need_transpose = G.size(-2) > G.size(-1)
need_transpose = X.size(-2) > X.size(-1)

X = X.view(num_experts, M, N)

# Transpose if rows > cols for numerical stability
need_transpose = G.size(-2) > G.size(-1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The same concern of non-triton NS impl applies here.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
need_transpose = G.size(-2) > G.size(-1)
need_transpose = X.size(-2) > X.size(-1)

return newton_schulz_func(X, epsilon=epsilon, num_experts=num_experts).reshape(original_shape)


def adjust_lr_rms_norm(lr, param_shape):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we split up the expert weights, the adjust_lr functions should also use per-expert shape instead of tensor shape. Current implementation would result in sqrt(max(2 * n_routed_experts * expert_dim, model_dim)) times base lr for fused_w1w3, which is way larger than sqrt(max(expert_dim, model_dim))

Copy link
Collaborator Author

@CyCle1024 CyCle1024 Mar 16, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

def adjust_lr_rms_norm(lr, param_shape, num_experts=1):
    # Adjust learning rate for constant element-wise RMS norm
    # https://arxiv.org/abs/2502.16982
    A = param_shape.size(-2) // num_experts
    B = param_shape.size(-1)
    adjusted_ratio = 0.2 * math.sqrt(max(A, B))
    adjusted_lr = lr * adjusted_ratio
    return adjusted_lr```

@HAOCHENYE
Copy link
Collaborator

@claude please update the branch follow nil0x9's comment

- Use model.config instead of model for n_routed_experts attribute
- Fix logging precision: use f-string formatting (:.2f) instead of // 1e6
- Simplify X.view(num_experts, -1, N) instead of computing M explicitly
- Fix need_transpose to use reshaped X shape instead of original G shape
- Fix adjust_lr functions to use per-expert shape (divide by num_experts)
@claude
Copy link

claude bot commented Mar 16, 2026

Claude: Updated the branch with nil0x9's review comments.

Changes made (commit 72e2a42):

  1. optim.py: Use model.config instead of model for n_routed_experts attribute
  2. optim.py: Fix logging precision — use f"{x / 1e6:.2f}" instead of x // 1e6 to avoid precision loss
  3. muon.py: Simplify X.view(num_experts, -1, N) instead of computing M explicitly
  4. muon.py: Fix need_transpose to use reshaped X shape instead of original G shape
  5. newton_schulz_triton.py: Same X.view simplification and need_transpose fix as muon.py
  6. muon.py: Fix adjust_lr_rms_norm and adjust_lr_spectral_norm to use per-expert shape (divide by num_experts) so MoE expert params get correct learning rate scaling

@HAOCHENYE
Copy link
Collaborator

@claude review

num_adamw = 0

# Get MoE config if available
num_experts = getattr(model.config, "n_routed_experts", 1) or 1
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude: Warning: model.config will raise AttributeError if the model object doesn't have a config attribute (e.g., non-HuggingFace models, or models wrapped in custom classes). Consider a safer access pattern:

Suggested change
num_experts = getattr(model.config, "n_routed_experts", 1) or 1
num_experts = getattr(getattr(model, "config", None), "n_routed_experts", 1) or 1

This way if model lacks a config attribute, it gracefully falls back to 1 instead of crashing.

Comment on lines +86 to 101

for name, p in model.named_parameters():
n = p.numel()
num_total += n
if name in trainable_names:
num_total_requires_grad += n
is_muon_tensor = p.ndim >= 2 and "embed_tokens" not in name and "lm_head" not in name
if is_muon_tensor:
num_muon += n
# Check if this is an MoE expert parameter
if is_moe_model and any(pattern in name for pattern in all_expert_patterns):
num_muon_moe += n
else:
num_muon += n
else:
num_adamw += n
else:
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude: Nit: The logic for classifying parameters into muon/adamw/moe groups iterates over model.named_parameters() twice — once here for counting, and again below (lines 105-117) for building the lists. Consider combining both loops to avoid the redundant iteration over all parameters.

Comment on lines +1 to +60
# Copyright (c) OpenMMLab. All rights reserved.
"""Test Muon optimizer Newton-Schulz functions with/without torch.compile.

Test shapes are based on Qwen3-30B-A3B model config:
- hidden_size: 2048
- num_experts: 128
- moe_intermediate_size: 768
- intermediate_size: 6144 (for shared expert)

MoE expert weight shapes:
- w1/w3: (num_experts * moe_intermediate_size, hidden_size) = (98304, 2048)
per expert: (768, 2048)
- w2: (hidden_size, num_experts * moe_intermediate_size) = (2048, 98304)
per expert: (2048, 768)

For testing, we use scaled-down versions to keep tests fast while maintaining
representative shapes.

================================================================================
IMPORTANT: DTensor Compatibility Note
================================================================================

The zeropower_via_newtonschulz5 function supports DTensor input, but with a
known limitation when M > N (e.g., w2 weights where hidden_size > moe_intermediate_size).

Root Cause Analysis (verified by /tmp/test_dtensor_root_cause_detailed.py):
---------------------------------------------------------------------------
When M > N, the Newton-Schulz algorithm transposes the input matrix:
X = G.view(1, M, N).mT # becomes (1, N, M)

For a DTensor sharded on dim 0 (M dimension):
1. After view(1, M, N): placements become Shard(dim=1)
2. After mT: placements become Shard(dim=2) # the M dimension moves to dim 2
3. X @ X.mT produces Partial(sum) DTensor # contraction dim is sharded
4. Partial values are not correctly reduced in subsequent operations
5. Error accumulates across 5 Newton-Schulz iterations:
Iter 1: X max ~0.016
Iter 2: X max ~0.060
Iter 3: X max ~0.099
Iter 4: X max ~0.29
Iter 5: X max ~47.5 (EXPLOSION!)
6. Final result is completely wrong (e.g., 0.1 -> 47.5)

Verification Results:
- M < N (w1/w3): ✓ PASS - A @ A.mT produces Shard(dim=1), results match exactly
- M > N (w2): ✗ FAIL - A @ A.mT produces Partial(sum), results explode
- M = N (square): ✓ PASS - A @ A.mT produces Shard(dim=1), results match exactly

Workaround:
For DTensor with M > N (w2 weights), convert to local tensor:
result = zeropower_via_newtonschulz5(dtensor.to_local(), num_experts=1)

Note:
This is NOT a torch.compile issue. The same problem occurs with or without
torch.compile. It's a fundamental limitation of DTensor's Partial placement
handling in complex matrix operation chains.

newton_schulz_triton:
Does not support DTensor at all due to direct Triton kernel usage.
Must use .to_local() to convert before calling.
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude: Warning: The module docstring contains ~40 lines of debug analysis (DTensor root cause, iteration-by-iteration error growth, references to /tmp/test_dtensor_root_cause_detailed.py). This reads like investigation notes rather than test documentation.

Consider trimming this to a brief summary of what's tested and why, with a link to an issue or doc if the DTensor limitation needs to be tracked. The reference to /tmp/test_dtensor_root_cause_detailed.py is particularly problematic since that file doesn't exist in the repo.

G, epsilon=self.epsilon, num_experts=num_experts
)

max_diff = (result1 - result2).abs().max().item()
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude: Nit: The tolerance was relaxed from 2e-2 to 3e-2 without explanation. The larger test matrix sizes likely explain this, but a brief comment on why would help (e.g., "Larger matrices amplify numerical differences between PyTorch and Triton kernels").

@@ -0,0 +1,59 @@
# AGENTS.md - Project Guidelines for AI Assistants
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude: Warning: Per project guidelines ("One logical change per PR. Do not mix bug fixes with features or refactors"), this documentation file should be in a separate PR. Also, this repo already has .claude/CLAUDE.md with project conventions — having a second set of guidelines in AGENTS.md at the repo root creates a risk of divergence. Consider whether this is needed at all, or if it should extend the existing CLAUDE.md.

@claude
Copy link

claude bot commented Mar 16, 2026

Claude: ## Summary

Fixes Muon optimizer to apply Newton-Schulz orthogonalization per-expert instead of on the concatenated matrix for MoE models. The core approach — reshaping to (num_experts, M, N) and using batch matrix multiplication — is mathematically sound and well-implemented.

Issues

Warning

  • xtuner/v1/config/optim.py:78: model.config access is unsafe — will AttributeError if model has no config attribute. Use double getattr for safety.
  • tests/optim/test_muon_compile.py:1-60: Module docstring contains ~40 lines of debug investigation notes and references a non-existent /tmp/ file. Should be trimmed to concise test documentation.
  • AGENTS.md: Unrelated docs change mixed into a bug fix PR. Per project guidelines, should be a separate PR. Also risks diverging from existing .claude/CLAUDE.md.

Nit

Verdict

COMMENT — The core algorithmic fix is correct and the fused_w1w3 handling with 2 * num_experts is a nice catch. The need_transpose and adjust_lr fixes from the latest commit are important correctness improvements. Main concerns are the unsafe model.config access and the unrelated AGENTS.md file that should be split out.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants