Skip to content

Commit ea51821

Browse files
Merge branch 'flash_attn_pad_bw_seqs' of github.com:sudhakarsingh27/TransformerEngine into flash_attn_pad_bw_seqs
2 parents c34f6a8 + 791bca7 commit ea51821

2 files changed

Lines changed: 15 additions & 5 deletions

File tree

tests/pytorch/attention/test_attention.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1176,8 +1176,12 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker:
11761176
max_seqlen_kv=config.max_seqlen_kv,
11771177
cu_seqlens_q=cu_seqlens_q,
11781178
cu_seqlens_kv=cu_seqlens_kv,
1179-
cu_seqlens_q_padded=cu_seqlens_q_after_pad if backend in ["FusedAttention", "FlashAttention"] else None,
1180-
cu_seqlens_kv_padded=cu_seqlens_kv_after_pad if backend in ["FusedAttention", "FlashAttention"] else None,
1179+
cu_seqlens_q_padded=(
1180+
cu_seqlens_q_after_pad if backend in ["FusedAttention", "FlashAttention"] else None
1181+
),
1182+
cu_seqlens_kv_padded=(
1183+
cu_seqlens_kv_after_pad if backend in ["FusedAttention", "FlashAttention"] else None
1184+
),
11811185
attn_mask_type=config.attn_mask_type,
11821186
checkpoint_core_attention=ckpt_attn,
11831187
core_attention_bias_type=config.attn_bias_type,

transformer_engine/pytorch/attention/dot_product_attention/backends.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -931,7 +931,9 @@ def forward(
931931
cu_seqlens_q_padded if cu_seqlens_q_padded is not None else cu_seqlens_q
932932
)
933933
fa_optional_forward_args_thd.append(
934-
cu_seqlens_kv_padded if cu_seqlens_kv_padded is not None else cu_seqlens_kv
934+
cu_seqlens_kv_padded
935+
if cu_seqlens_kv_padded is not None
936+
else cu_seqlens_kv
935937
)
936938
fa_optional_forward_args_thd.append(max_seqlen_q)
937939
fa_optional_forward_args_thd.append(max_seqlen_kv)
@@ -973,8 +975,12 @@ def forward(
973975
# in addition to `cu_seqlens_q_padded` and `cu_seqlens_kv_padded` to avoid affecting the
974976
# padding positions.
975977
if pad_between_seqs:
976-
fa_3_optional_forward_kwargs["seqused_q"] = cu_seqlens_q[1:] - cu_seqlens_q[:-1]
977-
fa_3_optional_forward_kwargs["seqused_k"] = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1]
978+
fa_3_optional_forward_kwargs["seqused_q"] = (
979+
cu_seqlens_q[1:] - cu_seqlens_q[:-1]
980+
)
981+
fa_3_optional_forward_kwargs["seqused_k"] = (
982+
cu_seqlens_kv[1:] - cu_seqlens_kv[:-1]
983+
)
978984
else:
979985
fa_3_optional_forward_kwargs["cu_seqlens_q"] = cu_seqlens_q
980986
fa_3_optional_forward_kwargs["max_seqlen_q"] = max_seqlen_q

0 commit comments

Comments
 (0)