[Fix] Muon optimizer per-expert orthogonalization for MoE models#1582
[Fix] Muon optimizer per-expert orthogonalization for MoE models#1582CyCle1024 wants to merge 4 commits intoInternLM:mainfrom
Conversation
Fix Muon optimizer to apply orthogonalization per expert matrix instead of to the concatenated large matrix for MoE models.
Add guidelines for git commit messages and PR descriptions.
|
@claude plz edit title and PR message. |
xtuner/v1/config/optim.py
Outdated
| num_adamw = 0 | ||
|
|
||
| # Get MoE config if available | ||
| num_experts = getattr(model, "n_routed_experts", 1) or 1 |
There was a problem hiding this comment.
I believe it should be model.config? n_routed_experts is no attribute of model per se.
There was a problem hiding this comment.
| num_experts = getattr(model, "n_routed_experts", 1) or 1 | |
| num_experts = getattr(model.config, "n_routed_experts", 1) or 1 |
xtuner/v1/config/optim.py
Outdated
| if is_moe_model: | ||
| logger.info( | ||
| f"Muon params: {(num_muon + num_muon_moe) // 1e6}M " | ||
| f"(regular: {num_muon // 1e6}M, MoE expert: {num_muon_moe // 1e6}M), " |
There was a problem hiding this comment.
int // 1e6 might introduce precision loss and hence confusion here. For example,
9e5 // 1e6would result in0.0.5.99e6 // 1e6would be5.0.
f"{xxx / 1e6:.2f}" might be better off here.
There was a problem hiding this comment.
| f"(regular: {num_muon // 1e6}M, MoE expert: {num_muon_moe // 1e6}M), " | |
| f"(regular: {num_muon // 1e6:.2f}M, MoE expert: {num_muon_moe // 1e6:.2f}M), " |
xtuner/v1/optim/muon.py
Outdated
| # For regular case (num_experts=1), this adds a batch dimension of size 1 | ||
| M = X.size(-2) // num_experts | ||
| N = X.size(-1) | ||
| X = X.view(num_experts, M, N) |
There was a problem hiding this comment.
X.view(num_experts, -1, N) would suffice?
There was a problem hiding this comment.
| X = X.view(num_experts, M, N) | |
| X = X.view(num_experts, -1, N) |
xtuner/v1/optim/muon.py
Outdated
|
|
||
| # Transpose if needed (when rows > cols) for numerical stability in NS iteration | ||
| # This ensures X @ X.mT produces a smaller square matrix | ||
| need_transpose = G.size(-2) > G.size(-1) |
There was a problem hiding this comment.
Now that X has been reshaped, need_transpose should be conditioned on shape of X instead of that of G?
There was a problem hiding this comment.
| need_transpose = G.size(-2) > G.size(-1) | |
| need_transpose = X.size(-2) > X.size(-1) |
| X = X.view(num_experts, M, N) | ||
|
|
||
| # Transpose if rows > cols for numerical stability | ||
| need_transpose = G.size(-2) > G.size(-1) |
There was a problem hiding this comment.
The same concern of non-triton NS impl applies here.
There was a problem hiding this comment.
| need_transpose = G.size(-2) > G.size(-1) | |
| need_transpose = X.size(-2) > X.size(-1) |
xtuner/v1/optim/muon.py
Outdated
| return newton_schulz_func(X, epsilon=epsilon, num_experts=num_experts).reshape(original_shape) | ||
|
|
||
|
|
||
| def adjust_lr_rms_norm(lr, param_shape): |
There was a problem hiding this comment.
Since we split up the expert weights, the adjust_lr functions should also use per-expert shape instead of tensor shape. Current implementation would result in sqrt(max(2 * n_routed_experts * expert_dim, model_dim)) times base lr for fused_w1w3, which is way larger than sqrt(max(expert_dim, model_dim))
There was a problem hiding this comment.
def adjust_lr_rms_norm(lr, param_shape, num_experts=1):
# Adjust learning rate for constant element-wise RMS norm
# https://arxiv.org/abs/2502.16982
A = param_shape.size(-2) // num_experts
B = param_shape.size(-1)
adjusted_ratio = 0.2 * math.sqrt(max(A, B))
adjusted_lr = lr * adjusted_ratio
return adjusted_lr```- Use model.config instead of model for n_routed_experts attribute - Fix logging precision: use f-string formatting (:.2f) instead of // 1e6 - Simplify X.view(num_experts, -1, N) instead of computing M explicitly - Fix need_transpose to use reshaped X shape instead of original G shape - Fix adjust_lr functions to use per-expert shape (divide by num_experts)
|
Claude: Updated the branch with nil0x9's review comments. Changes made (commit
|
|
@claude review |
| num_adamw = 0 | ||
|
|
||
| # Get MoE config if available | ||
| num_experts = getattr(model.config, "n_routed_experts", 1) or 1 |
There was a problem hiding this comment.
Claude: Warning: model.config will raise AttributeError if the model object doesn't have a config attribute (e.g., non-HuggingFace models, or models wrapped in custom classes). Consider a safer access pattern:
| num_experts = getattr(model.config, "n_routed_experts", 1) or 1 | |
| num_experts = getattr(getattr(model, "config", None), "n_routed_experts", 1) or 1 |
This way if model lacks a config attribute, it gracefully falls back to 1 instead of crashing.
|
|
||
| for name, p in model.named_parameters(): | ||
| n = p.numel() | ||
| num_total += n | ||
| if name in trainable_names: | ||
| num_total_requires_grad += n | ||
| is_muon_tensor = p.ndim >= 2 and "embed_tokens" not in name and "lm_head" not in name | ||
| if is_muon_tensor: | ||
| num_muon += n | ||
| # Check if this is an MoE expert parameter | ||
| if is_moe_model and any(pattern in name for pattern in all_expert_patterns): | ||
| num_muon_moe += n | ||
| else: | ||
| num_muon += n | ||
| else: | ||
| num_adamw += n | ||
| else: |
There was a problem hiding this comment.
Claude: Nit: The logic for classifying parameters into muon/adamw/moe groups iterates over model.named_parameters() twice — once here for counting, and again below (lines 105-117) for building the lists. Consider combining both loops to avoid the redundant iteration over all parameters.
| # Copyright (c) OpenMMLab. All rights reserved. | ||
| """Test Muon optimizer Newton-Schulz functions with/without torch.compile. | ||
|
|
||
| Test shapes are based on Qwen3-30B-A3B model config: | ||
| - hidden_size: 2048 | ||
| - num_experts: 128 | ||
| - moe_intermediate_size: 768 | ||
| - intermediate_size: 6144 (for shared expert) | ||
|
|
||
| MoE expert weight shapes: | ||
| - w1/w3: (num_experts * moe_intermediate_size, hidden_size) = (98304, 2048) | ||
| per expert: (768, 2048) | ||
| - w2: (hidden_size, num_experts * moe_intermediate_size) = (2048, 98304) | ||
| per expert: (2048, 768) | ||
|
|
||
| For testing, we use scaled-down versions to keep tests fast while maintaining | ||
| representative shapes. | ||
|
|
||
| ================================================================================ | ||
| IMPORTANT: DTensor Compatibility Note | ||
| ================================================================================ | ||
|
|
||
| The zeropower_via_newtonschulz5 function supports DTensor input, but with a | ||
| known limitation when M > N (e.g., w2 weights where hidden_size > moe_intermediate_size). | ||
|
|
||
| Root Cause Analysis (verified by /tmp/test_dtensor_root_cause_detailed.py): | ||
| --------------------------------------------------------------------------- | ||
| When M > N, the Newton-Schulz algorithm transposes the input matrix: | ||
| X = G.view(1, M, N).mT # becomes (1, N, M) | ||
|
|
||
| For a DTensor sharded on dim 0 (M dimension): | ||
| 1. After view(1, M, N): placements become Shard(dim=1) | ||
| 2. After mT: placements become Shard(dim=2) # the M dimension moves to dim 2 | ||
| 3. X @ X.mT produces Partial(sum) DTensor # contraction dim is sharded | ||
| 4. Partial values are not correctly reduced in subsequent operations | ||
| 5. Error accumulates across 5 Newton-Schulz iterations: | ||
| Iter 1: X max ~0.016 | ||
| Iter 2: X max ~0.060 | ||
| Iter 3: X max ~0.099 | ||
| Iter 4: X max ~0.29 | ||
| Iter 5: X max ~47.5 (EXPLOSION!) | ||
| 6. Final result is completely wrong (e.g., 0.1 -> 47.5) | ||
|
|
||
| Verification Results: | ||
| - M < N (w1/w3): ✓ PASS - A @ A.mT produces Shard(dim=1), results match exactly | ||
| - M > N (w2): ✗ FAIL - A @ A.mT produces Partial(sum), results explode | ||
| - M = N (square): ✓ PASS - A @ A.mT produces Shard(dim=1), results match exactly | ||
|
|
||
| Workaround: | ||
| For DTensor with M > N (w2 weights), convert to local tensor: | ||
| result = zeropower_via_newtonschulz5(dtensor.to_local(), num_experts=1) | ||
|
|
||
| Note: | ||
| This is NOT a torch.compile issue. The same problem occurs with or without | ||
| torch.compile. It's a fundamental limitation of DTensor's Partial placement | ||
| handling in complex matrix operation chains. | ||
|
|
||
| newton_schulz_triton: | ||
| Does not support DTensor at all due to direct Triton kernel usage. | ||
| Must use .to_local() to convert before calling. |
There was a problem hiding this comment.
Claude: Warning: The module docstring contains ~40 lines of debug analysis (DTensor root cause, iteration-by-iteration error growth, references to /tmp/test_dtensor_root_cause_detailed.py). This reads like investigation notes rather than test documentation.
Consider trimming this to a brief summary of what's tested and why, with a link to an issue or doc if the DTensor limitation needs to be tracked. The reference to /tmp/test_dtensor_root_cause_detailed.py is particularly problematic since that file doesn't exist in the repo.
| G, epsilon=self.epsilon, num_experts=num_experts | ||
| ) | ||
|
|
||
| max_diff = (result1 - result2).abs().max().item() |
There was a problem hiding this comment.
Claude: Nit: The tolerance was relaxed from 2e-2 to 3e-2 without explanation. The larger test matrix sizes likely explain this, but a brief comment on why would help (e.g., "Larger matrices amplify numerical differences between PyTorch and Triton kernels").
| @@ -0,0 +1,59 @@ | |||
| # AGENTS.md - Project Guidelines for AI Assistants | |||
There was a problem hiding this comment.
Claude: Warning: Per project guidelines ("One logical change per PR. Do not mix bug fixes with features or refactors"), this documentation file should be in a separate PR. Also, this repo already has .claude/CLAUDE.md with project conventions — having a second set of guidelines in AGENTS.md at the repo root creates a risk of divergence. Consider whether this is needed at all, or if it should extend the existing CLAUDE.md.
|
Claude: ## Summary Fixes Muon optimizer to apply Newton-Schulz orthogonalization per-expert instead of on the concatenated matrix for MoE models. The core approach — reshaping to IssuesWarning
Nit
VerdictCOMMENT — The core algorithmic fix is correct and the |
Summary
Fix the Muon optimizer to apply Newton-Schulz orthogonalization per expert matrix instead of on the concatenated large matrix for MoE models. Previously, all expert weights were concatenated into a single matrix before orthogonalization, which is mathematically incorrect — each expert should be orthogonalized independently with its own spectral norm.
Changes
num_expertsparameter tozeropower_via_newtonschulz5andnewton_schulz_tritonto support per-expert orthogonalization via batch matrix operations.optim.pyinto regular and MoE expert params, passingnum_expertsthrough the optimizer pipeline.torch.compileand between PyTorch/Triton implementations.Root Cause
The Muon optimizer treated MoE expert weight matrices as a single concatenated matrix during Newton-Schulz orthogonalization. This meant the spectral norm was computed over all experts jointly, and the iterative orthogonalization mixed information across experts — producing incorrect gradient updates for MoE models.
Fix
Reshape the concatenated expert matrix from
(num_experts * M, N)to(num_experts, M, N)and use batch matrix multiplication to orthogonalize each expert independently. For non-MoE models (num_experts=1), this is mathematically equivalent to the original code path.Test Plan
tests/optim/test_muon_compile.pywith tests for:zeropower_via_newtonschulz5with and withouttorch.compilefor both regular and MoE matrices.newton_schulz_tritonbasic correctness (shape, no NaN/Inf).