Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
103 commits
Select commit Hold shift + click to select a range
7dd732d
Add TE Lite: pure-Python replacement for C++ extensions (NVTE_LITE=1)
jayfurmanek Apr 6, 2026
6a7291a
Fix TE Lite Phase 0: GEMM, norms, and attention backend for GPU verif…
jayfurmanek Apr 6, 2026
86b27bc
Add Phase 0 test suite for TE Lite mode
jayfurmanek Apr 6, 2026
ba755a8
TE Lite Phase 2: Wire Triton norm kernels into _lite backend
jayfurmanek Apr 7, 2026
76a8db8
Wire Triton cast kernels into _lite/quantize.py, fix FP8 recursion bug
jayfurmanek Apr 7, 2026
523d299
Vectorize fp8_block_scaling functions in _lite/quantize.py
jayfurmanek Apr 7, 2026
a3eb181
Add Triton kernels for fp8_block_scaling in _lite/quantize.py
jayfurmanek Apr 7, 2026
83a4653
Document bgrad_quantize fusion trade-off in _lite/quantize.py
jayfurmanek Apr 7, 2026
07f2c13
Add AITER integration for GEMM, activations, and RoPE in _lite
jayfurmanek Apr 7, 2026
51c56bf
Add lite-only wheel build mode (NVTE_LITE_ONLY=1)
jayfurmanek Apr 7, 2026
95cb6ba
Implement optimized attention kernels in _lite with AITER, flash-attn…
jayfurmanek Apr 8, 2026
960c150
Add attention and GEMM tests for _lite module
jayfurmanek Apr 8, 2026
f65c1da
Wire up Triton kernels for MoE permutation and fix padding interface …
jayfurmanek Apr 8, 2026
f0166ba
Add README for tealite (_lite) module with feature status and gap ana…
jayfurmanek Apr 9, 2026
6405392
Add MORI expert parallelism integration for tealite distributed MoE
jayfurmanek Apr 9, 2026
83923a5
Update _lite README with MORI expert parallelism documentation
jayfurmanek Apr 10, 2026
88149d1
Add fused Triton MoE router with sigmoid support and fix _lite interf…
jayfurmanek Apr 10, 2026
88f3bbe
Add context parallelism support for lite mode RoPE and CP attention t…
jayfurmanek Apr 13, 2026
fe52154
Wire up AITER Triton norm kernels as primary backend for lite mode
jayfurmanek Apr 13, 2026
8193ba2
Add unit tests verifying AITER Triton backend is active for norm kernels
jayfurmanek Apr 13, 2026
0692a1c
Wire up AITER fused RMSNorm+FP8 quantize kernel for delayed scaling
jayfurmanek Apr 13, 2026
d6c56bf
Add tests for fused RMSNorm+FP8 quantize path
jayfurmanek Apr 13, 2026
817e49f
Wire up AITER per-row dynamic FP8 scaling for CurrentScaling recipe
jayfurmanek Apr 14, 2026
45c2d9e
Update _lite README with FP8 training section and feature status refresh
jayfurmanek Apr 14, 2026
9bf4e54
Fix misleading FP8 recipe management gap in _lite README
jayfurmanek Apr 14, 2026
05b1d9f
Fix MXFP8 BlockScaling support in _lite: detection, norms fusion, qua…
jayfurmanek Apr 14, 2026
57c57cc
Add lite-native LayerNormLinear, LayerNormMLP and fix DelayedScaling …
jayfurmanek Apr 15, 2026
10bfb93
Fix N-D tensor handling in _lite GEMM PyTorch fallback
jayfurmanek Apr 15, 2026
cd4cde4
Wire FP8 quantizers through backward GEMMs in _lite fused modules
jayfurmanek Apr 15, 2026
ea8e5aa
Wire AITER fused gated activation + block FP8 quantize into _lite
jayfurmanek Apr 15, 2026
a120df7
Add tests for AITER fused gated activation + block FP8 quantize
jayfurmanek Apr 15, 2026
6356c05
Wire AITER fused gated activation + per-row FP8 quantize for CurrentS…
jayfurmanek Apr 16, 2026
2c7125d
Fix CurrentScaling FP8 backward bugs and add recipe integration tests
jayfurmanek Apr 16, 2026
3ad2b70
Complete CurrentScaling FP8 backward path for LayerNormLinear and Lay…
jayfurmanek Apr 16, 2026
24e0020
Fix DelayedScaling end-to-end and wgrad transpose handling
jayfurmanek Apr 16, 2026
b604ebd
Use gemm_a8w8 correctly for per-tensor FP8 with expanded (M,)/(N,) sc…
jayfurmanek Apr 16, 2026
eeeb0a4
Wire TransformerLayer + FP8 end-to-end (N-D tensors, return_bias, N-D…
jayfurmanek Apr 16, 2026
cbc6200
Add API contract tests and FP8-vs-bf16 correlation tests for fused mo…
jayfurmanek Apr 16, 2026
f282c6a
Add FP8 training-loop tests (optimizer.step drives weight updates)
jayfurmanek Apr 16, 2026
a5058fa
Reject FP8 attention flags cleanly (fp8_dpa/fp8_mha) in lite mode
jayfurmanek Apr 16, 2026
8144ee5
Fix GroupedLinear bf16 in lite mode and add coverage
jayfurmanek Apr 16, 2026
928d29d
Update _lite README — correct LayerNormLinear/MLP, fused act+quant, F…
jayfurmanek Apr 16, 2026
0738081
Wire FSDPAGTensor emission into lite LayerNormLinear/LayerNormMLP
jayfurmanek Apr 16, 2026
9445691
Fix IS_HIP_EXTENSION detection in get_frameworks
jayfurmanek Apr 17, 2026
2c2637d
Skip ROCm framework validation in lite-only mode
jayfurmanek Apr 17, 2026
2895bfc
Report real version for lite-only installs
jayfurmanek Apr 17, 2026
78d15d5
Always return a 2-tuple from lite multi_tensor_l2norm
jayfurmanek Apr 17, 2026
b3c5f85
Fix lite multi_tensor_adam list order, master weights, L2 path
jayfurmanek Apr 17, 2026
3c878e7
Fix lite multi_tensor_scale and multi_tensor_sgd semantics
jayfurmanek Apr 17, 2026
c31b1bb
Add TestMultiTensor coverage in test_lite.py
jayfurmanek Apr 17, 2026
d68f81d
Honor C++ truncation semantics in lite multi_tensor_adam
jayfurmanek Apr 17, 2026
2f985d8
Document lite-specific env vars in _lite README
jayfurmanek Apr 17, 2026
3aee5b2
Honor output_dtype in lite generic_gemm PyTorch fallback
jayfurmanek Apr 17, 2026
1c5279b
Remove CPU-GPU syncs from lite FP8 amax/scale updates
jayfurmanek Apr 20, 2026
c102e13
Add lite dispatch probes and AITER fused-quant path fallback
jayfurmanek Apr 21, 2026
78ad2d7
Take Triton cast path for lite Float8 rowwise-only quantize
jayfurmanek Apr 21, 2026
8b31e82
Avoid float8_copy_kernel in lite GEMM operand transpose
jayfurmanek Apr 21, 2026
816472f
Mark transpose stale after lite fused RMSNorm+FP8 quant
jayfurmanek Apr 21, 2026
396a452
Revert "Mark transpose stale after lite fused RMSNorm+FP8 quant"
jayfurmanek Apr 21, 2026
53858c8
Revert "Avoid float8_copy_kernel in lite GEMM operand transpose"
jayfurmanek Apr 21, 2026
2b562c3
Revert "Revert "Avoid float8_copy_kernel in lite GEMM operand transpo…
jayfurmanek Apr 21, 2026
b5f90ce
Revert "Revert "Mark transpose stale after lite fused RMSNorm+FP8 qua…
jayfurmanek Apr 21, 2026
7284fe9
Track post-RMSNorm amax in lite fused FP8 delayed scaling
jayfurmanek Apr 21, 2026
bb0b152
Route per-row FP8 GEMMs to CK in lite dispatcher
jayfurmanek Apr 22, 2026
43c39d7
Add LITE-GEMM dispatch counter probe
jayfurmanek Apr 22, 2026
2c34e72
Instrument every CK dispatcher exit in LITE-GEMM probe
jayfurmanek Apr 22, 2026
8fbdc6e
Log shapes and message on first 5 CK GEMM RuntimeErrors
jayfurmanek Apr 22, 2026
73692b4
Route mixed-dtype FP8 GEMMs to torch._scaled_mm
jayfurmanek Apr 22, 2026
d2c785e
Revert "Route mixed-dtype FP8 GEMMs to torch._scaled_mm"
jayfurmanek Apr 22, 2026
eac04dd
Pad M to next power of 2 in lite AITER Triton FP8 GEMM
jayfurmanek Apr 22, 2026
f289c21
Fuse lite delayed-scaling amax/scale update into one Triton kernel
jayfurmanek Apr 23, 2026
c36b0ad
Assert K-innermost on FP8 operands before AITER Triton a8w8 dispatch
jayfurmanek Apr 23, 2026
ccb1f30
Revert "Pad M to next power of 2 in lite AITER Triton FP8 GEMM"
jayfurmanek Apr 23, 2026
3019e7c
Gate LITE dispatch counters and one-shot diags behind NVTE_LITE_DIAG
jayfurmanek Apr 23, 2026
6da4812
Route FP8 GEMMs through torch._scaled_mm in the PyTorch fallback path
jayfurmanek Apr 23, 2026
e827280
Fall back to AITER when _scaled_mm rejects under NVTE_LITE_GEMM_BACKE…
jayfurmanek Apr 23, 2026
4968168
Log first 5 torch._scaled_mm rejections with shape/dtype/scale context
jayfurmanek Apr 23, 2026
3ed9d8a
Pad mat1 M to div-by-16 for torch._scaled_mm hipBLASLt alignment
jayfurmanek Apr 23, 2026
5a660e9
Pass per-tensor FP8 scales as 0-dim scalars to torch._scaled_mm
jayfurmanek Apr 23, 2026
e8f0c5f
Switch default NVTE_LITE_GEMM_BACKEND from ck to pytorch
jayfurmanek Apr 24, 2026
fa14e3f
Add GEMM backend-matrix and dispatch-path tests
jayfurmanek Apr 24, 2026
03a08fd
Use keyword args for aiter _flash_attn_forward bshd call
jayfurmanek Apr 24, 2026
3f5d44b
skip FP8 dgrad round-trip
jayfurmanek Apr 28, 2026
cb34efb
Move grouped GEMM dispatcher into _lite/grouped_gemm.py
jayfurmanek Apr 28, 2026
eaf2ec0
Short-circuit empty-token grouped GEMM in lite dispatcher
jayfurmanek Apr 28, 2026
50bfb3f
Stop passing cu_seqlens for bshd/sbhd aiter fwd
jayfurmanek Apr 29, 2026
e80bdfd
Add one-shot fwd-args probe to lite aiter attention
jayfurmanek Apr 29, 2026
6dfbb17
Promote lite attn fwd probe to a permanent one-shot diag
jayfurmanek Apr 29, 2026
e4a05c5
Drop .contiguous() in lite _to_bshd sbhd->bshd path
jayfurmanek Apr 29, 2026
c62e977
Revert "Drop .contiguous() in lite _to_bshd sbhd->bshd path"
jayfurmanek Apr 29, 2026
055dada
Add NVTE_LITE_DIAG probe to identify non-contig input producers
jayfurmanek Apr 29, 2026
66dd440
LITE_DIAG noncontig probe: skip contextlib, capture 3 frames
jayfurmanek Apr 29, 2026
8f5e8c0
LITE_DIAG noncontig probe: skip wrapper frames, deepen stack
jayfurmanek Apr 29, 2026
993dcd3
Add NVTE_LITE_SKIP_NONCONTIG bypass in prepare_forward
jayfurmanek Apr 29, 2026
1bc68c3
Revert "Add NVTE_LITE_SKIP_NONCONTIG bypass in prepare_forward"
jayfurmanek Apr 29, 2026
dce41ed
Add NVTE_CONTIG_DIAG harness for full vs lite materialize attribution
jayfurmanek Apr 30, 2026
b74f420
Update tealite README: new env vars, grouped GEMM, scaled_mm default
jayfurmanek May 5, 2026
43d8efb
Add TestLitePerRowFP8: end-to-end coverage for per-row FP8 path
jayfurmanek May 5, 2026
28a4391
Add tealite SKILLS.md: operational notes complementing the README
jayfurmanek May 6, 2026
93a1b4a
Update tealite README: relocate distributed-parallelism rows to Commu…
jayfurmanek May 7, 2026
52f1b93
Add tealite ASCII logo and tagline to README header
jayfurmanek May 7, 2026
030b792
Merge remote-tracking branch 'origin/dev' into furmanek/dev-lite
jayfurmanek May 8, 2026
5f68f5c
Remove contiguous/non-continguous diag harness
jayfurmanek May 8, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -58,4 +58,5 @@ artifacts/
**/times.csv
transformer_engine/build_info.txt
transformer_engine/common/util/hip_nvml.*
transformer_engine/LITE_BUILD
*.DS_Store
7 changes: 4 additions & 3 deletions build_tools/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,13 +415,14 @@ def get_frameworks() -> List[str]:
if framework not in supported_frameworks:
raise ValueError(f"Transformer Engine does not support framework={framework}")

if rocm_build():
if rocm_build() and not bool(int(os.getenv("NVTE_LITE_ONLY", "0"))):
_unsupported_frameworks = []
if "pytorch" in _frameworks:
try:
from torch.utils.cpp_extension import IS_HIP_EXTENSION
import torch.utils.cpp_extension
IS_HIP_EXTENSION = getattr(torch.utils.cpp_extension, "IS_HIP_EXTENSION", False)
except ImportError:
IS_HIP_EXTENSION=False
IS_HIP_EXTENSION = False
if not IS_HIP_EXTENSION:
if "pytorch" in _requested_frameworks:
_unsupported_frameworks.append("pytorch")
Expand Down
25 changes: 22 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ class HipifyMeta(egg_info):
"""Custom egg_info command to hipify source files before packaging."""

def run(self):
if rocm_build():
if rocm_build() and not bool(int(os.getenv("NVTE_LITE_ONLY", "0"))):
from build_tools.hipify.hipify import do_hipify
print("Running hipification of installable headers for ROCm build...")
do_hipify(current_file_path, current_file_path / "transformer_engine/common/include")
Expand Down Expand Up @@ -229,7 +229,8 @@ def git_check_submodules() -> None:
if __name__ == "__main__":
__version__ = te_version()

git_check_submodules()
if not bool(int(os.getenv("NVTE_LITE_ONLY", "0"))):
git_check_submodules()

with open("README.rst", encoding="utf-8") as f:
long_description = f.read()
Expand All @@ -256,6 +257,23 @@ def git_check_submodules() -> None:
"rocm_pytorch": [f"transformer_engine_rocm7[pytorch]=={__version__}"],
"rocm_jax": [f"transformer_engine_rocm7[jax]=={__version__}"],
}
elif bool(int(os.getenv("NVTE_LITE_ONLY", "0"))):
# Lite-only build: no C++ compilation, pure Python + Triton kernels.
# Builds in seconds. NVTE_LITE=1 is forced at import time via marker file.
install_requires, test_requires = setup_requirements()
ext_modules = []
cmdclass = {"bdist_wheel": TimedBdist}
package_data = {
"": ["VERSION.txt", "LITE_BUILD"],
"transformer_engine.pytorch.triton_kernels.gmm": ["configs/*.json"],
}
include_package_data = True
extras_require = {"test": test_requires}

# Write marker file so import-time code knows this is a lite-only wheel
marker_path = current_file_path / "transformer_engine" / "LITE_BUILD"
marker_path.write_text("This is a lite-only build. NVTE_LITE=1 is forced.\n")
PACKAGE_NAME = "tealite"
else:
install_requires, test_requires = setup_requirements()
ext_modules = [setup_common_extension()]
Expand Down Expand Up @@ -289,7 +307,8 @@ def git_check_submodules() -> None:
)
)

PACKAGE_NAME="transformer_engine"
if not bool(int(os.getenv("NVTE_LITE_ONLY", "0"))):
PACKAGE_NAME="transformer_engine"
if (rocm_build() and bool(int(os.getenv("NVTE_RELEASE_BUILD", "0")))
and not bool(int(os.getenv("NVTE_BUILD_METAPACKAGE", "0"))) ):
PACKAGE_NAME=f"transformer_engine_rocm{rocm_version()[0]}"
Expand Down
290 changes: 290 additions & 0 deletions tests/pytorch/attention/run_lite_cp_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,290 @@
# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved.
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

"""Multi-process worker for testing context parallelism in lite mode.

This script is launched via torch.distributed.launch with >= 2 GPUs.
It runs DotProductAttention with and without CP, then compares outputs
and gradients.

Only BSHD and SBHD formats are tested (THD requires C++ thd_* helpers
that are not yet implemented in lite mode).
"""

import logging
import os
import pathlib
import sys

os.environ["NVTE_LITE"] = "1"

# Ensure repo root is on sys.path for dev-tree runs (no pip install)
_repo_root = str(pathlib.Path(__file__).resolve().parent.parent.parent.parent)
if _repo_root not in sys.path:
sys.path.insert(0, _repo_root)

import torch
import torch.distributed as dist

from transformer_engine.pytorch import DotProductAttention


logging.basicConfig(level=logging.INFO, format="[%(levelname)s] %(message)s")


# ---------------------------------------------------------------------------
# Configs
# ---------------------------------------------------------------------------

class CPTestConfig:
"""Minimal model config for CP tests."""

def __init__(
self,
batch_size,
max_seqlen,
num_heads,
head_dim,
num_gqa_groups=None,
attn_mask_type="causal",
):
self.batch_size = batch_size
self.max_seqlen = max_seqlen
self.num_heads = num_heads
self.head_dim = head_dim
self.num_gqa_groups = num_heads if num_gqa_groups is None else num_gqa_groups


TEST_CONFIGS = {
"mha_causal": CPTestConfig(2, 1024, 8, 64, attn_mask_type="causal"),
"gqa_causal": CPTestConfig(2, 1024, 8, 64, num_gqa_groups=2, attn_mask_type="causal"),
"mha_no_mask": CPTestConfig(2, 1024, 8, 64, attn_mask_type="no_mask"),
"gqa_no_mask": CPTestConfig(2, 1024, 8, 64, num_gqa_groups=2, attn_mask_type="no_mask"),
}


# ---------------------------------------------------------------------------
# DualChunkSwap partitioning for BSHD / SBHD
# ---------------------------------------------------------------------------

def partition_for_cp(tensor, qkv_format, rank, world_size):
"""Partition a tensor along the sequence dimension using DualChunkSwap.

Each rank gets 2 chunks: [rank] and [2*world_size - rank - 1].
"""
seq_dim = qkv_format.index("s")
shape = list(tensor.shape)
chunk_size = shape[seq_dim] // (2 * world_size)
new_shape = shape[:seq_dim] + [2 * world_size, chunk_size] + shape[seq_dim + 1:]
tensor = tensor.view(*new_shape)
seq_idx = torch.tensor([rank, 2 * world_size - rank - 1], device=tensor.device)
tensor = tensor.index_select(seq_dim, seq_idx)
final_shape = shape[:seq_dim] + [2 * chunk_size] + shape[seq_dim + 1:]
return tensor.reshape(*final_shape).contiguous()


def partition_dout(dout, qkv_format, rank, world_size):
"""Partition dout (output gradient) for CP comparison.

dout shape from DPA is (b, s, h*d) for bshd or (s, b, h*d) for sbhd.
"""
seq_dim = 0 if qkv_format == "sbhd" else 1
shape = list(dout.shape)
chunk_size = shape[seq_dim] // (2 * world_size)
new_shape = shape[:seq_dim] + [2 * world_size, chunk_size] + shape[seq_dim + 1:]
dout = dout.view(*new_shape)
seq_idx = torch.tensor([rank, 2 * world_size - rank - 1], device=dout.device)
dout = dout.index_select(seq_dim, seq_idx)
final_shape = shape[:seq_dim] + [2 * chunk_size] + shape[seq_dim + 1:]
return dout.reshape(*final_shape).contiguous()


# ---------------------------------------------------------------------------
# Core test logic
# ---------------------------------------------------------------------------

def run_test(
config_name,
qkv_format,
cp_comm_type,
attn_mask_type,
dtype_str="bf16",
):
"""Run a single CP vs no-CP comparison test using DotProductAttention."""
# Initialize distributed process group
local_rank = int(os.environ.get("LOCAL_RANK", 0))
torch.cuda.set_device(local_rank)
dist.init_process_group(backend="nccl")

rank = dist.get_rank()
world_size = dist.get_world_size()
device = torch.device(f"cuda:{local_rank}")

config = TEST_CONFIGS[config_name]
dtype = {"bf16": torch.bfloat16, "fp16": torch.float16}[dtype_str]

b = config.batch_size
s = config.max_seqlen
h_q = config.num_heads
h_kv = config.num_gqa_groups
d = config.head_dim

assert s % (2 * world_size) == 0, (
f"seqlen ({s}) must be divisible by 2*cp_size ({2 * world_size})"
)

# Generate full inputs -- same across all ranks (seeded)
torch.manual_seed(42)
torch.cuda.manual_seed(42)

if qkv_format == "bshd":
q_shape = (b, s, h_q, d)
k_shape = (b, s, h_kv, d)
v_shape = (b, s, h_kv, d)
elif qkv_format == "sbhd":
q_shape = (s, b, h_q, d)
k_shape = (s, b, h_kv, d)
v_shape = (s, b, h_kv, d)
else:
raise ValueError(f"Unsupported qkv_format: {qkv_format}")

q_orig = torch.randn(q_shape, dtype=dtype, device=device)
k_orig = torch.randn(k_shape, dtype=dtype, device=device)
v_orig = torch.randn(v_shape, dtype=dtype, device=device)

# DPA output shape is (b, s, h*d) for bshd, (s, b, h*d) for sbhd
if qkv_format == "bshd":
dout_shape = (b, s, h_q * d)
else:
dout_shape = (s, b, h_q * d)
dout_orig = torch.randn(dout_shape, dtype=dtype, device=device)

# ============== Run WITHOUT CP ==============
core_attn = DotProductAttention(
h_q, d, num_gqa_groups=h_kv, attention_dropout=0.0,
qkv_format=qkv_format, attn_mask_type=attn_mask_type,
).cuda()

q, k, v = [x.clone().detach().requires_grad_(True) for x in [q_orig, k_orig, v_orig]]
dout = dout_orig.clone().detach()

out = core_attn(q, k, v)
out.backward(dout)
dq, dk, dv = q.grad, k.grad, v.grad

# ============== Run WITH CP ==============
# Set up communication group
cp_comm_ranks = list(range(world_size))
cp_group = dist.new_group(cp_comm_ranks, backend="nccl")
cp_stream = torch.cuda.Stream(device=device)

# Partition inputs for this rank using DualChunkSwap
q_, k_, v_ = [
partition_for_cp(x, qkv_format, rank, world_size).clone().detach().requires_grad_(True)
for x in [q_orig, k_orig, v_orig]
]
dout_ = partition_dout(dout_orig, qkv_format, rank, world_size)

# Configure CP on the attention module
core_attn.set_context_parallel_group(cp_group, cp_comm_ranks, cp_stream, cp_comm_type)

out_ = core_attn(q_, k_, v_)
out_.backward(dout_)
dq_, dk_, dv_ = q_.grad, k_.grad, v_.grad

# ============== Validate ==============
# Check no NaN/Inf
for name, t in [("out_cp", out_), ("dq_cp", dq_), ("dk_cp", dk_), ("dv_cp", dv_)]:
assert torch.all(torch.isfinite(t)), f"Rank {rank}: {name} contains NaN or Inf!"

# Slice reference to match this rank's CP partition
seq_dim = qkv_format.index("s")

# For Q-side tensors (out, dq): partition ref the same as Q was partitioned
# DPA output is (b, s, h*d) / (s, b, h*d) -- seq_dim is 1 / 0
out_seq_dim = 1 if qkv_format == "bshd" else 0

def slice_ref(ref_tensor, local_tensor, s_dim):
"""Slice full reference tensor to match this rank's DualChunkSwap partition."""
shape = list(ref_tensor.shape)
chunk_size = shape[s_dim] // (2 * world_size)
new_shape = shape[:s_dim] + [2 * world_size, chunk_size] + shape[s_dim + 1:]
ref_chunked = ref_tensor.view(*new_shape)
seq_idx = torch.tensor([rank, 2 * world_size - rank - 1], device=ref_tensor.device)
ref_sliced = ref_chunked.index_select(s_dim, seq_idx)
local_reshaped = local_tensor.view(*ref_sliced.shape)
return ref_sliced, local_reshaped

# Tolerances
if dtype_str == "bf16":
if h_q == h_kv:
atol, rtol = 2.5e-2, 2.5e-2
else:
atol, rtol = 3.5e-2, 3.5e-2
else:
atol, rtol = 5e-3, 5e-3

# Compare output and Q-side grads (use output seq_dim since DPA reshapes)
for name, ref_full, cp_local in [("out", out, out_), ("dq", dq, dq_)]:
s_dim = out_seq_dim if name == "out" else seq_dim
ref_s, cp_s = slice_ref(ref_full, cp_local, s_dim)

for ci in range(2):
if s_dim == 1: # bshd
rc = ref_s[:, ci]
cc = cp_s[:, ci]
else: # sbhd
rc = ref_s[ci]
cc = cp_s[ci]

try:
torch.testing.assert_close(rc, cc, atol=atol, rtol=rtol)
except AssertionError:
diff = (rc.float() - cc.float()).abs()
rmse = diff.pow(2).mean().sqrt().item()
val_range = max(rc.abs().max().item(), cc.abs().max().item(), 1e-6)
assert rmse < 0.02 * val_range, (
f"Rank {rank}: {name} chunk {ci} RMSE {rmse:.6f} > "
f"tol {0.02 * val_range:.6f}"
)

# Compare K/V-side grads
for name, ref_full, cp_local in [("dk", dk, dk_), ("dv", dv, dv_)]:
ref_s, cp_s = slice_ref(ref_full, cp_local, seq_dim)

for ci in range(2):
if seq_dim == 1:
rc = ref_s[:, ci]
cc = cp_s[:, ci]
else:
rc = ref_s[ci]
cc = cp_s[ci]

try:
torch.testing.assert_close(rc, cc, atol=atol, rtol=rtol)
except AssertionError:
diff = (rc.float() - cc.float()).abs()
rmse = diff.pow(2).mean().sqrt().item()
val_range = max(rc.abs().max().item(), cc.abs().max().item(), 1e-6)
assert rmse < 0.02 * val_range, (
f"Rank {rank}: {name} chunk {ci} RMSE {rmse:.6f} > "
f"tol {0.02 * val_range:.6f}"
)

logging.info(
f"Rank {rank}: PASSED -- config={config_name} fmt={qkv_format} "
f"comm={cp_comm_type} mask={attn_mask_type} dtype={dtype_str}"
)

dist.destroy_process_group()


# ---------------------------------------------------------------------------
# Entry point
# ---------------------------------------------------------------------------

if __name__ == "__main__":
kwargs = dict(arg.split("=") for arg in sys.argv[2:])
run_test(**kwargs)
Loading