feat: Add TensorRT Edge-LLM AttentionPlugin backend support#4108
feat: Add TensorRT Edge-LLM AttentionPlugin backend support#4108
Conversation
narendasan
left a comment
There was a problem hiding this comment.
Overall, I think its close. @zewenli98 should take a pass but we can merge near as is. but I want to think about next how we might create lowering passes that insert the placeholder ops programmatically. Evan is about to disable decomposition by default for sdpa so we can basically dynamically insert a pass that keys on those ops
| trt_timings.append(elapsed_ms / 1000.0) | ||
| else: | ||
| # SDPA backend (default) | ||
| if args.cache == "static_v1": |
There was a problem hiding this comment.
We have a few threads, backend and cache and with @zewenli98's PR in core Attention. can we merge these settings so its easy to understand when you will get TRT-Edge-LLM, when you get native IAttention and when you get Static KV Cache?
There was a problem hiding this comment.
@narendasan @zewenli98 Thanks for the reviews!
Now, --backend supports three options — sdpa, iattention, and plugin and update it in README on what each provides:
sdpa: SDPA lowering pass + optional--cache static_v1/v2iattention: TRT native IAttention converters from converter: add sdpa, flash-sdpa, efficient-sdpa, and cudnn-sdpa converters #4104 (no KV cache yet)plugin: Edge-LLM attention plugin with built-in KV cache
Additionally, I've added --backend iattention and it works correctly (outputs match PyTorch for all tested models). However, I found that the TRT IAttention layer is never actually used with HF models — HF always passes attn_bias (causal mask tensor) to _scaled_dot_product_efficient_attention, which causes the converter to take the manual matmul+softmax decomposition path instead of ctx.net.add_attention(). Filed #4129 to track this.
narendasan
left a comment
There was a problem hiding this comment.
Do we have lowering passes to insert the tensorrt edge llm ops in place of pytorch ops?
| trt_timings.append(elapsed_ms / 1000.0) | ||
| else: | ||
| # SDPA backend (default) | ||
| if args.cache == "static_v1": |
Add plugin backend as an alternative to the default SDPA lowering for LLM inference, providing ~1.5x-1.8x speedup over SDPA and ~8x-11x over PyTorch eager execution. Supported Models: Llama 3.x (3.1/3.2), Qwen 2.5, Qwen 3 Changes: - examples/dynamo/attention_plugin_example.py: Standalone plugin demo with correctness validation against PyTorch SDPA - examples/dynamo/end_to_end_llm_generation_example.py: End-to-end LLM generation example with plugin integration and benchmarks - tools/llm/plugin_utils.py: Model-agnostic plugin utilities including op registration (tensorrt_edge_llm::xqa_attn), TensorRT converter, PluginAttention module, LLMPluginWrapper, compilation and generation - tools/llm/run_llm.py: Add --backend plugin/sdpa selection with plugin workflow integration - tools/llm/README.md: Plugin backend documentation with build guide, usage examples, and performance summary Plugin library built from TensorRT-Edge-LLM 0.4.0: https://github.com/chohk88/TensorRT-Edge-LLM/tree/feature/torch-tensorrt-python-runtime
26aaeeb to
d6f5873
Compare
…review comments - Add --backend iattention using TRT native IAttention layer (PR #4104) - Move plugin converter to plugin_converter.py (per review) - Fix compile_plugin_model: drop enabled_precisions with use_explicit_typing, use Debugger - Update README: document 3 backends, add MAX_JOBS=8, simplify performance section Note: iattention backend currently falls back to manual matmul+softmax decomposition because HuggingFace passes attn_bias (causal mask tensor) to _scaled_dot_product_efficient_attention, causing the converter to bypass TRT's native IAttention layer. See #4129 for tracking.
Not yet. Currently the plugin backend does module-level replacement before Implementing this as an FX lowering pass instead would require pattern-matching the full |
|
@narendasan @zewenli98 I ran correctness and benchmark tests across all backend × model combinations to verify the changes work as expected. All backends produce outputs identical to PyTorch, and performance numbers are in a reasonable range. Note that the specific performance numbers in the README have been removed since these are not from an official benchmarking setup — the README now describes approximate speedup ranges instead. Correctness (PyTorch vs TensorRT output match)
Benchmark (Median Latency ms, A100 80GB, FP16, ISL=2048, OSL=128, Batch=1)
|
Add plugin backend as an alternative to the default SDPA lowering for LLM inference, providing ~1.7x-3.3x speedup over SDPA and ~8x-11x over PyTorch eager execution.
Supported Models: Llama 3.x (3.1/3.2), Qwen 2.5, Qwen 3
Changes:
Plugin library built from TensorRT-Edge-LLM 0.4.0: https://github.com/chohk88/TensorRT-Edge-LLM/tree/feature/torch-tensorrt-python-runtime
Description
Please include a summary of the change and which issue is fixed. Please also include relevant motivation and context. List any dependencies that are required for this change.
Fixes # (issue)
Type of change
Please delete options that are not relevant and/or add your own.
Checklist: