[JAX] Add an MoE Block (Layer) that compound router, permutation, groupedGEMM and communication#2912
[JAX] Add an MoE Block (Layer) that compound router, permutation, groupedGEMM and communication#2912tdophung wants to merge 8 commits intoNVIDIA:mainfrom
Conversation
Greptile SummaryThis PR introduces
Confidence Score: 4/5Safe to merge for non-grouped-topk configs, but silently produces an incorrect auxiliary training objective for DeepSeek-style (num_groups/group_topk) models that also enable aux_loss_coeff. The aux loss routing map uses a clean standard top-k instead of the actual grouped-topk routing, making tokens_per_expert inconsistent with real routing decisions when num_groups/group_topk are set with aux_loss_coeff > 0. transformer_engine/jax/flax/moe.py — specifically _compute_aux_loss and the test_group_topk_deepseek test which does not exercise the aux loss path. Important Files Changed
Sequence DiagramsequenceDiagram
participant Input as Input [B,S,H]
participant Gate as _gate (einsum)
participant Router as _route_topk (fused_topk)
participant AuxLoss as _compute_aux_loss
participant Perm as _global_permute
participant A2A_Fwd as ragged_all_to_all (fwd)
participant LocalPerm as local_permute_after_a2a
participant FFN as _expert_ffn (grouped_dense x3)
participant LocalUnperm as local_unpermute_before_a2a
participant A2A_Rev as ragged_all_to_all (rev)
participant Combine as _global_combine
participant Output as Output [B,S,H]
Input->>Gate: inputs [B,S,H]
Gate->>Router: gate_logits [B,S,E]
Router->>AuxLoss: logits_2d (aux branch, parallel)
Router->>Perm: sparse_probs, routing_map
Perm->>A2A_Fwd: sorted_inputs [T*k,H] + group_sizes [E]
A2A_Fwd->>LocalPerm: x_recv [recv_buf, H]
LocalPerm->>FFN: sorted_x, local_group_sizes [E_local]
FFN->>LocalUnperm: expert_outputs [recv_buf, H]
LocalUnperm->>A2A_Rev: x_send_back
A2A_Rev->>Combine: y_back [T*k, H]
Combine->>Output: output [B,S,H] + aux_loss
Reviews (3): Last reviewed commit: "address greptile comments" | Re-trigger Greptile |
Signed-off-by: tdophung <tdophung@nvidia.com>
…ody single GPU vs. multi GPU Signed-off-by: tdophung <tdophung@nvidia.com>
Signed-off-by: tdophung <tdophung@nvidia.com>
Signed-off-by: tdophung <tdophung@nvidia.com>
…e and single device initial params in the MoEBlock. Tests should pass now Signed-off-by: tdophung <tdophung@nvidia.com>
Signed-off-by: tdophung <tdophung@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: tdophung <tdophung@nvidia.com>
| def _compute_aux_loss( | ||
| self, | ||
| logits_2d: jnp.ndarray, | ||
| ) -> Optional[jnp.ndarray]: | ||
| """Compute the MoE auxiliary load-balancing loss. | ||
|
|
||
| The score-for-aux kernel has no data dependency on the main | ||
| routing kernel, so XLA can overlap them on the GPU. | ||
|
|
||
| ``logits_2d`` should be the *full* logits tensor over the global | ||
| token batch -- under EP the caller is responsible for | ||
| :func:`jax.lax.all_gather` ing the logits before calling this so | ||
| the aux_loss formula | ||
| ``loss = (E * coeff / (k * T^2)) * sum_i(sum_t(probs[t,i]) * tokens[i])`` | ||
| sees the global ``T`` and the global ``tokens_per_expert``. | ||
| """ | ||
| if self.aux_loss_coeff <= 0.0: | ||
| return None | ||
| aux_scores, aux_routing_map = fused_topk_with_score_function( | ||
| logits_2d.astype(jnp.float32), | ||
| topk=self.num_experts_per_tok, | ||
| score_function=self.score_function, | ||
| compute_aux_scores=True, | ||
| ) | ||
| aux_tokens_per_expert = jnp.sum(aux_routing_map.astype(jnp.int32), axis=0) | ||
| return fused_moe_aux_loss( | ||
| aux_scores.astype(jnp.float32), | ||
| aux_tokens_per_expert, | ||
| topk=self.num_experts_per_tok, | ||
| coeff=self.aux_loss_coeff, | ||
| ) |
There was a problem hiding this comment.
Aux loss
tokens_per_expert is inconsistent with actual grouped-topk routing
When num_groups > 0 and group_topk > 0 (DeepSeek-style routing), fused_topk_with_score_function(..., compute_aux_scores=True) intentionally ignores those parameters and runs a clean standard top-k. The returned aux_routing_map therefore reflects different expert selections than the actual routing_map produced by _route_topk, causing aux_tokens_per_expert = sum(aux_routing_map, axis=0) to count a different token–expert distribution. Any user who combines num_groups > 0 + group_topk > 0 + aux_loss_coeff > 0 silently trains with a wrong auxiliary objective. The existing test_group_topk_deepseek test does not catch this because it leaves aux_loss_coeff at its default of 0.0.
| sharded_variables = jax.jit(sharded_block.init, out_shardings=out_shardings)( | ||
| init_key, inputs | ||
| ) | ||
| (sharded_loss, (sharded_output, sharded_aux)), sharded_grads = jax.value_and_grad( |
There was a problem hiding this comment.
SR: Do we want to jax.jit this function before calling it?
| assert aux_loss is None, "aux_loss should be None when aux_loss_coeff=0" | ||
|
|
||
| @pytest.mark.parametrize("permutation_backend", ["pure_jax", "triton"]) | ||
| def test_backward_grad(self, permutation_backend): |
There was a problem hiding this comment.
nit SR: rename to test_backward_grad_is_finite_and_nonzero or something similar to indicate this test doesn't compare with a reference impl
| for name in ("gate_kernel", "wi_0", "wi_1", "wo"): | ||
| g_pj = _unwrap_partitioned(grads_pj["params"][name]) | ||
| g_tr = _unwrap_partitioned(grads_tr["params"][name]) | ||
| assert jnp.allclose(g_pj, g_tr, atol=1e-1, rtol=1e-1), ( |
There was a problem hiding this comment.
1e-1 is a pretty high tolerance for most of our tests. What error values of atol and rtol do you typically get from these tests and is that error difference expected between jax/triton backends?
|
|
||
| @pytest.mark.xfail( | ||
| reason=( | ||
| "TE grouped_dense FFI asserts sum(group_sizes) == M at " |
There was a problem hiding this comment.
This assertion is only for the V1 grouped GEMM. For the V2 grouped GEMM sum(group_sizes) < M is supported. Can you try the following?
- Enforce the grouped GEMM with NVTE_JAX_ENFORCE_V2_GROUPED_GEMM=1, similar to what we do here:
TransformerEngine/tests/jax/utils.py
Line 1681 in 4b6923d
- Remove the functools cache on this function so it can change:
- Run this test within a try/catch. If it runs, great. If the "catch" catches a runtime error that contains the string here, then pytest.skip("V2 grouped gemm is not supported")
|
|
||
| * Fused, Triton-backed ``token_dispatch`` / ``token_combine`` - uses the | ||
| Triton kernels in ``transformer_engine.jax.triton_extensions.permutation``. | ||
| * Unfused, pure-JAX ``unfused_token_dispatch`` / ``unfused_token_combine`` - |
There was a problem hiding this comment.
I feel like "unfused" implies slower, when in practice this approach is faster at least in MaxText. Would "triton" or "pure_jax" like we have in the tests fit better. What do you think?
| inputs = with_sharding_constraint_by_logical_axes(inputs, self.input_axes) | ||
|
|
||
| _, _, hidden_size = inputs.shape | ||
| params = self._make_params(hidden_size) |
There was a problem hiding this comment.
Do we need this "params" dictionary? With nn.compact we can define each param right before it's usage inline and nn.compact will handle collectiong all params for us. Most of our other modules in module.py define params inline instead of in a dictionary upfront
| # Gate | ||
| # ------------------------------------------------------------------ | ||
|
|
||
| def _gate(self, inputs: jnp.ndarray, gate_kernel: jnp.ndarray) -> jnp.ndarray: |
There was a problem hiding this comment.
Is this gating proj GEMM something we should quantize?
| inputs_2d: jnp.ndarray, | ||
| sparse_probs: jnp.ndarray, | ||
| routing_map: jnp.ndarray, | ||
| ) -> dict: |
There was a problem hiding this comment.
Can you replace the dict with a dataclass with the same fields
| sparse_probs, routing_map = self._route_topk(logits_2d, params.get("expert_bias")) | ||
| aux_loss = self._compute_aux_loss(logits_2d) | ||
| perm = self._global_permute(inputs_2d, sparse_probs, routing_map) | ||
| expert_outputs = self._expert_ffn( |
There was a problem hiding this comment.
Since the grouped quant and grouped GEMM do not have custom partitioning rules, I think this being outside of shard_map will either raise an error about missing partitioning rules or silently replicate
| captured[name] = params[name] | ||
| in_specs[name] = P(ep_axis, None) | ||
|
|
||
| def _a2a_fn(local: dict) -> Tuple[jnp.ndarray, jnp.ndarray]: |
There was a problem hiding this comment.
This is a fairly large function inside another big function. Can we move this to an outer scope or is this required to capture something?
Description
Most of MoE building blocks integration work has been deeply coupled with Maxtext development. Now creating this MoE block to isolate the work from Maxtext and provide more room for experimentation.
MoEBlockis a self-contained Flax-Linen module that wires together TE's fused router, pluggable token-dispatch backends (pure-JAX argsort or Tritonsort_chunks_by_index),grouped_dense-based expert FFN, and ragged-all-to-all (A2Av) expert parallelism viashard_mapThis first iteration will start with ring-of-experts EP, sharding on batch dimention for FSDP, CUBLASLt groupedGEMM and 2 permutation backend: pure JAX or Triton kernels. The block also exposes pluggable knobs for: weight layout (
wi_kernel_axes/wo_kernel_axes), permutation backend, A2A vs no-EP (single GPU) path, data-parallelism axes for true FSDP (batch sharded across(ep, fsdp)simultaneously), top-K with optional grouped/sigmoid scoring (for DSv3 workload), and optional auxiliary load-balancing loss.Fixes #2895
Type of change
Changes
transformer_engine/jax/flax/moe.py--MoEBlockLinen module:gate -> fused topk -> global permute -> A2A EP shard_map (ragged_a2a fwd, local permute, 3x grouped GEMM SwiGLU FFN, local unpermute, ragged_a2a rev) -> global combine.
transformer_engine/jax/permutation.pywith A2A param helpers (compute_ragged_all_to_all_params,compute_reverse_ragged_all_to_all_params,local_permute_after_a2a,local_unpermute_before_a2a) and the pure-JAXunfused_token_dispatch/unfused_token_combinepathswith custom VJPs.
tests/jax/test_moe_block.py-- single-device shape, backward,cross-backend equivalence, aux-loss, group-topk, JIT determinism.
tests/jax/test_distributed_moe_block.py-- EP=2 x FSDP=2 mesh test using the canonical Flax-Linen sharded-init pattern (eval_shape -> get_partition_spec -> logical_to_mesh_sharding -> jit(init, out_shardings=...)) anddata_parallelism_axes=("fsdp",)to exercise true FSDP (batch sharded across both axes).Checklist: