Add Tealite: pure-Python TransformerEngine for ROCm/AMD GPUs#581
Open
jayfurmanek wants to merge 103 commits intodevfrom
Open
Add Tealite: pure-Python TransformerEngine for ROCm/AMD GPUs#581jayfurmanek wants to merge 103 commits intodevfrom
jayfurmanek wants to merge 103 commits intodevfrom
Conversation
Introduces transformer_engine/pytorch/_lite/ package that provides a drop-in replacement for the compiled transformer_engine_torch C++ extension module. When NVTE_LITE=1 is set, the lite module is registered via sys.modules, transparently replacing all tex.* calls with Triton/AITER/PyTorch-native implementations. This eliminates the need for C++ compilation and reduces ROCm/HIP dependencies while retaining functional correctness. Phase 0 scaffold: 18 files covering enums, activations, norms, GEMM, softmax, attention (stubbed), RoPE, dropout, transpose, quantization, permutation, multi-tensor ops, MOE router, comm stubs, and padding. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…ication - Fix GEMM transpose logic for cuBLAS column-major convention (B @ A order) - Fix GEMM return signature to 4-tuple (output, bias_grad, gelu_input, extra_output) - Fix GEMM bias handling: forward adds bias, backward computes bias_grad from grad_output - Fix rmsnorm_fwd to return 3 values matching C++ signature - Fix layernorm_bwd/rmsnorm_bwd signatures to include sm_margin parameter - Route get_fused_attn_backend to No_Backend (unfused SDPA) until Phase 3 Verified on MI300X: forward+backward pass for Linear, LayerNormLinear, LayerNormMLP, LayerNorm, RMSNorm. TransformerLayer forward works; backward needs Phase 1 fix for autograd Variable issue. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
20 tests covering import verification, forward pass, forward+backward, and numerical correctness for Linear, LayerNormLinear, LayerNormMLP, LayerNorm, RMSNorm, and TransformerLayer under NVTE_LITE=1. TransformerLayer backward is marked xfail pending Phase 1 fix. Run with: NVTE_LITE=1 pytest tests/pytorch/test_lite.py -v Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Replace PyTorch-native norm placeholders with calls to existing Triton kernels (triton_kernels/norms_common.py) for LayerNorm and RMSNorm fwd/bwd. Lazy-imports Triton at first call with automatic fallback to PyTorch if Triton is unavailable. Handles >2D input via reshape. Also stubs out AITER dispatch paths in _lite/gemm.py for generic_gemm and te_general_grouped_gemm (wiring deferred to AITER integration phase). Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Implement Triton dispatch for quantize/dequantize in lite mode by importing low-level Triton cast kernels directly (te_cast_transpose_noop, mxfp8, mxfp4) instead of going through cast.py which has a tex.quantize fallback that would infinitely recurse in lite mode. Key fixes: - Break quantize recursion: old code did quantizer.quantize() -> Float8Quantizer.quantize_impl() -> tex.quantize() -> infinite loop - FP8 dequantize: properly reinterpret uint8 data as FP8 bits via .view(fp8_dtype) before casting to target dtype - Plain tensor dequantize: check isinstance(torch.Tensor) first to avoid the no-op .dequantize() trap that ignores otype Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Replace nested Python loops with vectorized PyTorch ops for fp8_block_scaling_compute_partial_amax and fp8_block_scaling_partial_cast. Uses pad-reshape-reduce pattern: pad to block-aligned shape, reshape into (num_blocks_h, block_len, num_blocks_w, block_len), then amax or broadcast-multiply over block dims. Eliminates O(blocks) kernel launches. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Write 2D-tiled Triton kernels for fp8_block_scaling_compute_partial_amax and fp8_block_scaling_partial_cast. Each program processes one (BLOCK_LEN x BLOCK_LEN) block, loading TILE_ROWS x BLOCK_LEN elements per iteration for full intra-block parallelism. Autotuned over TILE_ROWS (4/8/16/32) and num_warps (4/8) to match the fused C++/CUDA kernel's single-launch, fully-parallel design. Falls back to vectorized PyTorch when Triton is unavailable. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
True single-pass fusion would require merging bgrad accumulation into the cast kernel. The separate sum + quantize path is already efficient since both dispatch to optimized CUDA/Triton kernels individually. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Wire AITER as an optional pip dependency (amd-aiter) for CK/Triton
kernel dispatch in lite mode:
- gemm.py: Multi-backend GEMM dispatch controlled by NVTE_LITE_GEMM_BACKEND
env var ("ck", "triton", "pytorch"). Supports all precisions:
- FP8 per-tensor: CK gemm_a8w8_CK / Triton gemm_a8w8
- FP8 block-scale: CK gemm_a8w8_blockscale / Triton gemm_a8w8_blockscale
- FP8 2D-blockwise (Float8BlockwiseQTensorStorage): routes to block-scale
kernels with proper rowwise/columnwise data extraction
- FP4 (MXFP4): CK gemm_a4w4 / Triton gemm_afp4wfp4
- BF16/FP16: Triton gemm_a16w16 (no CK individual GEMM)
- FP32: torch.matmul (preserves exact precision)
Auto-detects per-tensor vs block-scale from scale tensor shape.
Each backend falls through to the next on failure.
- activations.py: AITER fused gated activations (swiglu -> silu_and_mul,
geglu -> gelu_tanh_and_mul). Non-gated activations stay PyTorch.
- rope.py: Refactored to use shared aiter_utils instead of per-file flags.
- aiter_utils.py: New shared AITER availability detection with lru_cache.
Tested with amd-aiter 0.1.7 on all backends (ck, triton, pytorch):
all 37 tests pass + 1 xfail. Also passes without AITER installed.
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Build a pure-Python wheel with no C++ compilation: NVTE_LITE_ONLY=1 python setup.py bdist_wheel Builds in <1 second (vs 5-10 minutes for full build). The wheel: - Contains only Python files + Triton kernels (no .so files) - Platform tag: py3-none-any (architecture-independent) - Package name: tealite - Writes LITE_BUILD marker file that forces NVTE_LITE=1 at import Build changes: - setup.py: Skip C++ extensions, CMake, submodule checks, hipify when NVTE_LITE_ONLY=1. Write LITE_BUILD marker into package. - common/__init__.py: Detect LITE_BUILD marker at module level to auto-activate lite mode (skip core library loading). - .gitignore: Exclude LITE_BUILD marker from version control. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
… stub, and SDPA fallback Replace Phase 3 TODO stubs with full multi-backend attention dispatch: - AITER CK/ASM kernels via raw _flash_attn_forward/_backward (priority backend) - Flash-attention stubbed for future integration - PyTorch SDPA fallback with autograd-based backward - Pure PyTorch helpers: fa_prepare_fwd/bwd, copy_to_kv_cache, THD<->BSHD converters - C++ binding-compatible signatures so cpp_extensions/fused_attn.py works unmodified Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
19 attention tests: backend selection, fwd/bwd shapes for bshd/sbhd/thd, aux_ctx_tensors format, mask types, GQA, variable-length, AITER-vs-SDPA numerical comparison, helper functions, DotProductAttention and MultiheadAttention end-to-end. 14 GEMM tests: TN layout, all transpose combos, bias addition, bias grad, GELU epilogue, accumulate, alpha scaling, output-into-D, return format, FP32. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…in _lite Fix _lite/permutation.py signatures to match the tex.* C++ interface so the index-map path in permutation.py works correctly in lite mode. Wire up the existing Triton sort_chunks_by_map kernel for the forward gather operation with PyTorch fallback. Fix _lite/padding.py to match the 4-argument tex.fused_multi_row_padding/unpadding interface with proper zero-padding. Add 22 tests covering both MoE permutation and padding operations. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…lysis Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Integrate AMD's MORI (Modular RDMA Interface) library to provide high-performance distributed expert parallelism for the _lite module. This bridges the most significant distributed gap in tealite by enabling MoE token dispatch/combine across GPUs via XGMI (intra-node) and RDMA (inter-node) without requiring C++ extensions. Key components: - MoriExpertParallel: high-level wrapper with dispatch/combine for both flat and standard MoE (per-expert grouped) layouts - MoriEPDispatch/MoriEPCombine: autograd functions enabling gradient flow through distributed dispatch/combine for training - MoriEPDispatchStdMoE/MoriEPCombineStdMoE: autograd functions for the per-expert layout path used with grouped GEMM - mask_to_index/index_to_mask: routing map format converters between TE's mask-map and MORI's index-map formats - Layout converters between flat and per-expert token arrangements Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Document the new MORI EP integration including feature table, supported kernel types, and specific gaps vs the full build (no MoE module integration, no comm-overlap with expert GEMM, no pipeline-parallel EP, no heterogeneous expert placement, limited standard MoE kernel types). Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…ace gaps Implements a fused Triton kernel for the MoE router that combines score-function (sigmoid/softmax) + top-k + group-topk + normalization + scaling into a single kernel launch, replacing the previous 3-5 unfused PyTorch ops. Adds sigmoid score function support that was previously ignored in _lite, and fixes return signature mismatches between _lite and the C++ extension interface. Key changes: - New Triton JIT kernels (fwd/bwd) for fused router in common/triton/fused_router.py with PyTorch wrappers in pytorch/triton/fused_router.py - _lite/router.py: sigmoid scoring, group top-k, correct 3-tuple returns for fused_topk_with_score_function_fwd and fused_score_for_moe_aux_loss_fwd, and (loss, Const_buf) return for fused_moe_aux_loss_fwd - Comprehensive test coverage for all MoE permutation and router paths including mask-map, chunk-sort, numerical gradient verification, and Triton-vs-PyTorch cross-checks Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…ests - Rewrite _lite/rope.py to match full C++ tex.fused_rope_forward/backward signatures with cp_size, cp_rank, start_positions, qkv_format, interleaved, and cu_seqlens parameters - Implement DualChunkSwap frequency slicing (_get_freqs_on_this_cp_rank) for CP-aware position embedding in lite mode - Fix AITER wiring: import from aiter.ops.rope instead of non-existent aiter.rope, with adapter functions translating TE conventions (interleaved, qkv_format) to AITER conventions (rotate_style, nope_first) - Fix fused QKV RoPE head-dimension reshape to match C++ kernel behavior: Q is reshaped from [s,b,h,q_split] to [s,b,h*q_split/head_dim,head_dim] using K's dimension as the reference head size - Add start_positions support via per-batch freq stacking - Compute RoPE rotation in float32 for precision parity with C++ fused kernel - Add interleaved rotation support (_rotate_half_interleaved) - Add multi-GPU CP attention tests (test_lite_cp.py) covering P2P, AllGather, and A2A comm types with BSHD/SBHD formats - Add xfail skips in test_fused_rope.py for known lite gaps (non-contiguous tensors, THD+CP) Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
- Add AITER Triton as highest-priority backend for LayerNorm/RMSNorm (AITER Triton > TE Triton > PyTorch fallback) - Import aiter.ops.triton.rmsnorm._rmsnorm_forward/backward and aiter.ops.triton.norm._layernorm_forward/backward with lazy loading - Add adapter functions that translate between TE's norm API (N-D input, quantizer interface) and AITER's raw 2D Triton kernel interface - Refactor internal helpers: extract _ensure_2d/_restore_nd for clean N-D reshape handling, separate quantizer application from norm compute - Note: AITER's fused norm+quantize kernels (dynamicquant/smoothquant) use per-row scaling which is incompatible with TE's per-tensor FP8 scaling, so quantization remains a separate step via the quantizer interface - Update test_lite.py to match refactored _layernorm_fwd_pytorch and _rmsnorm_fwd_pytorch signatures (removed ln_out/quantizer/otype/ sm_margin params from internal PyTorch fallback functions) Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
- test_aiter_norms_active: confirms all 4 AITER norm functions (LayerNorm and RMSNorm, forward and backward) are loaded as the active backend - test_aiter_rmsnorm_fwd_bwd: validates RMSNorm forward + backward output from the AITER path against PyTorch reference - test_aiter_layernorm_fwd_bwd: same for LayerNorm Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Integrate AITER's fused_rms_fp8_per_tensor_static_quant Triton kernel into the lite RMSNorm forward path. When a Float8Quantizer (delayed scaling) is provided, the norm and FP8 quantization now execute in a single kernel launch instead of two separate passes over memory. The fused kernel: - Takes a pre-computed per-tensor scale from Float8Quantizer.scale - Computes RMSNorm and FP8 cast in one pass - Writes FP8 output directly (no intermediate BF16 materialization) Also adds detection scaffolding for MXFP8Quantizer with AITER's fused_rms_fp8_group_quant (block scaling), currently falling back to separate quantize until Float8Tensor wrapping is validated. Float8CurrentScalingQuantizer (JIT scaling) cannot use the static fused kernel since the scale is unknown before the forward pass -- it continues to use the separate norm -> quantizer.quantize() path. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
- test_fused_rmsnorm_fp8_quant_active: verifies the AITER fused kernel is loaded and produces a Float8Tensor with correct scale_inv and amax - test_fused_rmsnorm_fp8_quant_vs_separate: validates fused output against separate norm->quantize path (dequantized comparison) - test_fused_rmsnorm_fp8_quant_3d_input: verifies N-D shape preservation through the fused path Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Per-tensor CurrentScaling requires 3 kernel launches (norm → HBM → amax scan → HBM → quantize). Per-row fuses norm+quantize into a single kernel where each row computes its own scale in registers, eliminating the global amax data dependency and the BF16 intermediate write to HBM. Forward: rmsnorm2d_fwd_with_dynamicquant produces FP8 + yscale(M,). GEMM: gemm_a8w8_per_token_scale consumes per-row-scaled FP8 natively. Backward: dynamic_per_token_quant_fp8_i8 quantizes dY per-row for dgrad. Changes: - _lite/norms.py: Fused RMSNorm+FP8 per-row quant for CurrentScaling - _lite/gemm.py: Per-row scale detection and per-token GEMM dispatch - _lite/quantize.py: Per-row dynamic quantize path for CurrentScaling - tests/test_lite.py: 9 new tests covering all three paths Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Add new FP8 Training section documenting per-row dynamic scaling for CurrentScaling recipe -- a lite-only optimization that fuses norm+quantize into a single kernel, eliminating 2 HBM round-trips vs per-tensor scaling. Also updates feature tables for: AITER as primary norm backend, fused norm+quantize variants, per-row GEMM dispatch, per-row dynamic quantize, context parallelism for RoPE, and fused Triton MoE router. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
The recipe infrastructure (fp8_autocast, DelayedScaling, Float8CurrentScaling, MXFP8BlockScaling, RecipeState, make_quantizers) is pure Python that lives above _lite and works identically in both modes. No actual gap exists. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…ntize fallback MXFP8 tensors were misdetected as FP4 (shared _rowwise_data attribute), silently dequantized in GEMM, and the fused norm+quant output was discarded. GEMM detection (gemm.py): - Add _is_mxfp8() using _fp8_dtype discriminator vs _fp4_dtype for FP4 - Fix _is_fp4() to require _fp4_dtype attribute - Fix _is_quantized() and _get_raw_data() to handle MXFP8 - Add explicit MXFP8 early-return in CK/Triton dispatch with TODO(MI350) hooks for future native MXFP8 GEMM kernels Norms fusion (norms.py): - Complete fused_rms_fp8_group_quant wrapping (was returning None) - Add E8M0 scale conversion: AITER linear float32 → uint8 biased exponent - Produces proper MXFP8Tensor from single fused kernel Quantize fallback (quantize.py): - Add _linear_scale_to_e8m0() shared helper - Add _quantize_mxfp8_pytorch() pure PyTorch fallback for MXFP8 - Wire fallback into quantize() and _quantize_pytorch_fallback() Tests: 8 new tests covering detection, E8M0 conversion, quantize roundtrip, PyTorch fallback, GEMM dequant path, and fused RMSNorm+MXFP8. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…amax history Implements approach 2 for compound fused modules: clean autograd Functions composing existing _lite ops (norm, GEMM, activations) instead of routing through the full-build's 1500+ line distributed-heavy modules. LayerNormLinear: norm → quantize → GEMM, with full backward. LayerNormMLP: norm → FC1+bias → activation → FC2+bias, supporting all 11 activation variants (gelu, swiglu, geglu, silu, relu, etc.) with fused dbias_dact backward when available. Both inherit TransformerEngineBaseModule for fp8_autocast integration and accept the full-build constructor/forward kwargs for API compatibility (TP/SP/FSDP params are accepted but ignored). Also fixes fused_amax_and_scale_update_after_reduction which was ignoring the amax history window — now rolls history and supports "max", "most_recent", and custom callable algorithms matching the C++ kernel. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
The torch.matmul fallback in generic_gemm failed when operands had more than 2 dimensions (e.g. [batch, seq, hidden]) because it passed them directly to matmul without flattening. The C++ GEMM and AITER kernels handle this implicitly but torch.matmul does not. Flatten operands to 2D before matmul and restore B's leading dimensions in the output, matching cuBLAS convention. This fixes TransformerLayer backward in lite mode — removes the xfail marker from that test. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Previously the backward pass in LayerNormLinear and LayerNormMLP ran all GEMMs in bf16 even during FP8 training. Now: - dgrad GEMMs pass grad_input_quantizer to quantize output on the fly - wgrad GEMMs pass grad_weight_quantizer for FP8 weight gradients - Saved inputs (ln_out, act_out) are re-quantized with columnwise usage before NT wgrad GEMMs (AITER CK needs column-wise scaling) - grad_output gets columnwise usage enabled for wgrad GEMMs - MLP backward quantizes dact (activation backward output) via fc1_grad_output_quantizer before FC1 GEMMs This matches the full-build backward quantization strategy, enabling FP8 throughput benefits in the backward pass when AITER kernels are available. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
For gated activations (swiglu, geglu, reglu) with Float8BlockQuantizer, dispatch to AITER's act_mul_and_fp8_group_quant Triton kernel which fuses activation + gate multiply + FP8 cast in a single kernel pass. This eliminates the intermediate bf16 round-trip between activation and quantization. The kernel accepts group_size as a parameter, so we pass the quantizer's block_len (128 for Float8Block). AITER returns fp8 data + float32 dequant scales (scale_inv), which we wrap directly into a Float8BlockwiseQTensorStorage. Dispatch priority for gated activations is now: 1. AITER fused act+quant (when quantizer is Float8BlockQuantizer) 2. AITER fused gated act (silu_and_mul, gelu_tanh_and_mul) + separate quantize 3. PyTorch fallback + separate quantize MXFP8 is not supported by this fused path because AITER produces float32 scales while MXFP8 requires E8M0-encoded uint8 scales. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
hipBLASLt FP8 kernels require every mat1 dim divisible by 16. The LLaMA-3 config hits M=8184 (2046×4 tokens) on every forward and dgrad call, which trips either the explicit "trailing dimension must be divisible by 16" check (wgrad with K=8184) or the implicit "could not find valid hipblaslt solution" (fwd/dgrad with M=8184). Pad M with zero rows up to the next multiple of 16 (+8 rows for 8184, a 0.1% overhead), pad the per-row scale with 1.0 in lockstep (scale × 0 = 0 anyway), call _scaled_mm, slice the result back. Padded rows contribute zero and don't affect the GEMM result. Unlike the reverted pow2 padding (which inflated 28672 → 32768 at 12.5% overhead), div-by-16 padding only pads the misaligned tokens dim; weight dims (4096, 14336, 28672, 6144) are all already aligned. K misalignment (mat1 trailing dim = tokens = 8184 after the wgrad transpose) is not padded here — those calls also hit per-row-on-reduction issues and belong on AITER. Logged as "k_not_div16" for traceability. Counter: k_not_div16 added to the [LITE-SCALED-MM-FAIL] diag reasons. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Megatron uses TEDelayedScaling (per-tensor) recipes, so most _scale_inv tensors arrive as numel=1 scalars. The old _reshape_scale_for_scaled_mm broadcast these to (M, 1) / (1, N), forcing hipBLASLt's rowwise kernel family. That family isn't tuned for mixed-dtype (E4M3 × E5M2) FP8 on ROCm, hence "could not find valid hipblaslt solution" on every dgrad. Pass per-tensor scalars as 0-dim tensors so hipBLASLt selects the per-tensor kernel family — the same F8NBS/F8B8NBS Tensile kernels full TE uses for both same-dtype forward and mixed-dtype backward. Per-row (numel==dim) scales still reshape to (dim, 1) / (1, dim) for the rowwise kernel family; that path is for CurrentScaling recipes where the full rowwise kernel set does cover the shapes we hit. Padding logic updated: don't F.pad a 0-dim scalar scale (it applies uniformly to padded rows automatically). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
The pytorch path (torch._scaled_mm, hipBLASLt-backed on ROCm) is now the fastest of the three backends at 1.79 s/iter (443 TFLOP/s) on LLaMA-3-8B, edging ahead of the full build. triton and ck land within 6% at ~2.01 s/iter and remain available via explicit override. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
TestGemmBackendMatrix covers gaps that opened once three backends (pytorch/triton/ck) became production paths: - Parity across all three backends for BF16 and per-tensor FP8 (DelayedScaling layout, the Megatron default) - Counter assertion that per-tensor FP8 under backend=pytorch lands on torch._scaled_mm and not dequant+matmul — catches silent "scalar scale accidentally broadcast to rowwise" regressions - M=100 (not div-by-16) to exercise the pad-and-slice path for hipBLASLt FP8 alignment Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Newer aiter releases insert positional args (sink_size after window_size_right, q_descale/k_descale/v_descale after alibi_slopes) in _flash_attn_forward, which shifts the tail args and makes the existing positional call fail. Switching to keyword arguments makes the call resilient to future drift and unblocks TestRecipeIntegration's transformer_layer DelayedScaling/Float8CurrentScaling cases. The varlen path already uses keyword args so it was already drift-safe (and that's the path Megatron thd training exercises). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
optimization on the norm backward path: when the only downstream consumer of an FP8 tensor is going to dequantize it immediately, we can skip the cast entirely and still preserve DelayedScaling amax history by running a standalone reduction on the BF16 source. Gated with NVTE_LITE_SKIP_FP8_DGRAD_FOR_NORM env var
Pull te_general_grouped_gemm out of _lite/gemm.py into its own module with explicit FP8-operand detection. AITER's generic GMM kernels (gmm/ptgmm/nptgmm) are BF16/FP16 only — the p/np prefix is persistent vs non-persistent kernel, not per-tensor scaling — so FP8 operands now raise NotImplementedError pointing at the fused-MoE path (aiter.fused_moe / moe_op_gemm_a8w8_blockscale) that Phase 2 will wire. Public API unchanged: te_general_grouped_gemm is still exported from transformer_engine.pytorch._lite. BF16/FP16 path continues to delegate to general_grouped_gemm_triton, so existing TestGroupedLinear coverage is the regression check. Tests: - TestImport.test_key_symbols_exist: assert te_general_grouped_gemm is on the exported tex surface. - TestGroupedGemmDispatch.test_fp8_operands_raise_not_implemented (new): FP8 operands must fail loudly so they don't silently misroute through the BF16 kernel. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
MoE token routing can leave a rank with zero local tokens for its expert(s) — common in early training before the auxiliary load-balancing loss equalizes routing. AITER's gmm asserts M > 0, but Megatron treats this as a legal MoE state, so the lite wrapper must handle it. Hit on iter 3 of a single-node Mixtral 8x7B BF16 run (EP=8, MBS=1, seq=2048, mock data): one rank's TEColumnParallelGroupedLinear got m_splits summing to 0, the call traversed _lite/grouped_gemm.py -> general_grouped_gemm_triton -> gmm_common.get_gmm_shape, and tripped "AssertionError: M must be positive, it's 0." Fix: when sum(m_splits) == 0, short-circuit before invoking AITER. - Forward / dgrad: outputs already shape (0, ...) — nothing to do. - Wgrad: output is (G, K, N); zero-fill iff accumulate=False, leave alone iff accumulate=True (zero contribution = no-op). - Bias return matches the kernel-call return shape. Tests: - TestGroupedGemmDispatch.test_empty_tokens_short_circuit_forward: forward path returns cleanly with M=0 and out.shape=(0, N). - TestGroupedGemmDispatch.test_empty_tokens_short_circuit_wgrad_zeros_out: wgrad with accumulate=False zeros the (G, K, N) output buffer. Iters 1-2 of the same Mixtral run completed before the failure (loss 10.58 -> 10.53, grad norm 16 -> 53), so the BF16 MoE path through lite is otherwise sound at full Mixtral scale. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
The bshd/sbhd branch in `_aiter_attn_fwd` was forwarding the TE-supplied `cu_seqlens_q/kv` to `aiter._flash_attn_forward`. Those values are just fixed-length batch boundaries here (no packed sequences), but their non-None presence trips aiter's `can_impl_fmha_v3_fwd` gate and routes the call to the slower JIT `mha_fwd` path (`ck_tile::FmhaFwdKernel`) instead of the AOT `aiter::fmha_fwd_hd128_bf16_causal_*` kernel that full TE uses. Pass None explicitly on this branch. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Diagnostic only: under NVTE_LITE_DIAG=1, print q/k/v dtype + shapes, bias, dropout, causal/window, and seqlen on the first 1-2 calls into _aiter_attn_fwd's bshd/sbhd path. Gated behind _LITE_DIAG so default runs are unaffected. Will be reverted once we identify which can_impl_fmha_v3_fwd gate is still blocking the AOT v3 dispatch. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Replaces the temporary [LITE-ATTN-FWD-PROBE] used to diagnose v3 dispatch with a cleaner [LITE-ATTN-FWD] one-shot helper, gated behind NVTE_LITE_DIAG=1 like the rest of the lite diagnostics. Prints once per process from whichever branch (thd/bshd/sbhd) of _aiter_attn_fwd runs first, with the fields that map directly to aiter's can_impl_fmha_v3_fwd gate. Useful for catching future regressions where attention silently routes to the slower JIT ck_tile path. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
For Megatron's typical SBHD layout, _to_bshd was materializing a fresh BSHD-contiguous copy of q/k/v (and on bwd, o/dout) before each call to aiter._flash_attn_*. Aiter's maybe_contiguous only triggers a copy when stride(-1) != 1, and a plain transpose(0, 1) of an SBHD-contiguous tensor preserves head-dim stride 1, so the kernel reads the strided BSHD view directly with no internal materialization. Verified the fwd kernel produces bit-identical output (max abs diff 0). Cascade benefit on bwd: torch.empty_like with default preserve_format inherits the source view's strides, so dq/dk/dv allocations become strided BSHD views as well; the kernel writes into them through the strided indexing, and _from_bshd's later transpose+.contiguous() is a no-op (the transpose recovers SBHD-contiguous strides). The fwd output still copies because aiter allocates it BSHD-contiguous internally; fixing that one would need an aiter-side change to accept a preallocated out tensor. Trace had 384 launches of direct_copy_kernel_cuda variants (~63 ms) attributable to these copies; eliminated. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This reverts commit e4a05c5.
prepare_forward's defensive .contiguous() copy only triggers when the incoming activation is non-contiguous. In full TE that branch is dead; in lite it fires hundreds of times per step, materializing copies that account for ~50 ms of the elementwise gap (3 of the top direct_copy launches in the trace). Need to know which lite producer (Triton fused linear, SwiGLU, RMSNorm, cast-transpose, etc.) is emitting the non-standard strides so we can fix it at the source. Probe is gated behind NVTE_LITE_DIAG=1 (existing lite-mode env var) and capped at 20 unique (module, shape, stride, caller) signatures so output stays bounded. Walks the stack past the base.py frame to the caller for actionable identification. Zero overhead when the env var is off. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
prepare_forward is a @contextmanager; first frame above base.py is contextlib.__enter__. Skip both base.py and contextlib.py, then capture three user-code frames so the producer of the non-contig input is visible (innermost first). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
The TE Linear/LayerNormLinear forward methods, torch._dynamo's eval_frame, and torch.nn.Module._call_impl all sit between the actual producer and prepare_forward. Skip them and bump the captured-frame count to 8 so the layer that emitted the non-contig view (e.g., a transpose without .contiguous() in Megatron's attention forward) becomes visible in the [LITE-NONCONTIG] caller chain. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Experimental gate to A/B test whether the defensive .contiguous() materialize in prepare_forward is actually necessary, or whether the downstream GEMM can consume strided activations directly. Off by default; setting NVTE_LITE_SKIP_NONCONTIG=1 skips the materialize when an input is non-contiguous. Context: with --attention-backend fused + apply_rope_fusion=1, lite sees ~384 direct_copy launches per step from this materialize alone (~50 ms). The producer is Megatron's TEDotProductAttention output transpose at extensions/transformer_engine.py:811, which returns a non-contig BSH-shape view of SBH-contig memory. If hipBLASLt accepts that strided view (or only re-materializes once internally instead of the whole shape), we save the per-call copy. If the GEMM crashes or produces wrong output, revert by unsetting the env var. The diagnostic [LITE-NONCONTIG] log still fires regardless, so we keep visibility. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This reverts commit 993dcd3.
New _contig_diag module counts and times every prepare_forward .contiguous() materialize, keyed on (module, shape, stride, caller). Hooks tick_step() into FP8GlobalStateManager.autocast_exit so the step counter advances once per training step under DelayedScaling without patching Megatron. Activation: NVTE_CONTIG_DIAG=1 enable instrumentation NVTE_CONTIG_DIAG_DUMP_STEP=N auto-dump after step N Timing is perf_counter_ns around the .contiguous() call (CPU launch cost, no cuda.synchronize), to avoid distorting the very gap we are trying to measure. Use rocprof for device-side cost; the counter answers where and how often. Intended to side-by-side full and lite under apples-apples Megatron runs and diff the [CONTIG-DIAG] blocks to identify lite-only or higher-count materialize sites. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
- Document NVTE_LITE_AMAX_FUSED, NVTE_LITE_SKIP_FP8_DGRAD_FOR_NORM,
NVTE_LITE_DIAG; correct NVTE_LITE_GEMM_BACKEND default (ck -> pytorch)
and rewrite its description to reflect the torch._scaled_mm tier.
- Add grouped_gemm.py / amax_utils.py / fused_layernorm_{linear,mlp}.py
to the module-structure listing.
- Add MoE-section rows for AITER Triton grouped GEMM (BF16/FP16) and
call out FP8 grouped GEMM as NYI; cross-link the existing
TestGroupedLinear::test_fp8_forward xfail.
- Refresh GEMM gaps + Summary so the default pytorch backend's
_scaled_mm-first dispatch is reflected and FP8 grouped GEMM is listed
as a primary gap.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Exercises Float8CurrentScaling through te.LayerNormLinear and asserts both that the AITER per-row kernels actually fire (catching silent fallback to per-tensor — invisible to cosine-only checks) and that fwd/bwd numerics stay within FP8-appropriate tiered tolerances vs a BF16 reference. The fixture monkeypatches the kernel module-attrs via sys.modules because _lite/__init__.py re-exports a `quantize` function that shadows the `quantize` submodule under attribute lookup. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Captures invariants, perf baselines, dispatch hazards, and dead ends accumulated across the lite work. README documents what tealite supports; SKILLS documents what to know when working on it. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…nication section The LayerNorm/RMSNorm feature table listed "Tensor / sequence parallelism" and "FSDP2 integration" as gaps, which read as if the norm op itself was limited. The constraints actually live in the fused compound modules and the comm layer, so move both rows into the Communication / Distributed section. Also correct the FSDP2 row: lite supports FSDP2 with a 1D mesh (weights wrap in FSDPAGTensor via the inherited base-class path); only HSDP / 2D-mesh plumbing is missing. Expand TP/SP rows with the actual reason (kwargs accepted for API compat but ignored; Megatron SP requires TP) and upgrade the CP row to reflect that RoPE + attention CP is wired, not just THD/BSHD helpers. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Tealight-candle ASCII art alongside a figlet-style "tealite" wordmark, with the tagline "TransformerEngine, by candlelight". Wrapped in a fenced code block so monospace alignment survives all markdown renderers. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Description
Introduces
tealite— a pure-Python, drop-in replacement for thetransformer_engine_torchC++ extension, targeting ROCm / AMD GPUs andPyTorch only. Activated by setting
NVTE_LITE=1; when unset, none of thelite code is imported and full TE behavior is unchanged.
Why: the full build compiles hundreds of C++ / CUDA / HIP sources via
CMake, couples tightly to toolchain versions, and slows down rapid kernel
iteration on ROCm. Tealite sidesteps all of that by dispatching to
AITER kernels (CK + Triton, installed via
pip install amd-aiter— no submodule), standalone Triton kernels,torch._scaled_mm(hipBLASLt-backed) for FP8 GEMMs, and PyTorch-nativefallbacks. Build a wheel in seconds, not minutes.
Note I will squash upon review as needed.
Type of change
Changes
New
_lite/package (27 files, ~14k LOC) attransformer_engine/pytorch/_lite/:gemm.py,grouped_gemm.py): pluggable backend dispatcher(
NVTE_LITE_GEMM_BACKEND=pytorch/triton/ck); FP8 routesthrough
torch._scaled_mmwith AITER fallback for shapes hipBLASLt can'tserve (per-row reduction-axis, K-not-div-16, unsupported dtype combos).
norms.py): RMSNorm/LayerNorm via AITER Triton with TE-Tritonand PyTorch fallbacks. Per-row dynamic-quant fusion
(
rmsnorm2d_fwd_with_dynamicquant) is a lite-only CurrentScalingoptimization not available in the full build.
quantize.py): FP8 / MXFP8 / MXFP4 cast, transpose,block-scaling. Fused amax/scale update via a single Triton
multi-tensor-apply kernel mirroring
delayed_scaling.cu::kernel_bulk(replaces ~530k tiny elementwise launches/iter).
attention.py): SDPA, AITER CK (fmha_v3_fwd/fmha_v3_bwdAOT), and flash-attn package backends.
activations.py): AITER fused gated activations(swiglu, geglu, reglu, qgeglu) with PyTorch fallbacks.
rope.py): AITER Triton kernels with CP support.fused_layernorm_linear.py,fused_layernorm_mlp.py): pure-Pythontorch.autograd.Functionsubclasses with full DelayedScaling / CurrentScaling / MXFP8 support and
FSDP2 integration via the inherited
FSDPAGTensorwrap.comm.py,mori_ep.py(MORI-based expert parallelismfor MoE),
context_parallel.py(THD/BSHD helpers; CP is wired inattention + RoPE).
ops, router (fused TopK + softmax), permutation, MoE grouped GEMM.
Minimal full-build touchpoints (5 files, ~265 LOC) — all env-gated or
inert when
NVTE_LITEis unset:module/__init__.pyLayerNormLinear/LayerNormMLPfor lite versions whenNVTE_LITE=1module/base.pyFSDPAGTensorweight-wrap plumbing (gated onIS_HIP_EXTENSION)quantization.py_contig_diag.tick_step()hook for the optional materialize-attribution harness (no-op when env unset)triton/fused_router.pytriton_kernels/grouped_gemm.pyDocumentation: in-tree
_lite/README.md(full feature matrix, envvars, status notes per subsystem, known gaps) and
_lite/SKILLS.md(operational tips).
Performance
End-to-end FP8 training on 8×MI300X / LLaMA-3-8B / seq=2048 at the
same TE commit (2026-05-01):
At parity. Lite uses the same hipBLASLt FP8 kernels via
torch._scaled_mmfor fwd + dgrad, AITER Triton for wgrad shapes hipBLASLt can't serve, and
the fused Triton amax/scale update.
Known limitations (documented in
_lite/README.md)compound modules accept
tp_size/tp_group/parallel_mode/sequence_parallelfor API compat but hardcodetp_size=1. Themulti-node story for tealite is FSDP/HSDP-shaped, not TP-shaped.
and is tested.
texhot-swap into_lite/grouped_gemm.py; FP8 grouped is blocked on a Triton GMM dtypemismatch and
aiter.fused_moewiring.a
layernorm2d_fwd_with_dynamicquant).Activation
Checklist:
(
_lite/README.md,_lite/SKILLS.md)(
tests/pytorch/test_lite.py, 4967 lines — GEMM backend matrix,FP8 dispatch, per-row CurrentScaling end-to-end, MoE smoke,
FSDP2 wrap verification, attention backends, norms + quant,
activations, multi-tensor ops)
NVTE_LITE=1)