Summary
When the ONNX opset-23 Attention op is used with bfloat16 inputs, ORT fails in two ways:
-
CPU EP – The BFloat16 kernel is not registered, so ORT falls back to AOT function-body inlining. The inlined graph contains an Expand(13) node for which CPU EP has no bfloat16 kernel:
NOT_IMPLEMENTED: Could not find an implementation for Expand(13) node with name ''
-
CUDA EP (function-body fallback) – When Flash/MEA are unavailable and the CUDA kernel returns NOT_IMPLEMENTED, ORT tries to inline the function body. The ONNX Attention function body (in onnx/defs/nn/utils.cc) creates FloatNegInf and ScalarZero as float32 constants, then uses them in:
MaskTri = Where(BoolMaskTri, FloatNegInf, ScalarZero) # float32
AttnBiasCausalOrNot = Add(AttnBias, MaskTri) # bfloat16 + float32 → TYPE MISMATCH
Error: Type parameter (T) of Optype (Add) bound to different types (tensor(bfloat16) and tensor(float))
Minimal Repro
import numpy as np, onnx, onnxruntime as ort
from onnx import helper, TensorProto
B, S, H, d_k = 1, 4, 2, 8
inputs = [helper.make_tensor_value_info(n, TensorProto.BFLOAT16, [B, S, H, d_k]) for n in ['Q','K','V']]
out = helper.make_tensor_value_info('Y', TensorProto.BFLOAT16, [B, S, H, d_k])
node = helper.make_node('Attention', ['Q','K','V'], ['Y'], q_num_heads=H, kv_num_heads=H)
graph = helper.make_graph([node], 'g', inputs, [out])
model = helper.make_model(graph, opset_imports=[helper.make_opsetid('', 23)])
model.ir_version = 8
sess = ort.InferenceSession(model.SerializeToString(), providers=['CPUExecutionProvider'])
# → NOT_IMPLEMENTED: Could not find an implementation for Expand(13) node with name ''
Root Cause Analysis
CPU EP
onnxruntime/core/providers/cpu/llm/attention.cc registers Attention for float and MLFloat16 only — BFloat16 is missing. When no kernel matches, ORT falls back to function-body inlining via TryGetFunctionProto. The inlined graph uses Expand (and other ops) for which the CPU EP has no bfloat16 kernels.
Function body type mismatch
The ONNX Attention function body builder (onnx/defs/nn/utils.cc::AttentionAppendFunctionCausalMask) creates causal-mask constants as hard-coded float32:
float neg_inf = -std::numeric_limits<float>::infinity();
builder.Const1D("FloatNegInf", neg_inf); // float32
builder.Const1D("ScalarZero", 0.f); // float32
When attn_mask is provided as bfloat16 (or the bias is bfloat16 from ConstantOfShape path) and is_causal=1, the downstream Add(AttnBias, MaskTri) gets mixed types.
Proposed Fix
CPU EP: Register BFloat16 for the CPU Attention kernel (opset 23 & 24), upcasting to float32 for internal computation (same pattern as MLFloat16).
Function body: Add CastLike to match MaskTri to the type of AttnBias before the Add:
MaskTriTyped = CastLike(MaskTri, AttnBias)
AttnBiasCausalOrNot = Add(AttnBias, MaskTriTyped)
Environment
- ORT 1.24.4 (main branch at commit
aadf724)
- ONNX 1.20.1
- CPU EP (x86_64 Linux)
Summary
When the ONNX opset-23
Attentionop is used with bfloat16 inputs, ORT fails in two ways:CPU EP – The BFloat16 kernel is not registered, so ORT falls back to AOT function-body inlining. The inlined graph contains an
Expand(13)node for which CPU EP has no bfloat16 kernel:CUDA EP (function-body fallback) – When Flash/MEA are unavailable and the CUDA kernel returns
NOT_IMPLEMENTED, ORT tries to inline the function body. The ONNX Attention function body (inonnx/defs/nn/utils.cc) createsFloatNegInfandScalarZeroas float32 constants, then uses them in:Error:
Type parameter (T) of Optype (Add) bound to different types (tensor(bfloat16) and tensor(float))Minimal Repro
Root Cause Analysis
CPU EP
onnxruntime/core/providers/cpu/llm/attention.ccregisters Attention forfloatandMLFloat16only —BFloat16is missing. When no kernel matches, ORT falls back to function-body inlining viaTryGetFunctionProto. The inlined graph usesExpand(and other ops) for which the CPU EP has no bfloat16 kernels.Function body type mismatch
The ONNX
Attentionfunction body builder (onnx/defs/nn/utils.cc::AttentionAppendFunctionCausalMask) creates causal-mask constants as hard-coded float32:When
attn_maskis provided as bfloat16 (or the bias is bfloat16 fromConstantOfShapepath) andis_causal=1, the downstreamAdd(AttnBias, MaskTri)gets mixed types.Proposed Fix
CPU EP: Register
BFloat16for the CPU Attention kernel (opset 23 & 24), upcasting to float32 for internal computation (same pattern asMLFloat16).Function body: Add
CastLiketo matchMaskTrito the type ofAttnBiasbefore theAdd:Environment
aadf724)