Skip to content

Add examples for MoE models - Mixtral in TE#2642

Open
faradawn wants to merge 45 commits intoNVIDIA:mainfrom
faradawn:add-moe-example
Open

Add examples for MoE models - Mixtral in TE#2642
faradawn wants to merge 45 commits intoNVIDIA:mainfrom
faradawn:add-moe-example

Conversation

@faradawn
Copy link
Copy Markdown
Contributor

@faradawn faradawn commented Feb 2, 2026

Summary

This PR adds a complete tutorial for integrating HuggingFace Mixtral (MoE) with Transformer Engine, addressing the gap identified in #2573.

What's included

  • te_mixtral.py — Drop-in TEMixtralSparseMoeBlock that replaces HF's loop-over-experts with TE's GroupedLinear (batched GEMM) + moe_permute/moe_unpermute. Includes replace_moe_block context manager, TEMixtralForCausalLM with HF weight loading, and replace_params for expert weight packing.
  • utils.py — Data loading, BF16/FP8 model init, Accelerate wrapping, fine-tuning loop — mirrors te_llama/utils.py style.
  • requirements.txt — Pinned dependencies matching the Llama/Gemma tutorials.
  • Tutorial notebook — Full tutorial matching the quality bar of te_llama and te_gemma, covering:
    1. Architecture overview: Transformer → Mixtral MoE, HF bottleneck, TE approach
    2. Unit-test cell verifying output shape/dtype against the HF block
    3. [Baseline] HF Mixtral in BF16
    4. [Improvement 1] TE GroupedLinear MoE in BF16
    5. [Improvement 2] TE GroupedLinear MoE in FP8
    6. Expert routing considerations with mixed precision (m_splits, per-expert FP8 scaling, aux loss passthrough)
    7. Generalisation guide for other MoE architectures (DeepSeek, Grok-1, etc.)

Bug fix

Corrected the m_splits calculation flagged by the automated review:

# Before (wrong): double-counts tokens by reducing with .any() then multiplying by top_k
expert_mask = (selected_experts == expert_idx).any(dim=-1)
m_splits.append(expert_mask.sum().item() * self.top_k)

# After (correct): count the actual number of (token, top_k_slot) pairs per expert
m_splits = [(selected_experts == i).sum().item() for i in range(self.num_experts)]

Scope

Covers all topics requested in #2573:

  • How to wrap MoE layers with TE modules ✓
  • FP8 training configuration for MoE ✓
  • Expert routing considerations with mixed precision ✓
  • Generalisation to arbitrary MoE architectures ✓

@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented Feb 2, 2026

Greptile Summary

This PR adds a complete TE-accelerated Mixtral (MoE) tutorial with NVMixtralForCausalLM, an AllToAllTokenDispatcher for expert parallelism, an optional DeepEP-backed FusedTokenRouter, and an eight-tier benchmarking notebook and script.

  • te_mixtral.py: The new AllToAllTokenDispatcher.dispatch uses a non-standard sentinel value for num_out_tokens where the documented no-drop value should be used; several bugs from prior rounds (decode-step assertion crash, null-deref on the inputs-embeds path, DTensor unwrapping in _sync_expert_views) remain open.
  • utils.py / collator.py: Previously flagged class-name and third-party-import issues are fixed; flash-attn, required by the flash-attention-2 attention implementation, is still absent from requirements.txt.
  • fused_token_router.py / fused_a2a.py / fused_indices_converter.py: New DeepEP and Triton dispatch path uses the correct map type and the renamed merging_probs kwarg; no new issues found in these files.

Confidence Score: 3/5

Not safe to merge — multiple crash paths in the model forward pass remain open from prior review rounds, and the new AllToAllTokenDispatcher introduces a non-standard sentinel that could silently zero all expert outputs.

The new AllToAllTokenDispatcher uses a non-standard sentinel for num_out_tokens that works by accident today but could silently produce empty expert inputs in a future TE version. Several bugs flagged in earlier rounds are still unaddressed: the decode-step assertion fires whenever attention_mask is absent, the inputs_embeds forward path crashes on input_ids.shape, and _sync_expert_views never unwraps DTensor parameters under EP, causing GroupedLinear to receive unacceptable tensor types.

docs/examples/te_mixtral/te_mixtral.py needs the most attention: the AllToAllTokenDispatcher sentinel fix, the decode-step and inputs_embeds crashes, and the _sync_expert_views DTensor check. docs/examples/te_mixtral/requirements.txt is missing flash-attn.

Important Files Changed

Filename Overview
docs/examples/te_mixtral/te_mixtral.py Core NVMixtralForCausalLM implementation. Several unresolved bugs from prior reviews remain, plus a new issue in AllToAllTokenDispatcher.dispatch: uses a non-standard sentinel value for num_out_tokens instead of the documented -1 flag, which could silently produce empty expert inputs in future TE versions.
docs/examples/te_mixtral/utils.py Helper utilities for loading/wrapping models and running the fine-tuning loop. Previously-flagged issues (hardcoded wandb, broken class import, hardcoded warmup steps) appear fixed. load_state_dict(strict=False) still silently swallows weight-mapping failures. flash-attn is still absent from requirements.txt despite being required by the flash-attention-2 code path.
docs/examples/te_mixtral/collator.py Local DataCollatorWithFlattening that replaces the removed bionemo_mixtral dependency, producing THD-format batches with cu_seq_lens_q/k for TE fused attention. Logic appears correct: BSHD masking and packed sequence lengths stay consistent, separator boundaries are applied before padding.
docs/examples/te_mixtral/fused_token_router.py DeepEP-backed TokenDispatcher using fused all-to-all and Triton index conversion. Uses map_type=mask and merging_probs= (the correct renamed kwarg) for both moe_permute and moe_unpermute, correctly addressing the prior-round findings for the FusedTokenRouter path.
docs/examples/te_mixtral/requirements.txt wandb (previously missing) has been removed from the codebase call path. However flash-attn is still absent despite the flash-attention-2 implementation being hardcoded in both init_baseline_model and init_te_mixtral_model; any user running those paths hits an ImportError before the first training step.
docs/examples/te_mixtral/test_accuracy.py Parity test comparing HF and TE model logits on a small synthetic config. Correctly checks for unexpected and non-_extra_state missing keys after load_state_dict, providing the defensive pattern that utils.py should follow.
docs/examples/te_mixtral/tutorial_accelerate_hf_mixtral_with_te.ipynb Tutorial notebook. Prior review rounds flagged class name mismatches, missing from_hf_model classmethod, and tier/improvement-index misalignment with run_finetune_ep.py. Whether those are resolved in the latest notebook cells was not re-verified here.
docs/examples/te_mixtral/run_finetune_ep.py CLI launcher for all 8 improvement tiers. Internal tier labels and expert_ffn_mode/dispatcher_type assignments now match the script docstring.
docs/examples/te_mixtral/fused_a2a.py DeepEP-backed FusedDispatch and FusedCombine autograd wrappers with correct forward/backward split-size reversal and NVSHMEM fallback detection. No new issues found.

Reviews (25): Last reviewed commit: "Fix padding for normal permute path" | Re-trigger Greptile

Copy link
Copy Markdown
Contributor

@greptile-apps greptile-apps Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 file reviewed, 2 comments

Edit Code Review Agent Settings | Greptile

Comment thread docs/examples/te_mixtral/tutorial_accelerate_hf_mixtral_with_te.ipynb Outdated
Comment thread docs/examples/te_mixtral/tutorial_accelerate_hf_mixtral_with_te.ipynb Outdated
@sudhakarsingh27
Copy link
Copy Markdown
Collaborator

Thanks, @faradawn! Also adding @sbhavani to the discussion. Compared to other llama/gemma tutorials, this one seems a quite barebones and looks more like a code example than a tutorial. @sbhavani do you think in its current form, it covers the scope as you requested in #2573?

Comment thread docs/examples/te_mixtral/tutorial_accelerate_hf_mixtral_with_te.ipynb Outdated
Comment thread docs/examples/te_mixtral/tutorial_accelerate_hf_mixtral_with_te.ipynb Outdated
Comment thread docs/examples/te_mixtral/te_mixtral.py Outdated
Comment thread docs/examples/te_mixtral/te_mixtral.py Outdated
Comment thread docs/examples/te_mixtral/utils.py
@faradawn
Copy link
Copy Markdown
Contributor Author

faradawn commented Apr 2, 2026

Hi @sudhakarsingh27 can you check if this addresses your comments? Tested in 2x H100.

@sbhavani
Copy link
Copy Markdown
Collaborator

sbhavani commented Apr 6, 2026

Thanks, @faradawn! Also adding @sbhavani to the discussion. Compared to other llama/gemma tutorials, this one seems a quite barebones and looks more like a code example than a tutorial. @sbhavani do you think in its current form, it covers the scope as you requested in #2573?

Agreed! I think any example should show some perf gain and include the whole weight mapping so the user can run the example.

@pggPL pggPL self-assigned this Apr 13, 2026
@pggPL
Copy link
Copy Markdown
Collaborator

pggPL commented Apr 13, 2026

Documentation build is not working, if you fix it please ping me and I'll review.

Comment thread docs/examples/te_mixtral/tutorial_accelerate_hf_mixtral_with_te.ipynb Outdated
Comment thread docs/examples/te_mixtral/utils.py Outdated
Comment thread docs/examples/te_mixtral/tutorial_accelerate_hf_mixtral_with_te.ipynb Outdated
Comment thread docs/examples/te_mixtral/utils.py Outdated
Comment thread docs/examples/te_mixtral/HANDOFF.md Outdated
Comment thread docs/examples/te_mixtral/te_mixtral.py Outdated
faradawn and others added 11 commits April 21, 2026 11:29
Signed-off-by: Faradawn Yang <73060648+faradawn@users.noreply.github.com>
- Fix m_splits bug: use (selected_experts == i).sum() instead of
  .any(dim=-1).sum() * top_k, which caused dimension mismatches in
  GroupedLinear
- Add te_mixtral.py: TEMixtralSparseMoeBlock (GroupedLinear + moe_permute/
  unpermute), replace_moe_block context manager, TEMixtralForCausalLM with
  HF weight loading, and replace_params for expert weight packing
- Add utils.py: HyperParameters, data loading, BF16/FP8 model init,
  Accelerate wrapping, fine-tuning loop — mirrors te_llama/utils.py style
- Add requirements.txt matching te_llama versions
- Expand notebook from bare code snippet to full tutorial covering:
  architecture overview, HF vs TE comparison table, unit-test cell,
  baseline BF16 run, BF16 TE improvement, FP8 TE improvement, expert
  routing/scaling discussion, generalisation guide for other MoE models

Addresses reviewer feedback: fixes the critical runtime bug (greptile)
and expands to tutorial quality comparable to the Llama/Gemma examples
(sudhakarsingh27), covering the scope requested in issue NVIDIA#2573.

Signed-off-by: Faradawn Yang <73060648+faradawn@users.noreply.github.com>
for more information, see https://pre-commit.ci

Signed-off-by: Faradawn Yang <73060648+faradawn@users.noreply.github.com>
A few bugs found during code review:

- moe_permute and moe_unpermute were missing map_type='index'. The
  selected_experts tensor is [num_tokens, top_k] indices, not a mask,
  so without this the routing is completely wrong at runtime.
- num_out_tokens=None should be -1 (the API expects an int).
- moe_unpermute: replaced deprecated probs= with merging_probs= and
  kept routing_weights in float32 as TE recommends.
- utils.py: num_warmup_steps was hardcoded to 100 instead of using
  hyperparams.num_warmup_steps, which made the benchmark meaningless.
- requirements.txt: transformers==4.57.0 doesn't exist, fixed to 4.47.1.
- Notebook generalisation guide: updated code template with the same fixes.

Tested on 2xH100 in nvcr.io/nvidia/pytorch:26.01-py3 (PyTorch 2.10.0, CUDA 13.1):

  $ python3 -c "
  from te_mixtral import TEMixtralSparseMoeBlock
  from transformers import MixtralConfig
  from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
  import torch

  cfg = MixtralConfig(hidden_size=256, intermediate_size=512,
                      num_local_experts=4, num_experts_per_tok=2)
  x = torch.randn(2, 8, cfg.hidden_size, device='cuda', dtype=torch.bfloat16)
  hf_out, hf_logits = MixtralSparseMoeBlock(cfg).cuda().bfloat16()(x)
  te_out, te_logits = TEMixtralSparseMoeBlock(cfg).cuda().bfloat16()(x)
  assert hf_out.shape == te_out.shape
  assert hf_logits.shape == te_logits.shape
  print('PASS')"

  Input  shape : torch.Size([2, 8, 256])
  Output shape : torch.Size([2, 8, 256])  (matches HF: True)
  Logits shape : torch.Size([16, 4])  (matches HF: True)
  Output dtype : torch.bfloat16
  PASS

Signed-off-by: Faradawn Yang <73060648+faradawn@users.noreply.github.com>
The docs CI was failing because nbsphinx auto mode tries to execute
notebooks with no stored outputs. Add explicit execute:never metadata
so the docs builder renders the notebook as-is without running cells.

Signed-off-by: Faradawn Yang <73060648+faradawn@users.noreply.github.com>
…otebook

te_mixtral.py:
- Replace per-expert .item() loop with torch.bincount + .tolist() (8 GPU syncs
  → 1), eliminating the main cause of GroupedLinear speedup regression
- Remove unnecessary num_out_tokens/max_token_num args from moe_permute
- Fix replace_params to use .copy_() instead of fragile .data[] assignment,
  and load_state_dict from fully-populated te_state in one shot
- Add device_map="auto" support in from_hf_model via accelerate.dispatch_model
- Rename replace_params -> _pack_expert_weights for clarity

tutorial_accelerate_hf_mixtral_with_te.ipynb:
- Remove shape-check section (redundant)
- Add FP8 prefill and decode-regime benchmark cells with summary table
- Restructure training sections to match te_llama pattern: restart at top of
  each cell, combined hyperparams + init + wrap + finetune, concise markdown
- Bump benchmark SEQ 512 -> 2048 for realistic H100 workload

utils.py:
- Apply user batch size and sequence length adjustments

Signed-off-by: Faradawn Yang <73060648+faradawn@users.noreply.github.com>
for more information, see https://pre-commit.ci

Signed-off-by: Faradawn Yang <73060648+faradawn@users.noreply.github.com>
Signed-off-by: Faradawn Yang <73060648+faradawn@users.noreply.github.com>
Signed-off-by: Faradawn Yang <73060648+faradawn@users.noreply.github.com>
Signed-off-by: Faradawn Yang <73060648+faradawn@users.noreply.github.com>
Signed-off-by: Faradawn Yang <73060648+faradawn@users.noreply.github.com>
Comment thread docs/index.rst
@pggPL
Copy link
Copy Markdown
Collaborator

pggPL commented Apr 22, 2026

I was not reviewing it in detail, I would like to discuss on high-level flow of the tutorial first.
I see some issues:

  1. We use GroupedLinear with EP=8 and 8 experts, so 1 GroupedLinear is essentially having 1 expert, so it does not differ from using standard linear. Can we do EP=2 with 4 experts per each GPU or something like that.
  2. We use different types of parallelism in different tiers and it was not clear to me why. Maybe we can also use EP in HF if it is supported.
  3. The gains from Tier 1 -> Tier 2 are very huge, but Tier 2 -> Tier 3 -> Tier 4 are very small. Do you know why?

Apart of that it would be nice to prepare documentation (not tutorial) explaining why grouped linear is needed and which permute kernels we provide. I will work on that in different PR and it will land in Features section in our docs.

Pin transformers, accelerate, datasets, safetensors, huggingface_hub, and
tokenizers in requirements.txt to the exact versions validated on 8x B300
inside NGC pytorch-25.12-py3. torch, torchao, and transformer_engine are
left unpinned because they come from the NGC container and pinning them
would tie the tutorial to a specific container build.

Add a setup-cell note in the tutorial notebook calling out the tested
GPU (8x Blackwell B300) and container (pytorch-25.12-py3).

Addresses pggPL review comments on requirements.txt:
- "pin the versions of some of the frameworks"
- "specify GPU and maybe container to make sure the users can rerun it"

Signed-off-by: Faradawn Yang <73060648+faradawn@users.noreply.github.com>
Comment on lines +847 to +855
if isinstance(past_key_values, InferenceParams):
# input_ids is None when the caller supplies inputs_embeds directly.
_ref = input_ids if input_ids is not None else inputs_embeds
lengths = (
attention_mask.sum(dim=1).tolist()
if attention_mask.shape[:2] == _ref.shape[:2]
else [1] * _ref.shape[0]
)
past_key_values.pre_step(OrderedDict(zip(list(range(len(lengths))), lengths)))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 attention_mask is None crash inside InferenceParams branch

attention_mask is an optional parameter with no None-guard before line 852. In the decode step of model.generate() the mask is frequently omitted, and attention_mask.shape[:2] will raise AttributeError: 'NoneType' object has no attribute 'shape', crashing every auto-regressive generation call.

if isinstance(past_key_values, InferenceParams):
    _ref = input_ids if input_ids is not None else inputs_embeds
    lengths = (
        attention_mask.sum(dim=1).tolist()
        if attention_mask is not None and attention_mask.shape[:2] == _ref.shape[:2]
        else [1] * _ref.shape[0]
    )
    past_key_values.pre_step(OrderedDict(zip(list(range(len(lengths))), lengths)))

faradawn and others added 2 commits April 28, 2026 09:26
Adds a configurable expert FFN execution mode to NVMixtralSparseMoeBlock:
  - "grouped" (default): existing TE GroupedLinear path
  - "loop":   Python loop over experts with one F.linear per expert
              (HF-style), running against the same stacked weight tensor
              so swapping between modes is a runtime config flag, no
              checkpoint rework

Adds a new "Tier 2" to the tutorial's progressive optimization chain that
isolates EP+TE primitives without GroupedLinear, so the GroupedLinear
contribution can be measured independently. Renumbers run_finetune_ep.py:

  1 = HF baseline BF16 (device_map="auto")
  2 = TE EP BF16 -- Python loop, F.linear per expert  [new]
  3 = TE EP BF16 -- GroupedLinear                     (was Tier 2)
  4 = TE EP BF16 -- GroupedLinear + Fused DeepEP      (was Tier 3)
  5 = TE EP FP8  -- Float8CurrentScaling + grouped + DeepEP

Tier 5 switches the FP8 recipe from MXFP8BlockScaling to
Float8CurrentScaling -- the per-tensor scaling avoids the 32-aligned
tile constraint that MXFP8 imposes on grouped-matmul, so Tier 5 runs
cleanly at EP < num_experts where MXFP8 trips the per-expert padding
assertion.

Also threads the configuration through the harness:
  - HyperParameters.expert_ffn_mode
  - --ep-size default changed from 8 to 2 (4 experts/rank gives a
    non-degenerate loop-vs-grouped contrast; pass --ep-size 8 for the
    original 1-expert/rank config)
  - Allow WORLD_SIZE > EP via a 2D (DP, EP) DeviceMesh so EP=2 with 8
    GPUs (DP=4) works end-to-end
  - replace_params: use EP-rank-within-group (rank % ep_size) instead of
    global rank for the expert slicing -- otherwise ranks beyond ep_size
    slice past the end of the HF stacked weight tensor

Misc: cycle the dataloader (small benchmark datasets like
openassistant-guanaco exhaust at large batch * num_steps), and trim the
fine-tune summary output to median + last-step (was mean + median +
min/max/p99).

Signed-off-by: Faradawn Yang <73060648+faradawn@users.noreply.github.com>
Comment on lines +511 to +520
"""
gate_up_w = self.experts_gate_up_weight
if isinstance(gate_up_w, DTensor):
gate_up_w = gate_up_w.to_local()
for i in range(self.num_local_experts):
object.__setattr__(self.experts_gate_up, f"weight{i}", gate_up_w[i])

down_w = self.experts_down_weight
if isinstance(down_w, DTensor):
down_w = down_w.to_local()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Wrong DTensor check — GroupedLinear receives DTensor views under EP

self.experts_gate_up_weight is an nn.Parameter. After set_ep_group() wraps it with nn.Parameter(DTensor.from_local(...)), the parameter's type is still nn.Parameter (not DTensor), so isinstance(gate_up_w, DTensor) is always False. The guard never fires, to_local() is never called, and every subsequent gate_up_w[i] slice handed to object.__setattr__ is a DTensor. TE's GroupedLinear CUDA kernels cannot accept DTensor views and will either raise a type error or silently use incorrect data.

The _expert_ffn loop path already applies the correct pattern with .data:

if isinstance(gate_up_w.data, DTensor):
    gate_up_w = gate_up_w.data.to_local()
Suggested change
"""
gate_up_w = self.experts_gate_up_weight
if isinstance(gate_up_w, DTensor):
gate_up_w = gate_up_w.to_local()
for i in range(self.num_local_experts):
object.__setattr__(self.experts_gate_up, f"weight{i}", gate_up_w[i])
down_w = self.experts_down_weight
if isinstance(down_w, DTensor):
down_w = down_w.to_local()
gate_up_w = self.experts_gate_up_weight
if isinstance(gate_up_w.data, DTensor):
gate_up_w = gate_up_w.data.to_local()
for i in range(self.num_local_experts):
object.__setattr__(self.experts_gate_up, f"weight{i}", gate_up_w[i])
down_w = self.experts_down_weight
if isinstance(down_w.data, DTensor):
down_w = down_w.data.to_local()
for i in range(self.num_local_experts):
object.__setattr__(self.experts_down, f"weight{i}", down_w[i])

faradawn and others added 2 commits April 30, 2026 08:04
Adds EXPERIMENT_LOG.md with the GroupedLinear vs naive-loop investigation
(EP=2 batch sweep, batch=8 nsys + NVTX per-phase breakdown) and the
my_expert_ffn_{loop|grouped} / my_grouped_linear_{gate_up,down} NVTX
ranges that produced the per-phase numbers.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Comment thread docs/examples/te_mixtral/te_mixtral.py Outdated
Comment on lines +643 to +648
remainder = orig_num_tokens % 32
if remainder != 0:
pad_count = 32 - remainder
expert_input = torch.nn.functional.pad(expert_input, (0, 0, 0, pad_count))
tokens_per_expert = list(tokens_per_expert)
tokens_per_expert[-1] += pad_count
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Padding appended to the last expert's slot regardless of which expert has remaining tokens

When orig_num_tokens % 32 != 0, pad_count tokens are appended to expert_input and tokens_per_expert[-1] is incremented by pad_count. This is only correct if the padded tokens happen to logically belong to the last expert. If the last expert already has 0 tokens on this rank (a routine EP scenario), GroupedLinear will schedule a GEMM for pad_count zero-padded tokens against that expert's weights, inflating compute and potentially triggering a misaligned-size error rather than a clean no-op. A safer approach is to distribute the padding tokens as a trailing dummy entry (tokens_per_expert.append(pad_count)) or to pad at batch-preparation time before dispatch so padding is always aligned with a real expert slot.

@faradawn
Copy link
Copy Markdown
Contributor Author

Will wait on this PR: #2923

…p_proj/down_proj naming

Adds three drawio + SVG diagram pairs under docs/examples/te_mixtral/media/:
- dense_to_sparse: dense Transformer block vs sparse MoE block (router + 8 experts).
- mixtral_decoder_swap: Llama-tutorial-style API key->key map from HF MixtralDecoderLayer
  to TE NVMixtralDecoderLayer (fused QKV, packed expert tensors).
- moe_loop_vs_grouped: zoom into the MoE block; HF per-expert loop vs TE GroupedLinear
  with variable per-expert token counts (m_splits).

Adds a new "Architecture Overview" cell at the top of the tutorial notebook with the
two main figures, and renames the MoE weight references throughout the notebook prose
and tables from the older w1/w2/w3 form to the newer packed mlp.experts.gate_up_proj /
mlp.experts.down_proj keys (the form used by BioNemo's bionemo-recipes/models/mixtral/
convert.py and accepted by replace_params() in te_mixtral.py).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Comment on lines +60 to +85
`/lustre/share/coreai_prod_infbench/exp3_nvtx_per_phase/`.

---

All experiments run on a **SLURM cluster**. The login node has no CUDA, so any
experiment must execute inside an allocation on a compute node with 8 GPUs.

### Reserving a fresh node (8x GPU, 4h, exclusive)

```bash
srun -A coreai_prod_infbench \
-p batch \
-N 1 \
-J myjob \
-t 04:00:00 \
--exclusive \
--mpi=pmix \
--container-image=/lustre/fsw/coreai_prod_infbench/faradawny/docker/pytorch-25.12-py3.sqsh \
--container-save=/lustre/fsw/coreai_prod_infbench/faradawny/docker/pytorch-25.12-py3.sqsh \
--container-name=mycontainer \
--container-mounts=/lustre:/lustre \
--pty bash
```

This drops you into an interactive shell inside the NGC `pytorch-25.12-py3`
container with `/lustre` mounted and the container saved as `mycontainer`.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Internal cluster details in a public repository

EXPERIMENT_LOG.md is a personal working-notes file that contains NVIDIA-internal infrastructure details that should not be committed to a public GitHub repo:

  • A developer username embedded in paths (faradawny): /lustre/fsw/coreai_prod_infbench/faradawny/docker/pytorch-25.12-py3.sqsh
  • An internal SLURM account: coreai_prod_infbench
  • Internal shared storage paths: /lustre/share/coreai_prod_infbench/exp3_nvtx_per_phase/

This file looks like a benchmark log intended for internal coordination, not documentation for external tutorial users. Consider removing it entirely or replacing it with a trimmed BENCHMARKS.md that shows only public results (median step times, speedup numbers) without cluster-specific details.

faradawn and others added 3 commits May 6, 2026 16:01
Signed-off-by: Faradawn Yang <73060648+faradawn@users.noreply.github.com>
Add reproducible Mixtral EP tiers for TE Sequential Ops GroupedLinear, including FP8/MXFP8 variants and the dependency lockfile needed to rerun the experiments.

Co-authored-by: Cursor <cursoragent@cursor.com>
@@ -0,0 +1,275 @@
{
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Notebook Tier bash commands misaligned with run_finetune_ep.py improvement indices

The bash commands shown for Tiers 2–4 reference --improvement values that map to different configurations in run_finetune_ep.py than the notebook describes. Concretely:

Notebook Tier Described as Bash uses run_finetune_ep.py actually does
2 TE GroupedLinear EP --improvement 2 expert_ffn_mode="loop" (Python F.linear, not GroupedLinear)
3 Fused DeepEP dispatcher (BF16) --improvement 3 expert_ffn_mode="grouped_op", no DeepEP
4 MXFP8 precision --improvement 4 DeepEP BF16 GroupedLinear, not MXFP8

Additionally, every multi-GPU bash command omits --ep-size 8, so the default --ep-size 2 (4 experts/rank) is used, whereas the Python cells in each tier all set hp.expert_parallel_size = 8 (1 expert/rank). The benchmark step-time numbers in the results tables were measured at EP=8, so users following the bash commands will reproduce different performance characteristics.

The correct mapping from notebook tier to --improvement flag appears to be: Tier 2 → --improvement 6, Tier 3 → --improvement 4, Tier 4 → --improvement 8 (all with --ep-size 8 added).

faradawn and others added 3 commits May 7, 2026 10:14
Signed-off-by: Faradawn Yang <73060648+faradawn@users.noreply.github.com>
Comment on lines +1 to +10
torch
torchao!=0.14.0
transformer_engine[pytorch]

transformers
accelerate
datasets
safetensors
huggingface_hub
tokenizers
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Missing flash-attn dependency

Both init_baseline_model and init_te_mixtral_model in utils.py explicitly set config._attn_implementation = "flash_attention_2". HuggingFace transformers validates this at from_pretrained time and raises ImportError if the flash-attn package is absent. Every tutorial user following the baseline or TE fine-tuning paths hits this error before the first training step. Add flash-attn to requirements.txt (the flash-attn PyPI package).

faradawn added 2 commits May 8, 2026 07:24
Signed-off-by: Faradawn Yang <73060648+faradawn@users.noreply.github.com>
Comment on lines +236 to +237
model.load_state_dict(te_state_dict, strict=False)
del hf_model
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Silent weight-mapping failures in init_te_mixtral_model

load_state_dict(strict=False) swallows all missing and unexpected keys, so if replace_params fails to map any parameter (e.g., due to a checkpoint key-name mismatch), the model silently trains with random initialization. test_accuracy.py shows the correct pattern: capture the return value and assert that the only missing keys are TE-internal _extra_state entries.

Suggested change
model.load_state_dict(te_state_dict, strict=False)
del hf_model
missing, unexpected = model.load_state_dict(te_state_dict, strict=False)
if unexpected:
raise RuntimeError(f"Unexpected keys when loading TE state dict: {unexpected}")
non_extra_missing = [k for k in missing if not k.endswith("_extra_state")]
if non_extra_missing:
raise RuntimeError(f"Missing non-extra-state keys in TE model: {non_extra_missing}")
del hf_model

Signed-off-by: Faradawn Yang <73060648+faradawn@users.noreply.github.com>
Comment on lines +979 to +984
should_pack_inputs = not any(has_thd_input) and self.config.attn_input_format == "thd"

if should_pack_inputs:
assert (
attention_mask is not None
), "Attention mask is required when packing BSHD inputs."
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Decode-step crash with default attn_input_format="thd" when attention_mask is omitted

should_pack_inputs is True whenever attn_input_format == "thd" and no explicit THD kwargs are supplied, which is the normal case for every decode step during generation. In that path the assertion fires immediately because HF generate() typically does not propagate attention_mask on decode steps (sequence length 1). Adding a guard like past_key_values is None (or hidden_states.size(1) > 1) to should_pack_inputs would skip packing for single-token decode steps where the sequence is already effectively packed.

routing_weights_for_unpermute = routing_weights
map_type = "index"

if self._ep_group is not None:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Non-standard sentinel for no-token-dropping in moe_permute

The TE API documents -1 as the sentinel meaning "no token dropping"; the value used here (0) only happens to work because the fake/shape-inference path guards with num_out_tokens > 0, and 0 fails that check. If a future TE CUDA kernel instead treats 0 as a literal output-token count, permuted_hidden becomes empty and every all-to-all dispatch silently sends no tokens to any expert. Use -1 to match the documented contract.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants