Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 49 additions & 0 deletions examples/bfloat16.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
using oneAPI, Test

@static if !isdefined(Core, :BFloat16)
@info "BFloat16 requires Julia 1.12+, skipping."
exit()
end

bfloat16_supported = oneAPI._device_supports_bfloat16()

@info "BFloat16 support: $bfloat16_supported"

if !bfloat16_supported
@info "Device does not support BFloat16, skipping."
exit()
end

# Conversions: Core.BFloat16 in Julia 1.12 may not have Float32 constructors yet
float32_to_bf16(x::Float32) = reinterpret(Core.BFloat16, (reinterpret(UInt32, x) >> 16) % UInt16)
bf16_to_float32(x::Core.BFloat16) = reinterpret(Float32, UInt32(reinterpret(UInt16, x)) << 16)

# Simple kernel: scale BFloat16 values by 2 via Float32 round-trip
# (BFloat16 arithmetic is done by promoting to Float32 on device)
function scale_bf16(input, output)
i = get_global_id()
@inbounds begin
val = reinterpret(UInt16, input[i])
# BFloat16 -> Float32: shift left 16 bits
f = reinterpret(Float32, UInt32(val) << 16)
f *= 2.0f0
# Float32 -> BFloat16: take upper 16 bits
output[i] = reinterpret(Core.BFloat16, (reinterpret(UInt32, f) >> 16) % UInt16)
end
return
end

n = 1024
a = float32_to_bf16.(rand(Float32, n))

d_a = oneArray(a)
d_out = oneArray{Core.BFloat16}(undef, n)

@oneapi items=n scale_bf16(d_a, d_out)
result = Array(d_out)

# Verify: each output should be 2x the input (in Float32 space)
result_f32 = bf16_to_float32.(result)
expected_f32 = bf16_to_float32.(a) .* 2.0f0
@test result_f32 ≈ expected_f32
@info "BFloat16 scale-by-2 kernel passed!"
21 changes: 21 additions & 0 deletions src/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,22 @@ function contains_eltype(T, X)
return false
end

function _device_supports_bfloat16()
# check the driver extension first
if haskey(oneL0.extension_properties(driver()),
oneL0.ZE_BFLOAT16_CONVERSIONS_EXT_NAME)
return true
end
# some drivers (e.g. older versions on PVC/Max) don't advertise the extension,
# but the hardware supports BFloat16 natively. fall back to checking device ID.
dev_id = oneL0.properties(device()).deviceId
# Intel Data Center GPU Max (Ponte Vecchio): device IDs 0x0BD0-0x0BDB
if 0x0BD0 <= dev_id <= 0x0BDB
return true
end
return false
end

function check_eltype(T)
Base.allocatedinline(T) || error("oneArray only supports element types that are stored inline")
Base.isbitsunion(T) && error("oneArray does not yet support isbits-union arrays")
Expand All @@ -39,6 +55,11 @@ function check_eltype(T)
oneL0.ZE_DEVICE_MODULE_FLAG_FP64
contains_eltype(T, Float64) && error("Float64 is not supported on this device")
end
@static if isdefined(Core, :BFloat16)
if !_device_supports_bfloat16()
contains_eltype(T, Core.BFloat16) && error("BFloat16 is not supported on this device")
end
end
end

"""
Expand Down
126 changes: 125 additions & 1 deletion src/compiler/compilation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,14 @@ function GPUCompiler.finish_ir!(job::oneAPICompilerJob, mod::LLVM.Module,
# indices (e.g., "1 0") corrupts adjacent struct fields.
flatten_nested_insertvalue!(mod)

# When the device supports BFloat16 but the SPIR-V runtime doesn't accept
# SPV_KHR_bfloat16, lower all bfloat types to i16 so the translator can
# handle the module without the extension.
if @static(isdefined(Core, :BFloat16) && isdefined(LLVM, :BFloatType)) &&
_device_supports_bfloat16() && !_driver_supports_bfloat16_spirv()
lower_bfloat_to_i16!(mod)
end

return entry
end

Expand Down Expand Up @@ -158,6 +166,105 @@ function flatten_insert!(inst::LLVM.Instruction)
end


# Lower bfloat types to i16 in the LLVM IR.
# This is needed when the device supports BFloat16 but the SPIR-V runtime/translator
# doesn't support SPV_KHR_bfloat16. Since sizeof(bfloat)==sizeof(i16)==2, the memory
# layout is identical.
#
# TODO: Julia 1.12's Core.BFloat16 is a bare primitive (no Float32 conversion, no
# arithmetic), so fptrunc/fpext instructions never appear in practice. If Julia adds
# BFloat16 conversion methods in the future, this pass should be extended to handle
# fptrunc float→bfloat and fpext bfloat→float, either via inline RNE bit manipulation
# or calls to __devicelib_ConvertFToBF16INTEL / __devicelib_ConvertBF16ToFINTEL.
function lower_bfloat_to_i16!(mod::LLVM.Module)
T_bf16 = LLVM.BFloatType()
T_i16 = LLVM.Int16Type()

# Phase 1: Eliminate all bitcasts between i16 and bfloat (same bit width).
eliminate_bf16_bitcasts!(mod, T_bf16, T_i16)

# Phase 2: Replace remaining bfloat GEPs, loads, and stores with i16 equivalents.
for f in functions(mod)
isempty(blocks(f)) && continue
for bb in blocks(f)
to_replace = LLVM.Instruction[]
for inst in instructions(bb)
opcode = LLVM.API.LLVMGetInstructionOpcode(inst)
if opcode == LLVM.API.LLVMGetElementPtr
src_ty = LLVMType(LLVM.API.LLVMGetGEPSourceElementType(inst))
src_ty == T_bf16 && push!(to_replace, inst)
elseif opcode == LLVM.API.LLVMLoad
value_type(inst) == T_bf16 && push!(to_replace, inst)
elseif opcode == LLVM.API.LLVMStore
value_type(LLVM.operands(inst)[1]) == T_bf16 && push!(to_replace, inst)
end
end

for inst in to_replace
opcode = LLVM.API.LLVMGetInstructionOpcode(inst)
builder = LLVM.IRBuilder()
LLVM.position!(builder, inst)

if opcode == LLVM.API.LLVMGetElementPtr
ptr = LLVM.operands(inst)[1]
indices = LLVM.Value[LLVM.operands(inst)[i] for i in 2:length(LLVM.operands(inst))]
new_gep = if LLVM.API.LLVMIsInBounds(inst) != 0
LLVM.inbounds_gep!(builder, T_i16, ptr, indices)
else
LLVM.gep!(builder, T_i16, ptr, indices)
end
LLVM.replace_uses!(inst, new_gep)
elseif opcode == LLVM.API.LLVMLoad
ptr = LLVM.operands(inst)[1]
new_load = LLVM.load!(builder, T_i16, ptr)
LLVM.replace_uses!(inst, new_load)
elseif opcode == LLVM.API.LLVMStore
val = LLVM.operands(inst)[1]
ptr = LLVM.operands(inst)[2]
LLVM.store!(builder, val, ptr)
end

LLVM.API.LLVMInstructionEraseFromParent(inst)
LLVM.dispose(builder)
end
end
end

return true
end

# Iteratively eliminate bitcasts between i16 and bfloat (same bit representation).
function eliminate_bf16_bitcasts!(mod::LLVM.Module, T_bf16::LLVMType, T_i16::LLVMType)
changed = true
while changed
changed = false
for f in functions(mod)
isempty(blocks(f)) && continue
for bb in blocks(f)
to_delete = LLVM.Instruction[]
for inst in instructions(bb)
if LLVM.API.LLVMGetInstructionOpcode(inst) == LLVM.API.LLVMBitCast
src = LLVM.operands(inst)[1]
src_ty = value_type(src)
dst_ty = value_type(inst)
if (src_ty == T_i16 && dst_ty == T_bf16) ||
(src_ty == T_bf16 && dst_ty == T_i16) ||
(src_ty == dst_ty)
LLVM.replace_uses!(inst, src)
push!(to_delete, inst)
changed = true
end
end
end
for inst in to_delete
LLVM.API.LLVMInstructionEraseFromParent(inst)
end
end
end
end
end


## compiler implementation (cache, configure, compile, and link)

# cache of compilation caches, per device
Expand All @@ -183,18 +290,35 @@ function compiler_config(dev; kwargs...)
end
return config
end
# Whether the driver's SPIR-V runtime accepts the SPV_KHR_bfloat16 extension.
function _driver_supports_bfloat16_spirv()
@static if isdefined(Core, :BFloat16)
haskey(oneL0.extension_properties(driver()),
oneL0.ZE_BFLOAT16_CONVERSIONS_EXT_NAME)
else
false
end
end

@noinline function _compiler_config(dev; kernel=true, name=nothing, always_inline=false, kwargs...)
supports_fp16 = oneL0.module_properties(device()).fp16flags & oneL0.ZE_DEVICE_MODULE_FLAG_FP16 == oneL0.ZE_DEVICE_MODULE_FLAG_FP16
supports_fp64 = oneL0.module_properties(device()).fp64flags & oneL0.ZE_DEVICE_MODULE_FLAG_FP64 == oneL0.ZE_DEVICE_MODULE_FLAG_FP64
# Allow BFloat16 in IR if the device supports it (even if the SPIR-V runtime doesn't
# advertise the extension). We lower bfloat→i16 in finish_ir! when needed.
supports_bfloat16 = _device_supports_bfloat16()

# TODO: emit printf format strings in constant memory
extensions = String[
"SPV_EXT_relaxed_printf_string_address_space",
"SPV_EXT_shader_atomic_float_add"
]
# Only add the SPIR-V extension if the runtime actually supports it
if _driver_supports_bfloat16_spirv()
push!(extensions, "SPV_KHR_bfloat16")
end

# create GPUCompiler objects
target = SPIRVCompilerTarget(; extensions, supports_fp16, supports_fp64, kwargs...)
target = SPIRVCompilerTarget(; extensions, supports_fp16, supports_fp64, supports_bfloat16, kwargs...)
params = oneAPICompilerParams()
CompilerConfig(target, params; kernel, name, always_inline)
end
Expand Down
8 changes: 8 additions & 0 deletions test/setup.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,14 @@ const float64_supported = oneL0.module_properties(device()).fp64flags & oneL0.ZE
if float64_supported
append!(eltypes, [Float64, ComplexF64])
end
@static if isdefined(Core, :BFloat16)
const bfloat16_supported = oneAPI._device_supports_bfloat16()
if bfloat16_supported
push!(eltypes, Core.BFloat16)
end
else
const bfloat16_supported = false
end
TestSuite.supported_eltypes(::Type{<:oneArray}) = eltypes

const validation_layer = parse(Bool, get(ENV, "ZE_ENABLE_VALIDATION_LAYER", "false"))
Expand Down
Loading