Skip to content

Fix #567: BFloat16 support#568

Open
michel2323 wants to merge 3 commits intomasterfrom
bfloat16
Open

Fix #567: BFloat16 support#568
michel2323 wants to merge 3 commits intomasterfrom
bfloat16

Conversation

@michel2323
Copy link
Copy Markdown
Member

Adds BFloat16 support after JuliaGPU/GPUCompiler.jl#778 is merged.

@github-actions
Copy link
Copy Markdown
Contributor

github-actions bot commented Mar 31, 2026

Your PR requires formatting changes to meet the project's style guidelines.
Please consider running Runic (git runic master) to apply these changes.

Click here to view the suggested changes.
diff --git a/examples/bfloat16.jl b/examples/bfloat16.jl
index 4b35968..3fd88da 100644
--- a/examples/bfloat16.jl
+++ b/examples/bfloat16.jl
@@ -39,7 +39,7 @@ a = float32_to_bf16.(rand(Float32, n))
 d_a = oneArray(a)
 d_out = oneArray{Core.BFloat16}(undef, n)
 
-@oneapi items=n scale_bf16(d_a, d_out)
+@oneapi items = n scale_bf16(d_a, d_out)
 result = Array(d_out)
 
 # Verify: each output should be 2x the input (in Float32 space)
diff --git a/src/array.jl b/src/array.jl
index c23a46f..03f8995 100644
--- a/src/array.jl
+++ b/src/array.jl
@@ -29,19 +29,21 @@ function contains_eltype(T, X)
 end
 
 function _device_supports_bfloat16()
-  # check the driver extension first
-  if haskey(oneL0.extension_properties(driver()),
-            oneL0.ZE_BFLOAT16_CONVERSIONS_EXT_NAME)
-      return true
-  end
-  # some drivers (e.g. older versions on PVC/Max) don't advertise the extension,
-  # but the hardware supports BFloat16 natively. fall back to checking device ID.
-  dev_id = oneL0.properties(device()).deviceId
-  # Intel Data Center GPU Max (Ponte Vecchio): device IDs 0x0BD0-0x0BDB
-  if 0x0BD0 <= dev_id <= 0x0BDB
-      return true
-  end
-  return false
+    # check the driver extension first
+    if haskey(
+            oneL0.extension_properties(driver()),
+            oneL0.ZE_BFLOAT16_CONVERSIONS_EXT_NAME
+        )
+        return true
+    end
+    # some drivers (e.g. older versions on PVC/Max) don't advertise the extension,
+    # but the hardware supports BFloat16 natively. fall back to checking device ID.
+    dev_id = oneL0.properties(device()).deviceId
+    # Intel Data Center GPU Max (Ponte Vecchio): device IDs 0x0BD0-0x0BDB
+    if 0x0BD0 <= dev_id <= 0x0BDB
+        return true
+    end
+    return false
 end
 
 function check_eltype(T)
@@ -55,11 +57,11 @@ function check_eltype(T)
       oneL0.ZE_DEVICE_MODULE_FLAG_FP64
     contains_eltype(T, Float64) && error("Float64 is not supported on this device")
   end
-  @static if isdefined(Core, :BFloat16)
-    if !_device_supports_bfloat16()
-      contains_eltype(T, Core.BFloat16) && error("BFloat16 is not supported on this device")
+    return @static if isdefined(Core, :BFloat16)
+        if !_device_supports_bfloat16()
+            contains_eltype(T, Core.BFloat16) && error("BFloat16 is not supported on this device")
+        end
     end
-  end
 end
 
 """
diff --git a/src/compiler/compilation.jl b/src/compiler/compilation.jl
index 8015248..120a175 100644
--- a/src/compiler/compilation.jl
+++ b/src/compiler/compilation.jl
@@ -54,7 +54,7 @@ function GPUCompiler.finish_ir!(job::oneAPICompilerJob, mod::LLVM.Module,
     # SPV_KHR_bfloat16, lower all bfloat types to i16 so the translator can
     # handle the module without the extension.
     if @static(isdefined(Core, :BFloat16) && isdefined(LLVM, :BFloatType)) &&
-       _device_supports_bfloat16() && !_driver_supports_bfloat16_spirv()
+            _device_supports_bfloat16() && !_driver_supports_bfloat16_spirv()
         lower_bfloat_to_i16!(mod)
     end
 
@@ -248,8 +248,8 @@ function eliminate_bf16_bitcasts!(mod::LLVM.Module, T_bf16::LLVMType, T_i16::LLV
                         src_ty = value_type(src)
                         dst_ty = value_type(inst)
                         if (src_ty == T_i16 && dst_ty == T_bf16) ||
-                           (src_ty == T_bf16 && dst_ty == T_i16) ||
-                           (src_ty == dst_ty)
+                                (src_ty == T_bf16 && dst_ty == T_i16) ||
+                                (src_ty == dst_ty)
                             LLVM.replace_uses!(inst, src)
                             push!(to_delete, inst)
                             changed = true
@@ -262,6 +262,7 @@ function eliminate_bf16_bitcasts!(mod::LLVM.Module, T_bf16::LLVMType, T_i16::LLV
             end
         end
     end
+    return
 end
 
 
@@ -292,9 +293,11 @@ function compiler_config(dev; kwargs...)
 end
 # Whether the driver's SPIR-V runtime accepts the SPV_KHR_bfloat16 extension.
 function _driver_supports_bfloat16_spirv()
-    @static if isdefined(Core, :BFloat16)
-        haskey(oneL0.extension_properties(driver()),
-               oneL0.ZE_BFLOAT16_CONVERSIONS_EXT_NAME)
+    return @static if isdefined(Core, :BFloat16)
+        haskey(
+            oneL0.extension_properties(driver()),
+            oneL0.ZE_BFLOAT16_CONVERSIONS_EXT_NAME
+        )
     else
         false
     end

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant