File tree Expand file tree Collapse file tree
transformer_engine/pytorch/attention/dot_product_attention Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -1176,8 +1176,12 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker:
11761176 max_seqlen_kv = config .max_seqlen_kv ,
11771177 cu_seqlens_q = cu_seqlens_q ,
11781178 cu_seqlens_kv = cu_seqlens_kv ,
1179- cu_seqlens_q_padded = cu_seqlens_q_after_pad if backend in ["FusedAttention" , "FlashAttention" ] else None ,
1180- cu_seqlens_kv_padded = cu_seqlens_kv_after_pad if backend in ["FusedAttention" , "FlashAttention" ] else None ,
1179+ cu_seqlens_q_padded = (
1180+ cu_seqlens_q_after_pad if backend in ["FusedAttention" , "FlashAttention" ] else None
1181+ ),
1182+ cu_seqlens_kv_padded = (
1183+ cu_seqlens_kv_after_pad if backend in ["FusedAttention" , "FlashAttention" ] else None
1184+ ),
11811185 attn_mask_type = config .attn_mask_type ,
11821186 checkpoint_core_attention = ckpt_attn ,
11831187 core_attention_bias_type = config .attn_bias_type ,
Original file line number Diff line number Diff line change @@ -931,7 +931,9 @@ def forward(
931931 cu_seqlens_q_padded if cu_seqlens_q_padded is not None else cu_seqlens_q
932932 )
933933 fa_optional_forward_args_thd .append (
934- cu_seqlens_kv_padded if cu_seqlens_kv_padded is not None else cu_seqlens_kv
934+ cu_seqlens_kv_padded
935+ if cu_seqlens_kv_padded is not None
936+ else cu_seqlens_kv
935937 )
936938 fa_optional_forward_args_thd .append (max_seqlen_q )
937939 fa_optional_forward_args_thd .append (max_seqlen_kv )
@@ -973,8 +975,12 @@ def forward(
973975 # in addition to `cu_seqlens_q_padded` and `cu_seqlens_kv_padded` to avoid affecting the
974976 # padding positions.
975977 if pad_between_seqs :
976- fa_3_optional_forward_kwargs ["seqused_q" ] = cu_seqlens_q [1 :] - cu_seqlens_q [:- 1 ]
977- fa_3_optional_forward_kwargs ["seqused_k" ] = cu_seqlens_kv [1 :] - cu_seqlens_kv [:- 1 ]
978+ fa_3_optional_forward_kwargs ["seqused_q" ] = (
979+ cu_seqlens_q [1 :] - cu_seqlens_q [:- 1 ]
980+ )
981+ fa_3_optional_forward_kwargs ["seqused_k" ] = (
982+ cu_seqlens_kv [1 :] - cu_seqlens_kv [:- 1 ]
983+ )
978984 else :
979985 fa_3_optional_forward_kwargs ["cu_seqlens_q" ] = cu_seqlens_q
980986 fa_3_optional_forward_kwargs ["max_seqlen_q" ] = max_seqlen_q
You can’t perform that action at this time.
0 commit comments