[PyTorch] Add pad_between_seqs support for non-CP and CP (A2A and P2P) with FA3 + THD (varlen)#2596
[PyTorch] Add pad_between_seqs support for non-CP and CP (A2A and P2P) with FA3 + THD (varlen)#2596sudhakarsingh27 wants to merge 23 commits intoNVIDIA:mainfrom
pad_between_seqs support for non-CP and CP (A2A and P2P) with FA3 + THD (varlen)#2596Conversation
Greptile SummaryThis PR adds
Confidence Score: 4/5The core FA3 + THD + pad_between_seqs logic is correct for all intended code paths. The main residual risk is a crash when FA3 is chosen for bshd+padding+pad_between_seqs=True because cu_seqlens_q_padded is None in that branch (already flagged in earlier review comments). The previously identified issues with FA4 not being disabled and wrong cu_seqlens being passed to the FA3 memory-layout arg have both been fixed in this PR. The A2A and P2P CP paths correctly guard seqused derivation with qkv_format == "thd" checks, and the backward gradient init with zeros_like is correct. The one unresolved concern is in the non-CP FA3 path: when pad_between_seqs=True is combined with bshd+padding mask and FA3, cu_seqlens_q_padded is None, which would be passed as FA3's cu_seqlens arg — a confirmed crash. This scenario is already discussed in earlier review comments and no current caller triggers it, but the code lacks a defensive guard. transformer_engine/pytorch/attention/dot_product_attention/backends.py (non-CP FA3 seqused block lacks qkv_format guard); tests/pytorch/attention/test_attention_with_cp.py (new batch dispatch architecture) Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A["FlashAttention.forward\n(backends.py)"] --> B{context_parallel?}
B -- Yes --> C["attn_forward_func_with_cp\npasses cu_seqlens_q_padded\n+ pad_between_seqs"]
B -- No --> D{use_flash_attn_3?}
D -- Yes --> E["fa_optional_forward_args_thd:\ncu_seqlens_q_padded (layout)\n+\nseqused_q = cu_seqlens_q diff\n(actual counts)"]
D -- No/FA2 --> F["fa_optional_forward_args_thd:\ncu_seqlens_q (no seqused)"]
C --> G{cp_comm_type}
G -- p2p --> H["cp_p2p_fwd_flash_attn\nDerive seqused from cu_seqlens_per_step\nOverride cu_seqlens to padded\n(with section half-padding)"]
G -- a2a --> I["AttnFuncWithCPAndQKVOA2A\nDerive seqused from cu_seqlens\nOverride cu_seqlens to padded\nGuard: qkv_format==thd"]
H --> J["get_fa_args(seqused_q, seqused_k)\nflash_attn_varlen_func_v3"]
I --> J
E --> J
J --> K["FA3 kernel\nUses padded layout for memory\nUses seqused to skip padding tokens"]
K --> L["Backward:\nzeros_like for dq/dk/dv\nFA3 only writes to valid positions"]
Reviews (38): Last reviewed commit: "Fix parallel CP test port conflicts and ..." | Re-trigger Greptile |
| # if `pad_between_seqs` is True, provide flash_attn_3 with `seqused_q` and `seqused_k` | ||
| # in addition to `cu_seqlens_q_padded` and `cu_seqlens_kv_padded` to avoid affecting the | ||
| # padding positions. | ||
| if pad_between_seqs: | ||
| fa_3_optional_forward_kwargs["seqused_q"] = ( | ||
| cu_seqlens_q[1:] - cu_seqlens_q[:-1] | ||
| ) | ||
| fa_3_optional_forward_kwargs["seqused_k"] = ( | ||
| cu_seqlens_kv[1:] - cu_seqlens_kv[:-1] | ||
| ) |
There was a problem hiding this comment.
style: verify that flash_attn_3 with seqused_q/seqused_k truly avoids writing to padding positions - the related issue #2391 mentions "we need to manually set the output of the padded positions to zero" (similar to how FusedAttention zeroes output in C++ for THD format). if flash_attn_3 doesn't zero these internally, output may have garbage values in padded positions. have you verified that flash_attn_3 correctly handles padding internally with these parameters?
ea51821 to
e338049
Compare
|
/te-ci pytorch L2 |
|
|
||
| pad_between_seqs = False | ||
| if qkv_format == "thd" and cu_seqlens_q_padded is not None: | ||
| pad_between_seqs = not torch.equal(cu_seqlens_q_padded, cu_seqlens_q) |
There was a problem hiding this comment.
Can pad_between_seqs be decided ahead of time, passed by the user or something? This wouldn't be CUDA Graph-compatible right?
There was a problem hiding this comment.
This pattern exists in dpa.py as well. But yes, it's definitely redundant here
|
/te-ci pytorch L1 |
|
/te-ci pytorch L3 |
b0a3c64 to
057f406
Compare
|
/te-ci pytorch L3 |
1 similar comment
|
/te-ci pytorch L3 |
00bdc92 to
0f48ebc
Compare
| if not FlashAttentionUtils.v3_is_installed: | ||
| pytest.skip("pad_between_seqs with CP requires Flash Attention v3!") | ||
| if cp_comm_type == "a2a+p2p": | ||
| pytest.skip("pad_between_seqs is not yet supported with A2A+P2P CP comm type!") |
| if pad_between_seqs: | ||
| dq, dk, dv = [torch.zeros_like(x) for x in [q_part, k_part, v_part]] | ||
| else: | ||
| dq, dk, dv = [torch.empty_like(x) for x in [q_part, k_part, v_part]] |
There was a problem hiding this comment.
Just to confirm, we can't do this for fwd, right? Because fwd output is not allocated by us.
There was a problem hiding this comment.
It's a limitation in Flash Attention code - forward never mutates out (so pre-zeroing is overwritten), backward treats dq/dk/dv as in-place mutable (so pre-zeroing sticks). Also this zeroing out works only for CP code where we can provide the args.
None of the zeroing works for non-CP path because we only have the forward call in TE.
FA3 / Hopper (hopper/flash_attn_interface.py)
- Forward: mutates_args=() _ namespace flash_attn_3::_flash_attn_forward
- Backward: mutates_args=("dq", "dk", "dv") _ namespace flash_attn_3::_flash_attn_backward
|
/te-ci pytorch L3 |
Add support for padding between sequences (pad_between_seqs) in the FlashAttention 3 backend when used with context parallelism (CP). Key changes: - backends.py: Pass fa_pad_between_seqs through to FA3 forward/backward - context_parallel.py: Handle pad_between_seqs in A2A and P2P CP paths, zero FA3 padding garbage in CP forward, fix a2a backward alignment - dot_product_attention.py: Auto-detect pad_between_seqs from cu_seqlens - utils.py: Gate FA3 deterministic backward for hdim>=256, fix flash_attn_supported override for cross-attention and large head_dim, disable UnfusedDotProductAttention for pad_between_seqs, add SM100+ FA3 skip Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
Add test parametrization for pad_between_seqs in flash attention tests. Update run_attention_with_cp.py to support the new parameter and fix batch boundary alignment in the non-CP FA3 path. Run tests in parallel when multiple GPUs are available. Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
Add deterministic CP test runs to L3 FA versions test. Support TE_PATH positional arg and fix GPU threshold for parallel test execution. Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
…raint The previous check disabled FA3 for deterministic mode whenever head_dim_qk > 128, which was overly conservative — FA3 forward supports deterministic execution at any head dim. The actual constraint from flash_api.cpp is that the backward pass does not support deterministic mode when max(head_size, head_size_v) >= 256. Narrow the gate to only disable FA3 during training (backward) and raise the threshold to >= 256, checking both head_dim_qk and head_dim_v to handle MLA configs with asymmetric head dimensions. Ref: https://github.com/Dao-AILab/flash-attention/blob/ac6f2eb5/hopper/flash_api.cpp#L1370 Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
9c01601 to
4745f98
Compare
The pad_between_seqs gate in get_attention_backend only disabled FlashAttention 2, letting FA4 leak through to the test-time fused-vs-flash comparison. On B200 runners that install flash-attn-4, this caused test_dpa_qkv_layout_thd to compare FusedAttention against an FA4 output whose padded positions contain garbage, producing 48 numerics failures in L3_pytorch_FA_versions_test--B200_1GPU. The log message already claimed FA4 would be disabled — this change makes the code match the message: set use_flash_attention_4 = False alongside use_flash_attention_2 when pad_between_seqs is True. FA3 continues to support pad_between_seqs via seqused_q/seqused_k. Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
pad_between_seqs support for non-CP and CP (A2A and P2P) with FA3 + THD (varlen)
…_attn_pad_bw_seqs
|
/te-ci pytorch L3 |
FA4 install brings in nvidia-cutlass-dsl, whose `import cutlass` adds cutlass/base_dsl/ to sys.path. That directory contains a utils/ package that shadows tests/pytorch/utils.py, breaking collection of test_attention_with_cp.py with: ImportError: cannot import name 'ModelConfig' from 'utils' Prepend $TE_PATH/tests/pytorch to PYTHONPATH so the local utils.py is always resolved first, regardless of what FA4 dependencies install. Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
…_attn_pad_bw_seqs
…s its a known cudnn issue Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
…_attn_pad_bw_seqs
for more information, see https://pre-commit.ci
…ransformerEngine into flash_attn_pad_bw_seqs
|
/te-ci pytorch L3 |
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
|
/te-ci pytorch L3 |
|
/te-ci pytorch L3 |
…_attn_pad_bw_seqs
…ransformerEngine into flash_attn_pad_bw_seqs
|
/te-ci pytorch L3 |
PR 2596 added deterministic CP runs to the L3 FA-versions matrix, multiplying
CP wall time across every FA version and causing CI timeouts (pipeline
50243000). Run CP tests once per arch instead, picking the FA version each
arch's CP code path actually supports:
- sm90 (H100): FA3 3.0.0b1 - context_parallel.py is FA3-only on Hopper
(use_flash_attn_3 threaded throughout, FA4
not wired in; pad_between_seqs gated on
use_flash_attn_3 at lines 1038, 1366)
- sm>90 (B200): latest FA4 - FA3 is not built/installed for sm>90
Non-CP test_attention.py still runs for every FA version in the array.
Also drop FA 2.7.3 from the sm90 list (no longer maintained as a target)
and bump the FA4 pin from 4.0.0b8 to 4.0.0b11. b8 has an SM90 backward
kernel bug fixed by upstream PR NVIDIA#2513 in b11
(get_smem_store_C() got multiple values for argument 'transpose').
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
Three follow-ups on top of 13ba004 (L3 per-arch CP gating): 1. Skip the inline FA3 source build when flash_attn_interface is already importable. This makes the script a no-op on FA3 install when the base image has FA3 baked in (companion to TE !573 on te_ci, which auto-sets INSTALL_FA3=${RUN_L3_TESTS} so FA3 is preinstalled for L3 pipelines). Saves ~20 min of L3 H100 wall time once both land. Falls back to the existing inline build when FA3 is not pre-installed. 2. Suffix junit XMLs with the FA version (pytest_test_attention_fa2_8_3.xml etc.) so per-iteration results are preserved instead of overwritten. Pipeline 50348672 had no per-FA timing visibility because pytest.xml was clobbered by each loop iteration. 3. Include FA version in test_fail messages so CI dashboards show which FA iteration caused a failure (was "test_attention.py", now "test_attention.py (FA 2.8.3)"). Also fold the CP_FA_VERSION assignment into the same if-block as FA_versions (was a separate if-block immediately after) since the two are arch-keyed in lockstep. Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
…_attn_pad_bw_seqs
|
/te-ci pytorch L3 |
…_attn_pad_bw_seqs
Port CP test batching from sudhakars/cp_test_batching_pr (PR NVIDIA#2965). Groups parametrized configs into batches of CP_TEST_BATCH_SIZE (default 16) and runs each batch in a single torchrun invocation, amortizing the ~9s NCCL init overhead across configs instead of paying it per test. This is a temporary commit to validate batching under CI on the flash_attn_pad_bw_seqs branch — intended to be reverted after the run. Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
|
/te-ci pytorch L3 |
Two fixes on top of the batching port (ca8b383): 1. test.sh: assign distinct MASTER_PORT (29500 / 29501) to the two parallel pytest sessions so their torchrun batches don't collide. Without this, both sessions inherit the same MASTER_PORT and the second one fails with EADDRINUSE on every batch. 2. Restore the deterministic THD OOM skip that the batching PR dropped when it flattened the `if deterministic:` block. Without it, 5 fused-attention THD configs OOM on sm90 under NVTE_ALLOW_NONDETERMINISTIC_ALGO=0. Validated: 8×H100, parallel non-det (38 passed) + det (26 passed, 5 THD OOM correctly skipped), zero EADDRINUSE. Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
|
/te-ci pytorch L3 |
Description
TLDR
Enable
pad_between_seqs=Truefor FlashAttention 3 with THD format — both for context parallelism (A2A and P2P comm types) and non-CP paths. Previouslypad_between_seqswas only supported with FusedAttention.Problem
When using THD format with variable-length sequences, sequences are padded for divisibility across CP ranks. With
pad_between_seqs=True, the attention kernel needs to know actual (unpadded) token counts so it doesn't compute attention over padding tokens. FusedAttention already handled this viacu_seqlens_q_padded, but FlashAttention (both FA2 and FA3) hadpad_between_seqshardcoded toFalsein the CP path, and FA2 was entirely disabled forpad_between_seqs + thd. FA3 can natively handle this via itsseqused_q/seqused_kmechanism.Solution
Use FA3's
seqused_q/seqused_ktensors to communicate actual token counts per batch element. Passcu_seqlens_q_paddedfor tensor memory layout while derivingseqused_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1]from the realcu_seqlens. This applies to both the CP path (A2A and P2P) and the non-CP path.Fixes #2399
Type of change
Changes
Please list the changes introduced in this PR:
context_parallel.py
get_fa_args(): Addseqused_q/seqused_kparameters, pass through to FA3 forward and backward positional arg lists (replacing hardcodedNones).cp_p2p_fwd_flash_attn()/cp_p2p_bwd_flash_attn(): Acceptpad_between_seqs,cu_seqlens_q_padded,cu_seqlens_kv_padded. When enabled, derivesequsedtensors and overridecu_seqlensto padded versions (with half-padding for lower-triangle/upper-triangle sections).AttnFuncWithCPAndKVP2P: Threadpad_between_seqsand padded cu_seqlens through all forward/backwardcp_p2p_fwd/bwd_flash_attncall sites. Savectx.pad_between_seqsfor backward.AttnFuncWithCPAndQKVOA2A.forward(): Addpad_between_seqsparameter. When enabled with FA3+THD, derivesequsedand swapcu_seqlensfor padded versions before callingget_fa_args().AttnFuncWithCPAndQKVOA2A.backward(): Same seqused/cu_seqlens override. Usezeros_like(notempty_like) for gradient init whenpad_between_seqssince FA3 skips padding positions. Add extraNonein return tuple for the newpad_between_seqsgradient slot.attn_forward_func_with_cp(): Passpad_between_seqsin A2A args list.backends.py
FlashAttention.forward(): Acceptcu_seqlens_q_padded/cu_seqlens_kv_padded. Detectpad_between_seqsby comparing padded vs actual cu_seqlens. Pass padded cu_seqlens to CP path. For non-CP FA3 path, derive and passseqused_q/seqused_k.dot_product_attention.py
cu_seqlens_q_padded/cu_seqlens_kv_paddedthrough toFlashAttention.utils.py
pad_between_seqs + thd. FA3 handles this natively viaseqused.test_attention_with_cp.py
@pytest.mark.parametrize("pad_between_seqs", [False, True])to flash attention CP tests.pad_between_seqs=Truefor non-THD formats, when FA3 is not installed, and fora2a+p2pcomm type (not yet supported).run_attention_with_cp.py
pad_between_seqsthroughgenerate_input_shapes()andrun_dpa_with_cp().pad_between_seqs, setcu_seqlens_qto actual lengths (not just for FusedAttention).nan_to_num(nan=0.0).test_attention.py
_run_dot_product_attention()(previously FlashAttention used original unpadded inputs).cu_seqlens_q_padded/cu_seqlens_kv_paddedandpad_between_seqsto DPA call for FlashAttention backend.pad_between_seqs=Trueto parametrize with skip for non-THD formats.New Tests
CP tests (
test_attention_with_cp.py)Added
@pytest.mark.parametrize("pad_between_seqs", [False, True])totest_cp_with_flash_attention. Skip conditions: non-THD formats, FA3 not installed,a2a+p2pcomm type.5 new tests that run (all
pad_between_seqs=True, thd, bf16):True-p2p-thd-cp_1_0-bf16True-p2p-thd-cp_2_1-bf16True-a2a-thd-cp_1_0-bf16True-a2a-thd-cp_1_2-bf16True-a2a-thd-cp_2_1-bf16Non-CP tests (
test_attention.py)Added
Trueto@pytest.mark.parametrize("pad_between_seqs", [False, True])ontest_dot_product_attention, with skip for non-THD. Also changed_run_dot_product_attentionso FlashAttention uses padded inputs/cu_seqlens and receivespad_between_seqs=True.48 new test IDs collected, but all are skipped because the main parametrize uses
qkv_layout=None(defaults to sbhd, not thd). The non-CPpad_between_seqs+ FA3 code path is exercised indirectly when other test functions calltest_dot_product_attentionwithqkv_layout="thd_thd_thd"(e.g.,test_dpa_softmax_thd).Checklist: