Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
1ececdc
[PyTorch] Add pad_between_seqs support for A2A and P2P CP with FA3 + THD
sudhakarsingh27 Mar 10, 2026
e338049
Merge branch 'main' of github.com:NVIDIA/TransformerEngine into flash…
sudhakarsingh27 Mar 10, 2026
fb27e0c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 10, 2026
50839e1
Merge branch 'main' of github.com:NVIDIA/TransformerEngine into flash…
sudhakarsingh27 Mar 12, 2026
5c10658
[PyTorch] Add non-CP pad_between_seqs test support for FlashAttention
sudhakarsingh27 Mar 12, 2026
66e3352
Merge branch 'main' of github.com:NVIDIA/TransformerEngine into flash…
sudhakarsingh27 Mar 12, 2026
d8abce2
fixes from feedback
sudhakarsingh27 Mar 20, 2026
41e431a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 20, 2026
9efa48f
remove redundant condition
sudhakarsingh27 Mar 20, 2026
c8a84bf
Merge branch 'flash_attn_pad_bw_seqs' of github.com:sudhakarsingh27/T…
sudhakarsingh27 Mar 20, 2026
8652dba
Merge branch 'main' of github.com:NVIDIA/TransformerEngine into flash…
sudhakarsingh27 Mar 20, 2026
232b78d
remove unnecessary zeroing logic, fixes from other feedback
sudhakarsingh27 Mar 21, 2026
73f989c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 21, 2026
0228d08
Merge branch 'main' of github.com:NVIDIA/TransformerEngine into flash…
sudhakarsingh27 Mar 23, 2026
7b8bc13
Merge branch 'flash_attn_pad_bw_seqs' of github.com:sudhakarsingh27/T…
sudhakarsingh27 Mar 23, 2026
355252c
add the flag to skip flash attn3 for head_dim_qk>128
sudhakarsingh27 Mar 24, 2026
e02ab58
Merge branch 'main' of github.com:NVIDIA/TransformerEngine into flash…
sudhakarsingh27 Mar 24, 2026
d596db0
Merge branch 'main' of github.com:NVIDIA/TransformerEngine into flash…
sudhakarsingh27 Mar 25, 2026
417b318
fix kv cache block size issue for FA2
sudhakarsingh27 Mar 25, 2026
0530153
add a skip when trying to run FA3 on SM100+
sudhakarsingh27 Mar 25, 2026
bba807a
Merge branch 'main' of github.com:NVIDIA/TransformerEngine into flash…
sudhakarsingh27 Apr 6, 2026
81697e1
Merge branch 'main' of github.com:NVIDIA/TransformerEngine into flash…
sudhakarsingh27 Apr 7, 2026
6104ede
Merge branch 'main' into flash_attn_pad_bw_seqs
sudhakarsingh27 Apr 8, 2026
c50975b
add CP tests and deterministic runs to L3; also make it on par with e…
sudhakarsingh27 Apr 8, 2026
872a2ad
Merge branch 'main' of github.com:NVIDIA/TransformerEngine into flash…
sudhakarsingh27 Apr 8, 2026
fd24692
Merge branch 'flash_attn_pad_bw_seqs' of github.com:sudhakarsingh27/T…
sudhakarsingh27 Apr 8, 2026
2f3528a
Fix zero input shape for bgrad_group_quantize (#2854)
vthumbe1503 Apr 8, 2026
3192ce2
fix test skips for FA3 pad_between_seqs and deterministic CP tests
sudhakarsingh27 Apr 9, 2026
6035007
Merge branch 'main' of github.com:NVIDIA/TransformerEngine into flash…
sudhakarsingh27 Apr 9, 2026
fbe58ab
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 9, 2026
ff9bf4e
fix flash_attn_supported override for cross-attention causal mask
sudhakarsingh27 Apr 9, 2026
b0cdb4b
Merge branch 'flash_attn_pad_bw_seqs' of https://github.com/sudhakars…
sudhakarsingh27 Apr 9, 2026
3d71c35
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 9, 2026
dc6ccd5
Merge branch 'main' of https://github.com/NVIDIA/TransformerEngine in…
sudhakarsingh27 Apr 9, 2026
5fceae2
Merge branch 'flash_attn_pad_bw_seqs' of https://github.com/sudhakars…
sudhakarsingh27 Apr 9, 2026
7fa790d
[Common] Fix: IMA in `register_user_buffer_collective` on non-SM90 GP…
phu0ngng Apr 9, 2026
2ccc8ef
Simplify FA3 discovery (#2849)
vcherepanov-nv Apr 9, 2026
285c1eb
[PyTorch] Support scaled + clamped SwiGLU in `te.ops` and enable fuse…
ksivaman Apr 9, 2026
e9dccd8
[JAX] Fix BF16 tolerance for CGEMM + RS + BF16 test (#2860)
phu0ngng Apr 9, 2026
f278ce4
add high precision init weights to fully_shard example (#2785)
pstjohn Apr 9, 2026
18c802a
fix flash_attn_supported override for large head_dim configs
sudhakarsingh27 Apr 9, 2026
0f48ebc
fix merge conflicts
sudhakarsingh27 Apr 9, 2026
1c9325c
Merge branch 'main' of github.com:NVIDIA/TransformerEngine into flash…
sudhakarsingh27 Apr 10, 2026
8ddce7a
remove the non-determinism tests from L3 for now
sudhakarsingh27 Apr 15, 2026
d397380
run tests in parallel if multiple GPUs are available
sudhakarsingh27 Apr 15, 2026
ae8884b
Merge branch 'main' of github.com:NVIDIA/TransformerEngine into flash…
sudhakarsingh27 Apr 15, 2026
0177653
resolve merge conflicts with main
sudhakarsingh27 Apr 20, 2026
10f736c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 20, 2026
d3e310a
[PyTorch] Fix pad_between_seqs batch boundary alignment in non-CP FA3…
sudhakarsingh27 Apr 22, 2026
c514a97
[QA] Support TE_PATH positional arg and fix GPU threshold in FA test
sudhakarsingh27 Apr 22, 2026
5120631
[PyTorch] Disable UnfusedDotProductAttention for pad_between_seqs
sudhakarsingh27 Apr 22, 2026
2da11dc
Merge branch 'flash_attn_pad_bw_seqs' of github.com:sudhakarsingh27/T…
sudhakarsingh27 Apr 23, 2026
9652abb
[PyTorch] Zero FA3 padding garbage in CP forward path
sudhakarsingh27 Apr 23, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions qa/L1_pytorch_distributed_unittest/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_torch_fsdp2.xml $TE_
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_comm_gemm_overlap.xml $TE_PATH/tests/pytorch/distributed/test_comm_gemm_overlap.py || test_fail "test_comm_gemm_overlap.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops_with_userbuffers.xml $TE_PATH/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py || test_fail "test_fusible_ops_with_userbuffers.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_attention_with_cp.xml $TE_PATH/tests/pytorch/attention/test_attention_with_cp.py || test_fail "test_attention_with_cp.py"
NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_attention_deterministic_with_cp.xml $TE_PATH/tests/pytorch/attention/test_attention_with_cp.py || test_fail "NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 test_attention_with_cp.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_cp_utils.xml $TE_PATH/tests/pytorch/attention/test_cp_utils.py || test_fail "test_cp_utils.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_cast_master_weights_to_fp8.xml $TE_PATH/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py || test_fail "test_cast_master_weights_to_fp8.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_newton_schulz.xml $TE_PATH/tests/pytorch/distributed/test_newton_schulz.py || test_fail "test_newton_schulz.py"
Expand Down
47 changes: 44 additions & 3 deletions qa/L3_pytorch_FA_versions_test/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,25 @@
#
# See LICENSE for license information.

set -e
function error_exit() {
echo "Error: $1"
exit 1
}

function test_fail() {
RET=1
FAILED_CASES="$FAILED_CASES $1"
echo "Error: sub-test failed: $1"
}

RET=0
FAILED_CASES=""

: ${TE_PATH:=/opt/transformerengine}
: ${XML_LOG_DIR:=/logs}
mkdir -p "$XML_LOG_DIR"

pip3 install pytest==8.2.1
pip3 install pytest==8.2.1 || error_exit "Failed to install pytest"

# Limit parallel build jobs to avoid overwhelming system resources
export MAX_JOBS=32
Expand Down Expand Up @@ -41,6 +53,35 @@ do
fi

# Run tests
NVTE_TORCH_COMPILE=0 python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest.xml $TE_PATH/tests/pytorch/attention/test_attention.py
NUM_GPUS=$(nvidia-smi -L | wc -l)
echo "Detected $NUM_GPUS GPU(s)"
if [ "$NUM_GPUS" -ge 4 ]; then
CP_NUM_GPUS=$(( NUM_GPUS - 1 > 4 ? 4 : NUM_GPUS - 1 ))
CP_GPUS=$(seq -s, 1 $CP_NUM_GPUS)
echo "Running tests in parallel: test_attention.py on GPU 0, test_attention_with_cp.py on GPUs $CP_GPUS ($CP_NUM_GPUS GPUs)"

CUDA_VISIBLE_DEVICES=0 NVTE_TORCH_COMPILE=0 python3 -m pytest -v -s \
--junitxml=$XML_LOG_DIR/pytest.xml \
$TE_PATH/tests/pytorch/attention/test_attention.py &
PID_ATTN=$!

CUDA_VISIBLE_DEVICES=$CP_GPUS NVTE_TORCH_COMPILE=0 python3 -m pytest -v -s \
--junitxml=$XML_LOG_DIR/pytest_test_attention_with_cp.xml \
$TE_PATH/tests/pytorch/attention/test_attention_with_cp.py &
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The a2a+p2p tests need 4 GPUs right? so you might need to budget a bit more than "$NUM_GPUS" -ge 3?

PID_CP=$!

wait $PID_ATTN || test_fail "test_attention.py"
wait $PID_CP || test_fail "test_attention_with_cp.py"
else
echo "Running tests sequentially: need >=3 GPUs for parallel execution (1 for test_attention + 2 for test_attention_with_cp)"
NVTE_TORCH_COMPILE=0 python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest.xml $TE_PATH/tests/pytorch/attention/test_attention.py || test_fail "test_attention.py"
NVTE_TORCH_COMPILE=0 python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_attention_with_cp.xml $TE_PATH/tests/pytorch/attention/test_attention_with_cp.py || test_fail "test_attention_with_cp.py"
fi
done

if [ "$RET" -ne 0 ]; then
echo "Error in the following test cases:$FAILED_CASES"
exit 1
fi
echo "All tests passed"
exit 0
86 changes: 46 additions & 40 deletions tests/pytorch/attention/run_attention_with_cp.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def generate_input_shapes(
config: ModelConfig,
world_size: int,
kernel_backend: str,
pad_between_seqs: str = "False",
):
if qkv_format == "bshd":
q_input_shape = (
Expand Down Expand Up @@ -99,9 +100,9 @@ def generate_input_shapes(
).cuda()
cu_seqlens_q = torch.clone(cu_seqlens_q_padded)

# Since FlashAttention doesn't support pad b/w sequences, and FusedAttention does,
# cu_seqlens_q is updated to reflect non-padded lengths for FusedAttention only.
if kernel_backend == "FusedAttention":
# When pad_between_seqs is True, or for FusedAttention, cu_seqlens_q reflects
# non-padded (actual) lengths. FA3 handles this via seqused_q/seqused_k.
if kernel_backend == "FusedAttention" or pad_between_seqs == "True":
cu_seqlens_q[1:] = seqlens_q.cumsum(0, dtype=torch.int32).cuda()

# NOTE: In case of Cross-Attention, `cu_seqlens_kv` and `cu_seqlens_kv_padded`
Expand Down Expand Up @@ -180,6 +181,7 @@ def run_dpa_with_cp(
scaling_mode="delayed",
f16_O="False",
is_training="True",
pad_between_seqs="False",
log_level=logging.WARNING,
):
"""Test DotProductAttention module with context parallelism"""
Expand Down Expand Up @@ -275,7 +277,7 @@ def run_dpa_with_cp(
cu_seqlens_kv,
cu_seqlens_q_padded,
cu_seqlens_kv_padded,
) = generate_input_shapes(qkv_format, config, world_size, kernel_backend)
) = generate_input_shapes(qkv_format, config, world_size, kernel_backend, pad_between_seqs)
q_orig = torch.clamp(torch.randn(q_input_shape, dtype=dtypes[dtype]), min=-1, max=1).cuda()
k_orig = torch.clamp(torch.randn(k_input_shape, dtype=dtypes[dtype]), min=-1, max=1).cuda()
v_orig = torch.clamp(torch.randn(v_input_shape, dtype=dtypes[dtype]), min=-1, max=1).cuda()
Expand Down Expand Up @@ -351,6 +353,7 @@ def run_dpa_with_cp(
cu_seqlens_kv=cu_seqlens_kv,
cu_seqlens_q_padded=cu_seqlens_q_padded,
cu_seqlens_kv_padded=cu_seqlens_kv_padded,
pad_between_seqs=pad_between_seqs,
fp8_output=fp8_mha,
)
if config.return_max_logit:
Expand Down Expand Up @@ -494,6 +497,7 @@ def run_dpa_with_cp(

# get outputs
tensors = [out, dq, dk, dv, dbias, out_, dq_, dk_, dv_, dbias_]
tensor_names = ["out", "dq", "dk", "dv", "dbias", "out_", "dq_", "dk_", "dv_", "dbias_"]
if fp8_mha:
tensors_to_deq = [out, out_] if not fp8_bwd else tensors
for i, tensor in enumerate(tensors_to_deq):
Expand All @@ -502,11 +506,11 @@ def run_dpa_with_cp(
tensors_to_deq[i] = tensor.dequantize()
if not fp8_bwd:
tensors[0], tensors[5] = tensors_to_deq
for tensor in tensors:
for tensor, name in zip(tensors, tensor_names):
# dbias/dbias_ could be None, so skip check for it
if tensor is not None:
assert torch.all(~torch.isnan(tensor))
assert torch.all(~torch.isinf(tensor))
assert torch.all(~torch.isnan(tensor)), f"{name} has nan values"
assert torch.all(~torch.isinf(tensor)), f"{name} has inf values"
out, dq, dk, dv, dbias, out_, dq_, dk_, dv_, dbias_ = tensors

############ compare results between CP and no-CP ############
Expand Down Expand Up @@ -559,49 +563,51 @@ def run_dpa_with_cp(
if is_training:
dq, out = [x.index_select(0, seq_idx_q).contiguous() for x in [dq, out]]
dk, dv = [x.index_select(0, seq_idx_kv).contiguous() for x in [dk, dv]]
dq_, dk_, dv_, out_ = [dq_, dk_, dv_, out_]
cu_seqlens_q_padded = cu_seqlens_q_padded // world_size
cu_seqlens_q = get_cu_seqlens_on_cp_rank(
cu_seqlens_q, cu_seqlens_q_padded, world_size, rank, True, True
)
cu_pads_q = cu_seqlens_q_padded - cu_seqlens_q
num_pads_q = cu_pads_q[1:] - cu_pads_q[:-1]
for x in [dq, out, dq_, out_]:
assert torch.count_nonzero(x[cu_seqlens_q_padded[-1] :]).item() == 0
for b in range(config.batch_size):
assert (
num_pads_q[b] == 0
or torch.count_nonzero(
x[
(cu_seqlens_q_padded[b + 1] - num_pads_q[b]) : cu_seqlens_q_padded[
b + 1
]
]
).item()
== 0
)
num_pads_q = (cu_seqlens_q_padded - cu_seqlens_q)[1:] - (
cu_seqlens_q_padded - cu_seqlens_q
)[:-1]
cu_seqlens_kv_padded = cu_seqlens_kv_padded // world_size
cu_seqlens_kv = get_cu_seqlens_on_cp_rank(
cu_seqlens_kv, cu_seqlens_kv_padded, world_size, rank, True, True
)
cu_pads_kv = cu_seqlens_kv_padded - cu_seqlens_kv
num_pads_kv = cu_pads_kv[1:] - cu_pads_kv[:-1]
for x in [dk, dv, dk_, dv_]:
assert torch.count_nonzero(x[cu_seqlens_kv_padded[-1] :]).item() == 0
for b in range(config.batch_size):
assert (
num_pads_kv[b] == 0
or torch.count_nonzero(
x[
(
cu_seqlens_kv_padded[b + 1] - num_pads_kv[b]
) : cu_seqlens_kv_padded[b + 1]
]
).item()
== 0
num_pads_kv = (cu_seqlens_kv_padded - cu_seqlens_kv)[1:] - (
cu_seqlens_kv_padded - cu_seqlens_kv
)[:-1]
# FA3 leaves garbage at padding despite seqused_q/k (tile spillover).
# Zero non-CP tensors for comparison; CP tensors are zeroed in context_parallel.py.
if pad_between_seqs == "True":
for x in [out, dq]:
for b in range(config.batch_size):
x[cu_seqlens_q_padded[b + 1] - num_pads_q[b] : cu_seqlens_q_padded[b + 1]] = 0.0
x[cu_seqlens_q_padded[-1] :] = 0.0
for x in [dk, dv]:
for b in range(config.batch_size):
x[cu_seqlens_kv_padded[b + 1] - num_pads_kv[b] : cu_seqlens_kv_padded[b + 1]] = 0.0
x[cu_seqlens_kv_padded[-1] :] = 0.0
# Verify CP tensors have clean padding (zeroed in context_parallel.py).
for xname, x, cu, np_ in [
("out_", out_, cu_seqlens_q_padded, num_pads_q),
("dq_", dq_, cu_seqlens_q_padded, num_pads_q),
("dk_", dk_, cu_seqlens_kv_padded, num_pads_kv),
("dv_", dv_, cu_seqlens_kv_padded, num_pads_kv),
]:
nnz = torch.count_nonzero(x[cu[-1] :]).item()
assert nnz == 0, (
f"{xname} has {nnz} nonzero values in tail padding — "
f"context_parallel.py should zero padding positions"
)
for b in range(config.batch_size):
if np_[b] > 0:
nnz = torch.count_nonzero(x[cu[b + 1] - np_[b] : cu[b + 1]]).item()
assert nnz == 0, (
f"{xname} has {nnz} nonzero values in batch {b} padding — "
f"context_parallel.py should zero padding positions"
)
else:
# Forward-only: reshape only out/out_ for comparison
out = out.index_select(0, seq_idx_q).contiguous()
out_ = out_

Expand Down
39 changes: 23 additions & 16 deletions tests/pytorch/attention/test_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def reset_global_fp8_state():
@pytest.mark.parametrize("workspace_opt", [True, False])
@pytest.mark.parametrize("qkv_layout", [None])
@pytest.mark.parametrize("swa", [False])
@pytest.mark.parametrize("pad_between_seqs", [False])
@pytest.mark.parametrize("pad_between_seqs", [False, True])
def test_dot_product_attention(
dtype,
model_configs,
Expand Down Expand Up @@ -157,6 +157,8 @@ def test_dot_product_attention(

config.window_size = check_set_window_size(config.attn_mask_type, config.window_size)
qkv_format = qkv_layout.replace("3", "").replace("2", "").split("_")[0]
if pad_between_seqs and qkv_format != "thd":
pytest.skip("pad_between_seqs only applies to THD format!")
if qkv_format == "thd" and "padding" not in config.attn_mask_type:
config.attn_mask_type = (
"padding_" + config.attn_mask_type if config.attn_mask_type != "no_mask" else "padding"
Expand Down Expand Up @@ -195,18 +197,18 @@ def test_dot_product_attention(
)
flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends

# FlashAttention does not support pad_between_seqs, but _run_dot_product_attention
# mannually pads and unpads the input and output of FlashAttention for testing purposes
if (
pad_between_seqs
and FlashAttentionUtils.is_installed
and not (
# FA3 natively supports pad_between_seqs via seqused_q/seqused_k (SM90 only).
# Override flash_attn_supported only for pad_between_seqs=True because
# get_available_attention_backends doesn't know about FA3's seqused support yet.
# For pad_between_seqs=False, trust the backend checker's result as-is.
if pad_between_seqs:
cross_attn_causal = (
config.max_seqlen_q != config.max_seqlen_kv
and config.attn_mask_type in ["causal", "padding_causal"]
)
and (config.window_size[0] == -1 or FlashAttentionUtils.v2_3_plus)
):
flash_attn_supported = True
sm = get_device_compute_capability()
if not cross_attn_causal and FlashAttentionUtils.v3_is_installed and sm == (9, 0):
flash_attn_supported = True
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

get_available_attention_backends doesn't know about FA3's seqused support yet.
It probably knows it by now, as you've finished the PR?

Should the not cross_attn_causal logic also go to get_available_attention_backends in utils.py?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree and this is one of the instances where the flags to switch FA on/off are scattered across utils.py and actual tests.

Should we move the entire block (for pad_bw_seqs) to utils.py?


# Skip if only unfused backend is supported
if (len(fused_attn_backends) + flash_attn_supported + unfused_attn_supported) < 2:
Expand Down Expand Up @@ -1330,12 +1332,12 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker:
block.softmax_offset.requires_grad = True

# Run a forward and backward pass
if backend in ["FlashAttention", "UnfusedDotProductAttention"]:
if backend in ["UnfusedDotProductAttention"]:
q = inp_orig[0]
k = inp_orig[1]
v = inp_orig[2]
d_out = out_grad_orig
if backend == "FusedAttention":
if backend in ["FusedAttention", "FlashAttention"]:
q = inp[0]
k = inp[1]
v = inp[2]
Expand All @@ -1351,14 +1353,19 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker:
max_seqlen_kv=config.max_seqlen_kv,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_kv,
cu_seqlens_q_padded=cu_seqlens_q_after_pad if backend == "FusedAttention" else None,
cu_seqlens_kv_padded=cu_seqlens_kv_after_pad if backend == "FusedAttention" else None,
cu_seqlens_q_padded=(
cu_seqlens_q_after_pad if backend in ["FusedAttention", "FlashAttention"] else None
),
cu_seqlens_kv_padded=(
cu_seqlens_kv_after_pad if backend in ["FusedAttention", "FlashAttention"] else None
),
attn_mask_type=config.attn_mask_type,
checkpoint_core_attention=ckpt_attn,
core_attention_bias_type=config.attn_bias_type,
core_attention_bias=bias,
alibi_slopes=alibi_slopes,
fast_zero_fill=True,
pad_between_seqs=pad_between_seqs,
# Only pass num_splits when exercising the FlashAttention path
num_splits=config.num_splits if backend == "FlashAttention" else 1,
)
Expand All @@ -1372,12 +1379,12 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker:
if is_training and config.softmax_type != "vanilla":
d_softmax_offset = block.softmax_offset.grad

if backend in ["FlashAttention", "UnfusedDotProductAttention"]:
if backend in ["UnfusedDotProductAttention"]:
if is_training:
return out, max_logit, (q.grad, k.grad, v.grad, d_softmax_offset)
else:
return out, max_logit, (None, None, None, d_softmax_offset)
if backend == "FusedAttention":
if backend in ["FusedAttention", "FlashAttention"]:
if qkv_format == "thd" and pad_between_seqs:
out_orig = torch.Tensor([]).to(device="cuda", dtype=dtype)
if is_training:
Expand Down
23 changes: 22 additions & 1 deletion tests/pytorch/attention/test_attention_with_cp.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,11 +85,20 @@ def get_bash_arguments(num_gpus_per_node, **kwargs):
@pytest.mark.parametrize("model", model_configs_flash_attn.keys())
@pytest.mark.parametrize("qkv_format", qkv_formats)
@pytest.mark.parametrize("cp_comm_type", cp_comm_types)
def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type):
@pytest.mark.parametrize("pad_between_seqs", [False, True])
def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type, pad_between_seqs):
num_gpus = 4 if cp_comm_type == "a2a+p2p" else 2
if num_gpus > torch.cuda.device_count():
pytest.skip(f"Test requires {num_gpus} GPUs, but found {torch.cuda.device_count()}")

if pad_between_seqs:
if qkv_format != "thd":
pytest.skip("pad_between_seqs only applies to THD format!")
if not FlashAttentionUtils.v3_is_installed:
pytest.skip("pad_between_seqs with CP requires Flash Attention v3!")
if cp_comm_type == "a2a+p2p":
pytest.skip("pad_between_seqs is not yet supported with A2A+P2P CP comm type!")
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about AG?


config = model_configs_flash_attn[model]
config.context_parallel = True
config.cp_comm_type = cp_comm_type
Expand Down Expand Up @@ -133,6 +142,7 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type):
qkv_format=qkv_format,
kernel_backend="FlashAttention",
cp_comm_type=cp_comm_type,
pad_between_seqs=pad_between_seqs,
log_level=pytest_logging_level,
),
)
Expand Down Expand Up @@ -364,9 +374,20 @@ def test_cp_with_fused_attention(
is_training=is_training,
)
_, fused_attn_supported, _ = available_backends

# Skip any tests if not supported by the configs
if not fused_attn_supported:
pytest.skip("No attention backend available.")

deterministic = not bool(int(os.getenv("NVTE_ALLOW_NONDETERMINISTIC_ALGO", "1")))
if deterministic:
if config.softmax_type != "vanilla":
pytest.skip(
"Deterministic mode does not support non-vanilla softmax with FusedAttention"
)
if config.attn_bias_type == "post_scale_bias" and is_training:
pytest.skip("Deterministic mode does not support post_scale_bias with requires_grad")

run_distributed(
get_bash_arguments(
num_gpus_per_node=num_gpus,
Expand Down
Loading