Skip to content

Using the Attention Python component to build Multi Head Attention (MHA), the inference result is incorrect #4732

@zz-jin

Description

@zz-jin

Description

Environment

TensorRT Version:TensorRT-10.16.0.72

NVIDIA GPU:NVIDIA GeForce RTX 3090

NVIDIA Driver Version:570.133.07

CUDA Version:cuda_12.8.r12.8/compiler.35583870_0

CUDNN Version:cuDNN version: 90701

Operating System:
NAME="Ubuntu"
VERSION="20.04.3 LTS (Focal Fossa)"
ID=ubuntu
ID_LIKE=debian
PRETTY_NAME="Ubuntu 20.04.3 LTS"
VERSION_ID="20.04"
HOME_URL="https://www.ubuntu.com/"
SUPPORT_URL="https://help.ubuntu.com/"
BUG_REPORT_URL="https://bugs.launchpad.net/ubuntu/"
PRIVACY_POLICY_URL="https://www.ubuntu.com/legal/terms-and-policies/privacy-policy"
VERSION_CODENAME=focal
UBUNTU_CODENAME=focal

Python Version (if applicable):Python 3.10.19

Tensorflow Version (if applicable):

PyTorch Version (if applicable):
torch 2.7.1+cu128
torchaudio 2.7.1+cu128
torchprofile 0.0.4
torchvision 0.22.1+cu128

Baremetal or Container (if so, version):

Relevant Files

Model link:

Steps To Reproduce

Commands or scripts:

Runnable code is next

#!/usr/bin/env python3
"""
TensorRT Attention Operator 完整测试脚本

基于 NVIDIA TensorRT 文档 Attention 算子示例:
https://docs.nvidia.com/deeplearning/tensorrt/latest/_static/operators/Attention.html#examples

支持两种模式:

  • TensorRT >= 10.0: 使用原生 add_attention API
  • TensorRT < 10.0: 使用 MatrixMultiply + Softmax 手动实现等效逻辑
    """

import sys
import numpy as np
import torch # 新增:用于加载 .pt 文件

try:
import tensorrt as trt
except ImportError:
print("错误: 未安装 tensorrt,请先安装 TensorRT Python 包。")
sys.exit(1)

try:
import pycuda.driver as cuda
import pycuda.autoinit
except ImportError:
print("错误: 未安装 pycuda,请运行: pip install pycuda")
sys.exit(1)

TRT_VERSION = tuple(int(x) for x in trt.version.split(".")[:2])
TRT_LOGGER = trt.Logger(trt.Logger.WARNING)

def build_attention_network_v8(builder: trt.Builder, q_shape, kv_shape, mask_shape):
"""
TensorRT < 10.0: 用 MatrixMultiply + Softmax 手动实现 Scaled Dot-Product Attention

Attention(Q, K, V) = softmax(Q @ K^T / sqrt(h) + mask) @ V
输入形状: [b, d, s, h]  (batch, num_heads, seq_len, head_dim)
mask 为 additive mask,0 表示允许 attend,-inf 表示屏蔽
"""
network = builder.create_network(
    flags=1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
)
b, d, s_q, h = kv_shape
query = network.add_input("query", dtype=trt.float16, shape=q_shape)
key = network.add_input("key", dtype=trt.float16, shape=kv_shape)
value = network.add_input("value", dtype=trt.float16, shape=kv_shape)
mask = network.add_input("mask", dtype=trt.float16, shape=mask_shape)

# Q @ K^T => [b, d, s_q, s_kv]
bmm1 = network.add_matrix_multiply(
    query, trt.MatrixOperation.NONE,
    key, trt.MatrixOperation.TRANSPOSE,
)
bmm1.name = "BMM1_QK"

# scale = 1 / sqrt(head_dim)
scale_val = np.array([1.0 / np.sqrt(h)], dtype=np.float16)
scale_const = network.add_constant((1, 1, 1, 1), scale_val)
scale_layer = network.add_elementwise(
    bmm1.get_output(0), scale_const.get_output(0), trt.ElementWiseOperation.PROD
)
scale_layer.name = "Scale"

# additive mask (0.0 = attend, -inf = block)
mask_add = network.add_elementwise(
    scale_layer.get_output(0), mask, trt.ElementWiseOperation.SUM
)
mask_add.name = "MaskAdd"

# softmax over last dimension (s_kv)
softmax = network.add_softmax(mask_add.get_output(0))
softmax.axes = 1 << 3  # axis=3 (s_kv dimension)
softmax.name = "Softmax"

# Attention @ V => [b, d, s_q, h]
bmm2 = network.add_matrix_multiply(
    softmax.get_output(0), trt.MatrixOperation.NONE,
    value, trt.MatrixOperation.NONE,
)
bmm2.name = "BMM2_AV"

network.mark_output(bmm2.get_output(0))
return network

def build_attention_network_v10(builder: trt.Builder, q_shape, kv_shape, mask_shape):
"""TensorRT >= 10.0: 使用原生 add_attention API,支持 Q/KV shape 不同"""
network = builder.create_network(
flags=1 << int(trt.NetworkDefinitionCreationFlag.STRONGLY_TYPED)
)
query = network.add_input("query", dtype=trt.float16, shape=q_shape)
key = network.add_input("key", dtype=trt.float16, shape=kv_shape)
value = network.add_input("value", dtype=trt.float16, shape=kv_shape)
mask = network.add_input("mask", dtype=trt.bool, shape=mask_shape)

layer = network.add_attention(
    query, key, value, trt.AttentionNormalizationOp.SOFTMAX, False
)
layer.mask = mask
network.mark_output(layer.get_output(0))
return network

def prepare_test_data_from_files(q_shape, kv_shape, mask_shape, use_native_api: bool):
"""从 .pt 文件加载 Q/K/V,并构造 mask"""
# 加载 PyTorch 张量并转为 NumPy float16
if 0 :
q_tensor = torch.load("q2.pt").cuda().to(torch.float16)
k_tensor = torch.load("k2.pt").cuda().to(torch.float16)
v_tensor = torch.load("v2.pt").cuda().to(torch.float16)
else :
q_tensor = torch.randn(q_shape, dtype=torch.float16, device='cuda')
k_tensor = torch.randn(kv_shape, dtype=torch.float16, device='cuda')
v_tensor = torch.randn(kv_shape, dtype=torch.float16, device='cuda')

assert q_tensor.shape == q_shape, f"Q shape mismatch: {q_tensor.shape} vs {q_shape}"
assert k_tensor.shape == kv_shape, f"K shape mismatch: {k_tensor.shape} vs {kv_shape}"
assert v_tensor.shape == kv_shape, f"V shape mismatch: {v_tensor.shape} vs {kv_shape}"

inputs = {
    "query": q_tensor.cpu().numpy().astype(np.float16),
    "key": k_tensor.cpu().numpy().astype(np.float16),
    "value": v_tensor.cpu().numpy().astype(np.float16),
}

if use_native_api:
    # mask: (B, 1, S_q, S_kv) 或 (B, H, S_q, S_kv)
    # 这里用全 True 表示 attend to all
    mask_data = np.ones(mask_shape, dtype=np.bool_)
    # mask_data = np.zeros(mask_shape, dtype=np.float16)
else:
    # 手动实现路径暂不支持,此处仅为占位
    mask_data = np.zeros(mask_shape, dtype=np.float16)

inputs["mask"] = mask_data
return inputs

def numpy_reference_attention(query, key, value):
"""NumPy 参考实现(支持 Q/KV seq_len 不同)"""
q = query.astype(np.float32) # (B, H, S_q, D)
k = key.astype(np.float32) # (B, H, S_kv, D)
v = value.astype(np.float32) # (B, H, S_kv, D)
d = q.shape[-1]
# Q @ K^T / sqrt(d) -> (B, H, S_q, S_kv)
scores = np.matmul(q, k.transpose(0, 1, 3, 2)) / np.sqrt(d)
# softmax over S_kv dim
scores_exp = np.exp(scores - scores.max(axis=-1, keepdims=True))
attn_weights = scores_exp / scores_exp.sum(axis=-1, keepdims=True)
# attn_weights @ V -> (B, H, S_q, D)
out = np.matmul(attn_weights, v)
return out.astype(np.float16)

def build_engine(builder, network):
config = builder.create_builder_config()
# config.set_flag(trt.BuilderFlag.FP16)
config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 30) # 1 GiB

print("正在构建 TensorRT 引擎 ...")
serialized = builder.build_serialized_network(network, config)
if serialized is None:
    raise RuntimeError("引擎构建失败!")

runtime = trt.Runtime(TRT_LOGGER)
engine = runtime.deserialize_cuda_engine(serialized)
if engine is None:
    raise RuntimeError("引擎反序列化失败!")
print("引擎构建成功。")
return engine

def _get_binding_info(engine, i):
"""兼容 TRT 8.x / 10.x 的 binding 信息获取"""
if hasattr(engine, "get_tensor_name"):
name = engine.get_tensor_name(i)
shape = engine.get_tensor_shape(name)
dtype_trt = engine.get_tensor_dtype(name)
is_input = engine.get_tensor_mode(name) == trt.TensorIOMode.INPUT
else:
name = engine.get_binding_name(i)
shape = engine.get_binding_shape(i)
dtype_trt = engine.get_binding_dtype(i)
is_input = engine.binding_is_input(i)
return name, tuple(shape), trt.nptype(dtype_trt), is_input

def allocate_buffers(engine, inputs_dict):
"""分配 GPU 缓冲区并拷贝输入数据"""
n_bindings = engine.num_io_tensors if hasattr(engine, "num_io_tensors") else engine.num_bindings
bindings = [None] * n_bindings
device_buffers = {}
output_buffers = {}

for i in range(n_bindings):
    name, shape, dtype_np, is_input = _get_binding_info(engine, i)
    size = int(np.prod(shape)) * np.dtype(dtype_np).itemsize

    d_buf = cuda.mem_alloc(max(size, 1))
    device_buffers[name] = d_buf
    bindings[i] = int(d_buf)

    if is_input:
        host_data = inputs_dict[name]
        cuda.memcpy_htod(d_buf, np.ascontiguousarray(host_data))
    else:
        output_buffers[name] = (shape, dtype_np, d_buf)

return bindings, device_buffers, output_buffers

def run_inference(engine, inputs_dict):
"""执行推理并返回输出"""
context = engine.create_execution_context()
bindings, device_buffers, output_buffers = allocate_buffers(engine, inputs_dict)

stream = cuda.Stream()

if hasattr(context, "execute_async_v3"):
    for name, d_buf in device_buffers.items():
        context.set_tensor_address(name, int(d_buf))
    context.execute_async_v3(stream_handle=stream.handle)
else:
    context.execute_async_v2(bindings=bindings, stream_handle=stream.handle)

stream.synchronize()

results = {}
for name, (shape, dtype_np, d_buf) in output_buffers.items():
    host_out = np.empty(shape, dtype=dtype_np)
    cuda.memcpy_dtoh(host_out, d_buf)
    results[name] = host_out

for buf in device_buffers.values():
    buf.free()

return results

def prepare_test_data(qkv_shape, mask_shape, use_native_api: bool):
"""准备测试输入数据和期望输出"""
inputs = {}

query_data = np.ones(qkv_shape, dtype=np.float16)
for i in range(qkv_shape[1]):
    query_data[0, i, 0, :] = i + 1
inputs["query"] = query_data

key_data = np.ones(qkv_shape, dtype=np.float16)
for i in range(qkv_shape[1]):
    key_data[0, i, 0, :] = i + 1
inputs["key"] = key_data

value_data = np.ones(qkv_shape, dtype=np.float16)
for i in range(qkv_shape[1]):
    value_data[0, i, 0, :] = i + 1
inputs["value"] = value_data

if use_native_api:
    mask_data = np.ones(mask_shape, dtype=np.bool_)
else:
    mask_data = np.zeros(mask_shape, dtype=np.float16)  # additive mask: 0 = attend
inputs["mask"] = mask_data

expected_data = np.ones(qkv_shape, dtype=np.float16)
for i in range(qkv_shape[1]):
    expected_data[0, i, 0, :] = i + 1

return inputs, expected_data

def numpy_reference_attention(query, key, value):
"""NumPy 参考实现,用于对比验证"""
q = query.astype(np.float32)
k = key.astype(np.float32)
v = value.astype(np.float32)
h = q.shape[-1]
# Q @ K^T / sqrt(h)
scores = np.matmul(q, k.transpose(0, 1, 3, 2)) / np.sqrt(h)
# softmax along last axis
scores_exp = np.exp(scores - scores.max(axis=-1, keepdims=True))
attn_weights = scores_exp / scores_exp.sum(axis=-1, keepdims=True)
out = np.matmul(attn_weights, v)
return out.astype(np.float16)

def main():
print(f"TensorRT 版本: {trt.version}")
print(f"GPU: {cuda.Device(0).name()}")
print("-" * 60)

q_shape = (1, 8, 1, 16)
kv_shape = (1, 8, 2, 16)
mask_shape = (1, 1, 1, 2)  # (B, 1, S_q, S_kv) —— 可被广播到 (B, H, S_q, S_kv)

use_native = TRT_VERSION >= (10, 0)
# use_native = False


print("模式: 原生 add_attention API (TRT >= 10.0)")
print(f"Query 形状: {q_shape}")
print(f"Key/Value 形状: {kv_shape}")
print(f"Mask 形状: {mask_shape}  # (B, 1, S_q, S_kv)")
print("-" * 60)

builder = trt.Builder(TRT_LOGGER)

if use_native:
    network = build_attention_network_v10(builder, q_shape, kv_shape, mask_shape)
else:
    network = build_attention_network_v8(builder, q_shape, kv_shape, mask_shape)

engine = build_engine(builder, network)

print("use_native : ", use_native)
inputs = prepare_test_data_from_files(q_shape, kv_shape, mask_shape, use_native)
results = run_inference(engine, inputs)

output_name = list(results.keys())[0]
actual = results[output_name]
print("推理输出形状:", actual.shape)
print("actual : ", actual)
# print(f"actual : {actual:.2f}")  # 输出: 3.14
# 可选:与 NumPy 参考结果对比
np_ref = numpy_reference_attention(inputs["query"], inputs["key"], inputs["value"])
print("np_ref : ", np_ref)
print("NumPy 参考形状:", np_ref.shape)
diff = np.abs(actual - np_ref).max()
print(f"最大绝对误差: {diff:.6f}")

if name == "main":
torch.set_printoptions(precision=4, sci_mode=False)
np.set_printoptions(precision=4)
sys.exit(main())

example result is next :

TensorRT 版本: 10.16.0.72
GPU: NVIDIA GeForce RTX 3090

模式: 原生 add_attention API (TRT >= 10.0)
Query 形状: (1, 8, 1, 16)
Key/Value 形状: (1, 8, 2, 16)
Mask 形状: (1, 1, 1, 2) # (B, 1, S_q, S_kv)

正在构建 TensorRT 引擎 ...
引擎构建成功。
use_native : True
推理输出形状: (1, 8, 1, 16)
actual : [[[[-1.053 -0.1176 -2.38 1.424 1.275 -0.0709 0.182 0.697
0.2712 -0.908 0.1036 0.0455 0.5835 0.1721 -0.4185 -0.5874]]

[[ 0.555 0.1472 0.8066 -0.7705 -1.095 -0.5796 -0.9536 -1.927
-0.1566 1.756 -0.5464 -0.7583 0.4456 0.1359 0.687 -0.7617]]

[[ 0.985 1.23 0.095 0.1102 0.5396 0.97 -0.1757 0.2944
1.61 -0.7217 -1.174 -1.98 0.7227 0.812 -2.092 -0.4463]]

[[-1.138 0.2874 -1.549 1.1875 -0.901 1.2705 -0.994 -2.168
-0.404 1.953 0.0195 -1.066 -0.9834 -0.3882 1.168 -1.047 ]]

[[ 0.7705 0.949 -0.8696 -0.8716 -1.713 0.866 0.478 0.0445
-0.507 -1.389 -0.668 -0.2405 0.771 0.6133 -2.225 -1.054 ]]

[[ 0.696 0.8525 1.826 2.209 -0.5874 -1.98 -0.3994 -1.179
-0.1019 1.7705 -0.0214 -1.665 -0.9194 -0.4307 2.287 -1.268 ]]

[[-0.6055 -1.472 0.1969 0.3088 -0.0927 0.5107 0.538 -0.0392
-0.335 0.1619 1.105 0.1758 -0.1315 -0.2086 0.5674 0.0462]]

[[ 0.454 -1.278 0.865 -1.368 -0.5356 1.137 0.1066 -0.3372
0.8057 -1.239 0.9463 -0.2053 -0.2837 0.3354 -0.222 -0.609 ]]]]
np_ref : [[[[-0.8384 -0.0892 -1.551 1.578 0.743 -0.322 0.238 0.9575
0.0112 -0.687 0.1752 -0.185 0.544 0.33 -0.0775 -0.817 ]]

[[ 0.5117 0.0112 1.129 -0.446 -0.4597 -0.0243 -1.07 -1.345
0.1693 1.604 -0.8037 -0.9507 0.1526 0.3232 0.2659 -0.4111]]

[[ 0.5913 0.639 0.1238 0.3093 0.3838 0.494 -0.4536 0.384
1.372 -0.3115 -1.023 -1.474 0.658 0.4717 -1.5625 -0.1462]]

[[-0.7495 0.2434 -1.166 0.709 -0.734 0.346 -0.9263 -1.314
-0.2878 1.04 0.1853 -0.6963 -0.6006 -0.4006 0.623 -1.113 ]]

[[ 0.678 0.8867 -0.8994 -0.8896 -1.707 0.81 0.4343 0.1217
-0.427 -1.4375 -0.6934 -0.2612 0.786 0.575 -2.172 -0.982 ]]

[[ 0.58 0.86 1.419 2.031 -0.5522 -1.761 -0.477 -1.125
0.0249 1.505 -0.2788 -1.526 -0.932 -0.3494 2.049 -1.056 ]]

[[-0.6714 -1.203 0.0881 0.4436 -0.2834 0.477 0.637 0.2717
-0.607 0.1923 0.833 0.1099 0.0457 -0.2534 0.5576 0.066 ]]

[[ 0.324 -0.8438 0.5767 -1.107 -0.301 0.85 0.0753 0.179
0.127 -0.636 0.95 -0.0675 0.3687 -0.0655 -0.3174 -0.1871]]]]
NumPy 参考形状: (1, 8, 1, 16)
最大绝对误差: 0.924805

It can be seen that there is a significant error

Metadata

Metadata

Assignees

No one assigned

    Labels

    Module:AccuracyOutput mismatch between TensorRT and other frameworks

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions