Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
d2fd002
Changed VERSION to 2.12.0
ptrendx Jan 20, 2026
6add8c9
[Common] Enable determinism for cuDNN >= 9.18.1 on Blackwell (#2584)
cyanguwa Jan 20, 2026
cfabd83
[Common] Tuned NVFP4 cast kernel (#2412)
Oleg-Goncharov Jan 21, 2026
42e803d
Fixed the year to 2026 (#2611)
Oleg-Goncharov Jan 21, 2026
d759aa6
[pyTorch] CPU performance optimizations (#2439)
ptrendx Jan 21, 2026
bf4af7e
[JAX] Fix cb.CUDAOptions usage for Triton 3.6.0 (#2610)
jberchtold-nvidia Jan 22, 2026
f49f515
Fix bugs in permutation custom partitioning (#2617)
tdophung Jan 23, 2026
d9b7fc5
[Common] Disabled the tuned NVFP4 kernels (#2615)
Oleg-Goncharov Jan 23, 2026
07f7750
[PyT] Update THD sink attention logic for cudnn >=9.18.0 (#2568)
cuichenx Jan 22, 2026
fdc0168
Add support for SWA (left, right) with FusedAttention (#2477)
sudhakarsingh27 Jan 22, 2026
3da26cd
[JAX] Use "nyu-mll/glue" instead of "glue" for encoder datasets to fi…
jberchtold-nvidia Jan 27, 2026
cad802f
[PyTorch] ONNX test fix + export for FP8 attention (#2598)
pggPL Jan 28, 2026
9bb9d22
[common] Add support for cuBLASLt GEMM for GroupedTensor (#2502)
pggPL Jan 28, 2026
5671fd3
Revert "[common] Add support for cuBLASLt GEMM for GroupedTensor (#25…
KshitijLakhani Jan 28, 2026
6ab6dfe
Merge remote-tracking branch 'upstream/release_v2.12' into zain/rel-2…
Micky774 Apr 13, 2026
80187b2
Add guards to new functions
Micky774 Apr 14, 2026
6ec90f8
Updated signatures
Micky774 Apr 14, 2026
7911721
Adjusted call sites for deterministic kwd
Micky774 Apr 15, 2026
45d50df
Build corrections and hardening for ptx
Micky774 Apr 16, 2026
b5318e1
Added back rounding error mitigation in comparison
Micky774 Apr 21, 2026
6a51c42
PR feedback
Micky774 Apr 24, 2026
4d27c38
Fix build on Pytorch 2.11 (#16505) (#575)
ipanfilo May 5, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion 3rdparty/cudnn-frontend
Submodule cudnn-frontend updated 102 files
2 changes: 1 addition & 1 deletion build_tools/VERSION.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
2.12.0.dev0
2.12.0
1 change: 0 additions & 1 deletion build_tools/hipify/custom_map.json
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
"__nv_fp8_e5m2" : "te_hip_fp8_e5m2",
"__nv_fp8_e4m3" : "te_hip_fp8_e4m3",
"cuda::getCurrentCUDAStream" : "hip::getCurrentHIPStreamMasqueradingAsCUDA",
"at::cuda::CUDAGuard" : "at::hip::HIPGuardMasqueradingAsCUDA",
"__nv_fp4_e2m1" : "__hip_fp4_e2m1",
"__nv_fp4x2_e2m1" : "__hip_fp4x2_e2m1",
"__nv_fp4x4_e2m1" : "__hip_fp4x4_e2m1",
Expand Down
3 changes: 3 additions & 0 deletions examples/jax/datasets.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# Datasets used by TE encoder tests. Pull these to pre-emptively cache datasets
ylecun/mnist
nyu-mll/glue
1 change: 1 addition & 0 deletions qa/L0_jax_unittest/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ pip3 install pytest==8.2.1 || error_exit "Failed to install pytest"
mkdir -p "$XML_LOG_DIR"

python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_jax_not_distributed.xml $TE_PATH/tests/jax -k 'not distributed' || test_fail "tests/jax/*not_distributed_*"
NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_jax_fused_attn_with_determinism.xml $TE_PATH/tests/jax/test_fused_attn.py -k "TestFusedAttnWithDeterminism" || test_fail "tests/jax/test_fused_attn.py"

pip3 install -r $TE_PATH/examples/jax/mnist/requirements.txt || error_exit "Failed to install mnist requirements"
python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_mnist.xml $TE_PATH/examples/jax/mnist || test_fail "mnist"
Expand Down
1 change: 1 addition & 0 deletions qa/L0_pytorch_unittest/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_parallel_cross_e
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_cpu_offloading.xml $TE_PATH/tests/pytorch/test_cpu_offloading.py || test_fail "test_cpu_offloading.py"
NVTE_FLASH_ATTN=0 NVTE_CPU_OFFLOAD_V1=1 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_cpu_offloading_v1.xml $TE_PATH/tests/pytorch/test_cpu_offloading_v1.py || test_fail "test_cpu_offloading_v1.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_attention.xml $TE_PATH/tests/pytorch/attention/test_attention.py || test_fail "test_attention.py"
NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_attention_deterministic.xml $TE_PATH/tests/pytorch/attention/test_attention.py || test_fail "NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 test_attention.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_kv_cache.xml $TE_PATH/tests/pytorch/attention/test_kv_cache.py || test_fail "test_kv_cache.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_hf_integration.xml $TE_PATH/tests/pytorch/test_hf_integration.py || test_fail "test_hf_integration.py"
NVTE_TEST_CHECKPOINT_ARTIFACT_PATH=$TE_PATH/artifacts/tests/pytorch/test_checkpoint python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_checkpoint.xml $TE_PATH/tests/pytorch/test_checkpoint.py || test_fail "test_checkpoint.py"
Expand Down
3 changes: 2 additions & 1 deletion qa/L1_pytorch_onnx_unittest/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,5 @@
: ${XML_LOG_DIR:=/logs}
mkdir -p "$XML_LOG_DIR"

python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/test_onnx_export.xml $TE_PATH/tests/pytorch/test_onnx_export.py
# NVTE_UnfusedDPA_Emulate_FP8=1 enables FP8 attention emulation when no native backend is available
NVTE_UnfusedDPA_Emulate_FP8=1 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/test_onnx_export.xml $TE_PATH/tests/pytorch/test_onnx_export.py
125 changes: 78 additions & 47 deletions tests/cpp/operator/test_cast_nvfp4_transpose.cu

Large diffs are not rendered by default.

213 changes: 203 additions & 10 deletions tests/jax/test_fused_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#
# See LICENSE for license information.
"""Tests for fused attention"""
import os
from enum import Enum, auto
from dataclasses import dataclass, field
from functools import partial
Expand Down Expand Up @@ -52,6 +53,9 @@
from distributed_test_base import assert_equal_collectives
from utils import assert_allclose, print_debug_tensor_stats

# Get determinism
_deterministic = not bool(int(os.getenv("NVTE_ALLOW_NONDETERMINISTIC_ALGO", "1")))


@pytest.fixture(autouse=True, scope="module")
def init():
Expand Down Expand Up @@ -417,16 +421,24 @@ def _check_configs(self):
pytest.skip(
"seqlen_q > seqlen_kv is not supported with sliding window attention in cuDNN"
)
# TODO(KshitijLakhani): Set the upper limit for skipping this test when cuDNN adds support
if (
get_device_compute_capability(0) >= 100
and self.dropout_prob == 0.1
and self.attn_bias_type is not AttnBiasType.NO_BIAS
and not is_hip_extension()
):
pytest.skip(
"For sm100+, bprop kernel support for dropout + determinism (bias) is not supported"
)
if not is_hip_extension() and get_device_compute_capability(0) >= 100 and self.is_training:
if FusedAttnHelper.is_non_deterministic_allowed() and (
(self.dropout_prob != 0.0 and self.attn_bias_type != AttnBiasType.NO_BIAS)
or get_cudnn_version() < 90700
):
pytest.skip(
"For sm100+, non-deterministic bprop (cuDNN 9.7+) does not support bias with"
" dropout"
)
if not FusedAttnHelper.is_non_deterministic_allowed() and (
self.dropout_prob != 0.0
or self.attn_bias_type != AttnBiasType.NO_BIAS
or get_cudnn_version() < 91801
):
pytest.skip(
"For sm100+, deterministic bprop (cuDNN 9.18.1+) does not support bias or"
" dropout"
)
# Test the MLA case where head dims for qk differ from head dims for v, only if the tensors
# are provided in BSHD_BSHD_BSHD or THD_THD_THD formats
if self.head_dim_qk != self.head_dim_v and not self.qkv_layout.is_separate():
Expand Down Expand Up @@ -1346,6 +1358,7 @@ def check_dqkv(primitive, reference, pad, idx):
pytest.param(SeqDescFormat.SegmentIDs, id="SegmentIDs"),
],
)
@pytest.mark.skipif(_deterministic, reason="Test non-determinism only")
class TestFusedAttn:
"""
Fused attention tester
Expand Down Expand Up @@ -1507,3 +1520,183 @@ def test_jax_new_rng():
)
runner = FusedAttnRunner(**kwargs)
runner.test_forward()



@pytest.mark.parametrize(
"attn_mask_type",
[
pytest.param(AttnMaskType.NO_MASK, id="NO_MASK"),
pytest.param(AttnMaskType.PADDING_MASK, id="PADDING"),
pytest.param(AttnMaskType.CAUSAL_MASK, id="CAUSAL"),
pytest.param(AttnMaskType.PADDING_CAUSAL_MASK, id="PADDING_CAUSAL"),
pytest.param(
AttnMaskType.PADDING_CAUSAL_BOTTOM_RIGHT_MASK, id="PADDING_CAUSAL_BOTTOM_RIGHT"
),
],
)
@pytest.mark.parametrize(
"softmax_type",
[
pytest.param(AttnSoftmaxType.VANILLA_SOFTMAX, id="VANILLA_SOFTMAX"),
],
)
@pytest.mark.parametrize(
"b, s_q, s_kv, h_q, h_kv, d_qk, d_v, dtype, qkv_layout",
[
# large data size + fp16 + cross attn + gqa + diff hidden v dim + qkv separate
pytest.param(
2,
1024,
2048,
12,
6,
128,
64,
jnp.bfloat16,
QKVLayout.BSHD_BSHD_BSHD,
id="2-1024-2048-12-6-128-64-BF16-CROSS-GQA-SEPARATE",
),
pytest.param(
2,
1024,
2048,
12,
6,
128,
64,
jnp.bfloat16,
QKVLayout.THD_THD_THD,
id="2-1024-2048-12-6-128-64-BF16-CROSS-GQA-RAGGED_SEPARATE",
),
],
)
@pytest.mark.parametrize(
"dropout_prob",
[
pytest.param(0.0, id="DROP_0.0"),
],
)
@pytest.mark.parametrize(
"swa",
[
pytest.param(False, id="NO_SWA"),
],
)
@pytest.mark.parametrize(
"seq_desc_format",
[
pytest.param(SeqDescFormat.Seqlens, id="Seqlens"),
],
)
@pytest.mark.skipif(not _deterministic, reason="Test determinism only")
class TestFusedAttnWithDeterminism:
"""
Fused attention tester with determinism
"""

@staticmethod
@pytest.mark.parametrize(
"is_training",
[
pytest.param(True, id="TRAINING"),
],
)
@pytest.mark.parametrize(
"attn_bias_type, bias_shape",
[
pytest.param(AttnBiasType.NO_BIAS, None, id="NO_BIAS"),
pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape._1HSS, id="POST_SCALE_BIAS-1HSS"),
],
)
def _test_forward(
b,
s_q,
s_kv,
h_q,
h_kv,
d_qk,
d_v,
attn_bias_type,
attn_mask_type,
softmax_type,
dropout_prob,
dtype,
is_training,
qkv_layout,
bias_shape,
swa,
seq_desc_format,
):
"""
Test forward with parameterized configs
This test is not intended to run automatically during CI as it is time-consuming
It is kept for development and debugging
"""
TestFusedAttn._test_forward(
b,
s_q,
s_kv,
h_q,
h_kv,
d_qk,
d_v,
attn_bias_type,
attn_mask_type,
softmax_type,
dropout_prob,
dtype,
is_training,
qkv_layout,
bias_shape,
swa,
seq_desc_format,
)

@staticmethod
@pytest.mark.parametrize(
"attn_bias_type, bias_shape",
[
pytest.param(AttnBiasType.NO_BIAS, None, id="NO_BIAS"),
pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape._1HSS, id="POST_SCALE_BIAS-1HSS"),
],
)
def test_backward(
b,
s_q,
s_kv,
h_q,
h_kv,
d_qk,
d_v,
attn_bias_type,
attn_mask_type,
softmax_type,
dropout_prob,
dtype,
qkv_layout,
bias_shape,
swa,
seq_desc_format,
):
"""
Test backward with parameterized configs
"""
TestFusedAttn.test_backward(
b,
s_q,
s_kv,
h_q,
h_kv,
d_qk,
d_v,
attn_bias_type,
attn_mask_type,
softmax_type,
dropout_prob,
dtype,
qkv_layout,
bias_shape,
swa,
seq_desc_format,
)
4 changes: 2 additions & 2 deletions tests/jax/test_permutation.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
(128, 5, 128, 3),
(1024, 8, 128, 8),
(4096, 32, 1280, 2),
(4096, 256, 4096, 6),
(4096, 64, 4096, 6),
]
DISPATCH_COMBINE_CASES = {
"L0": ALL_DISPATCH_COMBINE_CASES[0:2],
Expand All @@ -44,7 +44,7 @@
(128, 5, 128, 3, 8),
(1024, 8, 128, 8, 16),
(4096, 32, 1280, 2, 128),
(4096, 256, 4096, 6, 16),
(4096, 64, 4096, 6, 16),
]
DISPATCH_COMBINE_PADDING_CASES = {
"L0": ALL_DISPATCH_COMBINE_PADDING_CASES[0:2],
Expand Down
Loading