Skip to content

Add Tealite: pure-Python TransformerEngine for ROCm/AMD GPUs#581

Open
jayfurmanek wants to merge 103 commits intodevfrom
furmanek/dev-lite
Open

Add Tealite: pure-Python TransformerEngine for ROCm/AMD GPUs#581
jayfurmanek wants to merge 103 commits intodevfrom
furmanek/dev-lite

Conversation

@jayfurmanek
Copy link
Copy Markdown
Contributor

Description

Introduces tealite — a pure-Python, drop-in replacement for the
transformer_engine_torch C++ extension, targeting ROCm / AMD GPUs and
PyTorch only. Activated by setting NVTE_LITE=1; when unset, none of the
lite code is imported and full TE behavior is unchanged.

Why: the full build compiles hundreds of C++ / CUDA / HIP sources via
CMake, couples tightly to toolchain versions, and slows down rapid kernel
iteration on ROCm. Tealite sidesteps all of that by dispatching to
AITER kernels (CK + Triton, installed via
pip install amd-aiter — no submodule), standalone Triton kernels,
torch._scaled_mm (hipBLASLt-backed) for FP8 GEMMs, and PyTorch-native
fallbacks. Build a wheel in seconds, not minutes.

Note I will squash upon review as needed.

Type of change

  • New feature (non-breaking change which adds functionality)
  • Documentation change

Changes

New _lite/ package (27 files, ~14k LOC) at
transformer_engine/pytorch/_lite/:

  • GEMM (gemm.py, grouped_gemm.py): pluggable backend dispatcher
    (NVTE_LITE_GEMM_BACKEND = pytorch / triton / ck); FP8 routes
    through torch._scaled_mm with AITER fallback for shapes hipBLASLt can't
    serve (per-row reduction-axis, K-not-div-16, unsupported dtype combos).
  • Norms (norms.py): RMSNorm/LayerNorm via AITER Triton with TE-Triton
    and PyTorch fallbacks. Per-row dynamic-quant fusion
    (rmsnorm2d_fwd_with_dynamicquant) is a lite-only CurrentScaling
    optimization not available in the full build.
  • Quantization (quantize.py): FP8 / MXFP8 / MXFP4 cast, transpose,
    block-scaling. Fused amax/scale update via a single Triton
    multi-tensor-apply kernel mirroring delayed_scaling.cu::kernel_bulk
    (replaces ~530k tiny elementwise launches/iter).
  • Attention (attention.py): SDPA, AITER CK (fmha_v3_fwd/fmha_v3_bwd
    AOT), and flash-attn package backends.
  • Activations (activations.py): AITER fused gated activations
    (swiglu, geglu, reglu, qgeglu) with PyTorch fallbacks.
  • RoPE (rope.py): AITER Triton kernels with CP support.
  • Compound modules (fused_layernorm_linear.py,
    fused_layernorm_mlp.py): pure-Python torch.autograd.Function
    subclasses with full DelayedScaling / CurrentScaling / MXFP8 support and
    FSDP2 integration via the inherited FSDPAGTensor wrap.
  • Distributed: comm.py, mori_ep.py (MORI-based expert parallelism
    for MoE), context_parallel.py (THD/BSHD helpers; CP is wired in
    attention + RoPE).
  • Other: dropout, softmax, padding, transpose, multi-tensor optimizer
    ops, router (fused TopK + softmax), permutation, MoE grouped GEMM.

Minimal full-build touchpoints (5 files, ~265 LOC) — all env-gated or
inert when NVTE_LITE is unset:

File Lines What
module/__init__.py +7 env-gated swap of LayerNormLinear / LayerNormMLP for lite versions when NVTE_LITE=1
module/base.py +66/-13 FSDPAGTensor weight-wrap plumbing (gated on IS_HIP_EXTENSION)
quantization.py +2 _contig_diag.tick_step() hook for the optional materialize-attribution harness (no-op when env unset)
triton/fused_router.py +183 (new) fused TopK + softmax router
triton_kernels/grouped_gemm.py +9/-4 grouped GEMM helper tweaks

Documentation: in-tree _lite/README.md (full feature matrix, env
vars, status notes per subsystem, known gaps) and _lite/SKILLS.md
(operational tips).

Performance

End-to-end FP8 training on 8×MI300X / LLaMA-3-8B / seq=2048 at the
same TE commit (2026-05-01):

Mode iter tok/GPU/s
full 1712 ms 9567
lite 1712 ms 9569

At parity. Lite uses the same hipBLASLt FP8 kernels via torch._scaled_mm
for fwd + dgrad, AITER Triton for wgrad shapes hipBLASLt can't serve, and
the fused Triton amax/scale update.

Known limitations (documented in _lite/README.md)

  • No tensor parallelism / Megatron-style sequence parallelism — the
    compound modules accept tp_size/tp_group/parallel_mode/
    sequence_parallel for API compat but hardcode tp_size=1. The
    multi-node story for tealite is FSDP/HSDP-shaped, not TP-shaped.
  • HSDP (2D device mesh) not plumbed — FSDP2 with a 1D mesh works
    and is tested.
  • MoE FP8 path — BF16 grouped works via the tex hot-swap into
    _lite/grouped_gemm.py; FP8 grouped is blocked on a Triton GMM dtype
    mismatch and aiter.fused_moe wiring.
  • No comm-overlap / userbuffers — stubs raise.
  • No cuDNN backend, no pre-tuned hidden sizes — auto-tune only.
  • Per-row CurrentScaling LayerNorm — RMSNorm only (AITER doesn't ship
    a layernorm2d_fwd_with_dynamicquant).

Activation

pip install amd-aiter
NVTE_LITE=1 python your_training_script.py

Checklist:

  • The functionality is complete (for the documented scope)
  • I have made corresponding changes to the documentation
    (_lite/README.md, _lite/SKILLS.md)
  • I have added tests that prove my feature works
    (tests/pytorch/test_lite.py, 4967 lines — GEMM backend matrix,
    FP8 dispatch, per-row CurrentScaling end-to-end, MoE smoke,
    FSDP2 wrap verification, attention backends, norms + quant,
    activations, multi-tensor ops)
  • New and existing unit tests pass locally with my changes
  • My changes generate no new warnings (when NVTE_LITE=1)
  • I have read and followed the contributing guidelines

jayfurmanek and others added 30 commits April 6, 2026 15:44
Introduces transformer_engine/pytorch/_lite/ package that provides a drop-in
replacement for the compiled transformer_engine_torch C++ extension module.
When NVTE_LITE=1 is set, the lite module is registered via sys.modules,
transparently replacing all tex.* calls with Triton/AITER/PyTorch-native
implementations. This eliminates the need for C++ compilation and reduces
ROCm/HIP dependencies while retaining functional correctness.

Phase 0 scaffold: 18 files covering enums, activations, norms, GEMM,
softmax, attention (stubbed), RoPE, dropout, transpose, quantization,
permutation, multi-tensor ops, MOE router, comm stubs, and padding.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…ication

- Fix GEMM transpose logic for cuBLAS column-major convention (B @ A order)
- Fix GEMM return signature to 4-tuple (output, bias_grad, gelu_input, extra_output)
- Fix GEMM bias handling: forward adds bias, backward computes bias_grad from grad_output
- Fix rmsnorm_fwd to return 3 values matching C++ signature
- Fix layernorm_bwd/rmsnorm_bwd signatures to include sm_margin parameter
- Route get_fused_attn_backend to No_Backend (unfused SDPA) until Phase 3

Verified on MI300X: forward+backward pass for Linear, LayerNormLinear,
LayerNormMLP, LayerNorm, RMSNorm. TransformerLayer forward works;
backward needs Phase 1 fix for autograd Variable issue.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
20 tests covering import verification, forward pass, forward+backward,
and numerical correctness for Linear, LayerNormLinear, LayerNormMLP,
LayerNorm, RMSNorm, and TransformerLayer under NVTE_LITE=1.

TransformerLayer backward is marked xfail pending Phase 1 fix.

Run with: NVTE_LITE=1 pytest tests/pytorch/test_lite.py -v

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Replace PyTorch-native norm placeholders with calls to existing Triton
kernels (triton_kernels/norms_common.py) for LayerNorm and RMSNorm
fwd/bwd. Lazy-imports Triton at first call with automatic fallback to
PyTorch if Triton is unavailable. Handles >2D input via reshape.

Also stubs out AITER dispatch paths in _lite/gemm.py for generic_gemm
and te_general_grouped_gemm (wiring deferred to AITER integration phase).

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Implement Triton dispatch for quantize/dequantize in lite mode by
importing low-level Triton cast kernels directly (te_cast_transpose_noop,
mxfp8, mxfp4) instead of going through cast.py which has a tex.quantize
fallback that would infinitely recurse in lite mode.

Key fixes:
- Break quantize recursion: old code did quantizer.quantize() ->
  Float8Quantizer.quantize_impl() -> tex.quantize() -> infinite loop
- FP8 dequantize: properly reinterpret uint8 data as FP8 bits via
  .view(fp8_dtype) before casting to target dtype
- Plain tensor dequantize: check isinstance(torch.Tensor) first to
  avoid the no-op .dequantize() trap that ignores otype

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Replace nested Python loops with vectorized PyTorch ops for
fp8_block_scaling_compute_partial_amax and fp8_block_scaling_partial_cast.
Uses pad-reshape-reduce pattern: pad to block-aligned shape, reshape into
(num_blocks_h, block_len, num_blocks_w, block_len), then amax or
broadcast-multiply over block dims. Eliminates O(blocks) kernel launches.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Write 2D-tiled Triton kernels for fp8_block_scaling_compute_partial_amax
and fp8_block_scaling_partial_cast. Each program processes one
(BLOCK_LEN x BLOCK_LEN) block, loading TILE_ROWS x BLOCK_LEN elements
per iteration for full intra-block parallelism. Autotuned over TILE_ROWS
(4/8/16/32) and num_warps (4/8) to match the fused C++/CUDA kernel's
single-launch, fully-parallel design. Falls back to vectorized PyTorch
when Triton is unavailable.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
True single-pass fusion would require merging bgrad accumulation into
the cast kernel. The separate sum + quantize path is already efficient
since both dispatch to optimized CUDA/Triton kernels individually.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Wire AITER as an optional pip dependency (amd-aiter) for CK/Triton
kernel dispatch in lite mode:

- gemm.py: Multi-backend GEMM dispatch controlled by NVTE_LITE_GEMM_BACKEND
  env var ("ck", "triton", "pytorch"). Supports all precisions:
  - FP8 per-tensor: CK gemm_a8w8_CK / Triton gemm_a8w8
  - FP8 block-scale: CK gemm_a8w8_blockscale / Triton gemm_a8w8_blockscale
  - FP8 2D-blockwise (Float8BlockwiseQTensorStorage): routes to block-scale
    kernels with proper rowwise/columnwise data extraction
  - FP4 (MXFP4): CK gemm_a4w4 / Triton gemm_afp4wfp4
  - BF16/FP16: Triton gemm_a16w16 (no CK individual GEMM)
  - FP32: torch.matmul (preserves exact precision)
  Auto-detects per-tensor vs block-scale from scale tensor shape.
  Each backend falls through to the next on failure.
- activations.py: AITER fused gated activations (swiglu -> silu_and_mul,
  geglu -> gelu_tanh_and_mul). Non-gated activations stay PyTorch.
- rope.py: Refactored to use shared aiter_utils instead of per-file flags.
- aiter_utils.py: New shared AITER availability detection with lru_cache.

Tested with amd-aiter 0.1.7 on all backends (ck, triton, pytorch):
all 37 tests pass + 1 xfail. Also passes without AITER installed.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Build a pure-Python wheel with no C++ compilation:
  NVTE_LITE_ONLY=1 python setup.py bdist_wheel

Builds in <1 second (vs 5-10 minutes for full build). The wheel:
- Contains only Python files + Triton kernels (no .so files)
- Platform tag: py3-none-any (architecture-independent)
- Package name: tealite
- Writes LITE_BUILD marker file that forces NVTE_LITE=1 at import

Build changes:
- setup.py: Skip C++ extensions, CMake, submodule checks, hipify when
  NVTE_LITE_ONLY=1. Write LITE_BUILD marker into package.
- common/__init__.py: Detect LITE_BUILD marker at module level to
  auto-activate lite mode (skip core library loading).
- .gitignore: Exclude LITE_BUILD marker from version control.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
… stub, and SDPA fallback

Replace Phase 3 TODO stubs with full multi-backend attention dispatch:
- AITER CK/ASM kernels via raw _flash_attn_forward/_backward (priority backend)
- Flash-attention stubbed for future integration
- PyTorch SDPA fallback with autograd-based backward
- Pure PyTorch helpers: fa_prepare_fwd/bwd, copy_to_kv_cache, THD<->BSHD converters
- C++ binding-compatible signatures so cpp_extensions/fused_attn.py works unmodified

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
19 attention tests: backend selection, fwd/bwd shapes for bshd/sbhd/thd,
aux_ctx_tensors format, mask types, GQA, variable-length, AITER-vs-SDPA
numerical comparison, helper functions, DotProductAttention and
MultiheadAttention end-to-end.

14 GEMM tests: TN layout, all transpose combos, bias addition, bias grad,
GELU epilogue, accumulate, alpha scaling, output-into-D, return format, FP32.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…in _lite

Fix _lite/permutation.py signatures to match the tex.* C++ interface so the
index-map path in permutation.py works correctly in lite mode. Wire up the
existing Triton sort_chunks_by_map kernel for the forward gather operation
with PyTorch fallback. Fix _lite/padding.py to match the 4-argument
tex.fused_multi_row_padding/unpadding interface with proper zero-padding.
Add 22 tests covering both MoE permutation and padding operations.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…lysis

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Integrate AMD's MORI (Modular RDMA Interface) library to provide
high-performance distributed expert parallelism for the _lite module.
This bridges the most significant distributed gap in tealite by enabling
MoE token dispatch/combine across GPUs via XGMI (intra-node) and RDMA
(inter-node) without requiring C++ extensions.

Key components:
- MoriExpertParallel: high-level wrapper with dispatch/combine for
  both flat and standard MoE (per-expert grouped) layouts
- MoriEPDispatch/MoriEPCombine: autograd functions enabling gradient
  flow through distributed dispatch/combine for training
- MoriEPDispatchStdMoE/MoriEPCombineStdMoE: autograd functions for
  the per-expert layout path used with grouped GEMM
- mask_to_index/index_to_mask: routing map format converters between
  TE's mask-map and MORI's index-map formats
- Layout converters between flat and per-expert token arrangements

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Document the new MORI EP integration including feature table, supported
kernel types, and specific gaps vs the full build (no MoE module
integration, no comm-overlap with expert GEMM, no pipeline-parallel EP,
no heterogeneous expert placement, limited standard MoE kernel types).

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…ace gaps

Implements a fused Triton kernel for the MoE router that combines
score-function (sigmoid/softmax) + top-k + group-topk + normalization +
scaling into a single kernel launch, replacing the previous 3-5 unfused
PyTorch ops. Adds sigmoid score function support that was previously
ignored in _lite, and fixes return signature mismatches between _lite
and the C++ extension interface.

Key changes:
- New Triton JIT kernels (fwd/bwd) for fused router in
  common/triton/fused_router.py with PyTorch wrappers in
  pytorch/triton/fused_router.py
- _lite/router.py: sigmoid scoring, group top-k, correct 3-tuple
  returns for fused_topk_with_score_function_fwd and
  fused_score_for_moe_aux_loss_fwd, and (loss, Const_buf) return
  for fused_moe_aux_loss_fwd
- Comprehensive test coverage for all MoE permutation and router paths
  including mask-map, chunk-sort, numerical gradient verification, and
  Triton-vs-PyTorch cross-checks

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…ests

- Rewrite _lite/rope.py to match full C++ tex.fused_rope_forward/backward
  signatures with cp_size, cp_rank, start_positions, qkv_format, interleaved,
  and cu_seqlens parameters
- Implement DualChunkSwap frequency slicing (_get_freqs_on_this_cp_rank) for
  CP-aware position embedding in lite mode
- Fix AITER wiring: import from aiter.ops.rope instead of non-existent
  aiter.rope, with adapter functions translating TE conventions (interleaved,
  qkv_format) to AITER conventions (rotate_style, nope_first)
- Fix fused QKV RoPE head-dimension reshape to match C++ kernel behavior:
  Q is reshaped from [s,b,h,q_split] to [s,b,h*q_split/head_dim,head_dim]
  using K's dimension as the reference head size
- Add start_positions support via per-batch freq stacking
- Compute RoPE rotation in float32 for precision parity with C++ fused kernel
- Add interleaved rotation support (_rotate_half_interleaved)
- Add multi-GPU CP attention tests (test_lite_cp.py) covering P2P, AllGather,
  and A2A comm types with BSHD/SBHD formats
- Add xfail skips in test_fused_rope.py for known lite gaps (non-contiguous
  tensors, THD+CP)

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
- Add AITER Triton as highest-priority backend for LayerNorm/RMSNorm
  (AITER Triton > TE Triton > PyTorch fallback)
- Import aiter.ops.triton.rmsnorm._rmsnorm_forward/backward and
  aiter.ops.triton.norm._layernorm_forward/backward with lazy loading
- Add adapter functions that translate between TE's norm API (N-D input,
  quantizer interface) and AITER's raw 2D Triton kernel interface
- Refactor internal helpers: extract _ensure_2d/_restore_nd for clean
  N-D reshape handling, separate quantizer application from norm compute
- Note: AITER's fused norm+quantize kernels (dynamicquant/smoothquant)
  use per-row scaling which is incompatible with TE's per-tensor FP8
  scaling, so quantization remains a separate step via the quantizer
  interface
- Update test_lite.py to match refactored _layernorm_fwd_pytorch and
  _rmsnorm_fwd_pytorch signatures (removed ln_out/quantizer/otype/
  sm_margin params from internal PyTorch fallback functions)

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
- test_aiter_norms_active: confirms all 4 AITER norm functions (LayerNorm
  and RMSNorm, forward and backward) are loaded as the active backend
- test_aiter_rmsnorm_fwd_bwd: validates RMSNorm forward + backward output
  from the AITER path against PyTorch reference
- test_aiter_layernorm_fwd_bwd: same for LayerNorm

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Integrate AITER's fused_rms_fp8_per_tensor_static_quant Triton kernel
into the lite RMSNorm forward path. When a Float8Quantizer (delayed
scaling) is provided, the norm and FP8 quantization now execute in a
single kernel launch instead of two separate passes over memory.

The fused kernel:
- Takes a pre-computed per-tensor scale from Float8Quantizer.scale
- Computes RMSNorm and FP8 cast in one pass
- Writes FP8 output directly (no intermediate BF16 materialization)

Also adds detection scaffolding for MXFP8Quantizer with AITER's
fused_rms_fp8_group_quant (block scaling), currently falling back to
separate quantize until Float8Tensor wrapping is validated.

Float8CurrentScalingQuantizer (JIT scaling) cannot use the static fused
kernel since the scale is unknown before the forward pass -- it
continues to use the separate norm -> quantizer.quantize() path.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
- test_fused_rmsnorm_fp8_quant_active: verifies the AITER fused kernel
  is loaded and produces a Float8Tensor with correct scale_inv and amax
- test_fused_rmsnorm_fp8_quant_vs_separate: validates fused output
  against separate norm->quantize path (dequantized comparison)
- test_fused_rmsnorm_fp8_quant_3d_input: verifies N-D shape preservation
  through the fused path

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Per-tensor CurrentScaling requires 3 kernel launches (norm → HBM → amax
scan → HBM → quantize). Per-row fuses norm+quantize into a single kernel
where each row computes its own scale in registers, eliminating the global
amax data dependency and the BF16 intermediate write to HBM.

Forward: rmsnorm2d_fwd_with_dynamicquant produces FP8 + yscale(M,).
GEMM: gemm_a8w8_per_token_scale consumes per-row-scaled FP8 natively.
Backward: dynamic_per_token_quant_fp8_i8 quantizes dY per-row for dgrad.

Changes:
- _lite/norms.py: Fused RMSNorm+FP8 per-row quant for CurrentScaling
- _lite/gemm.py: Per-row scale detection and per-token GEMM dispatch
- _lite/quantize.py: Per-row dynamic quantize path for CurrentScaling
- tests/test_lite.py: 9 new tests covering all three paths

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Add new FP8 Training section documenting per-row dynamic scaling for
CurrentScaling recipe -- a lite-only optimization that fuses norm+quantize
into a single kernel, eliminating 2 HBM round-trips vs per-tensor scaling.

Also updates feature tables for: AITER as primary norm backend, fused
norm+quantize variants, per-row GEMM dispatch, per-row dynamic quantize,
context parallelism for RoPE, and fused Triton MoE router.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
The recipe infrastructure (fp8_autocast, DelayedScaling, Float8CurrentScaling,
MXFP8BlockScaling, RecipeState, make_quantizers) is pure Python that lives
above _lite and works identically in both modes. No actual gap exists.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…ntize fallback

MXFP8 tensors were misdetected as FP4 (shared _rowwise_data attribute),
silently dequantized in GEMM, and the fused norm+quant output was discarded.

GEMM detection (gemm.py):
- Add _is_mxfp8() using _fp8_dtype discriminator vs _fp4_dtype for FP4
- Fix _is_fp4() to require _fp4_dtype attribute
- Fix _is_quantized() and _get_raw_data() to handle MXFP8
- Add explicit MXFP8 early-return in CK/Triton dispatch with TODO(MI350)
  hooks for future native MXFP8 GEMM kernels

Norms fusion (norms.py):
- Complete fused_rms_fp8_group_quant wrapping (was returning None)
- Add E8M0 scale conversion: AITER linear float32 → uint8 biased exponent
- Produces proper MXFP8Tensor from single fused kernel

Quantize fallback (quantize.py):
- Add _linear_scale_to_e8m0() shared helper
- Add _quantize_mxfp8_pytorch() pure PyTorch fallback for MXFP8
- Wire fallback into quantize() and _quantize_pytorch_fallback()

Tests: 8 new tests covering detection, E8M0 conversion, quantize roundtrip,
PyTorch fallback, GEMM dequant path, and fused RMSNorm+MXFP8.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…amax history

Implements approach 2 for compound fused modules: clean autograd Functions
composing existing _lite ops (norm, GEMM, activations) instead of routing
through the full-build's 1500+ line distributed-heavy modules.

LayerNormLinear: norm → quantize → GEMM, with full backward.
LayerNormMLP: norm → FC1+bias → activation → FC2+bias, supporting all 11
activation variants (gelu, swiglu, geglu, silu, relu, etc.) with fused
dbias_dact backward when available.

Both inherit TransformerEngineBaseModule for fp8_autocast integration and
accept the full-build constructor/forward kwargs for API compatibility
(TP/SP/FSDP params are accepted but ignored).

Also fixes fused_amax_and_scale_update_after_reduction which was ignoring
the amax history window — now rolls history and supports "max",
"most_recent", and custom callable algorithms matching the C++ kernel.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
The torch.matmul fallback in generic_gemm failed when operands had more
than 2 dimensions (e.g. [batch, seq, hidden]) because it passed them
directly to matmul without flattening. The C++ GEMM and AITER kernels
handle this implicitly but torch.matmul does not.

Flatten operands to 2D before matmul and restore B's leading dimensions
in the output, matching cuBLAS convention. This fixes TransformerLayer
backward in lite mode — removes the xfail marker from that test.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Previously the backward pass in LayerNormLinear and LayerNormMLP ran
all GEMMs in bf16 even during FP8 training. Now:

- dgrad GEMMs pass grad_input_quantizer to quantize output on the fly
- wgrad GEMMs pass grad_weight_quantizer for FP8 weight gradients
- Saved inputs (ln_out, act_out) are re-quantized with columnwise
  usage before NT wgrad GEMMs (AITER CK needs column-wise scaling)
- grad_output gets columnwise usage enabled for wgrad GEMMs
- MLP backward quantizes dact (activation backward output) via
  fc1_grad_output_quantizer before FC1 GEMMs

This matches the full-build backward quantization strategy, enabling
FP8 throughput benefits in the backward pass when AITER kernels are
available.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
For gated activations (swiglu, geglu, reglu) with Float8BlockQuantizer,
dispatch to AITER's act_mul_and_fp8_group_quant Triton kernel which
fuses activation + gate multiply + FP8 cast in a single kernel pass.
This eliminates the intermediate bf16 round-trip between activation
and quantization.

The kernel accepts group_size as a parameter, so we pass the quantizer's
block_len (128 for Float8Block). AITER returns fp8 data + float32
dequant scales (scale_inv), which we wrap directly into a
Float8BlockwiseQTensorStorage.

Dispatch priority for gated activations is now:
1. AITER fused act+quant (when quantizer is Float8BlockQuantizer)
2. AITER fused gated act (silu_and_mul, gelu_tanh_and_mul) + separate quantize
3. PyTorch fallback + separate quantize

MXFP8 is not supported by this fused path because AITER produces float32
scales while MXFP8 requires E8M0-encoded uint8 scales.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
jayfurmanek and others added 26 commits April 23, 2026 23:28
hipBLASLt FP8 kernels require every mat1 dim divisible by 16. The LLaMA-3
config hits M=8184 (2046×4 tokens) on every forward and dgrad call, which
trips either the explicit "trailing dimension must be divisible by 16"
check (wgrad with K=8184) or the implicit "could not find valid hipblaslt
solution" (fwd/dgrad with M=8184). Pad M with zero rows up to the next
multiple of 16 (+8 rows for 8184, a 0.1% overhead), pad the per-row scale
with 1.0 in lockstep (scale × 0 = 0 anyway), call _scaled_mm, slice the
result back. Padded rows contribute zero and don't affect the GEMM result.

Unlike the reverted pow2 padding (which inflated 28672 → 32768 at 12.5%
overhead), div-by-16 padding only pads the misaligned tokens dim; weight
dims (4096, 14336, 28672, 6144) are all already aligned.

K misalignment (mat1 trailing dim = tokens = 8184 after the wgrad
transpose) is not padded here — those calls also hit per-row-on-reduction
issues and belong on AITER. Logged as "k_not_div16" for traceability.

Counter: k_not_div16 added to the [LITE-SCALED-MM-FAIL] diag reasons.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Megatron uses TEDelayedScaling (per-tensor) recipes, so most _scale_inv
tensors arrive as numel=1 scalars. The old _reshape_scale_for_scaled_mm
broadcast these to (M, 1) / (1, N), forcing hipBLASLt's rowwise kernel
family. That family isn't tuned for mixed-dtype (E4M3 × E5M2) FP8 on
ROCm, hence "could not find valid hipblaslt solution" on every dgrad.

Pass per-tensor scalars as 0-dim tensors so hipBLASLt selects the
per-tensor kernel family — the same F8NBS/F8B8NBS Tensile kernels full
TE uses for both same-dtype forward and mixed-dtype backward.

Per-row (numel==dim) scales still reshape to (dim, 1) / (1, dim) for
the rowwise kernel family; that path is for CurrentScaling recipes
where the full rowwise kernel set does cover the shapes we hit.

Padding logic updated: don't F.pad a 0-dim scalar scale (it applies
uniformly to padded rows automatically).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
The pytorch path (torch._scaled_mm, hipBLASLt-backed on ROCm) is now the
fastest of the three backends at 1.79 s/iter (443 TFLOP/s) on LLaMA-3-8B,
edging ahead of the full build. triton and ck land within 6% at ~2.01 s/iter
and remain available via explicit override.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
TestGemmBackendMatrix covers gaps that opened once three backends
(pytorch/triton/ck) became production paths:
- Parity across all three backends for BF16 and per-tensor FP8
  (DelayedScaling layout, the Megatron default)
- Counter assertion that per-tensor FP8 under backend=pytorch lands on
  torch._scaled_mm and not dequant+matmul — catches silent "scalar
  scale accidentally broadcast to rowwise" regressions
- M=100 (not div-by-16) to exercise the pad-and-slice path for
  hipBLASLt FP8 alignment

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Newer aiter releases insert positional args (sink_size after
window_size_right, q_descale/k_descale/v_descale after alibi_slopes) in
_flash_attn_forward, which shifts the tail args and makes the existing
positional call fail. Switching to keyword arguments makes the call
resilient to future drift and unblocks TestRecipeIntegration's
transformer_layer DelayedScaling/Float8CurrentScaling cases.

The varlen path already uses keyword args so it was already drift-safe
(and that's the path Megatron thd training exercises).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
optimization on the norm backward path: when the only downstream consumer
of an FP8 tensor is going to dequantize it immediately, we can skip the
cast entirely and still preserve DelayedScaling amax history by running
a standalone reduction on the BF16 source. Gated with NVTE_LITE_SKIP_FP8_DGRAD_FOR_NORM env var
Pull te_general_grouped_gemm out of _lite/gemm.py into its own module
with explicit FP8-operand detection. AITER's generic GMM kernels
(gmm/ptgmm/nptgmm) are BF16/FP16 only — the p/np prefix is
persistent vs non-persistent kernel, not per-tensor scaling — so FP8
operands now raise NotImplementedError pointing at the fused-MoE path
(aiter.fused_moe / moe_op_gemm_a8w8_blockscale) that Phase 2 will wire.

Public API unchanged: te_general_grouped_gemm is still exported from
transformer_engine.pytorch._lite. BF16/FP16 path continues to delegate
to general_grouped_gemm_triton, so existing TestGroupedLinear coverage
is the regression check.

Tests:
- TestImport.test_key_symbols_exist: assert te_general_grouped_gemm is
  on the exported tex surface.
- TestGroupedGemmDispatch.test_fp8_operands_raise_not_implemented (new):
  FP8 operands must fail loudly so they don't silently misroute through
  the BF16 kernel.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
MoE token routing can leave a rank with zero local tokens for its
expert(s) — common in early training before the auxiliary load-balancing
loss equalizes routing. AITER's gmm asserts M > 0, but Megatron treats
this as a legal MoE state, so the lite wrapper must handle it.

Hit on iter 3 of a single-node Mixtral 8x7B BF16 run (EP=8, MBS=1,
seq=2048, mock data): one rank's TEColumnParallelGroupedLinear got
m_splits summing to 0, the call traversed _lite/grouped_gemm.py ->
general_grouped_gemm_triton -> gmm_common.get_gmm_shape, and tripped
"AssertionError: M must be positive, it's 0."

Fix: when sum(m_splits) == 0, short-circuit before invoking AITER.
- Forward / dgrad: outputs already shape (0, ...) — nothing to do.
- Wgrad: output is (G, K, N); zero-fill iff accumulate=False, leave
  alone iff accumulate=True (zero contribution = no-op).
- Bias return matches the kernel-call return shape.

Tests:
- TestGroupedGemmDispatch.test_empty_tokens_short_circuit_forward:
  forward path returns cleanly with M=0 and out.shape=(0, N).
- TestGroupedGemmDispatch.test_empty_tokens_short_circuit_wgrad_zeros_out:
  wgrad with accumulate=False zeros the (G, K, N) output buffer.

Iters 1-2 of the same Mixtral run completed before the failure (loss
10.58 -> 10.53, grad norm 16 -> 53), so the BF16 MoE path through lite
is otherwise sound at full Mixtral scale.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
The bshd/sbhd branch in `_aiter_attn_fwd` was forwarding the TE-supplied
`cu_seqlens_q/kv` to `aiter._flash_attn_forward`. Those values are just
fixed-length batch boundaries here (no packed sequences), but their
non-None presence trips aiter's `can_impl_fmha_v3_fwd` gate and routes
the call to the slower JIT `mha_fwd` path (`ck_tile::FmhaFwdKernel`)
instead of the AOT `aiter::fmha_fwd_hd128_bf16_causal_*` kernel that
full TE uses. Pass None explicitly on this branch.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Diagnostic only: under NVTE_LITE_DIAG=1, print q/k/v dtype + shapes,
bias, dropout, causal/window, and seqlen on the first 1-2 calls into
_aiter_attn_fwd's bshd/sbhd path. Gated behind _LITE_DIAG so default
runs are unaffected. Will be reverted once we identify which
can_impl_fmha_v3_fwd gate is still blocking the AOT v3 dispatch.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Replaces the temporary [LITE-ATTN-FWD-PROBE] used to diagnose v3
dispatch with a cleaner [LITE-ATTN-FWD] one-shot helper, gated behind
NVTE_LITE_DIAG=1 like the rest of the lite diagnostics. Prints once
per process from whichever branch (thd/bshd/sbhd) of _aiter_attn_fwd
runs first, with the fields that map directly to aiter's
can_impl_fmha_v3_fwd gate. Useful for catching future regressions
where attention silently routes to the slower JIT ck_tile path.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
For Megatron's typical SBHD layout, _to_bshd was materializing a fresh
BSHD-contiguous copy of q/k/v (and on bwd, o/dout) before each call to
aiter._flash_attn_*. Aiter's maybe_contiguous only triggers a copy when
stride(-1) != 1, and a plain transpose(0, 1) of an SBHD-contiguous
tensor preserves head-dim stride 1, so the kernel reads the strided
BSHD view directly with no internal materialization. Verified the fwd
kernel produces bit-identical output (max abs diff 0).

Cascade benefit on bwd: torch.empty_like with default preserve_format
inherits the source view's strides, so dq/dk/dv allocations become
strided BSHD views as well; the kernel writes into them through the
strided indexing, and _from_bshd's later transpose+.contiguous() is a
no-op (the transpose recovers SBHD-contiguous strides). The fwd output
still copies because aiter allocates it BSHD-contiguous internally;
fixing that one would need an aiter-side change to accept a
preallocated out tensor.

Trace had 384 launches of direct_copy_kernel_cuda variants (~63 ms)
attributable to these copies; eliminated.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
prepare_forward's defensive .contiguous() copy only triggers when the
incoming activation is non-contiguous. In full TE that branch is dead;
in lite it fires hundreds of times per step, materializing copies that
account for ~50 ms of the elementwise gap (3 of the top direct_copy
launches in the trace). Need to know which lite producer (Triton fused
linear, SwiGLU, RMSNorm, cast-transpose, etc.) is emitting the
non-standard strides so we can fix it at the source.

Probe is gated behind NVTE_LITE_DIAG=1 (existing lite-mode env var)
and capped at 20 unique (module, shape, stride, caller) signatures so
output stays bounded. Walks the stack past the base.py frame to the
caller for actionable identification. Zero overhead when the env var
is off.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
prepare_forward is a @contextmanager; first frame above base.py is
contextlib.__enter__. Skip both base.py and contextlib.py, then
capture three user-code frames so the producer of the non-contig
input is visible (innermost first).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
The TE Linear/LayerNormLinear forward methods, torch._dynamo's
eval_frame, and torch.nn.Module._call_impl all sit between the actual
producer and prepare_forward. Skip them and bump the captured-frame
count to 8 so the layer that emitted the non-contig view (e.g., a
transpose without .contiguous() in Megatron's attention forward)
becomes visible in the [LITE-NONCONTIG] caller chain.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Experimental gate to A/B test whether the defensive .contiguous()
materialize in prepare_forward is actually necessary, or whether the
downstream GEMM can consume strided activations directly. Off by
default; setting NVTE_LITE_SKIP_NONCONTIG=1 skips the materialize
when an input is non-contiguous.

Context: with --attention-backend fused + apply_rope_fusion=1, lite
sees ~384 direct_copy launches per step from this materialize alone
(~50 ms). The producer is Megatron's TEDotProductAttention output
transpose at extensions/transformer_engine.py:811, which returns a
non-contig BSH-shape view of SBH-contig memory. If hipBLASLt accepts
that strided view (or only re-materializes once internally instead of
the whole shape), we save the per-call copy.

If the GEMM crashes or produces wrong output, revert by unsetting the
env var. The diagnostic [LITE-NONCONTIG] log still fires regardless,
so we keep visibility.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
New _contig_diag module counts and times every prepare_forward
.contiguous() materialize, keyed on (module, shape, stride, caller).
Hooks tick_step() into FP8GlobalStateManager.autocast_exit so the step
counter advances once per training step under DelayedScaling without
patching Megatron.

Activation:
  NVTE_CONTIG_DIAG=1                  enable instrumentation
  NVTE_CONTIG_DIAG_DUMP_STEP=N        auto-dump after step N

Timing is perf_counter_ns around the .contiguous() call (CPU launch
cost, no cuda.synchronize), to avoid distorting the very gap we are
trying to measure. Use rocprof for device-side cost; the counter
answers where and how often.

Intended to side-by-side full and lite under apples-apples Megatron
runs and diff the [CONTIG-DIAG] blocks to identify lite-only or
higher-count materialize sites.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
- Document NVTE_LITE_AMAX_FUSED, NVTE_LITE_SKIP_FP8_DGRAD_FOR_NORM,
  NVTE_LITE_DIAG; correct NVTE_LITE_GEMM_BACKEND default (ck -> pytorch)
  and rewrite its description to reflect the torch._scaled_mm tier.
- Add grouped_gemm.py / amax_utils.py / fused_layernorm_{linear,mlp}.py
  to the module-structure listing.
- Add MoE-section rows for AITER Triton grouped GEMM (BF16/FP16) and
  call out FP8 grouped GEMM as NYI; cross-link the existing
  TestGroupedLinear::test_fp8_forward xfail.
- Refresh GEMM gaps + Summary so the default pytorch backend's
  _scaled_mm-first dispatch is reflected and FP8 grouped GEMM is listed
  as a primary gap.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Exercises Float8CurrentScaling through te.LayerNormLinear and asserts
both that the AITER per-row kernels actually fire (catching silent
fallback to per-tensor — invisible to cosine-only checks) and that
fwd/bwd numerics stay within FP8-appropriate tiered tolerances vs a
BF16 reference.

The fixture monkeypatches the kernel module-attrs via sys.modules
because _lite/__init__.py re-exports a `quantize` function that
shadows the `quantize` submodule under attribute lookup.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Captures invariants, perf baselines, dispatch hazards, and dead ends
accumulated across the lite work. README documents what tealite supports;
SKILLS documents what to know when working on it.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…nication section

The LayerNorm/RMSNorm feature table listed "Tensor / sequence parallelism"
and "FSDP2 integration" as gaps, which read as if the norm op itself was
limited. The constraints actually live in the fused compound modules and
the comm layer, so move both rows into the Communication / Distributed
section. Also correct the FSDP2 row: lite supports FSDP2 with a 1D mesh
(weights wrap in FSDPAGTensor via the inherited base-class path); only
HSDP / 2D-mesh plumbing is missing. Expand TP/SP rows with the actual
reason (kwargs accepted for API compat but ignored; Megatron SP requires
TP) and upgrade the CP row to reflect that RoPE + attention CP is wired,
not just THD/BSHD helpers.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Tealight-candle ASCII art alongside a figlet-style "tealite" wordmark,
with the tagline "TransformerEngine, by candlelight". Wrapped in a
fenced code block so monospace alignment survives all markdown
renderers.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@jayfurmanek jayfurmanek marked this pull request as ready for review May 8, 2026 12:58
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant