Skip to content
Draft
10 changes: 10 additions & 0 deletions python_bindings/src/halide/halide_/PyEnums.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,16 @@ void define_enums(py::module &m) {
.value("AVX10_1", Target::Feature::AVX10_1)
.value("X86APX", Target::Feature::X86APX)
.value("Simulator", Target::Feature::Simulator)
.value("D3D12ComputeSM60", Target::Feature::D3D12ComputeSM60)
.value("D3D12ComputeSM61", Target::Feature::D3D12ComputeSM61)
.value("D3D12ComputeSM62", Target::Feature::D3D12ComputeSM62)
.value("D3D12ComputeSM63", Target::Feature::D3D12ComputeSM63)
.value("D3D12ComputeSM64", Target::Feature::D3D12ComputeSM64)
.value("D3D12ComputeSM65", Target::Feature::D3D12ComputeSM65)
.value("D3D12ComputeSM66", Target::Feature::D3D12ComputeSM66)
.value("D3D12ComputeSM67", Target::Feature::D3D12ComputeSM67)
.value("D3D12ComputeSM68", Target::Feature::D3D12ComputeSM68)
.value("D3D12ComputeSM69", Target::Feature::D3D12ComputeSM69)
.value("FeatureEnd", Target::Feature::FeatureEnd);

py::enum_<halide_type_code_t>(m, "TypeCode")
Expand Down
567 changes: 513 additions & 54 deletions src/CodeGen_D3D12Compute_Dev.cpp

Large diffs are not rendered by default.

6 changes: 5 additions & 1 deletion src/DeviceInterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,11 @@ Expr make_device_interface_call(DeviceAPI device_api, MemoryType memory_type) {
interface_name = "halide_hexagon_dma_device_interface";
break;
case DeviceAPI::D3D12Compute:
interface_name = "halide_d3d12compute_device_interface";
if (memory_type == MemoryType::GPUTexture) {
interface_name = "halide_d3d12compute_image_device_interface";
} else {
interface_name = "halide_d3d12compute_device_interface";
}
break;
case DeviceAPI::Vulkan:
interface_name = "halide_vulkan_device_interface";
Expand Down
135 changes: 130 additions & 5 deletions src/Target.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -776,6 +776,16 @@ const std::map<std::string, Target::Feature> feature_name_map = {
{"trace_realizations", Target::TraceRealizations},
{"trace_pipeline", Target::TracePipeline},
{"d3d12compute", Target::D3D12Compute},
{"d3d12compute_sm60", Target::D3D12ComputeSM60},
{"d3d12compute_sm61", Target::D3D12ComputeSM61},
{"d3d12compute_sm62", Target::D3D12ComputeSM62},
{"d3d12compute_sm63", Target::D3D12ComputeSM63},
{"d3d12compute_sm64", Target::D3D12ComputeSM64},
{"d3d12compute_sm65", Target::D3D12ComputeSM65},
{"d3d12compute_sm66", Target::D3D12ComputeSM66},
{"d3d12compute_sm67", Target::D3D12ComputeSM67},
{"d3d12compute_sm68", Target::D3D12ComputeSM68},
{"d3d12compute_sm69", Target::D3D12ComputeSM69},
{"strict_float", Target::StrictFloat},
{"tsan", Target::TSAN},
{"asan", Target::ASAN},
Expand Down Expand Up @@ -1135,6 +1145,22 @@ void Target::validate_features() const {
VSX,
});
}

// D3D12Compute SM version features require D3D12Compute to also be set.
if (!has_feature(D3D12Compute)) {
do_check_bad(*this, {
D3D12ComputeSM60,
D3D12ComputeSM61,
D3D12ComputeSM62,
D3D12ComputeSM63,
D3D12ComputeSM64,
D3D12ComputeSM65,
D3D12ComputeSM66,
D3D12ComputeSM67,
D3D12ComputeSM68,
D3D12ComputeSM69,
});
}
}

Target::Target(const std::string &target) {
Expand Down Expand Up @@ -1378,6 +1404,43 @@ int Target::get_vulkan_capability_lower_bound() const {
return 10;
}

int Target::get_d3d12compute_capability_lower_bound() const {
if (!has_feature(Target::D3D12Compute)) {
return -1;
}
if (has_feature(Target::D3D12ComputeSM60)) {
return 60;
}
if (has_feature(Target::D3D12ComputeSM61)) {
return 61;
}
if (has_feature(Target::D3D12ComputeSM62)) {
return 62;
}
if (has_feature(Target::D3D12ComputeSM63)) {
return 63;
}
if (has_feature(Target::D3D12ComputeSM64)) {
return 64;
}
if (has_feature(Target::D3D12ComputeSM65)) {
return 65;
}
if (has_feature(Target::D3D12ComputeSM66)) {
return 66;
}
if (has_feature(Target::D3D12ComputeSM67)) {
return 67;
}
if (has_feature(Target::D3D12ComputeSM68)) {
return 68;
}
if (has_feature(Target::D3D12ComputeSM69)) {
return 69;
}
return 51; // default: SM 5.1 (FXC)
}

int Target::get_arm_v8_lower_bound() const {
if (has_feature(Target::ARMv8a)) {
return 80;
Expand Down Expand Up @@ -1416,13 +1479,13 @@ bool Target::supports_type(const Type &t) const {
if (t.bits() == 64) {
if (t.is_float()) {
return (!has_feature(Metal) &&
!has_feature(D3D12Compute) &&
(!has_feature(D3D12Compute) || get_d3d12compute_capability_lower_bound() >= 60) &&
(!has_feature(Target::OpenCL) || has_feature(Target::CLDoubles)) &&
(!has_feature(Vulkan) || has_feature(Target::VulkanFloat64)) &&
!has_feature(WebGPU));
} else {
return (!has_feature(Metal) &&
!has_feature(D3D12Compute) &&
(!has_feature(D3D12Compute) || get_d3d12compute_capability_lower_bound() >= 60) &&
(!has_feature(Vulkan) || has_feature(Target::VulkanInt64)) &&
!has_feature(WebGPU));
}
Expand Down Expand Up @@ -1450,9 +1513,18 @@ bool Target::supports_type(const Type &t, DeviceAPI device) const {
return has_feature(Target::CLDoubles);
}
} else if (device == DeviceAPI::D3D12Compute) {
// Shader Model 5.x can optionally support double-precision; 64-bit int
// types are not supported.
return t.bits() < 64;
// SM 5.1 (FXC): no 64-bit types. float16 and int8 work via widening.
// SM 6.0+: 64-bit int and float (double, int64_t, uint64_t) supported.
// SM 6.2+: native 16-bit float (float16_t) and int (int16_t, uint16_t).
// SM 6.6+: native 8-bit int (int8_t, uint8_t). Earlier SMs widen to int32.
// SM 6.9+: long vectors (5–1024 lanes) via vector<T, N> syntax.
if (t.bits() == 64) {
return get_d3d12compute_capability_lower_bound() >= 60;
}
if (t.lanes() > 4) {
return get_d3d12compute_capability_lower_bound() >= 69;
}
return true;
} else if (device == DeviceAPI::Vulkan) {
if (t.is_float() && t.bits() == 64) {
return has_feature(Target::VulkanFloat64);
Expand Down Expand Up @@ -1653,6 +1725,17 @@ bool Target::get_runtime_compatible_target(const Target &other, Target &result)
VulkanV12,
VulkanV13,

D3D12ComputeSM60,
D3D12ComputeSM61,
D3D12ComputeSM62,
D3D12ComputeSM63,
D3D12ComputeSM64,
D3D12ComputeSM65,
D3D12ComputeSM66,
D3D12ComputeSM67,
D3D12ComputeSM68,
D3D12ComputeSM69,

ARMv8a,
ARMv81a,
ARMv82a,
Expand Down Expand Up @@ -1787,6 +1870,43 @@ bool Target::get_runtime_compatible_target(const Target &other, Target &result)
output.features.reset(VulkanV13);
}

// Pick tight lower bound for D3D12Compute SM version. Use fall-through to clear redundant features
int d3d12_sm_a = get_d3d12compute_capability_lower_bound();
int d3d12_sm_b = other.get_d3d12compute_capability_lower_bound();

// Same trick as CUDA: -1 (unused) becomes large when cast to unsigned, so min gives the true lower bound.
int d3d12_sm = std::min((unsigned)d3d12_sm_a, (unsigned)d3d12_sm_b);
if (d3d12_sm < 60) {
output.features.reset(D3D12ComputeSM60);
}
if (d3d12_sm < 61) {
output.features.reset(D3D12ComputeSM61);
}
if (d3d12_sm < 62) {
output.features.reset(D3D12ComputeSM62);
}
if (d3d12_sm < 63) {
output.features.reset(D3D12ComputeSM63);
}
if (d3d12_sm < 64) {
output.features.reset(D3D12ComputeSM64);
}
if (d3d12_sm < 65) {
output.features.reset(D3D12ComputeSM65);
}
if (d3d12_sm < 66) {
output.features.reset(D3D12ComputeSM66);
}
if (d3d12_sm < 67) {
output.features.reset(D3D12ComputeSM67);
}
if (d3d12_sm < 68) {
output.features.reset(D3D12ComputeSM68);
}
if (d3d12_sm < 69) {
output.features.reset(D3D12ComputeSM69);
}

// Pick tight lower bound for HVX version. Use fall-through to clear redundant features
int hvx_a = get_hvx_lower_bound(*this);
int hvx_b = get_hvx_lower_bound(other);
Expand Down Expand Up @@ -1874,6 +1994,11 @@ void target_test() {
{{"hexagon-32-qurt-hvx_v62", "hexagon-32-qurt", "hexagon-32-qurt"}},
{{"hexagon-32-qurt-hvx_v62-hvx", "hexagon-32-qurt", ""}},
{{"hexagon-32-qurt-hvx_v62-hvx", "hexagon-32-qurt-hvx", "hexagon-32-qurt-hvx"}},
{{"x86-64-windows-d3d12compute-d3d12compute_sm66", "x86-64-windows-d3d12compute", "x86-64-windows-d3d12compute"}},
{{"x86-64-windows-d3d12compute-d3d12compute_sm66", "x86-64-windows-d3d12compute-d3d12compute_sm60", "x86-64-windows-d3d12compute-d3d12compute_sm60"}},
{{"x86-64-windows-d3d12compute-d3d12compute_sm62", "x86-64-windows-d3d12compute-d3d12compute_sm62", "x86-64-windows-d3d12compute-d3d12compute_sm62"}},
{{"x86-64-windows-d3d12compute-d3d12compute_sm69", "x86-64-windows-d3d12compute", "x86-64-windows-d3d12compute"}},
{{"x86-64-windows-d3d12compute-d3d12compute_sm69", "x86-64-windows-d3d12compute-d3d12compute_sm60", "x86-64-windows-d3d12compute-d3d12compute_sm60"}},
};

for (const auto &test : gcd_tests) {
Expand Down
15 changes: 15 additions & 0 deletions src/Target.h
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,16 @@ struct Target {
AVX10_1 = halide_target_feature_avx10_1,
X86APX = halide_target_feature_x86_apx,
Simulator = halide_target_feature_simulator,
D3D12ComputeSM60 = halide_target_feature_d3d12compute_sm60,
D3D12ComputeSM61 = halide_target_feature_d3d12compute_sm61,
D3D12ComputeSM62 = halide_target_feature_d3d12compute_sm62,
D3D12ComputeSM63 = halide_target_feature_d3d12compute_sm63,
D3D12ComputeSM64 = halide_target_feature_d3d12compute_sm64,
D3D12ComputeSM65 = halide_target_feature_d3d12compute_sm65,
D3D12ComputeSM66 = halide_target_feature_d3d12compute_sm66,
D3D12ComputeSM67 = halide_target_feature_d3d12compute_sm67,
D3D12ComputeSM68 = halide_target_feature_d3d12compute_sm68,
D3D12ComputeSM69 = halide_target_feature_d3d12compute_sm69,
FeatureEnd = halide_target_feature_end
};
Target() = default;
Expand Down Expand Up @@ -349,6 +359,11 @@ struct Target {
* features are set. */
int get_vulkan_capability_lower_bound() const;

/** Get the minimum D3D12Compute Shader Model version as an integer
* (e.g. 60 for SM 6.0, 62 for SM 6.2). Returns 51 (SM 5.1, FXC path)
* if no SM 6.x features are set, or -1 if D3D12Compute is not enabled. */
int get_d3d12compute_capability_lower_bound() const;

/** Get the minimum ARM v8.x capability found as an integer. Returns
* -1 if no ARM v8.x features are set. */
int get_arm_v8_lower_bound() const;
Expand Down
10 changes: 10 additions & 0 deletions src/runtime/HalideRuntime.h
Original file line number Diff line number Diff line change
Expand Up @@ -1478,6 +1478,16 @@ typedef enum halide_target_feature_t {
halide_target_feature_avx10_1, ///< Intel AVX10 version 1 support. vector_bits is used to indicate width.
halide_target_feature_x86_apx, ///< Intel x86 APX support. Covers initial set of features released as APX: egpr,push2pop2,ppx,ndd .
halide_target_feature_simulator, ///< Target is for a simulator environment. Currently only applies to iOS.
halide_target_feature_d3d12compute_sm60, ///< Enable D3D12 Shader Model 6.0 (DXIL, 64-bit types, wave intrinsics). Requires d3d12compute. Uses DXC compiler.
halide_target_feature_d3d12compute_sm61, ///< Enable D3D12 Shader Model 6.1
halide_target_feature_d3d12compute_sm62, ///< Enable D3D12 Shader Model 6.2 (native 16-bit scalar types with -enable-16bit-types)
halide_target_feature_d3d12compute_sm63, ///< Enable D3D12 Shader Model 6.3
halide_target_feature_d3d12compute_sm64, ///< Enable D3D12 Shader Model 6.4
halide_target_feature_d3d12compute_sm65, ///< Enable D3D12 Shader Model 6.5
halide_target_feature_d3d12compute_sm66, ///< Enable D3D12 Shader Model 6.6 (64-bit atomics, packed 8-bit types)
halide_target_feature_d3d12compute_sm67, ///< Enable D3D12 Shader Model 6.7
halide_target_feature_d3d12compute_sm68, ///< Enable D3D12 Shader Model 6.8
halide_target_feature_d3d12compute_sm69, ///< Enable D3D12 Shader Model 6.9 (long vectors 5-1024 lanes, native 16-bit/wave/int64 required)
halide_target_feature_end ///< A sentinel. Every target is considered to have this feature, and setting this feature does nothing.
} halide_target_feature_t;

Expand Down
3 changes: 3 additions & 0 deletions src/runtime/HalideRuntimeD3D12Compute.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,14 @@
extern "C" {
#endif

#define HALIDE_RUNTIME_D3D12COMPUTE

/** \file
* Routines specific to the Halide Direct3D 12 Compute runtime.
*/

extern const struct halide_device_interface_t *halide_d3d12compute_device_interface();
extern const struct halide_device_interface_t *halide_d3d12compute_image_device_interface();

/** These are forward declared here to allow clients to override the
* Halide Direct3D 12 Compute runtime. Do not call them. */
Expand Down
Loading