Skip to content

feat: support attn_bias for efficient SDPA#4131

Open
zewenli98 wants to merge 1 commit intomainfrom
polish_attention_api
Open

feat: support attn_bias for efficient SDPA#4131
zewenli98 wants to merge 1 commit intomainfrom
polish_attention_api

Conversation

@zewenli98
Copy link
Collaborator

Description

Support attn_bias for efficient SDPA

Fixes #4129

Type of change

  • New feature (non-breaking change which adds functionality)

Checklist:

  • My code follows the style guidelines of this project (You can use the linters)
  • I have performed a self-review of my own code
  • I have commented my code, particularly in hard-to-understand areas and hacks
  • I have made corresponding changes to the documentation
  • I have added tests to verify my fix or my feature
  • New and existing unit tests pass locally with my changes
  • I have added the relevant labels to my PR in so that relevant reviewers are notified

@zewenli98 zewenli98 requested review from chohk88 and narendasan March 17, 2026 07:27
@zewenli98 zewenli98 self-assigned this Mar 17, 2026
@meta-cla meta-cla bot added the cla signed label Mar 17, 2026
@github-actions github-actions bot added documentation Improvements or additions to documentation component: tests Issues re: Tests component: conversion Issues re: Conversion stage component: core Issues re: The core compiler component: converters Issues re: Specific op converters component: api [Python] Issues re: Python API component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths labels Mar 17, 2026
@github-actions github-actions bot requested a review from cehongwang March 17, 2026 07:27
@chohk88
Copy link
Collaborator

chohk88 commented Mar 18, 2026

Hi Evan, I tested this PR on three HuggingFace models (Qwen2.5-0.5B, Qwen3-0.6B, Llama-3.2-1B) with the iattention backend. Here's what I found.

Background: How HuggingFace uses _scaled_dot_product_efficient_attention

HuggingFace always passes the causal mask as attn_bias with is_causal=False, rather than setting is_causal=True directly:

# HuggingFace's actual call pattern:
aten._scaled_dot_product_efficient_attention(
    Q, K, V,
    attn_bias=[[[ 0,   -inf, -inf, -inf],
                [ 0,    0,   -inf, -inf],
                [ 0,    0,    0,   -inf],
                [ 0,    0,    0,    0  ]]],
    is_causal=False
)

This additive causal mask (0 for attend, -inf for block) is semantically equivalent to is_causal=True, but it forces the converter to handle a large mask tensor.

Issue 1: Dynamic head dimension causes add_attention failure

When running with HuggingFace models, torch.export marks the query tensor's head dimension as dynamic (-1), even though it should be static. This is because query goes through .view().transpose() during tracing, while key/value go through an additional repeat_kv (.expand().reshape()) that preserves the static head count.

Observed shapes at converter:
  query:        (1, -1, -1, 64)    # head dim is dynamic
  key:          (1, 14, -1, 64)    # head dim is static (14)
  scaled_query: (1, -1, -1, 64)    # still dynamic after scaling

TRT's add_attention requires a static head dimension for validation, so it fails with:

Error Code 3: API Usage Error ((Unnamed Layer*) [AttentionInput]:
query number of heads must be divisible by key/value number of heads)

I worked around this by reshaping scaled_query to match key's shape (they should be identical after HF's GQA expansion via repeat_kv):

if (
    isinstance(scaled_query.shape[1], int)
    and scaled_query.shape[1] < 0
    and isinstance(key.shape[1], int)
    and key.shape[1] > 0
):
    shape_layer = ctx.net.add_shape(key)
    shape_layer.name = name + "_key_shape"
    shuffle = ctx.net.add_shuffle(scaled_query)
    shuffle.set_input(1, shape_layer.get_output(0))
    shuffle.name = name + "_fix_head_dim"
    scaled_query = shuffle.get_output(0)

Issue 2: IAttention with mask is slower than the previous manual decomposition

With the dynamic head dim fix applied, the PR's approach (IAttention + mask) does run, but the large causal mask tensor ([1, 14, 2048, 2048] per attention layer) introduces overhead that offsets the IAttention kernel's benefit. Benchmark on A100 80GB, FP16, ISL=2048, OSL=128, Batch=1:

  • PyTorch: 4743 ms
  • iattention (main branch, manual matmul decomposition): 5421 ms
  • iattention (this PR, IAttention + mask): 5633 ms ← slower than main

Observation: Discarding the mask and using is_causal=True gives dramatic improvement

Since HuggingFace's attn_bias is always a causal mask, discarding it and setting is_causal=True is semantically equivalent. I tested this by adding the following before add_attention:

if mask_tensor is not None:
    mask_tensor = None
    use_causal = True

Full benchmark results (Median Latency in ms):

Model PyTorch iattention (main) iattention (PR+mask) iattention (is_causal) sdpa no cache sdpa static_v1 plugin
Qwen2.5-0.5B 4743 5421 (0.88x) 5633 (0.84x) 2303 (2.06x) 3271 (1.5x) 1238 (3.8x) 421 (11.3x)
Qwen3-0.6B 6891 6792 (1.01x) 3097 (2.22x) 4031 (1.7x) 1708 (4.0x) 569 (12.1x)
Llama-3.2-1B 7064 8283 (0.85x) 4426 (1.60x) 5466 (1.3x) 1379 (5.1x) 465 (15.2x)

Parentheses show speedup vs PyTorch.

The is_causal=True approach makes iattention 1.6–2.2x faster than PyTorch and even faster than sdpa no cache (which goes through a separate graph-level lowering pass to achieve the same mask-to-causal conversion).

Trade-off discussion

However, unconditionally discarding the mask means the converter can no longer support arbitrary (non-causal) masks. The existing tests in test_attention_aten.py pass random float tensors as attn_bias, which would produce incorrect results if the mask were discarded.

If the iattention backend is intended primarily for HuggingFace LLM inference (where attn_bias is always a causal mask), then discarding the mask is the right optimization. But if the converter needs to remain general-purpose for arbitrary masks, we need a different approach — perhaps detecting the causal mask pattern at the graph level, or providing a flag (e.g., decompose_attention) to let users control this behavior.

What are your thoughts on how to proceed?

@chohk88
Copy link
Collaborator

chohk88 commented Mar 18, 2026

After further investigation, I think adding an iattention-specific lowering pass would be a cleaner solution than unconditionally discarding the mask inside the converter like above:

if mask_tensor is not None:
    mask_tensor = None
    use_causal = True

The sdpa backend avoids the mask overhead by using a lowering pass (enable_sdpa_converter()) that sets attn_bias=None and is_causal=True at the graph level before conversion. We could take a similar approach for the iattention backend, but with one difference: the lowering pass should modify the args of _efficient_attention without changing the target op, so that the built-in converter with add_attention() (IAttention kernel) is still used.

# sdpa lowering pass (existing):
_efficient_attention(Q, K, V, attn_bias=<mask>, is_causal=False)
  → F.sdpa(Q, K, V, attn_mask=None, is_causal=True)     ← replaces the op
  → sdpa_converter.py (matmul decomposition, no IAttention)

# iattention lowering pass (proposed):
_efficient_attention(Q, K, V, attn_bias=<mask>, is_causal=False)
  → _efficient_attention(Q, K, V, attn_bias=None, is_causal=True)  ← same op, args only
  → attention.py built-in converter (add_attention → IAttention kernel)

This keeps the converter general-purpose (arbitrary masks still work when no lowering pass is registered, e.g., in unit tests), while getting the optimal IAttention performance for LLM inference. It also feels cleaner than unconditionally discarding the mask inside the converter itself.

Note: the dynamic head dimension fix from Issue 1 is still needed in the converter regardless, since torch.export can mark Q's head dim as dynamic even when attn_bias=None.

What do you think?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cla signed component: api [Python] Issues re: Python API component: conversion Issues re: Conversion stage component: converters Issues re: Specific op converters component: core Issues re: The core compiler component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths component: tests Issues re: Tests documentation Improvements or additions to documentation

Projects

None yet

Development

Successfully merging this pull request may close these issues.

↔ [Converter] IAttention converter bypasses TRT native IAttention layer due to HuggingFace causal mask (attn_bias)

2 participants