Skip to content

[JAX] Add an MoE Block (Layer) that compound router, permutation, groupedGEMM and communication#2912

Open
tdophung wants to merge 8 commits intoNVIDIA:mainfrom
tdophung:teddy/moe_block
Open

[JAX] Add an MoE Block (Layer) that compound router, permutation, groupedGEMM and communication#2912
tdophung wants to merge 8 commits intoNVIDIA:mainfrom
tdophung:teddy/moe_block

Conversation

@tdophung
Copy link
Copy Markdown
Collaborator

@tdophung tdophung commented Apr 21, 2026

Description

Most of MoE building blocks integration work has been deeply coupled with Maxtext development. Now creating this MoE block to isolate the work from Maxtext and provide more room for experimentation. MoEBlock is a self-contained Flax-Linen module that wires together TE's fused router, pluggable token-dispatch backends (pure-JAX argsort or Triton sort_chunks_by_index), grouped_dense-based expert FFN, and ragged-all-to-all (A2Av) expert parallelism via shard_map

This first iteration will start with ring-of-experts EP, sharding on batch dimention for FSDP, CUBLASLt groupedGEMM and 2 permutation backend: pure JAX or Triton kernels. The block also exposes pluggable knobs for: weight layout (wi_kernel_axes/ wo_kernel_axes), permutation backend, A2A vs no-EP (single GPU) path, data-parallelism axes for true FSDP (batch sharded across (ep, fsdp) simultaneously), top-K with optional grouped/sigmoid scoring (for DSv3 workload), and optional auxiliary load-balancing loss.

Fixes #2895

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • New transformer_engine/jax/flax/moe.py -- MoEBlock Linen module:
    gate -> fused topk -> global permute -> A2A EP shard_map (ragged_a2a fwd, local permute, 3x grouped GEMM SwiGLU FFN, local unpermute, ragged_a2a rev) -> global combine.
  • Extended transformer_engine/jax/permutation.py with A2A param helpers (compute_ragged_all_to_all_params, compute_reverse_ragged_all_to_all_params, local_permute_after_a2a, local_unpermute_before_a2a) and the pure-JAX unfused_token_dispatch / unfused_token_combine paths
    with custom VJPs.
  • tests/jax/test_moe_block.py -- single-device shape, backward,
    cross-backend equivalence, aux-loss, group-topk, JIT determinism.
  • tests/jax/test_distributed_moe_block.py -- EP=2 x FSDP=2 mesh test using the canonical Flax-Linen sharded-init pattern (eval_shape -> get_partition_spec -> logical_to_mesh_sharding -> jit(init, out_shardings=...)) and data_parallelism_axes=("fsdp",) to exercise true FSDP (batch sharded across both axes).

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@tdophung tdophung marked this pull request as ready for review May 5, 2026 21:47
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented May 5, 2026

Greptile Summary

This PR introduces MoEBlock, a new self-contained Flax Linen layer that wires together TE's fused router, two permutation backends (pure-JAX argsort and Triton), grouped-GEMM expert FFN, and optional ragged-all-to-all expert parallelism via jax.shard_map. It also adds the unfused_token_dispatch/unfused_token_combine family, UnfusedPermState, and the compute_ragged_all_to_all_params helpers needed for the A2A EP path.

  • New MoEBlock: Decomposes the MoE forward into _route, _global_permute, _expert_ffn, _global_combine; the _forward_a2a_ep variant wraps the body in shard_map and inserts forward/reverse ragged_all_to_all + local permute around the FFN.
  • Unfused pure-JAX backend: unfused_token_dispatch/unfused_token_combine use jnp.argsort-based gather with a custom VJP; align_size > 0 is implemented but gated behind xfail tests.
  • EP helpers: compute_ragged_all_to_all_params and compute_reverse_ragged_all_to_all_params translate the gathered [num_ep, num_experts] token-count matrix into the four ragged_all_to_all offset/size arrays.

Confidence Score: 4/5

Safe to merge for non-grouped-topk configs, but silently produces an incorrect auxiliary training objective for DeepSeek-style (num_groups/group_topk) models that also enable aux_loss_coeff.

The aux loss routing map uses a clean standard top-k instead of the actual grouped-topk routing, making tokens_per_expert inconsistent with real routing decisions when num_groups/group_topk are set with aux_loss_coeff > 0.

transformer_engine/jax/flax/moe.py — specifically _compute_aux_loss and the test_group_topk_deepseek test which does not exercise the aux loss path.

Important Files Changed

Filename Overview
transformer_engine/jax/flax/moe.py New 974-line MoEBlock implementing no-EP and ragged-A2A-EP paths; aux loss routing map is inconsistent with actual grouped-topk routing when num_groups/group_topk are configured
transformer_engine/jax/permutation.py Adds pure-JAX unfused dispatch/combine, ragged-A2A parameter helpers, and local permute utilities; logic is sound and well-documented
tests/jax/test_moe_block.py Good single-device coverage; align_size>0 xfail is intentional; missing test combining num_groups + aux_loss_coeff that would expose the tokens_per_expert mismatch
tests/jax/test_distributed_moe_block.py Single EP2x FSDP2 test with gradient comparison; tolerances are wide but appropriate for bfloat16 A2A-EP
transformer_engine/jax/flax/init.py Correctly exports MoEBlock to the public flax API

Sequence Diagram

sequenceDiagram
    participant Input as Input [B,S,H]
    participant Gate as _gate (einsum)
    participant Router as _route_topk (fused_topk)
    participant AuxLoss as _compute_aux_loss
    participant Perm as _global_permute
    participant A2A_Fwd as ragged_all_to_all (fwd)
    participant LocalPerm as local_permute_after_a2a
    participant FFN as _expert_ffn (grouped_dense x3)
    participant LocalUnperm as local_unpermute_before_a2a
    participant A2A_Rev as ragged_all_to_all (rev)
    participant Combine as _global_combine
    participant Output as Output [B,S,H]

    Input->>Gate: inputs [B,S,H]
    Gate->>Router: gate_logits [B,S,E]
    Router->>AuxLoss: logits_2d (aux branch, parallel)
    Router->>Perm: sparse_probs, routing_map
    Perm->>A2A_Fwd: sorted_inputs [T*k,H] + group_sizes [E]
    A2A_Fwd->>LocalPerm: x_recv [recv_buf, H]
    LocalPerm->>FFN: sorted_x, local_group_sizes [E_local]
    FFN->>LocalUnperm: expert_outputs [recv_buf, H]
    LocalUnperm->>A2A_Rev: x_send_back
    A2A_Rev->>Combine: y_back [T*k, H]
    Combine->>Output: output [B,S,H] + aux_loss
Loading

Reviews (3): Last reviewed commit: "address greptile comments" | Re-trigger Greptile

Comment thread transformer_engine/pytorch/permutation.py
Comment thread transformer_engine/jax/flax/moe.py
tdophung added 6 commits May 5, 2026 16:35
Signed-off-by: tdophung <tdophung@nvidia.com>
…ody single GPU vs. multi GPU

Signed-off-by: tdophung <tdophung@nvidia.com>
Signed-off-by: tdophung <tdophung@nvidia.com>
Signed-off-by: tdophung <tdophung@nvidia.com>
…e and single device initial params in the MoEBlock. Tests should pass now

Signed-off-by: tdophung <tdophung@nvidia.com>
Signed-off-by: tdophung <tdophung@nvidia.com>
@tdophung tdophung force-pushed the teddy/moe_block branch from 8a838f3 to 6aeb491 Compare May 5, 2026 23:44
pre-commit-ci Bot and others added 2 commits May 5, 2026 23:45
Signed-off-by: tdophung <tdophung@nvidia.com>
Comment on lines +427 to +457
def _compute_aux_loss(
self,
logits_2d: jnp.ndarray,
) -> Optional[jnp.ndarray]:
"""Compute the MoE auxiliary load-balancing loss.

The score-for-aux kernel has no data dependency on the main
routing kernel, so XLA can overlap them on the GPU.

``logits_2d`` should be the *full* logits tensor over the global
token batch -- under EP the caller is responsible for
:func:`jax.lax.all_gather` ing the logits before calling this so
the aux_loss formula
``loss = (E * coeff / (k * T^2)) * sum_i(sum_t(probs[t,i]) * tokens[i])``
sees the global ``T`` and the global ``tokens_per_expert``.
"""
if self.aux_loss_coeff <= 0.0:
return None
aux_scores, aux_routing_map = fused_topk_with_score_function(
logits_2d.astype(jnp.float32),
topk=self.num_experts_per_tok,
score_function=self.score_function,
compute_aux_scores=True,
)
aux_tokens_per_expert = jnp.sum(aux_routing_map.astype(jnp.int32), axis=0)
return fused_moe_aux_loss(
aux_scores.astype(jnp.float32),
aux_tokens_per_expert,
topk=self.num_experts_per_tok,
coeff=self.aux_loss_coeff,
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Aux loss tokens_per_expert is inconsistent with actual grouped-topk routing

When num_groups > 0 and group_topk > 0 (DeepSeek-style routing), fused_topk_with_score_function(..., compute_aux_scores=True) intentionally ignores those parameters and runs a clean standard top-k. The returned aux_routing_map therefore reflects different expert selections than the actual routing_map produced by _route_topk, causing aux_tokens_per_expert = sum(aux_routing_map, axis=0) to count a different token–expert distribution. Any user who combines num_groups > 0 + group_topk > 0 + aux_loss_coeff > 0 silently trains with a wrong auxiliary objective. The existing test_group_topk_deepseek test does not catch this because it leaves aux_loss_coeff at its default of 0.0.

sharded_variables = jax.jit(sharded_block.init, out_shardings=out_shardings)(
init_key, inputs
)
(sharded_loss, (sharded_output, sharded_aux)), sharded_grads = jax.value_and_grad(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SR: Do we want to jax.jit this function before calling it?

assert aux_loss is None, "aux_loss should be None when aux_loss_coeff=0"

@pytest.mark.parametrize("permutation_backend", ["pure_jax", "triton"])
def test_backward_grad(self, permutation_backend):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit SR: rename to test_backward_grad_is_finite_and_nonzero or something similar to indicate this test doesn't compare with a reference impl

for name in ("gate_kernel", "wi_0", "wi_1", "wo"):
g_pj = _unwrap_partitioned(grads_pj["params"][name])
g_tr = _unwrap_partitioned(grads_tr["params"][name])
assert jnp.allclose(g_pj, g_tr, atol=1e-1, rtol=1e-1), (
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1e-1 is a pretty high tolerance for most of our tests. What error values of atol and rtol do you typically get from these tests and is that error difference expected between jax/triton backends?


@pytest.mark.xfail(
reason=(
"TE grouped_dense FFI asserts sum(group_sizes) == M at "
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This assertion is only for the V1 grouped GEMM. For the V2 grouped GEMM sum(group_sizes) < M is supported. Can you try the following?

  1. Enforce the grouped GEMM with NVTE_JAX_ENFORCE_V2_GROUPED_GEMM=1, similar to what we do here:
    def use_jax_gemm(enabled=False):
  2. Remove the functools cache on this function so it can change:
    def _should_enforce_v2_grouped_gemm() -> bool:
  3. Run this test within a try/catch. If it runs, great. If the "catch" catches a runtime error that contains the string here, then pytest.skip("V2 grouped gemm is not supported")


* Fused, Triton-backed ``token_dispatch`` / ``token_combine`` - uses the
Triton kernels in ``transformer_engine.jax.triton_extensions.permutation``.
* Unfused, pure-JAX ``unfused_token_dispatch`` / ``unfused_token_combine`` -
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like "unfused" implies slower, when in practice this approach is faster at least in MaxText. Would "triton" or "pure_jax" like we have in the tests fit better. What do you think?

inputs = with_sharding_constraint_by_logical_axes(inputs, self.input_axes)

_, _, hidden_size = inputs.shape
params = self._make_params(hidden_size)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need this "params" dictionary? With nn.compact we can define each param right before it's usage inline and nn.compact will handle collectiong all params for us. Most of our other modules in module.py define params inline instead of in a dictionary upfront

# Gate
# ------------------------------------------------------------------

def _gate(self, inputs: jnp.ndarray, gate_kernel: jnp.ndarray) -> jnp.ndarray:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this gating proj GEMM something we should quantize?

inputs_2d: jnp.ndarray,
sparse_probs: jnp.ndarray,
routing_map: jnp.ndarray,
) -> dict:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you replace the dict with a dataclass with the same fields

sparse_probs, routing_map = self._route_topk(logits_2d, params.get("expert_bias"))
aux_loss = self._compute_aux_loss(logits_2d)
perm = self._global_permute(inputs_2d, sparse_probs, routing_map)
expert_outputs = self._expert_ffn(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since the grouped quant and grouped GEMM do not have custom partitioning rules, I think this being outside of shard_map will either raise an error about missing partitioning rules or silently replicate

captured[name] = params[name]
in_specs[name] = P(ep_axis, None)

def _a2a_fn(local: dict) -> Tuple[jnp.ndarray, jnp.ndarray]:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a fairly large function inside another big function. Can we move this to an outer scope or is this required to capture something?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[JAX] Create initial MoE Block

2 participants