Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion crates/rustc_codegen_spirv/src/codegen_cx/entry.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use crate::builder_spirv::{SpirvFunctionCursor, SpirvValue, SpirvValueExt};
use crate::spirv_type::SpirvType;
use rspirv::dr::Operand;
use rspirv::spirv::{
Capability, Decoration, Dim, ExecutionModel, FunctionControl, StorageClass, Word,
BuiltIn, Capability, Decoration, Dim, ExecutionModel, FunctionControl, StorageClass, Word,
};
use rustc_abi::FieldsShape;
use rustc_codegen_ssa::traits::{BaseTypeCodegenMethods, BuilderMethods, MiscCodegenMethods as _};
Expand Down Expand Up @@ -916,6 +916,11 @@ impl<'tcx> CodegenCx<'tcx> {
);
}

// Check builtin-specific type requirements.
if let Some(builtin) = attrs.builtin {
self.check_builtin_type(hir_param.ty_span, value_layout.ty, builtin);
}

if let Ok(storage_class) = storage_class {
self.check_for_bad_types(
execution_model,
Expand Down Expand Up @@ -1083,4 +1088,15 @@ impl<'tcx> CodegenCx<'tcx> {
}
}
}

/// Check that builtin variables have the correct type.
fn check_builtin_type(&self, span: Span, rust_ty: Ty<'tcx>, builtin: Spanned<BuiltIn>) {
// LocalInvocationIndex must be a u32.
if builtin.value == BuiltIn::LocalInvocationIndex && rust_ty != self.tcx.types.u32 {
self.tcx.dcx().span_err(
span,
format!("`#[spirv(local_invocation_index)]` must be a `u32`, not `{rust_ty}`"),
);
}
}
}
4 changes: 2 additions & 2 deletions tests/compiletests/ui/arch/shared/dce_shared.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@ pub fn main(
#[spirv(descriptor_set = 0, binding = 1, storage_buffer)] output: &mut f32,
#[spirv(workgroup)] used_shared: &mut f32,
#[spirv(workgroup)] dce_shared: &mut [i32; 2],
#[spirv(local_invocation_index)] inv_id: UVec3,
#[spirv(local_invocation_index)] inv_id: u32,
) {
unsafe {
let inv_id = inv_id.x as usize;
let inv_id = inv_id as usize;
if inv_id == 0 {
*used_shared = *input;
}
Expand Down
19 changes: 9 additions & 10 deletions tests/compiletests/ui/arch/shared/dce_shared.stderr
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,15 @@ OpDecorate %4 BuiltIn LocalInvocationIndex
%11 = OpTypePointer Workgroup %9
%12 = OpTypeInt 32 0
%13 = OpConstant %12 2
%14 = OpTypeVector %12 3
%15 = OpTypePointer Input %14
%16 = OpTypeVoid
%17 = OpTypeFunction %16
%18 = OpTypePointer StorageBuffer %9
%14 = OpTypePointer Input %12
%15 = OpTypeVoid
%16 = OpTypeFunction %15
%17 = OpTypePointer StorageBuffer %9
%2 = OpVariable %10 StorageBuffer
%19 = OpConstant %12 0
%18 = OpConstant %12 0
%3 = OpVariable %10 StorageBuffer
%4 = OpVariable %15 Input
%20 = OpTypeBool
%4 = OpVariable %14 Input
%19 = OpTypeBool
%5 = OpVariable %11 Workgroup
%21 = OpConstant %12 264
%22 = OpConstant %12 1
%20 = OpConstant %12 264
%21 = OpConstant %12 1
4 changes: 2 additions & 2 deletions tests/compiletests/ui/arch/shared/reduction_array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,10 @@ pub fn main(
#[spirv(descriptor_set = 0, binding = 0, storage_buffer)] input: &[Value],
#[spirv(descriptor_set = 0, binding = 1, storage_buffer)] output: &mut Value,
#[spirv(workgroup)] shared: &mut [Value; WG_SIZE],
#[spirv(local_invocation_index)] inv_id: UVec3,
#[spirv(local_invocation_index)] inv_id: u32,
) {
unsafe {
let inv_id = inv_id.x as usize;
let inv_id = inv_id as usize;
shared[inv_id] = input[inv_id];
workgroup_memory_barrier_with_group_sync();

Expand Down
4 changes: 2 additions & 2 deletions tests/compiletests/ui/arch/shared/reduction_big_struct.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,10 @@ pub fn main(
#[spirv(descriptor_set = 0, binding = 0, storage_buffer)] input: &[Value],
#[spirv(descriptor_set = 0, binding = 1, storage_buffer)] output: &mut Value,
#[spirv(workgroup)] shared: &mut [Value; WG_SIZE],
#[spirv(local_invocation_index)] inv_id: UVec3,
#[spirv(local_invocation_index)] inv_id: u32,
) {
unsafe {
let inv_id = inv_id.x as usize;
let inv_id = inv_id as usize;
shared[inv_id] = input[inv_id];
workgroup_memory_barrier_with_group_sync();

Expand Down
4 changes: 2 additions & 2 deletions tests/compiletests/ui/arch/shared/reduction_u32.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@ pub fn main(
#[spirv(descriptor_set = 0, binding = 0, storage_buffer)] input: &[Value],
#[spirv(descriptor_set = 0, binding = 1, storage_buffer)] output: &mut Value,
#[spirv(workgroup)] shared: &mut [Value; WG_SIZE],
#[spirv(local_invocation_index)] inv_id: UVec3,
#[spirv(local_invocation_index)] inv_id: u32,
) {
unsafe {
let inv_id = inv_id.x as usize;
let inv_id = inv_id as usize;
shared[inv_id] = input[inv_id];
workgroup_memory_barrier_with_group_sync();

Expand Down
4 changes: 2 additions & 2 deletions tests/compiletests/ui/arch/shared/reduction_vec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@ pub fn main(
#[spirv(descriptor_set = 0, binding = 0, storage_buffer)] input: &[Value],
#[spirv(descriptor_set = 0, binding = 1, storage_buffer)] output: &mut Value,
#[spirv(workgroup)] shared: &mut [Value; WG_SIZE],
#[spirv(local_invocation_index)] inv_id: UVec3,
#[spirv(local_invocation_index)] inv_id: u32,
) {
unsafe {
let inv_id = inv_id.x as usize;
let inv_id = inv_id as usize;
shared[inv_id] = input[inv_id];
workgroup_memory_barrier_with_group_sync();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ pub struct Zst;

#[spirv(compute(threads(32)))]
pub fn main(
#[spirv(local_invocation_index)] inv_id: UVec3,
#[spirv(local_invocation_id)] inv_id: UVec3,
#[spirv(descriptor_set = 0, binding = 0, storage_buffer)] output: &mut UVec3,
) {
unsafe {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,15 @@ fn disassembly(my_struct: MyStruct) -> bool {

#[spirv(compute(threads(32)))]
pub fn main(
#[spirv(local_invocation_index)] inv_id: UVec3,
#[spirv(local_invocation_index)] inv_id: u32,
#[spirv(local_invocation_id)] inv_id_3d: UVec3,
#[spirv(descriptor_set = 0, binding = 0, storage_buffer)] output: &mut u32,
) {
unsafe {
let my_struct = MyStruct {
a: inv_id.x as f32,
b: inv_id,
c: Nested(5i32 - inv_id.x as i32),
a: inv_id as f32,
b: inv_id_3d,
c: Nested(5i32 - inv_id as i32),
d: Zst,
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,11 @@ fn disassembly(my_struct: MyEnum, id: u32) -> MyEnum {

#[spirv(compute(threads(32)))]
pub fn main(
#[spirv(local_invocation_index)] inv_id: UVec3,
#[spirv(local_invocation_index)] inv_id: u32,
#[spirv(descriptor_set = 0, binding = 0, storage_buffer)] output: &mut MyEnum,
) {
unsafe {
let my_enum = MyEnum::from(inv_id.x % 3);
let my_enum = MyEnum::from(inv_id % 3);
*output = disassembly(my_enum, 5);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,15 @@ fn disassembly(my_struct: MyStruct, id: u32) -> MyStruct {

#[spirv(compute(threads(32)))]
pub fn main(
#[spirv(local_invocation_index)] inv_id: UVec3,
#[spirv(local_invocation_index)] inv_id: u32,
#[spirv(local_invocation_id)] inv_id_3d: UVec3,
#[spirv(descriptor_set = 0, binding = 0, storage_buffer)] output: &mut MyStruct,
) {
unsafe {
let my_struct = MyStruct {
a: inv_id.x as f32,
b: inv_id,
c: Nested(5i32 - inv_id.x as i32),
a: inv_id as f32,
b: inv_id_3d,
c: Nested(5i32 - inv_id as i32),
d: Zst,
};

Expand Down
2 changes: 1 addition & 1 deletion tests/compiletests/ui/spirv-attr/all-builtins.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ pub fn vertex(
#[spirv(frag_stencil_ref_ext)] frag_stencil_ref_ext: &mut u32,
#[spirv(instance_index)] instance_index: u32,
#[spirv(layer_per_view_nv)] layer_per_view_nv: u32,
#[spirv(local_invocation_index)] local_invocation_index: UVec3,
#[spirv(local_invocation_index)] local_invocation_index: u32,
#[spirv(mesh_view_count_nv)] mesh_view_count_nv: u32,
#[spirv(mesh_view_indices_nv)] mesh_view_indices_nv: u32,
#[spirv(point_size)] point_size: &mut u32,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
// build-fail

use spirv_std::glam::UVec3;
use spirv_std::spirv;

#[spirv(compute(threads(1)))]
pub fn main(#[spirv(local_invocation_index)] index: UVec3) {}
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
error: `#[spirv(local_invocation_index)]` must be a `u32`, not `spirv_std::glam::UVec3`
--> $DIR/local-invocation-index-type.rs:7:53
|
LL | pub fn main(#[spirv(local_invocation_index)] index: UVec3) {}
| ^^^^^

error: aborting due to 1 previous error

Loading