-
Notifications
You must be signed in to change notification settings - Fork 706
[PyTorch] Add pad_between_seqs support for A2A and P2P CP with FA3 + THD #2596
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
1ececdc
e338049
fb27e0c
50839e1
5c10658
66e3352
d8abce2
41e431a
9efa48f
c8a84bf
8652dba
232b78d
73f989c
0228d08
7b8bc13
355252c
e02ab58
d596db0
417b318
0530153
bba807a
81697e1
6104ede
c50975b
872a2ad
fd24692
2f3528a
3192ce2
6035007
fbe58ab
ff9bf4e
b0cdb4b
3d71c35
dc6ccd5
5fceae2
7fa790d
2ccc8ef
285c1eb
e9dccd8
f278ce4
18c802a
0f48ebc
1c9325c
8ddce7a
d397380
ae8884b
0177653
10f736c
d3e310a
c514a97
5120631
2da11dc
9652abb
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -124,7 +124,7 @@ def reset_global_fp8_state(): | |
| @pytest.mark.parametrize("workspace_opt", [True, False]) | ||
| @pytest.mark.parametrize("qkv_layout", [None]) | ||
| @pytest.mark.parametrize("swa", [False]) | ||
| @pytest.mark.parametrize("pad_between_seqs", [False]) | ||
| @pytest.mark.parametrize("pad_between_seqs", [False, True]) | ||
| def test_dot_product_attention( | ||
| dtype, | ||
| model_configs, | ||
|
|
@@ -157,6 +157,8 @@ def test_dot_product_attention( | |
|
|
||
| config.window_size = check_set_window_size(config.attn_mask_type, config.window_size) | ||
| qkv_format = qkv_layout.replace("3", "").replace("2", "").split("_")[0] | ||
| if pad_between_seqs and qkv_format != "thd": | ||
| pytest.skip("pad_between_seqs only applies to THD format!") | ||
| if qkv_format == "thd" and "padding" not in config.attn_mask_type: | ||
| config.attn_mask_type = ( | ||
| "padding_" + config.attn_mask_type if config.attn_mask_type != "no_mask" else "padding" | ||
|
|
@@ -195,18 +197,18 @@ def test_dot_product_attention( | |
| ) | ||
| flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends | ||
|
|
||
| # FlashAttention does not support pad_between_seqs, but _run_dot_product_attention | ||
| # mannually pads and unpads the input and output of FlashAttention for testing purposes | ||
| if ( | ||
| pad_between_seqs | ||
| and FlashAttentionUtils.is_installed | ||
| and not ( | ||
| # FA3 natively supports pad_between_seqs via seqused_q/seqused_k (SM90 only). | ||
| # Override flash_attn_supported only for pad_between_seqs=True because | ||
| # get_available_attention_backends doesn't know about FA3's seqused support yet. | ||
| # For pad_between_seqs=False, trust the backend checker's result as-is. | ||
| if pad_between_seqs: | ||
| cross_attn_causal = ( | ||
| config.max_seqlen_q != config.max_seqlen_kv | ||
| and config.attn_mask_type in ["causal", "padding_causal"] | ||
| ) | ||
| and (config.window_size[0] == -1 or FlashAttentionUtils.v2_3_plus) | ||
| ): | ||
| flash_attn_supported = True | ||
| sm = get_device_compute_capability() | ||
| if not cross_attn_causal and FlashAttentionUtils.v3_is_installed and sm == (9, 0): | ||
| flash_attn_supported = True | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Should the
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree and this is one of the instances where the flags to switch FA on/off are scattered across utils.py and actual tests. Should we move the entire block (for |
||
|
|
||
| # Skip if only unfused backend is supported | ||
| if (len(fused_attn_backends) + flash_attn_supported + unfused_attn_supported) < 2: | ||
|
|
@@ -1330,12 +1332,12 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: | |
| block.softmax_offset.requires_grad = True | ||
|
|
||
| # Run a forward and backward pass | ||
| if backend in ["FlashAttention", "UnfusedDotProductAttention"]: | ||
| if backend in ["UnfusedDotProductAttention"]: | ||
| q = inp_orig[0] | ||
| k = inp_orig[1] | ||
| v = inp_orig[2] | ||
| d_out = out_grad_orig | ||
| if backend == "FusedAttention": | ||
| if backend in ["FusedAttention", "FlashAttention"]: | ||
| q = inp[0] | ||
| k = inp[1] | ||
| v = inp[2] | ||
|
|
@@ -1351,14 +1353,19 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: | |
| max_seqlen_kv=config.max_seqlen_kv, | ||
| cu_seqlens_q=cu_seqlens_q, | ||
| cu_seqlens_kv=cu_seqlens_kv, | ||
| cu_seqlens_q_padded=cu_seqlens_q_after_pad if backend == "FusedAttention" else None, | ||
| cu_seqlens_kv_padded=cu_seqlens_kv_after_pad if backend == "FusedAttention" else None, | ||
| cu_seqlens_q_padded=( | ||
| cu_seqlens_q_after_pad if backend in ["FusedAttention", "FlashAttention"] else None | ||
| ), | ||
| cu_seqlens_kv_padded=( | ||
| cu_seqlens_kv_after_pad if backend in ["FusedAttention", "FlashAttention"] else None | ||
| ), | ||
| attn_mask_type=config.attn_mask_type, | ||
| checkpoint_core_attention=ckpt_attn, | ||
| core_attention_bias_type=config.attn_bias_type, | ||
| core_attention_bias=bias, | ||
| alibi_slopes=alibi_slopes, | ||
| fast_zero_fill=True, | ||
| pad_between_seqs=pad_between_seqs, | ||
| # Only pass num_splits when exercising the FlashAttention path | ||
| num_splits=config.num_splits if backend == "FlashAttention" else 1, | ||
| ) | ||
|
|
@@ -1372,12 +1379,12 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: | |
| if is_training and config.softmax_type != "vanilla": | ||
| d_softmax_offset = block.softmax_offset.grad | ||
|
|
||
| if backend in ["FlashAttention", "UnfusedDotProductAttention"]: | ||
| if backend in ["UnfusedDotProductAttention"]: | ||
| if is_training: | ||
| return out, max_logit, (q.grad, k.grad, v.grad, d_softmax_offset) | ||
| else: | ||
| return out, max_logit, (None, None, None, d_softmax_offset) | ||
| if backend == "FusedAttention": | ||
| if backend in ["FusedAttention", "FlashAttention"]: | ||
| if qkv_format == "thd" and pad_between_seqs: | ||
| out_orig = torch.Tensor([]).to(device="cuda", dtype=dtype) | ||
| if is_training: | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -85,11 +85,20 @@ def get_bash_arguments(num_gpus_per_node, **kwargs): | |
| @pytest.mark.parametrize("model", model_configs_flash_attn.keys()) | ||
| @pytest.mark.parametrize("qkv_format", qkv_formats) | ||
| @pytest.mark.parametrize("cp_comm_type", cp_comm_types) | ||
| def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): | ||
| @pytest.mark.parametrize("pad_between_seqs", [False, True]) | ||
| def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type, pad_between_seqs): | ||
| num_gpus = 4 if cp_comm_type == "a2a+p2p" else 2 | ||
| if num_gpus > torch.cuda.device_count(): | ||
| pytest.skip(f"Test requires {num_gpus} GPUs, but found {torch.cuda.device_count()}") | ||
|
|
||
| if pad_between_seqs: | ||
| if qkv_format != "thd": | ||
| pytest.skip("pad_between_seqs only applies to THD format!") | ||
| if not FlashAttentionUtils.v3_is_installed: | ||
| pytest.skip("pad_between_seqs with CP requires Flash Attention v3!") | ||
| if cp_comm_type == "a2a+p2p": | ||
| pytest.skip("pad_between_seqs is not yet supported with A2A+P2P CP comm type!") | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What about AG? |
||
|
|
||
| config = model_configs_flash_attn[model] | ||
| config.context_parallel = True | ||
| config.cp_comm_type = cp_comm_type | ||
|
|
@@ -133,6 +142,7 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): | |
| qkv_format=qkv_format, | ||
| kernel_backend="FlashAttention", | ||
| cp_comm_type=cp_comm_type, | ||
| pad_between_seqs=pad_between_seqs, | ||
| log_level=pytest_logging_level, | ||
| ), | ||
| ) | ||
|
|
@@ -364,9 +374,20 @@ def test_cp_with_fused_attention( | |
| is_training=is_training, | ||
| ) | ||
| _, fused_attn_supported, _ = available_backends | ||
|
|
||
| # Skip any tests if not supported by the configs | ||
| if not fused_attn_supported: | ||
| pytest.skip("No attention backend available.") | ||
|
|
||
| deterministic = not bool(int(os.getenv("NVTE_ALLOW_NONDETERMINISTIC_ALGO", "1"))) | ||
| if deterministic: | ||
| if config.softmax_type != "vanilla": | ||
| pytest.skip( | ||
| "Deterministic mode does not support non-vanilla softmax with FusedAttention" | ||
| ) | ||
| if config.attn_bias_type == "post_scale_bias" and is_training: | ||
| pytest.skip("Deterministic mode does not support post_scale_bias with requires_grad") | ||
|
|
||
| run_distributed( | ||
| get_bash_arguments( | ||
| num_gpus_per_node=num_gpus, | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The a2a+p2p tests need 4 GPUs right? so you might need to budget a bit more than
"$NUM_GPUS" -ge 3?