Add examples for MoE models - Mixtral in TE#2642
Add examples for MoE models - Mixtral in TE#2642faradawn wants to merge 45 commits intoNVIDIA:mainfrom
Conversation
Greptile SummaryThis PR adds a complete TE-accelerated Mixtral (MoE) tutorial with
Confidence Score: 3/5Not safe to merge — multiple crash paths in the model forward pass remain open from prior review rounds, and the new AllToAllTokenDispatcher introduces a non-standard sentinel that could silently zero all expert outputs. The new AllToAllTokenDispatcher uses a non-standard sentinel for num_out_tokens that works by accident today but could silently produce empty expert inputs in a future TE version. Several bugs flagged in earlier rounds are still unaddressed: the decode-step assertion fires whenever attention_mask is absent, the inputs_embeds forward path crashes on input_ids.shape, and _sync_expert_views never unwraps DTensor parameters under EP, causing GroupedLinear to receive unacceptable tensor types. docs/examples/te_mixtral/te_mixtral.py needs the most attention: the AllToAllTokenDispatcher sentinel fix, the decode-step and inputs_embeds crashes, and the _sync_expert_views DTensor check. docs/examples/te_mixtral/requirements.txt is missing flash-attn. Important Files Changed
Reviews (25): Last reviewed commit: "Fix padding for normal permute path" | Re-trigger Greptile |
|
Hi @sudhakarsingh27 can you check if this addresses your comments? Tested in 2x H100. |
Agreed! I think any example should show some perf gain and include the whole weight mapping so the user can run the example. |
|
Documentation build is not working, if you fix it please ping me and I'll review. |
d7301a6 to
fcc7e35
Compare
72e0e45 to
226173a
Compare
Signed-off-by: Faradawn Yang <73060648+faradawn@users.noreply.github.com>
- Fix m_splits bug: use (selected_experts == i).sum() instead of .any(dim=-1).sum() * top_k, which caused dimension mismatches in GroupedLinear - Add te_mixtral.py: TEMixtralSparseMoeBlock (GroupedLinear + moe_permute/ unpermute), replace_moe_block context manager, TEMixtralForCausalLM with HF weight loading, and replace_params for expert weight packing - Add utils.py: HyperParameters, data loading, BF16/FP8 model init, Accelerate wrapping, fine-tuning loop — mirrors te_llama/utils.py style - Add requirements.txt matching te_llama versions - Expand notebook from bare code snippet to full tutorial covering: architecture overview, HF vs TE comparison table, unit-test cell, baseline BF16 run, BF16 TE improvement, FP8 TE improvement, expert routing/scaling discussion, generalisation guide for other MoE models Addresses reviewer feedback: fixes the critical runtime bug (greptile) and expands to tutorial quality comparable to the Llama/Gemma examples (sudhakarsingh27), covering the scope requested in issue NVIDIA#2573. Signed-off-by: Faradawn Yang <73060648+faradawn@users.noreply.github.com>
for more information, see https://pre-commit.ci Signed-off-by: Faradawn Yang <73060648+faradawn@users.noreply.github.com>
A few bugs found during code review:
- moe_permute and moe_unpermute were missing map_type='index'. The
selected_experts tensor is [num_tokens, top_k] indices, not a mask,
so without this the routing is completely wrong at runtime.
- num_out_tokens=None should be -1 (the API expects an int).
- moe_unpermute: replaced deprecated probs= with merging_probs= and
kept routing_weights in float32 as TE recommends.
- utils.py: num_warmup_steps was hardcoded to 100 instead of using
hyperparams.num_warmup_steps, which made the benchmark meaningless.
- requirements.txt: transformers==4.57.0 doesn't exist, fixed to 4.47.1.
- Notebook generalisation guide: updated code template with the same fixes.
Tested on 2xH100 in nvcr.io/nvidia/pytorch:26.01-py3 (PyTorch 2.10.0, CUDA 13.1):
$ python3 -c "
from te_mixtral import TEMixtralSparseMoeBlock
from transformers import MixtralConfig
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
import torch
cfg = MixtralConfig(hidden_size=256, intermediate_size=512,
num_local_experts=4, num_experts_per_tok=2)
x = torch.randn(2, 8, cfg.hidden_size, device='cuda', dtype=torch.bfloat16)
hf_out, hf_logits = MixtralSparseMoeBlock(cfg).cuda().bfloat16()(x)
te_out, te_logits = TEMixtralSparseMoeBlock(cfg).cuda().bfloat16()(x)
assert hf_out.shape == te_out.shape
assert hf_logits.shape == te_logits.shape
print('PASS')"
Input shape : torch.Size([2, 8, 256])
Output shape : torch.Size([2, 8, 256]) (matches HF: True)
Logits shape : torch.Size([16, 4]) (matches HF: True)
Output dtype : torch.bfloat16
PASS
Signed-off-by: Faradawn Yang <73060648+faradawn@users.noreply.github.com>
The docs CI was failing because nbsphinx auto mode tries to execute notebooks with no stored outputs. Add explicit execute:never metadata so the docs builder renders the notebook as-is without running cells. Signed-off-by: Faradawn Yang <73060648+faradawn@users.noreply.github.com>
…otebook te_mixtral.py: - Replace per-expert .item() loop with torch.bincount + .tolist() (8 GPU syncs → 1), eliminating the main cause of GroupedLinear speedup regression - Remove unnecessary num_out_tokens/max_token_num args from moe_permute - Fix replace_params to use .copy_() instead of fragile .data[] assignment, and load_state_dict from fully-populated te_state in one shot - Add device_map="auto" support in from_hf_model via accelerate.dispatch_model - Rename replace_params -> _pack_expert_weights for clarity tutorial_accelerate_hf_mixtral_with_te.ipynb: - Remove shape-check section (redundant) - Add FP8 prefill and decode-regime benchmark cells with summary table - Restructure training sections to match te_llama pattern: restart at top of each cell, combined hyperparams + init + wrap + finetune, concise markdown - Bump benchmark SEQ 512 -> 2048 for realistic H100 workload utils.py: - Apply user batch size and sequence length adjustments Signed-off-by: Faradawn Yang <73060648+faradawn@users.noreply.github.com>
for more information, see https://pre-commit.ci Signed-off-by: Faradawn Yang <73060648+faradawn@users.noreply.github.com>
Signed-off-by: Faradawn Yang <73060648+faradawn@users.noreply.github.com>
Signed-off-by: Faradawn Yang <73060648+faradawn@users.noreply.github.com>
Signed-off-by: Faradawn Yang <73060648+faradawn@users.noreply.github.com>
Signed-off-by: Faradawn Yang <73060648+faradawn@users.noreply.github.com>
|
I was not reviewing it in detail, I would like to discuss on high-level flow of the tutorial first.
Apart of that it would be nice to prepare documentation (not tutorial) explaining why grouped linear is needed and which permute kernels we provide. I will work on that in different PR and it will land in Features section in our docs. |
Pin transformers, accelerate, datasets, safetensors, huggingface_hub, and tokenizers in requirements.txt to the exact versions validated on 8x B300 inside NGC pytorch-25.12-py3. torch, torchao, and transformer_engine are left unpinned because they come from the NGC container and pinning them would tie the tutorial to a specific container build. Add a setup-cell note in the tutorial notebook calling out the tested GPU (8x Blackwell B300) and container (pytorch-25.12-py3). Addresses pggPL review comments on requirements.txt: - "pin the versions of some of the frameworks" - "specify GPU and maybe container to make sure the users can rerun it" Signed-off-by: Faradawn Yang <73060648+faradawn@users.noreply.github.com>
| if isinstance(past_key_values, InferenceParams): | ||
| # input_ids is None when the caller supplies inputs_embeds directly. | ||
| _ref = input_ids if input_ids is not None else inputs_embeds | ||
| lengths = ( | ||
| attention_mask.sum(dim=1).tolist() | ||
| if attention_mask.shape[:2] == _ref.shape[:2] | ||
| else [1] * _ref.shape[0] | ||
| ) | ||
| past_key_values.pre_step(OrderedDict(zip(list(range(len(lengths))), lengths))) |
There was a problem hiding this comment.
attention_mask is None crash inside InferenceParams branch
attention_mask is an optional parameter with no None-guard before line 852. In the decode step of model.generate() the mask is frequently omitted, and attention_mask.shape[:2] will raise AttributeError: 'NoneType' object has no attribute 'shape', crashing every auto-regressive generation call.
if isinstance(past_key_values, InferenceParams):
_ref = input_ids if input_ids is not None else inputs_embeds
lengths = (
attention_mask.sum(dim=1).tolist()
if attention_mask is not None and attention_mask.shape[:2] == _ref.shape[:2]
else [1] * _ref.shape[0]
)
past_key_values.pre_step(OrderedDict(zip(list(range(len(lengths))), lengths)))Adds a configurable expert FFN execution mode to NVMixtralSparseMoeBlock:
- "grouped" (default): existing TE GroupedLinear path
- "loop": Python loop over experts with one F.linear per expert
(HF-style), running against the same stacked weight tensor
so swapping between modes is a runtime config flag, no
checkpoint rework
Adds a new "Tier 2" to the tutorial's progressive optimization chain that
isolates EP+TE primitives without GroupedLinear, so the GroupedLinear
contribution can be measured independently. Renumbers run_finetune_ep.py:
1 = HF baseline BF16 (device_map="auto")
2 = TE EP BF16 -- Python loop, F.linear per expert [new]
3 = TE EP BF16 -- GroupedLinear (was Tier 2)
4 = TE EP BF16 -- GroupedLinear + Fused DeepEP (was Tier 3)
5 = TE EP FP8 -- Float8CurrentScaling + grouped + DeepEP
Tier 5 switches the FP8 recipe from MXFP8BlockScaling to
Float8CurrentScaling -- the per-tensor scaling avoids the 32-aligned
tile constraint that MXFP8 imposes on grouped-matmul, so Tier 5 runs
cleanly at EP < num_experts where MXFP8 trips the per-expert padding
assertion.
Also threads the configuration through the harness:
- HyperParameters.expert_ffn_mode
- --ep-size default changed from 8 to 2 (4 experts/rank gives a
non-degenerate loop-vs-grouped contrast; pass --ep-size 8 for the
original 1-expert/rank config)
- Allow WORLD_SIZE > EP via a 2D (DP, EP) DeviceMesh so EP=2 with 8
GPUs (DP=4) works end-to-end
- replace_params: use EP-rank-within-group (rank % ep_size) instead of
global rank for the expert slicing -- otherwise ranks beyond ep_size
slice past the end of the HF stacked weight tensor
Misc: cycle the dataloader (small benchmark datasets like
openassistant-guanaco exhaust at large batch * num_steps), and trim the
fine-tune summary output to median + last-step (was mean + median +
min/max/p99).
Signed-off-by: Faradawn Yang <73060648+faradawn@users.noreply.github.com>
for more information, see https://pre-commit.ci
| """ | ||
| gate_up_w = self.experts_gate_up_weight | ||
| if isinstance(gate_up_w, DTensor): | ||
| gate_up_w = gate_up_w.to_local() | ||
| for i in range(self.num_local_experts): | ||
| object.__setattr__(self.experts_gate_up, f"weight{i}", gate_up_w[i]) | ||
|
|
||
| down_w = self.experts_down_weight | ||
| if isinstance(down_w, DTensor): | ||
| down_w = down_w.to_local() |
There was a problem hiding this comment.
Wrong DTensor check — GroupedLinear receives DTensor views under EP
self.experts_gate_up_weight is an nn.Parameter. After set_ep_group() wraps it with nn.Parameter(DTensor.from_local(...)), the parameter's type is still nn.Parameter (not DTensor), so isinstance(gate_up_w, DTensor) is always False. The guard never fires, to_local() is never called, and every subsequent gate_up_w[i] slice handed to object.__setattr__ is a DTensor. TE's GroupedLinear CUDA kernels cannot accept DTensor views and will either raise a type error or silently use incorrect data.
The _expert_ffn loop path already applies the correct pattern with .data:
if isinstance(gate_up_w.data, DTensor):
gate_up_w = gate_up_w.data.to_local()| """ | |
| gate_up_w = self.experts_gate_up_weight | |
| if isinstance(gate_up_w, DTensor): | |
| gate_up_w = gate_up_w.to_local() | |
| for i in range(self.num_local_experts): | |
| object.__setattr__(self.experts_gate_up, f"weight{i}", gate_up_w[i]) | |
| down_w = self.experts_down_weight | |
| if isinstance(down_w, DTensor): | |
| down_w = down_w.to_local() | |
| gate_up_w = self.experts_gate_up_weight | |
| if isinstance(gate_up_w.data, DTensor): | |
| gate_up_w = gate_up_w.data.to_local() | |
| for i in range(self.num_local_experts): | |
| object.__setattr__(self.experts_gate_up, f"weight{i}", gate_up_w[i]) | |
| down_w = self.experts_down_weight | |
| if isinstance(down_w.data, DTensor): | |
| down_w = down_w.data.to_local() | |
| for i in range(self.num_local_experts): | |
| object.__setattr__(self.experts_down, f"weight{i}", down_w[i]) |
Adds EXPERIMENT_LOG.md with the GroupedLinear vs naive-loop investigation
(EP=2 batch sweep, batch=8 nsys + NVTX per-phase breakdown) and the
my_expert_ffn_{loop|grouped} / my_grouped_linear_{gate_up,down} NVTX
ranges that produced the per-phase numbers.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
for more information, see https://pre-commit.ci
| remainder = orig_num_tokens % 32 | ||
| if remainder != 0: | ||
| pad_count = 32 - remainder | ||
| expert_input = torch.nn.functional.pad(expert_input, (0, 0, 0, pad_count)) | ||
| tokens_per_expert = list(tokens_per_expert) | ||
| tokens_per_expert[-1] += pad_count |
There was a problem hiding this comment.
Padding appended to the last expert's slot regardless of which expert has remaining tokens
When orig_num_tokens % 32 != 0, pad_count tokens are appended to expert_input and tokens_per_expert[-1] is incremented by pad_count. This is only correct if the padded tokens happen to logically belong to the last expert. If the last expert already has 0 tokens on this rank (a routine EP scenario), GroupedLinear will schedule a GEMM for pad_count zero-padded tokens against that expert's weights, inflating compute and potentially triggering a misaligned-size error rather than a clean no-op. A safer approach is to distribute the padding tokens as a trailing dummy entry (tokens_per_expert.append(pad_count)) or to pad at batch-preparation time before dispatch so padding is always aligned with a real expert slot.
|
Will wait on this PR: #2923 |
…p_proj/down_proj naming Adds three drawio + SVG diagram pairs under docs/examples/te_mixtral/media/: - dense_to_sparse: dense Transformer block vs sparse MoE block (router + 8 experts). - mixtral_decoder_swap: Llama-tutorial-style API key->key map from HF MixtralDecoderLayer to TE NVMixtralDecoderLayer (fused QKV, packed expert tensors). - moe_loop_vs_grouped: zoom into the MoE block; HF per-expert loop vs TE GroupedLinear with variable per-expert token counts (m_splits). Adds a new "Architecture Overview" cell at the top of the tutorial notebook with the two main figures, and renames the MoE weight references throughout the notebook prose and tables from the older w1/w2/w3 form to the newer packed mlp.experts.gate_up_proj / mlp.experts.down_proj keys (the form used by BioNemo's bionemo-recipes/models/mixtral/ convert.py and accepted by replace_params() in te_mixtral.py). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
| `/lustre/share/coreai_prod_infbench/exp3_nvtx_per_phase/`. | ||
|
|
||
| --- | ||
|
|
||
| All experiments run on a **SLURM cluster**. The login node has no CUDA, so any | ||
| experiment must execute inside an allocation on a compute node with 8 GPUs. | ||
|
|
||
| ### Reserving a fresh node (8x GPU, 4h, exclusive) | ||
|
|
||
| ```bash | ||
| srun -A coreai_prod_infbench \ | ||
| -p batch \ | ||
| -N 1 \ | ||
| -J myjob \ | ||
| -t 04:00:00 \ | ||
| --exclusive \ | ||
| --mpi=pmix \ | ||
| --container-image=/lustre/fsw/coreai_prod_infbench/faradawny/docker/pytorch-25.12-py3.sqsh \ | ||
| --container-save=/lustre/fsw/coreai_prod_infbench/faradawny/docker/pytorch-25.12-py3.sqsh \ | ||
| --container-name=mycontainer \ | ||
| --container-mounts=/lustre:/lustre \ | ||
| --pty bash | ||
| ``` | ||
|
|
||
| This drops you into an interactive shell inside the NGC `pytorch-25.12-py3` | ||
| container with `/lustre` mounted and the container saved as `mycontainer`. |
There was a problem hiding this comment.
Internal cluster details in a public repository
EXPERIMENT_LOG.md is a personal working-notes file that contains NVIDIA-internal infrastructure details that should not be committed to a public GitHub repo:
- A developer username embedded in paths (
faradawny):/lustre/fsw/coreai_prod_infbench/faradawny/docker/pytorch-25.12-py3.sqsh - An internal SLURM account:
coreai_prod_infbench - Internal shared storage paths:
/lustre/share/coreai_prod_infbench/exp3_nvtx_per_phase/
This file looks like a benchmark log intended for internal coordination, not documentation for external tutorial users. Consider removing it entirely or replacing it with a trimmed BENCHMARKS.md that shows only public results (median step times, speedup numbers) without cluster-specific details.
Signed-off-by: Faradawn Yang <73060648+faradawn@users.noreply.github.com>
Add reproducible Mixtral EP tiers for TE Sequential Ops GroupedLinear, including FP8/MXFP8 variants and the dependency lockfile needed to rerun the experiments. Co-authored-by: Cursor <cursoragent@cursor.com>
for more information, see https://pre-commit.ci
| @@ -0,0 +1,275 @@ | |||
| { | |||
There was a problem hiding this comment.
Notebook Tier bash commands misaligned with
run_finetune_ep.py improvement indices
The bash commands shown for Tiers 2–4 reference --improvement values that map to different configurations in run_finetune_ep.py than the notebook describes. Concretely:
| Notebook Tier | Described as | Bash uses | run_finetune_ep.py actually does |
|---|---|---|---|
| 2 | TE GroupedLinear EP | --improvement 2 |
expert_ffn_mode="loop" (Python F.linear, not GroupedLinear) |
| 3 | Fused DeepEP dispatcher (BF16) | --improvement 3 |
expert_ffn_mode="grouped_op", no DeepEP |
| 4 | MXFP8 precision | --improvement 4 |
DeepEP BF16 GroupedLinear, not MXFP8 |
Additionally, every multi-GPU bash command omits --ep-size 8, so the default --ep-size 2 (4 experts/rank) is used, whereas the Python cells in each tier all set hp.expert_parallel_size = 8 (1 expert/rank). The benchmark step-time numbers in the results tables were measured at EP=8, so users following the bash commands will reproduce different performance characteristics.
The correct mapping from notebook tier to --improvement flag appears to be: Tier 2 → --improvement 6, Tier 3 → --improvement 4, Tier 4 → --improvement 8 (all with --ep-size 8 added).
Signed-off-by: Faradawn Yang <73060648+faradawn@users.noreply.github.com>
…ne into add-moe-example
for more information, see https://pre-commit.ci
| torch | ||
| torchao!=0.14.0 | ||
| transformer_engine[pytorch] | ||
|
|
||
| transformers | ||
| accelerate | ||
| datasets | ||
| safetensors | ||
| huggingface_hub | ||
| tokenizers |
There was a problem hiding this comment.
Both init_baseline_model and init_te_mixtral_model in utils.py explicitly set config._attn_implementation = "flash_attention_2". HuggingFace transformers validates this at from_pretrained time and raises ImportError if the flash-attn package is absent. Every tutorial user following the baseline or TE fine-tuning paths hits this error before the first training step. Add flash-attn to requirements.txt (the flash-attn PyPI package).
…to add-moe-example
Signed-off-by: Faradawn Yang <73060648+faradawn@users.noreply.github.com>
| model.load_state_dict(te_state_dict, strict=False) | ||
| del hf_model |
There was a problem hiding this comment.
Silent weight-mapping failures in
init_te_mixtral_model
load_state_dict(strict=False) swallows all missing and unexpected keys, so if replace_params fails to map any parameter (e.g., due to a checkpoint key-name mismatch), the model silently trains with random initialization. test_accuracy.py shows the correct pattern: capture the return value and assert that the only missing keys are TE-internal _extra_state entries.
| model.load_state_dict(te_state_dict, strict=False) | |
| del hf_model | |
| missing, unexpected = model.load_state_dict(te_state_dict, strict=False) | |
| if unexpected: | |
| raise RuntimeError(f"Unexpected keys when loading TE state dict: {unexpected}") | |
| non_extra_missing = [k for k in missing if not k.endswith("_extra_state")] | |
| if non_extra_missing: | |
| raise RuntimeError(f"Missing non-extra-state keys in TE model: {non_extra_missing}") | |
| del hf_model |
Signed-off-by: Faradawn Yang <73060648+faradawn@users.noreply.github.com>
| should_pack_inputs = not any(has_thd_input) and self.config.attn_input_format == "thd" | ||
|
|
||
| if should_pack_inputs: | ||
| assert ( | ||
| attention_mask is not None | ||
| ), "Attention mask is required when packing BSHD inputs." |
There was a problem hiding this comment.
Decode-step crash with default
attn_input_format="thd" when attention_mask is omitted
should_pack_inputs is True whenever attn_input_format == "thd" and no explicit THD kwargs are supplied, which is the normal case for every decode step during generation. In that path the assertion fires immediately because HF generate() typically does not propagate attention_mask on decode steps (sequence length 1). Adding a guard like past_key_values is None (or hidden_states.size(1) > 1) to should_pack_inputs would skip packing for single-token decode steps where the sequence is already effectively packed.
| routing_weights_for_unpermute = routing_weights | ||
| map_type = "index" | ||
|
|
||
| if self._ep_group is not None: |
There was a problem hiding this comment.
Non-standard sentinel for no-token-dropping in
moe_permute
The TE API documents -1 as the sentinel meaning "no token dropping"; the value used here (0) only happens to work because the fake/shape-inference path guards with num_out_tokens > 0, and 0 fails that check. If a future TE CUDA kernel instead treats 0 as a literal output-token count, permuted_hidden becomes empty and every all-to-all dispatch silently sends no tokens to any expert. Use -1 to match the documented contract.
Summary
This PR adds a complete tutorial for integrating HuggingFace Mixtral (MoE) with Transformer Engine, addressing the gap identified in #2573.
What's included
te_mixtral.py— Drop-inTEMixtralSparseMoeBlockthat replaces HF's loop-over-experts with TE'sGroupedLinear(batched GEMM) +moe_permute/moe_unpermute. Includesreplace_moe_blockcontext manager,TEMixtralForCausalLMwith HF weight loading, andreplace_paramsfor expert weight packing.utils.py— Data loading, BF16/FP8 model init, Accelerate wrapping, fine-tuning loop — mirrorste_llama/utils.pystyle.requirements.txt— Pinned dependencies matching the Llama/Gemma tutorials.te_llamaandte_gemma, covering:Bug fix
Corrected the
m_splitscalculation flagged by the automated review:Scope
Covers all topics requested in #2573: