Skip to content

feat: Add TensorRT Edge-LLM AttentionPlugin backend support#4108

Open
chohk88 wants to merge 2 commits intomainfrom
attn-plugin-workflow
Open

feat: Add TensorRT Edge-LLM AttentionPlugin backend support#4108
chohk88 wants to merge 2 commits intomainfrom
attn-plugin-workflow

Conversation

@chohk88
Copy link
Collaborator

@chohk88 chohk88 commented Mar 3, 2026

Add plugin backend as an alternative to the default SDPA lowering for LLM inference, providing ~1.7x-3.3x speedup over SDPA and ~8x-11x over PyTorch eager execution.

Supported Models: Llama 3.x (3.1/3.2), Qwen 2.5, Qwen 3

Changes:

  • examples/dynamo/attention_plugin_example.py: Standalone plugin demo with correctness validation against PyTorch SDPA
  • examples/dynamo/end_to_end_llm_generation_example.py: End-to-end LLM generation example with plugin integration and benchmarks
  • tools/llm/plugin_utils.py: Model-agnostic plugin utilities including op registration (tensorrt_edge_llm::xqa_attn), TensorRT converter, PluginAttention module, LLMPluginWrapper, compilation and generation
  • tools/llm/run_llm.py: Add --backend plugin/sdpa selection with plugin workflow integration
  • tools/llm/README.md: Plugin backend documentation with build guide, usage examples, and performance summary

Plugin library built from TensorRT-Edge-LLM 0.4.0: https://github.com/chohk88/TensorRT-Edge-LLM/tree/feature/torch-tensorrt-python-runtime

Description

Please include a summary of the change and which issue is fixed. Please also include relevant motivation and context. List any dependencies that are required for this change.

Fixes # (issue)

Type of change

Please delete options that are not relevant and/or add your own.

  • New feature (non-breaking change which adds functionality)

Checklist:

  • My code follows the style guidelines of this project (You can use the linters)
  • I have performed a self-review of my own code
  • I have commented my code, particularly in hard-to-understand areas and hacks
  • I have made corresponding changes to the documentation
  • I have added tests to verify my fix or my feature
  • New and existing unit tests pass locally with my changes
  • I have added the relevant labels to my PR in so that relevant reviewers are notified

@chohk88 chohk88 requested review from narendasan and zewenli98 March 3, 2026 13:54
@chohk88 chohk88 self-assigned this Mar 3, 2026
@meta-cla meta-cla bot added the cla signed label Mar 3, 2026
Copy link
Collaborator

@narendasan narendasan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall, I think its close. @zewenli98 should take a pass but we can merge near as is. but I want to think about next how we might create lowering passes that insert the placeholder ops programmatically. Evan is about to disable decomposition by default for sdpa so we can basically dynamically insert a pass that keys on those ops

trt_timings.append(elapsed_ms / 1000.0)
else:
# SDPA backend (default)
if args.cache == "static_v1":
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have a few threads, backend and cache and with @zewenli98's PR in core Attention. can we merge these settings so its easy to understand when you will get TRT-Edge-LLM, when you get native IAttention and when you get Static KV Cache?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@chohk88 I implemented the converters for some attention variants in #4104. Can you take a look how to integrate?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@narendasan @zewenli98 Thanks for the reviews!

Now, --backend supports three options — sdpa, iattention, and plugin and update it in README on what each provides:

Additionally, I've added --backend iattention and it works correctly (outputs match PyTorch for all tested models). However, I found that the TRT IAttention layer is never actually used with HF models — HF always passes attn_bias (causal mask tensor) to _scaled_dot_product_efficient_attention, which causes the converter to take the manual matmul+softmax decomposition path instead of ctx.net.add_attention(). Filed #4129 to track this.

Copy link
Collaborator

@narendasan narendasan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have lowering passes to insert the tensorrt edge llm ops in place of pytorch ops?

trt_timings.append(elapsed_ms / 1000.0)
else:
# SDPA backend (default)
if args.cache == "static_v1":
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@chohk88 I implemented the converters for some attention variants in #4104. Can you take a look how to integrate?

Add plugin backend as an alternative to the default SDPA lowering for
LLM inference, providing ~1.5x-1.8x speedup over SDPA and ~8x-11x
over PyTorch eager execution.

Supported Models: Llama 3.x (3.1/3.2), Qwen 2.5, Qwen 3

Changes:
- examples/dynamo/attention_plugin_example.py: Standalone plugin demo
  with correctness validation against PyTorch SDPA
- examples/dynamo/end_to_end_llm_generation_example.py: End-to-end LLM
  generation example with plugin integration and benchmarks
- tools/llm/plugin_utils.py: Model-agnostic plugin utilities including
  op registration (tensorrt_edge_llm::xqa_attn), TensorRT converter,
  PluginAttention module, LLMPluginWrapper, compilation and generation
- tools/llm/run_llm.py: Add --backend plugin/sdpa selection with plugin
  workflow integration
- tools/llm/README.md: Plugin backend documentation with build guide,
  usage examples, and performance summary

Plugin library built from TensorRT-Edge-LLM 0.4.0:
https://github.com/chohk88/TensorRT-Edge-LLM/tree/feature/torch-tensorrt-python-runtime
@chohk88 chohk88 force-pushed the attn-plugin-workflow branch from 26aaeeb to d6f5873 Compare March 15, 2026 08:45
…review comments

- Add --backend iattention using TRT native IAttention layer (PR #4104)
- Move plugin converter to plugin_converter.py (per review)
- Fix compile_plugin_model: drop enabled_precisions with use_explicit_typing, use Debugger
- Update README: document 3 backends, add MAX_JOBS=8, simplify performance section

Note: iattention backend currently falls back to manual matmul+softmax
decomposition because HuggingFace passes attn_bias (causal mask tensor) to
_scaled_dot_product_efficient_attention, causing the converter to bypass
TRT's native IAttention layer. See #4129 for tracking.
@chohk88
Copy link
Collaborator Author

chohk88 commented Mar 16, 2026

Do we have lowering passes to insert the tensorrt edge llm ops in place of pytorch ops?

Not yet. Currently the plugin backend does module-level replacement before torch.export: it iterates over the HF model's decoder layers (model.model.layers[i].self_attn) and swaps each attention module with a PluginAttention that calls the Edge-LLM custom op.

Implementing this as an FX lowering pass instead would require pattern-matching the full Q/K/V projection → reshape → RoPE → SDPA op sequence in the exported graph, which is non-trivial since these patterns differ across architectures (Qwen, Llama, Gemma, etc.).

@chohk88
Copy link
Collaborator Author

chohk88 commented Mar 16, 2026

@narendasan @zewenli98 I ran correctness and benchmark tests across all backend × model combinations to verify the changes work as expected. All backends produce outputs identical to PyTorch, and performance numbers are in a reasonable range. Note that the specific performance numbers in the README have been removed since these are not from an official benchmarking setup — the README now describes approximate speedup ranges instead.

Correctness (PyTorch vs TensorRT output match)

Model sdpa (no cache) sdpa (static_v1) plugin iattention
Qwen2.5-0.5B
Qwen3-0.6B
Llama-3.2-1B

Benchmark (Median Latency ms, A100 80GB, FP16, ISL=2048, OSL=128, Batch=1)

Model PyTorch sdpa (no cache) sdpa (static_v1) plugin iattention
Qwen2.5-0.5B 4751 3271 1238 421 5421
Qwen3-0.6B 6875 4031 1708 569 6792
Llama-3.2-1B 7053 5466 1379 465 8283

Note on iattention: Currently on par with or slightly slower than PyTorch for autoregressive generation because the TRT IAttention layer is not actually used — HF passes attn_bias (causal mask) to efficient_attention, causing the converter to take the manual decomposition path. Tracked in #4129

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

↔ [Converter] IAttention converter bypasses TRT native IAttention layer due to HuggingFace causal mask (attn_bias)

3 participants