Skip to content

Pass packed boundary metadata to Qwen3.5 linear-attention fast kernels from data collator#45034

Open
sdharani91 wants to merge 15 commits intohuggingface:mainfrom
sdharani91:feature_packing_qwen
Open

Pass packed boundary metadata to Qwen3.5 linear-attention fast kernels from data collator#45034
sdharani91 wants to merge 15 commits intohuggingface:mainfrom
sdharani91:feature_packing_qwen

Conversation

@sdharani91
Copy link
Copy Markdown

What does this PR do?

This is a follow up to #44867

This PR fixes Qwen3.5 padding-free packed inputs on the linear-attention fast path by consuming collator-provided packed metadata. The linear-attention block now uses seq_idx for the causal convolution path and cu_seq_lens_q / cu_seq_lens_k for the FLA path, matching the repo’s existing DataCollatorWithFlattening contract. I also added a deterministic fast-path regression test comparing padded and padding-free inputs, plus a slow-path contract test that raises clearly when padding-free kwargs are passed without fast-kernel support. The slow fallback implementation itself is unchanged in this PR.

Fixes # 44717

Code Agent Policy

The Transformers repo is currently being overwhelmed by a large number of PRs and issue comments written by
code agents. We are currently bottlenecked by our ability to review and respond to them. As a result,
we ask that new users do not submit pure code agent PRs at this time.
You may use code agents in drafting or to help you diagnose issues. We'd also ask autonomous "OpenClaw"-like agents
not to open any PRs or issues for the moment.

PRs that appear to be fully agent-written will probably be closed without review, and we may block users who do this
repeatedly or maliciously.

This is a rapidly-evolving situation that's causing significant shockwaves in the open-source community. As a result,
this policy is likely to be updated regularly in the near future. For more information, please read CONTRIBUTING.md.

  • [Y ] I confirm that this is not a pure code agent PR.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • [ Y] Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • [ Y] Did you write any new necessary tests?

Who can review?

@vasqu

Copy link
Copy Markdown
Contributor

@vasqu vasqu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for sticking with this and sorry about the confusion about the other PR / issues

I added some comments, my main point stands to move even more to the data collator and take a look at bamba maybe which has done something similar with seq_idx for example

@@ -214,6 +206,9 @@ def forward(
hidden_states: torch.Tensor,
cache_params: Qwen3_5DynamicCache | None = None,
attention_mask: torch.Tensor | None = None,
seq_idx: torch.IntTensor | None = None,
cu_seq_lens_q: torch.LongTensor | None = None,
cu_seq_lens_k: torch.LongTensor | None = None,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, not on you but it is annoying because they have another different standard on FLA side 😭 They use cu_seqlens, see e.g. https://github.com/fla-org/flash-linear-attention/blob/2e90142c8075af0a0efe4979c22136194a307140/fla/ops/gated_delta_rule/fused_recurrent.py#L298

We had a similar thing in Bamba where we added these for typing under our kwargs, see

class BambaFlashAttentionKwargs(TypedDict, total=False):
just now adjusted for this special case in this weird FLA/conv mixup

We do need to change the datacollator tho to support only returning cu_seqlens instead of the q/k versions

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have made the change in data collator to pass cu_seqlens along with q/k versions to keep it compatible for other models like bamba which use those.

Comment on lines +248 to +255
has_fast_path = self.causal_conv1d_fn is not None and self.chunk_gated_delta_rule.__module__.startswith(
"fla."
)
if not has_fast_path and any(x is not None for x in (seq_idx, cu_seq_lens_q, cu_seq_lens_k)):
raise NotImplementedError(
"Padding-free training kwargs require fast path support. Please install `flash-linear-attention` "
"and `causal-conv1d`."
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
has_fast_path = self.causal_conv1d_fn is not None and self.chunk_gated_delta_rule.__module__.startswith(
"fla."
)
if not has_fast_path and any(x is not None for x in (seq_idx, cu_seq_lens_q, cu_seq_lens_k)):
raise NotImplementedError(
"Padding-free training kwargs require fast path support. Please install `flash-linear-attention` "
"and `causal-conv1d`."
)

we shouldn't have these checks, it should stay a power feature for people who know what they do

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed this - although I see similar checks in https://github.com/huggingface/transformers/blob/main/src/transformers/models/bamba/modular_bamba.py#L679. Should we have it here to keep behavior consistent?

Comment thread src/transformers/models/qwen3_5/modular_qwen3_5.py Outdated
Comment on lines +378 to +380
seq_idx=kwargs.get("seq_idx"),
cu_seq_lens_q=kwargs.get("cu_seq_lens_q"),
cu_seq_lens_k=kwargs.get("cu_seq_lens_k"),
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would pass kwargs directly and adjust the signature instead

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated

@@ -57,6 +59,7 @@ class Qwen3_5TextModelTester(CausalLMModelTester):

def __init__(self, parent):
super().__init__(parent=parent)
self.hidden_act = "silu"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually needed?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

because otherwise they inherit gelu from CausalLMModelTester, which is invalid for the FLA fused path

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah gotcha, makes sense

Comment on lines +202 to +203
if not is_flash_linear_attention_available() or not is_causal_conv1d_available():
self.skipTest("Qwen3.5 padding-free fast path requires `flash-linear-attention` and `causal-conv1d`.")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's make require decorators out of these instead

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added decorators.

def test_padding_free_matches_padded_fast_path_regression(self):
if not is_flash_linear_attention_available() or not is_causal_conv1d_available():
self.skipTest("Qwen3.5 padding-free fast path requires `flash-linear-attention` and `causal-conv1d`.")
torch.manual_seed(0)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
torch.manual_seed(0)

Comment on lines +206 to +224
config = Qwen3_5TextConfig(
vocab_size=100,
hidden_size=64,
intermediate_size=128,
num_hidden_layers=2,
num_attention_heads=4,
num_key_value_heads=2,
head_dim=16,
max_position_embeddings=64,
hidden_act="silu",
layer_types=["full_attention", "linear_attention"],
linear_conv_kernel_dim=2,
linear_key_head_dim=16,
linear_value_head_dim=16,
linear_num_key_heads=2,
linear_num_value_heads=4,
pad_token_id=0,
)
model = Qwen3_5ForCausalLM(config).to(torch_device).eval()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we use the prepare configs function instead and get the text config through that?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated

Comment on lines +225 to +228
linear_attn = model.model.layers[1].linear_attn
self.assertIsNotNone(linear_attn.causal_conv1d_fn)
self.assertTrue(linear_attn.chunk_gated_delta_rule.__module__.startswith("fla."))
self.assertTrue(linear_attn.recurrent_gated_delta_rule.__module__.startswith("fla."))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
linear_attn = model.model.layers[1].linear_attn
self.assertIsNotNone(linear_attn.causal_conv1d_fn)
self.assertTrue(linear_attn.chunk_gated_delta_rule.__module__.startswith("fla."))
self.assertTrue(linear_attn.recurrent_gated_delta_rule.__module__.startswith("fla."))

Not needed

self.assertTrue(linear_attn.recurrent_gated_delta_rule.__module__.startswith("fla."))

padded_input_ids = torch.tensor([[0, 1, 2, 3], [4, 5, 6, 7]], device=torch_device)
attention_mask = torch.tensor([[0, 1, 1, 1], [1, 1, 1, 1]], dtype=torch.long, device=torch_device)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can sparsify this a bit more (i.e. more padding)

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated

@vasqu
Copy link
Copy Markdown
Contributor

vasqu commented Mar 31, 2026

You can ping me when its ready for a review

@sdharani91
Copy link
Copy Markdown
Author

@vasqu - I have tried to address all your comments - please take a look.

@sdharani91 sdharani91 marked this pull request as ready for review March 31, 2026 17:01
Copy link
Copy Markdown
Contributor

@vasqu vasqu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some comments 🤗 since qwen 3.5 is essentially using the same patterns as in 3.5 Moe and Next, does it make sense to adopt it to those as well? Ofc, we should first make sure to have a good version on 3.5

Comment thread src/transformers/data/data_collator.py Outdated
- uses `separator_id` to separate sequences within the concatenated `labels`, default value is -100
- no padding will be added, returns `input_ids`, `labels` and `position_ids` by default
- optionally returns the kwargs contained in FlashAttentionKwargs
- optionally returns the kwargs contained in FlashAttentionKwargs, plus `cu_seqlens` for FLA-style kernels
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should make this a new option via kwargs that defaults to false

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need a new option for returning cu_seqlens? Adding more options for each specific usecase could make it messy? Can't we club this in the flash attention kwargs?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, if we allow everything under the same FA kwargs they will only get worse. Believe me, I've now seen multiple messy standards, in the future I want to deprecate this so that we map to our standards / or the underlying standard and splitting here makes it easier to remove in the future.

Just look into modeling flash attn utils and torch made it worse again 😅

@@ -214,7 +220,10 @@ def forward(
hidden_states: torch.Tensor,
cache_params: Qwen3_5DynamicCache | None = None,
attention_mask: torch.Tensor | None = None,
**kwargs: Unpack[Qwen3_5FlashAttentionKwargs],
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be properly in the signature here, not within the kwargs, the parent should have the new kwargs type if that makes sense

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The parent (decoder block) already has Transformerkwargs. Then do we still need Qwen3_5FlashAttentionKwargs or could that be added to Transformerkwargs itself and passed into GatedDeltaNetForward as proper args seq_idx and cu_seqlens?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, honestly good question - maybe we could refactor the bamba kwargs as well into general TransformersKwargs. Wdyt @ArthurZucker

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah sounds good

Comment on lines +225 to +226
seq_idx = kwargs.get("seq_idx")
cu_seqlens = kwargs.get("cu_seqlens")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then we don't need this, and just directly pass

Comment on lines +296 to +298
chunk_kwargs = {}
if getattr(self.chunk_gated_delta_rule, "__module__", "").startswith("fla."):
chunk_kwargs["cu_seqlens"] = cu_seqlens
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should not have to check the underlying function but directly pass it either way (we might need to change the pure torch functions for that to ignore it via kwargs

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My understanding is that the torch fallback torch_chunk_gated_delta_rule used here actually comes from Qwen3Next, not modular_qwen3_5.py. Should the durable fix be to make the Qwen3Next fallback accept and ignore cu_seqlens/extra kwargs, so Qwen3.5 can pass it unconditionally?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we should fix the parent tbh to allow it

@@ -367,6 +381,7 @@ def forward(
hidden_states=hidden_states,
cache_params=past_key_values,
attention_mask=attention_mask,
**kwargs,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is where it would be differently typed then

Comment thread src/transformers/testing_utils.py Outdated
@@ -703,6 +704,19 @@ def require_all_flash_attn(test_case):
)(test_case)


def require_flash_linear_attention_and_causal_conv1d(test_case):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would not fuse them tbh and just make two separate ones (if they don't exist yet). We can always stack decorators which easier to compose

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense.

@@ -57,6 +59,7 @@ class Qwen3_5TextModelTester(CausalLMModelTester):

def __init__(self, parent):
super().__init__(parent=parent)
self.hidden_act = "silu"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah gotcha, makes sense

model = Qwen3_5ForCausalLM(config).to(torch_device).eval()

padded_input_ids = torch.tensor([[0, 0, 0, 1, 2, 3], [0, 0, 0, 0, 4, 5]], device=torch_device)
attention_mask = torch.tensor([[0, 0, 0, 1, 1, 1], [0, 0, 0, 0, 1, 1]], dtype=torch.long, device=torch_device)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
attention_mask = torch.tensor([[0, 0, 0, 1, 1, 1], [0, 0, 0, 0, 1, 1]], dtype=torch.long, device=torch_device)
attention_mask = torch.tensor([[0, 1, 1, 1, 1, 1], [0, 0, 0, 0, 1, 1]], dtype=torch.long, device=torch_device)

would like a more "extreme" case of balance

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay will add add that as another case in this test

Comment thread tests/models/qwen3_5/test_modeling_qwen3_5.py Outdated
logits_padded = res_padded.logits[attention_mask.bool()]
logits_padfree = res_padfree.logits[0]

torch.testing.assert_close(logits_padded, logits_padfree, atol=1e-5, rtol=1e-5)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't mind if we have to raise the tolerance a bit then

@vasqu
Copy link
Copy Markdown
Contributor

vasqu commented Apr 1, 2026

Also looks like a few things changed on main, let's sync

@sdharani91
Copy link
Copy Markdown
Author

@vasqu - Have follow up questions on a couple of comments - pls help clarify.
Will address and rebase.

@vasqu
Copy link
Copy Markdown
Contributor

vasqu commented Apr 2, 2026

Answered, for safety re kwargs it would make more sense to follow Bamba for now - we can refactor into TransformersKwargs later on

@sdharani91
Copy link
Copy Markdown
Author

@vasqu Regarding following bamba - bamba does not have TransformersKwargs - it only has BambaFlashAttentionKwargs: https://github.com/huggingface/transformers/blob/main/src/transformers/models/bamba/modular_bamba.py#L60C7-L60C32.
I am sceptic about replacing TransformersKwargs with Qwen3_5FlashAttentionKwargs since TransformersKwargs is broader - we might impact flows which need other args not in Qwen3_5FlashAttentionKwargs?

@vasqu
Copy link
Copy Markdown
Contributor

vasqu commented Apr 2, 2026

On a second thought, do we need the additional kwargs on our side? Meaning we can just pass cu seq lens q directly no (ie no changes to the collator)? Just add a comment there that this is intentional

The only additional thing would be seq idx which could be added to the transformers kwargs.

Iirc, it is "just" typing so it would likely not affect actually passing other kwargs. If we were to create new kwargs, it would need to inherit from the transformers kwargs and extend them iiuc

@sdharani91
Copy link
Copy Markdown
Author

sdharani91 commented Apr 3, 2026

Yeah I had something similar in my earlier commit - not creating new kwargs but I was passing both q and k versions: 33979c7

Let me summarize so I understand correctly:

  1. Add seq_id to Transformerkwargs.
  2. The parent (decoder block) which already has Transformerkwargs passes only cu_seqlens (== cu seq lens q) and seq_id to GatedDeltaNetForward as proper args.
    @vasqu Pls cpnfirm if this is okay.

@vasqu
Copy link
Copy Markdown
Contributor

vasqu commented Apr 9, 2026

@sdharani91 Sorry was out for easter and pytorch conference

Yes, that sounds good to me. We should add a note within the docstrings there re cu seq lens on how we standardize this just to be clear on it

hamishivi added a commit to allenai/open-instruct that referenced this pull request Apr 15, 2026
Qwen3.5's GatedDeltaNet (linear attention) layers ignore sequence
boundaries in packed inputs: causal conv1d leaks across sequences
(seq_idx=None) and the recurrent state carries over. This causes
incorrect logprobs during training and inflated KL divergence.

Monkey-patch based on huggingface/transformers#45034:
- Pass seq_idx to causal_conv1d_fn for packing-aware convolution
- Pass cu_seqlens to FLA chunk_gated_delta_rule kernel
- Forward **kwargs through DecoderLayer to linear_attn
- Add cu_seqlens to padding_free_collator output

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@vasqu
Copy link
Copy Markdown
Contributor

vasqu commented Apr 15, 2026

Gentle ping @sdharani91 if you are still interested in this

@sdharani91
Copy link
Copy Markdown
Author

Yes @vasqu - I will raise a follow up revision by this week.

@github-actions
Copy link
Copy Markdown
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: qwen3_5, qwen3_next

@sdharani91
Copy link
Copy Markdown
Author

@vasqu - have addressed comments, please take a look.

Copy link
Copy Markdown
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's type the kwargs we are adding as well, otherwise LGTM

Comment on lines +428 to +429
seq_idx: torch.IntTensor | None = None,
cu_seqlens: torch.LongTensor | None = None,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should be in kwargs: Unpack[TransformersKwargs] no?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In Qwen3_5DecoderLayer we fetch this from TransformersKwargs and pass it to GatedDeltaNet as proper arguments.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants