Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
456458e
Add Mixtral MoE integration tutorial with TE GroupedLinear
faradawn Feb 2, 2026
1df2360
Add full Mixtral MoE tutorial with TE GroupedLinear integration
faradawn Apr 2, 2026
d8a5ee2
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 2, 2026
0cf250b
Fix routing bugs and API issues in TEMixtralSparseMoeBlock
faradawn Apr 2, 2026
e6df830
Fix docs build: add nbsphinx execute:never to mixtral notebook
faradawn Apr 13, 2026
6a00cd2
Revise te_mixtral: fix perf, add FP8/decode benchmarks, restructure n…
faradawn Apr 13, 2026
83358cb
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 13, 2026
c3c2e18
add memory calculation
faradawn Apr 14, 2026
3497ebb
docs(te_mixtral): rename example script to test_accuracy
faradawn Apr 14, 2026
37bb13f
add replace param
faradawn Apr 14, 2026
cbddd45
smoke test of ep worked
faradawn Apr 14, 2026
af1bd54
add expert-parallel mixtral finetune workflow and launcher
faradawn Apr 14, 2026
c35ee17
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 14, 2026
1ae1e38
fix mixtral ep finetune path and distributed setup
faradawn Apr 14, 2026
38f06bf
align EP finetune with BioNeMo FP8 data pipeline
faradawn Apr 15, 2026
b04b836
both dispatcher tests passed
faradawn Apr 16, 2026
bd13e77
add FusedTokenRouter (DeepEP) dispatcher for 3-tier EP performance be…
faradawn Apr 16, 2026
0e85484
change to mxfp8 and increase seq len
faradawn Apr 16, 2026
6af087b
add mxfp8 data and finalize notebook
faradawn Apr 16, 2026
8230caf
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 14, 2026
f8b8a81
docs(te_mixtral): add collator and update tutorial utils
faradawn Apr 16, 2026
4f49eaf
fix: update copyright headers to match repo standard and add Python 3…
faradawn Apr 17, 2026
6155d2b
fix: MXFP8 alignment padding and update tutorial with seq=256 results
faradawn Apr 18, 2026
1687649
docs(te_mixtral): update notebook results and add presentation
faradawn Apr 20, 2026
ecf66e9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 20, 2026
590814a
fix(te_mixtral): resolve docs build, vermin, and notebook cleanup
faradawn Apr 21, 2026
090219c
docs(te_mixtral): add Mixtral tutorial to Sphinx toctree
faradawn Apr 21, 2026
4e88c1d
fix(te_mixtral): remove unused log_with=wandb from Accelerator
faradawn Apr 21, 2026
4da6b69
fix(te_mixtral): inputs_embeds null-safety + per-step steady-state ti…
faradawn Apr 22, 2026
8fede08
docs(te_mixtral): pin HF framework versions and document container
faradawn Apr 22, 2026
7e3656a
feat(te_mixtral): add Tier 2 naive expert loop + 5-tier renumbering
faradawn Apr 28, 2026
2449f1a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 28, 2026
8f207bd
docs(te_mixtral): add experiment log + NVTX markers in _expert_ffn
faradawn Apr 30, 2026
6a69bed
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 30, 2026
5f122e3
docs(te_mixtral): add architecture diagrams + switch to packed gate_u…
faradawn May 2, 2026
3639ac9
enlarge font in the graphs
faradawn May 6, 2026
ad1ddcc
docs(te_mixtral): add sequential groupedlinear tiers
faradawn May 7, 2026
e175c5a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 7, 2026
1547a30
align the tone to llama tutorial
faradawn May 7, 2026
4690198
Merge branch 'add-moe-example' of github.com:faradawn/TransformerEngi…
faradawn May 7, 2026
bfc28c4
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 7, 2026
835a865
Merge branch 'main' of https://github.com/NVIDIA/TransformerEngine in…
faradawn May 8, 2026
10d8176
fix mxfp8 padding
faradawn May 8, 2026
f799690
add pad to 128
faradawn May 8, 2026
2118c29
Fix padding for normal permute path
faradawn May 9, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
338 changes: 338 additions & 0 deletions docs/examples/te_mixtral/EXPERIMENT_LOG.md

Large diffs are not rendered by default.

151 changes: 151 additions & 0 deletions docs/examples/te_mixtral/collator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

"""Data collator for THD sequence packing (variable-length flash attention).

Adapted from bionemo-recipes. Only the subset needed by this tutorial is included.
"""

import logging
from dataclasses import dataclass
from typing import Any

import torch
from transformers import DataCollatorForLanguageModeling


logger = logging.getLogger(__name__)


def _pt_flatten_collate(features: list[dict[str, list[int]]], return_position_ids: bool = False):
"""Flatten a list of tokenized samples into a single packed batch with cumulative sequence lengths."""
is_labels_provided = "labels" in features[0]
sample_lengths = [len(sample["input_ids"]) for sample in features]

batch = {}
batch["max_length_q"] = batch["max_length_k"] = max(sample_lengths)
batch["input_ids"] = torch.tensor(
[[token for sample in features for token in sample["input_ids"]]], dtype=torch.int64
)
if is_labels_provided:
batch["labels"] = torch.tensor(
[[label for sample in features for label in sample["labels"]]], dtype=torch.int64
)
cu_seq_lens = torch.zeros(len(features) + 1, dtype=torch.int32)
cu_seq_lens[1:] = torch.cumsum(torch.tensor(sample_lengths), dim=0, dtype=torch.int32)
batch["cu_seq_lens_q"] = batch["cu_seq_lens_k"] = cu_seq_lens
if "attention_mask" in features[0]:
batch["attention_mask"] = torch.tensor(
[[v for sample in features for v in sample["attention_mask"]]], dtype=torch.int64
)
if return_position_ids:
batch["position_ids"] = torch.hstack(
[torch.arange(sample_len, dtype=torch.int64) for sample_len in sample_lengths]
).unsqueeze(0)

return batch


def _pt_pad_to_multiple_of(
batch: dict[str, Any], pad_to_multiple_of: int, token_pad: int, label_pad: int
):
"""Pad a batch to a multiple of ``pad_to_multiple_of`` by appending a mock sequence."""
remainder = -batch["input_ids"].numel() % pad_to_multiple_of
if remainder == 0:
return batch

batch["input_ids"] = torch.cat(
[batch["input_ids"], torch.full((1, remainder), token_pad, dtype=batch["input_ids"].dtype)],
dim=1,
)
if "labels" in batch:
batch["labels"] = torch.cat(
[batch["labels"], torch.full((1, remainder), label_pad, dtype=batch["labels"].dtype)],
dim=1,
)
if "cu_seq_lens_q" in batch:
batch["cu_seq_lens_q"] = torch.cat(
[
batch["cu_seq_lens_q"],
torch.tensor(
[batch["cu_seq_lens_q"][-1] + remainder], dtype=batch["cu_seq_lens_q"].dtype
),
],
dim=0,
)
batch["cu_seq_lens_k"] = batch["cu_seq_lens_q"]
if "max_length_q" in batch:
batch["max_length_q"] = max(batch["max_length_q"], remainder)
batch["max_length_k"] = batch["max_length_q"]
if "attention_mask" in batch:
batch["attention_mask"] = torch.cat(
[
batch["attention_mask"],
torch.zeros((1, remainder), dtype=batch["attention_mask"].dtype),
],
dim=1,
)
if "position_ids" in batch:
batch["position_ids"] = torch.cat(
[
batch["position_ids"],
torch.arange(remainder, dtype=batch["position_ids"].dtype).unsqueeze(0),
],
dim=1,
)

return batch


@dataclass
class DataCollatorWithFlattening:
"""Data collator that flattens variable-length sequences into a single packed tensor for flash attention.

Wraps a ``DataCollatorForLanguageModeling`` and produces THD-format batches with
``cu_seq_lens_q`` / ``cu_seq_lens_k`` metadata for TE's fused attention kernels.

Args:
collator: The base collator for MLM/CLM masking.
pad_to_multiple_of: If set, pads the total token count to be divisible by this number.
separator_id: Label value inserted at sequence boundaries (typically -100 for causal LM).
"""

collator: DataCollatorForLanguageModeling
pad_to_multiple_of: int | None = None
separator_id: int | None = None

def __call__(self, features, return_tensors=None):
"""Pack features into a single THD batch with flash-attention metadata."""
if return_tensors is not None and return_tensors != "pt":
raise NotImplementedError(
f"Only return_tensors='pt' is supported, got '{return_tensors}'"
)

bshd_batch = self.collator(features, return_tensors=return_tensors)
packed_batch = _pt_flatten_collate(features)

masked_input_ids = bshd_batch["input_ids"][bshd_batch["attention_mask"].bool()].unsqueeze(0)
masked_labels = bshd_batch["labels"][bshd_batch["attention_mask"].bool()].unsqueeze(0)

if self.separator_id is not None:
masked_labels[:, packed_batch["cu_seq_lens_q"][1:-1]] = self.separator_id

packed_batch["input_ids"] = masked_input_ids
packed_batch["labels"] = masked_labels

if self.pad_to_multiple_of is not None:
pad_token_id = self.collator.tokenizer.pad_token_id
if not isinstance(pad_token_id, int):
logger.warning(
f"tokenizer.pad_token_id is not an integer, using 1 instead: {pad_token_id}"
)
pad_token_id = 1
packed_batch = _pt_pad_to_multiple_of(
packed_batch,
self.pad_to_multiple_of,
token_pad=pad_token_id,
label_pad=-100,
)

return packed_batch
240 changes: 240 additions & 0 deletions docs/examples/te_mixtral/fused_a2a.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,240 @@
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

# Portions of this code are from DeepSeek DeepEP project
# Copyright (c) 2025 DeepSeek
# Licensed under the MIT License - https://github.com/deepseek-ai/DeepEP/blob/main/LICENSE

"""DeepEP fused dispatch/combine wrapped in differentiable autograd functions."""

import os


try:
from deep_ep import Buffer
from deep_ep.utils import EventHandle, EventOverlap

HAVE_DEEP_EP = True
Buffer.set_num_sms(int(os.environ.get("DEEP_EP_SM_NUMS", "20")))
except ImportError:
HAVE_DEEP_EP = False

import torch


_buffer = None
_nvshmem_available = None


def _is_nvshmem_available() -> bool:
"""Check if DeepEP was compiled with NVSHMEM support."""
global _nvshmem_available # noqa: PLW0603
if _nvshmem_available is None:
try:
config = Buffer.get_dispatch_config(2)
config.get_rdma_buffer_size_hint(256, 2)
_nvshmem_available = True
except RuntimeError:
_nvshmem_available = False
return _nvshmem_available


def get_hidden_bytes(x: torch.Tensor) -> int:
"""Calculate the number of hidden bytes for a tensor."""
return x.size(1) * max(x.element_size(), 2)


def get_buffer(group: torch.distributed.ProcessGroup, hidden_bytes: int):
"""Get or create a DeepEP buffer for all-to-all communication."""
global _buffer # noqa: PLW0603
num_nvl_bytes, num_rdma_bytes = 0, 0
nvshmem = _is_nvshmem_available()
for config in (
Buffer.get_dispatch_config(group.size()),
Buffer.get_combine_config(group.size()),
):
num_nvl_bytes = max(
config.get_nvl_buffer_size_hint(hidden_bytes, group.size()), num_nvl_bytes
)
if nvshmem:
num_rdma_bytes = max(
config.get_rdma_buffer_size_hint(hidden_bytes, group.size()), num_rdma_bytes
)

if (
_buffer is None
or _buffer.group != group
or _buffer.num_nvl_bytes < num_nvl_bytes
or _buffer.num_rdma_bytes < num_rdma_bytes
):
_buffer = Buffer(group, num_nvl_bytes, num_rdma_bytes)
return _buffer


class FusedDispatch(torch.autograd.Function):
"""Fused dispatch operation for MoE routing combining computation and communication."""

@staticmethod
def forward(
ctx,
x,
token_indices,
token_probs,
num_experts,
group,
async_finish=False,
allocate_on_comm_stream=False,
):
"""Forward pass of fused dispatch."""
previous_event = None
if async_finish:
previous_event = EventOverlap(EventHandle())
buffer = get_buffer(group, get_hidden_bytes(x))
(
num_tokens_per_rank,
num_tokens_per_rdma_rank,
num_tokens_per_expert,
is_token_in_rank,
event,
) = buffer.get_dispatch_layout(
token_indices,
num_experts,
previous_event=previous_event,
async_finish=async_finish,
allocate_on_comm_stream=allocate_on_comm_stream,
)

(
recv_x,
recv_token_indices,
recv_token_probs,
num_recv_tokens_per_expert_list,
handle,
after_event_overlap,
) = buffer.dispatch(
x,
topk_idx=token_indices,
topk_weights=token_probs,
num_tokens_per_rank=num_tokens_per_rank,
num_tokens_per_rdma_rank=num_tokens_per_rdma_rank,
is_token_in_rank=is_token_in_rank,
num_tokens_per_expert=num_tokens_per_expert,
previous_event=event,
async_finish=async_finish,
allocate_on_comm_stream=allocate_on_comm_stream,
)

if async_finish:
after_event_overlap.current_stream_wait()

ctx.group = group
ctx.handle = handle
ctx.async_finish = async_finish
ctx.allocate_on_comm_stream = allocate_on_comm_stream
tokens_per_expert = torch.tensor(num_recv_tokens_per_expert_list)

return (recv_x, recv_token_indices, recv_token_probs, tokens_per_expert, handle)

@staticmethod
def backward(
ctx,
grad_output,
grad_token_indices,
grad_token_probs,
grad_tokens_per_expert,
grad_handle,
):
"""Backward pass of fused dispatch."""
buffer = get_buffer(ctx.group, get_hidden_bytes(grad_output))
handle = ctx.handle
previous_event = None
if ctx.async_finish:
previous_event = EventOverlap(EventHandle())
grad_x, grad_token_probs, after_event = buffer.combine(
grad_output.contiguous(),
handle,
topk_weights=grad_token_probs.float(),
previous_event=previous_event,
async_finish=ctx.async_finish,
allocate_on_comm_stream=ctx.allocate_on_comm_stream,
)
if ctx.async_finish:
after_event.current_stream_wait()
return grad_x, None, grad_token_probs, None, None, None, None


class FusedCombine(torch.autograd.Function):
"""Fused combine operation for MoE output combining computation and communication."""

@staticmethod
def forward(ctx, x, group, handle, async_finish=False, allocate_on_comm_stream=False):
"""Forward pass of fused combine."""
previous_event = None
if async_finish:
previous_event = EventOverlap(EventHandle())
buffer = get_buffer(group, get_hidden_bytes(x))
combined_x, _, after_event = buffer.combine(
x,
handle=handle,
async_finish=async_finish,
previous_event=previous_event,
allocate_on_comm_stream=allocate_on_comm_stream,
)
if async_finish:
after_event.current_stream_wait()

ctx.handle = handle
ctx.group = group
ctx.async_finish = async_finish
ctx.allocate_on_comm_stream = allocate_on_comm_stream
return combined_x, None

@staticmethod
def backward(ctx, grad_output, previous_event=None):
"""Backward pass of fused combine."""
previous_event = None
if ctx.async_finish:
previous_event = EventOverlap(EventHandle())
buffer = get_buffer(ctx.group, get_hidden_bytes(grad_output))
grad_x, _, _, _, _, after_event = buffer.dispatch(
grad_output.contiguous(),
handle=ctx.handle,
previous_event=previous_event,
async_finish=ctx.async_finish,
allocate_on_comm_stream=ctx.allocate_on_comm_stream,
)
if ctx.async_finish:
after_event.current_stream_wait()
return grad_x, None, None, None, None


if HAVE_DEEP_EP:

def fused_dispatch(
x,
token_indices,
token_probs,
num_experts,
group,
async_finish=False,
allocate_on_comm_stream=False,
):
"""Perform fused dispatch operation via DeepEP."""
return FusedDispatch.apply(
x.contiguous(),
token_indices,
token_probs,
num_experts,
group,
async_finish,
allocate_on_comm_stream,
)

def fused_combine(x, group, handle, async_finish=False, allocate_on_comm_stream=False):
"""Perform fused combine operation via DeepEP."""
return FusedCombine.apply(x, group, handle, async_finish, allocate_on_comm_stream)

else:
fused_dispatch = None
fused_combine = None
Loading
Loading