Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 22 additions & 30 deletions crates/cust/src/memory/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,10 @@ pub enum ArrayFormat {
I16,
/// Signed 32-bit integer
I32,
/// Half-precision floating point number
/// Half-precision floating point number (f16)
F16,
/// Single-precision floating point number (f32)
F32,
/// Single-precision floating point number
F64,
}

impl ArrayFormat {
Expand All @@ -52,9 +52,8 @@ impl ArrayFormat {

match self {
U8 | I8 => 1,
U16 | I16 => 2,
U16 | I16 | F16 => 2,
U32 | I32 | F32 => 4,
F64 => 8,
}
}
}
Expand All @@ -74,7 +73,6 @@ impl private::Sealed for i8 {}
impl private::Sealed for i16 {}
impl private::Sealed for i32 {}
impl private::Sealed for f32 {}
impl private::Sealed for f64 {}

impl ArrayPrimitive for u8 {
fn array_format() -> ArrayFormat {
Expand Down Expand Up @@ -118,12 +116,6 @@ impl ArrayPrimitive for f32 {
}
}

impl ArrayPrimitive for f64 {
fn array_format() -> ArrayFormat {
ArrayFormat::F64
}
}

impl ArrayFormat {
/// Creates ArrayFormat from the CUDA Driver API enum
pub fn from_raw(raw: CUarray_format) -> Self {
Expand All @@ -134,8 +126,8 @@ impl ArrayFormat {
CUarray_format_enum::CU_AD_FORMAT_SIGNED_INT8 => ArrayFormat::I8,
CUarray_format_enum::CU_AD_FORMAT_SIGNED_INT16 => ArrayFormat::I16,
CUarray_format_enum::CU_AD_FORMAT_SIGNED_INT32 => ArrayFormat::I32,
CUarray_format_enum::CU_AD_FORMAT_HALF => ArrayFormat::F32,
CUarray_format_enum::CU_AD_FORMAT_FLOAT => ArrayFormat::F64,
CUarray_format_enum::CU_AD_FORMAT_HALF => ArrayFormat::F16,
CUarray_format_enum::CU_AD_FORMAT_FLOAT => ArrayFormat::F32,
// there are literally no docs on what nv12 is???
// it seems to be something with multiplanar arrays, needs some investigation
CUarray_format_enum::CU_AD_FORMAT_NV12 => panic!("nv12 is not supported yet"),
Expand All @@ -152,8 +144,8 @@ impl ArrayFormat {
ArrayFormat::I8 => CUarray_format_enum::CU_AD_FORMAT_SIGNED_INT8,
ArrayFormat::I16 => CUarray_format_enum::CU_AD_FORMAT_SIGNED_INT16,
ArrayFormat::I32 => CUarray_format_enum::CU_AD_FORMAT_SIGNED_INT32,
ArrayFormat::F32 => CUarray_format_enum::CU_AD_FORMAT_HALF,
ArrayFormat::F64 => CUarray_format_enum::CU_AD_FORMAT_FLOAT,
ArrayFormat::F16 => CUarray_format_enum::CU_AD_FORMAT_HALF,
ArrayFormat::F32 => CUarray_format_enum::CU_AD_FORMAT_FLOAT,
}
}
}
Expand Down Expand Up @@ -921,11 +913,11 @@ mod test {
fn descriptor_round_trip() {
let _context = crate::quick_init().unwrap();

let obj = ArrayObject::new([1, 2, 3], ArrayFormat::F64, 2).unwrap();
let obj = ArrayObject::new([1, 2, 3], ArrayFormat::F32, 2).unwrap();

let descriptor = obj.descriptor().unwrap();
assert_eq!([1, 2, 3], descriptor.dims());
assert_eq!(ArrayFormat::F64, descriptor.format());
assert_eq!(ArrayFormat::F32, descriptor.format());
assert_eq!(2, descriptor.num_channels());
assert_eq!(ArrayObjectFlags::default(), descriptor.flags());
}
Expand All @@ -934,7 +926,7 @@ mod test {
fn allow_1d_arrays() {
let _context = crate::quick_init().unwrap();

let obj = ArrayObject::new([10, 0, 0], ArrayFormat::F64, 1).unwrap();
let obj = ArrayObject::new([10, 0, 0], ArrayFormat::F32, 1).unwrap();

let descriptor = obj.descriptor().unwrap();
assert_eq!([10, 0, 0], descriptor.dims());
Expand All @@ -944,7 +936,7 @@ mod test {
fn allow_2d_arrays() {
let _context = crate::quick_init().unwrap();

let obj = ArrayObject::new([10, 20, 0], ArrayFormat::F64, 1).unwrap();
let obj = ArrayObject::new([10, 20, 0], ArrayFormat::F32, 1).unwrap();

let descriptor = obj.descriptor().unwrap();
assert_eq!([10, 20, 0], descriptor.dims());
Expand All @@ -954,7 +946,7 @@ mod test {
fn allow_1d_layered_arrays() {
let _context = crate::quick_init().unwrap();

let obj = ArrayObject::new_layered([10, 0], 20, ArrayFormat::F64, 1).unwrap();
let obj = ArrayObject::new_layered([10, 0], 20, ArrayFormat::F32, 1).unwrap();

let descriptor = obj.descriptor().unwrap();
assert_eq!([10, 0, 20], descriptor.dims());
Expand All @@ -965,7 +957,7 @@ mod test {
fn allow_cubemaps() {
let _context = crate::quick_init().unwrap();

let obj = ArrayObject::new_cubemap(4, ArrayFormat::F64, 1).unwrap();
let obj = ArrayObject::new_cubemap(4, ArrayFormat::F32, 1).unwrap();

let descriptor = obj.descriptor().unwrap();
assert_eq!([4, 4, 6], descriptor.dims());
Expand All @@ -976,7 +968,7 @@ mod test {
fn allow_layered_cubemaps() {
let _context = crate::quick_init().unwrap();

let obj = ArrayObject::new_layered_cubemap(4, 4, ArrayFormat::F64, 1).unwrap();
let obj = ArrayObject::new_layered_cubemap(4, 4, ArrayFormat::F32, 1).unwrap();

let descriptor = obj.descriptor().unwrap();
assert_eq!([4, 4, 24], descriptor.dims());
Expand All @@ -991,23 +983,23 @@ mod test {
fn fail_on_zero_width_1d_array() {
let _context = crate::quick_init().unwrap();

let _ = ArrayObject::new_1d(0, ArrayFormat::F64, 1).unwrap();
let _ = ArrayObject::new_1d(0, ArrayFormat::F32, 1).unwrap();
}

#[test]
#[should_panic]
fn fail_on_zero_size_widths() {
let _context = crate::quick_init().unwrap();

let _ = ArrayObject::new([0, 10, 20], ArrayFormat::F64, 1).unwrap();
let _ = ArrayObject::new([0, 10, 20], ArrayFormat::F32, 1).unwrap();
}

#[test]
#[should_panic]
fn fail_cubemaps_with_unmatching_width_height() {
let _context = crate::quick_init().unwrap();

let mut descriptor = ArrayDescriptor::from_dims_format([2, 3, 6], ArrayFormat::F64);
let mut descriptor = ArrayDescriptor::from_dims_format([2, 3, 6], ArrayFormat::F32);
descriptor.set_flags(ArrayObjectFlags::CUBEMAP);

let _ = ArrayObject::from_descriptor(&descriptor).unwrap();
Expand All @@ -1018,7 +1010,7 @@ mod test {
fn fail_cubemaps_with_non_six_depth() {
let _context = crate::quick_init().unwrap();

let mut descriptor = ArrayDescriptor::from_dims_format([4, 4, 5], ArrayFormat::F64);
let mut descriptor = ArrayDescriptor::from_dims_format([4, 4, 5], ArrayFormat::F32);
descriptor.set_flags(ArrayObjectFlags::CUBEMAP);

let _ = ArrayObject::from_descriptor(&descriptor).unwrap();
Expand All @@ -1029,7 +1021,7 @@ mod test {
fn fail_cubemaps_with_non_six_multiple_depth() {
let _context = crate::quick_init().unwrap();

let mut descriptor = ArrayDescriptor::from_dims_format([4, 4, 10], ArrayFormat::F64);
let mut descriptor = ArrayDescriptor::from_dims_format([4, 4, 10], ArrayFormat::F32);
descriptor.set_flags(ArrayObjectFlags::LAYERED | ArrayObjectFlags::CUBEMAP);

let _ = ArrayObject::from_descriptor(&descriptor).unwrap();
Expand All @@ -1040,14 +1032,14 @@ mod test {
fn fail_with_depth_without_height() {
let _context = crate::quick_init().unwrap();

let _ = ArrayObject::new([10, 0, 20], ArrayFormat::F64, 1).unwrap();
let _ = ArrayObject::new([10, 0, 20], ArrayFormat::F32, 1).unwrap();
}

#[test]
#[should_panic]
fn fails_on_invalid_num_channels() {
let _context = crate::quick_init().unwrap();

let _ = ArrayObject::new([1, 2, 3], ArrayFormat::F64, 3).unwrap();
let _ = ArrayObject::new([1, 2, 3], ArrayFormat::F32, 3).unwrap();
}
}
5 changes: 0 additions & 5 deletions crates/cust/src/texture.rs
Original file line number Diff line number Diff line change
Expand Up @@ -238,11 +238,6 @@ impl ResourceViewFormat {
format_impl!(num_channels, I16, I16x1, I16x2, I16x4);
format_impl!(num_channels, I32, I32x1, I32x2, I32x4);
format_impl!(num_channels, F32, F32x1, F32x2, F32x4);
assert_ne!(
format,
ArrayFormat::F64,
"CUDA Does not have 64 bit float textures, you can instead use int textures with 2 channels then cast the ints to a double in the kernel"
);
unreachable!()
}
}
Expand Down