Skip to content

[SP] add SP deny list instead of allow#7887

Open
kashif wants to merge 12 commits intodeepspeedai:masterfrom
kashif:sp_attn_deny
Open

[SP] add SP deny list instead of allow#7887
kashif wants to merge 12 commits intodeepspeedai:masterfrom
kashif:sp_attn_deny

Conversation

@kashif
Copy link
Contributor

@kashif kashif commented Mar 5, 2026

this way one can register kernels based flash-attn as well with SP

Signed-off-by: Kashif Rasul <kashif.rasul@gmail.com>
Copy link
Collaborator

@tohtana tohtana left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @kashif,

Thank you for opening this PR! I think supporting HF hub kernels is is a significant update.

Regarding the approach, we check if core_attn_implementation is in ALL_ATTENTION_FUNCTIONS but HF hub kernels like kernels-community/flash-attn2 is not in the list. So HF hub kernels won’t still be available with this fix.

We probably need to do the proper registration steps:

  1. Reject known-bad impls explicitly: eager, flex_attention, and probably paged|eager.
  2. If core_attn_implementation is an HF hub kernel string, call the HF registration path first. (Using lazy_import_flash_attention(…))
  3. Then read core_attn_function = ALL_ATTENTION_FUNCTIONS[core_attn_implementation].
  4. Build uattn from that original function.
  5. Replace that key with uattn_wrapper.

Does it make sense to you?

@kashif kashif requested a review from loadams as a code owner March 8, 2026 09:41
Signed-off-by: Kashif Rasul <kashif.rasul@gmail.com>
@kashif
Copy link
Contributor Author

kashif commented Mar 8, 2026

thanks @tohtana I have tried to fix all the issues raised, if you can kindly check again?

@stas00
Copy link
Collaborator

stas00 commented Mar 9, 2026

Reject known-bad impls explicitly: eager, flex_attention, and probably paged|eager.

We actually don't know if flex_attention is bad, we just haven't tried it out. Do you have resources to try it out, Kashif? Same for the others on the list.

That's why we started with approve list, rather than deny.

The only reason eager is denied is that it requires 4D attention_mask which is a bad idea for long sequence.

BTW, SDPA is silently broken with packed samples - when there is no attn mask, it ignores pos ids and attends to the whole sequence instead. Expect bad results. Not sure how to flag that to users - probably need to inspect pos ids and see if they reset at least once and disallow sdpa then.

@tohtana
Copy link
Collaborator

tohtana commented Mar 10, 2026

Hi @kashif,
Thank you for addressing my comments! It looks good to me.

I also think Stas's comment makes sense. Can you try implementing such a validation?
You can refer to transformers' find_packed_sequence_indices.

@kashif
Copy link
Contributor Author

kashif commented Mar 10, 2026

sure @tohtana i can check

@stas00
Copy link
Collaborator

stas00 commented Mar 12, 2026

to make things more exact - it's packed samples + pos ids + 4D attention_mask=None where sdpa silently does the wrong thing. I haven't validated but it most likely will do the right thing with 4D attention mask being not None- but it can't be used with SP because it becomes too large too quickly.

@stas00
Copy link
Collaborator

stas00 commented Mar 12, 2026

oh, Kashif, I'm being told eager has the exact same problem as sdpa - so both need to be fixed on the transformers side. Thank you very much!

@kashif
Copy link
Contributor Author

kashif commented Mar 14, 2026

I ran some experiments comparing flash_attention_2, sdpa, and flex_attention with SP=4 on Qwen3-4B (GQA: 32 Q
heads, 8 KV heads), 8K seq length, 10 steps.

Without SP (1 GPU baseline): flash_attention_2 and sdpa produce identical losses — confirming the backends are
equivalent in the standard path.

  ┌──────┬───────┬───────┐
  │ Step │  fa2  │ sdpa  │
  ├──────┼───────┼───────┤
  │ 1    │ 0.736 │ 0.737 │
  ├──────┼───────┼───────┤
  │ 2    │ 1.841 │ 1.843 │
  ├──────┼───────┼───────┤
  │ 3    │ 0.806 │ 0.807 │
  └──────┴───────┴───────┘

With SP=4 (4 GPUs): sdpa and flex_attention match each other, but both diverge significantly from
flash_attention_2:

  ┌──────┬──────┬──────┬──────┐
  │ Step │ fa2  │ sdpa │ flex │
  ├──────┼──────┼──────┼──────┤
  │ 1    │ 2.37 │ 4.55 │ 4.55 │
  ├──────┼──────┼──────┼──────┤
  │ 5    │ 2.28 │ 3.52 │ 3.52 │
  ├──────┼──────┼──────┼──────┤
  │ 10   │ 2.29 │ 3.02 │ 3.02 │
  └──────┴──────┴──────┴──────┘

@stas00 any ideas on what flash_attention_2 might be doing differently after the all-to-all that
sdpa/flex_attention aren't? The Q/K/V shapes and attention_mask=None + is_causal=True path should be equivalent,
but something in the SP gather/scatter is exposing a difference.

Signed-off-by: Kashif Rasul <kashif.rasul@gmail.com>
@kashif
Copy link
Contributor Author

kashif commented Mar 14, 2026

ok @stas00 I now enerate position_ids if missing from batch, build causal BlockMask for flex_attention and do a one-time packed sample validation for packed samples + sdpa/eager

Now the outputs are matching:

  ┌──────┬───────────────────┬───────┬────────────────┐
  │ Step │ flash_attention_2 │ sdpa  │ flex_attention │
  ├──────┼───────────────────┼───────┼────────────────┤
  │ 1    │ 2.152             │ 2.152 │ 2.150          │
  ├──────┼───────────────────┼───────┼────────────────┤
  │ 2    │ 2.469             │ 2.468 │ 2.468          │
  ├──────┼───────────────────┼───────┼────────────────┤
  │ 3    │ 2.045             │ 2.044 │ 2.045          │
  ├──────┼───────────────────┼───────┼────────────────┤
  │ 4    │ 2.197             │ 2.197 │ 2.197          │
  ├──────┼───────────────────┼───────┼────────────────┤
  │ 5    │ 2.113             │ 2.112 │ 2.112          │
  ├──────┼───────────────────┼───────┼────────────────┤
  │ 6    │ 2.173             │ 2.173 │ 2.172          │
  ├──────┼───────────────────┼───────┼────────────────┤
  │ 7    │ 2.351             │ 2.350 │ 2.351          │
  ├──────┼───────────────────┼───────┼────────────────┤
  │ 8    │ 2.380             │ 2.380 │ 2.379          │
  ├──────┼───────────────────┼───────┼────────────────┤
  │ 9    │ 1.847             │ 1.847 │ 1.847          │
  ├──────┼───────────────────┼───────┼────────────────┤
  │ 10   │ 2.151             │ 2.151 │ 2.151          │
  └──────┴───────────────────┴───────┴────────────────┘

kashif added 2 commits March 14, 2026 17:02
Signed-off-by: Kashif Rasul <kashif.rasul@gmail.com>
Signed-off-by: Kashif Rasul <kashif.rasul@gmail.com>
@kashif kashif requested a review from tohtana March 14, 2026 17:22
@stas00
Copy link
Collaborator

stas00 commented Mar 15, 2026

Thank you for running those quality comparison experiments, Kashif

I'm a bit unclear about your last "success" comment - what was missing to make FA2 match? are you saying the mismatch was from missing position_ids? but we said that already that SDPA (and now most likely FlexAttenion) have a trouble with no-attn-mask / yes-pos-id and will ignore packed samples. SDPA on the other hand does the right thing here.

And it's great to hear Flex Attention works as well with Ulysses, so we could add it to the allow list.

Comment on lines +295 to +296
if has_packed_samples and self.core_attn_implementation in ("sdpa", "eager"):
raise ValueError(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

heh, I thought we were discussing that it's HF Transformers that has to do that, not Ulysses SP. It affects all users regardless of whether they use Ulysses or not. Unless HF Transformers disallows not providing attn-mask with sdpa/eager, which I don't think is the case.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agree, removed from DeepSpeed side

# looks like packed sequences [0,...,N, 0,...,N, ...]. flash_attention_2 handles
# this via flash_varlen_fn, but sdpa/flex_attention apply full causal masking
# across the resets, producing incorrect attention.
if "position_ids" not in batch:
Copy link
Collaborator

@stas00 stas00 Mar 15, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure about this. This might lead to a user getting the wrong behavior if they packed samples but forgot to supply pos ids. Should we simply assert if pos ids aren't there and not potentially create invalid pos ids?

I agree there needs to be a check and it's not there.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, It would need to be in the TRL trainer, for the collator to always provide position_ids when SP is enabled, so the adapter never needs to generate them. I Can try to fix it there.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you, Kashif.

And probably then add an assert on SP side if pos id isn't there?

@kashif
Copy link
Contributor Author

kashif commented Mar 16, 2026

Thank you for running those quality comparison experiments, Kashif

I'm a bit unclear about your last "success" comment - what was missing to make FA2 match? are you saying the mismatch was from missing position_ids? but we said that already that SDPA (and now most likely FlexAttenion) have a trouble with no-attn-mask / yes-pos-id and will ignore packed samples. SDPA on the other hand does the right thing here.

And it's great to hear Flex Attention works as well with Ulysses, so we could add it to the allow list.

So, FA2 was the one producing correct results, while SDPA/flex were wrong. Here's what was happening:

When position_ids are not in the dataloader batch (common with SFTTrainer + packing=False), UlyssesSPDataLoaderAdapter doesn't generate them. The Trainer then generates position_ids = [0,...,chunk_len-1] on each rank AFTER the adapter has already sharded the sequence. After all_gather in UlyssesSPAttentionHF.forward(), the concatenated position_ids become:

 [0,...,2047, 0,...,2047, 0,...,2047, 0,...,2047]  # looks like 4 packed documents!

FA2 "accidentally" handles this correctly — _is_packed_sequence() detects the resets and switches to flash_varlen_fn, treating each shard as a separate document. This gives correct attention within each shard.

SDPA with is_causal=True applies a simple lower-triangular causal mask over the entire gathered sequence, allowing tokens to attend across the position_id resets. This produced loss=4.55 vs FA2's correct loss=2.37.

The fix: generate position_ids in UlyssesSPDataLoaderAdapter.refill() BEFORE all_gather and sharding, so each rank gets correct global positions (rank 1 gets [2048,...,4095], not [0,...,2047]). After gather they reconstruct to monotonic [0,...,8191] — no resets, all backends produce identical results.

With this fix, all three backends match within numerical precision:

  ┌──────┬───────┬───────┬───────┐
  │ Step │  FA2  │ SDPA  │ Flex  │
  ├──────┼───────┼───────┼───────┤
  │ 1    │ 2.152 │ 2.152 │ 2.150 │
  ├──────┼───────┼───────┼───────┤
  │ 5    │ 2.113 │ 2.113 │ 2.112 │
  ├──────┼───────┼───────┼───────┤
  │ 10   │ 2.151 │ 2.152 │ 2.151 │
  └──────┴───────┴───────┴───────┘

For flex_attention, we also needed to rebuild the BlockMask for the full gathered sequence length after the all-to-all (the wrapper discards the original one since it was built for the local shard).

Signed-off-by: Kashif Rasul <kashif.rasul@gmail.com>
@stas00
Copy link
Collaborator

stas00 commented Mar 16, 2026

great explanations, Kashif - thank you!

  1. let's assert if pos ids isn't there, trusting that the user will set it up correctly. Generating a warning doesn't guarantee the user will see. But an assert and telling them to do it correctly is probably the safest/resilient way forward.
  2. and as you said a special treatment needs to be added for BlockMask for flex attn - I'm not familiar with this one, so will see your implementation when you get a chance to add it.

Thank you, Kashif

Signed-off-by: Kashif Rasul <kashif.rasul@gmail.com>
@kashif
Copy link
Contributor Author

kashif commented Mar 17, 2026

@stas00, regarding point 2, we added BlockMask handling for flex_attention in these places:

  1. uattn_wrapper: keeps the BlockMask instead of discarding it (other mask types are set to None)
  2. UlyssesSPAttentionHF.forward(): rebuilds the BlockMask for the full gathered sequence length after the all-to-all (the original was built for the local shard)
  3. register_with_transformers(): imports BlockMask and create_block_mask once and stores them on the instance (only when core_attn_implementation == "flex_attention")

Signed-off-by: Kashif Rasul <kashif.rasul@gmail.com>
@stas00
Copy link
Collaborator

stas00 commented Mar 17, 2026

Thank you very much, Kashif.

Do you think all this amazing tooling you added should live here and not in HF Transformers?

@kashif
Copy link
Contributor Author

kashif commented Mar 17, 2026

checking

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants