Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions src/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ unsafe impl<T: Element, D: Dimension> PyTypeInfo for PyArray<T, D> {
const MODULE: Option<&'static str> = Some("numpy");

fn type_object_raw<'py>(py: Python<'py>) -> *mut ffi::PyTypeObject {
unsafe { npyffi::PY_ARRAY_API.get_type_object(py, npyffi::NpyTypes::PyArray_Type) }
unsafe { npyffi::get_type_object(py, npyffi::NpyTypes::PyArray_Type) }
}

fn is_type_of(ob: &Bound<'_, PyAny>) -> bool {
Expand Down Expand Up @@ -233,7 +233,7 @@ impl<T: Element, D: Dimension> PyArray<T, D> {
let mut dims = dims.into_dimension();
let ptr = PY_ARRAY_API.PyArray_NewFromDescr(
py,
PY_ARRAY_API.get_type_object(py, npyffi::NpyTypes::PyArray_Type),
npyffi::get_type_object(py, npyffi::NpyTypes::PyArray_Type),
T::get_dtype(py).into_dtype_ptr(),
dims.ndim_cint(),
dims.as_dims_ptr(),
Expand All @@ -259,7 +259,7 @@ impl<T: Element, D: Dimension> PyArray<T, D> {
let mut dims = dims.into_dimension();
let ptr = PY_ARRAY_API.PyArray_NewFromDescr(
py,
PY_ARRAY_API.get_type_object(py, npyffi::NpyTypes::PyArray_Type),
npyffi::get_type_object(py, npyffi::NpyTypes::PyArray_Type),
T::get_dtype(py).into_dtype_ptr(),
dims.ndim_cint(),
dims.as_dims_ptr(),
Expand Down
4 changes: 2 additions & 2 deletions src/dtype.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ use pyo3::{
};

use crate::npyffi::{
NpyTypes, PyArray_Descr, PyDataType_ALIGNMENT, PyDataType_ELSIZE, PyDataType_FIELDS,
self, NpyTypes, PyArray_Descr, PyDataType_ALIGNMENT, PyDataType_ELSIZE, PyDataType_FIELDS,
PyDataType_FLAGS, PyDataType_NAMES, PyDataType_SUBARRAY, NPY_ALIGNED_STRUCT,
NPY_BYTEORDER_CHAR, NPY_ITEM_HASOBJECT, NPY_TYPES, PY_ARRAY_API,
};
Expand Down Expand Up @@ -58,7 +58,7 @@ unsafe impl PyTypeInfo for PyArrayDescr {

#[inline]
fn type_object_raw<'py>(py: Python<'py>) -> *mut ffi::PyTypeObject {
unsafe { PY_ARRAY_API.get_type_object(py, NpyTypes::PyArrayDescr_Type) }
unsafe { npyffi::get_type_object(py, NpyTypes::PyArrayDescr_Type) }
}
}

Expand Down
256 changes: 112 additions & 144 deletions src/npyffi/array.rs

Large diffs are not rendered by default.

34 changes: 17 additions & 17 deletions src/npyffi/flags.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use super::npy_uint32;
use super::{npy_uint32, npy_uint64};
use std::os::raw::c_int;

pub const NPY_ARRAY_C_CONTIGUOUS: c_int = 0x0001;
Expand All @@ -11,8 +11,8 @@ pub const NPY_ARRAY_ELEMENTSTRIDES: c_int = 0x0080;
pub const NPY_ARRAY_ALIGNED: c_int = 0x0100;
pub const NPY_ARRAY_NOTSWAPPED: c_int = 0x0200;
pub const NPY_ARRAY_WRITEABLE: c_int = 0x0400;
pub const NPY_ARRAY_UPDATEIFCOPY: c_int = 0x1000;
pub const NPY_ARRAY_WRITEBACKIFCOPY: c_int = 0x2000;
pub const NPY_ARRAY_ENSURENOCOPY: c_int = 0x4000;
pub const NPY_ARRAY_BEHAVED: c_int = NPY_ARRAY_ALIGNED | NPY_ARRAY_WRITEABLE;
pub const NPY_ARRAY_BEHAVED_NS: c_int = NPY_ARRAY_BEHAVED | NPY_ARRAY_NOTSWAPPED;
pub const NPY_ARRAY_CARRAY: c_int = NPY_ARRAY_C_CONTIGUOUS | NPY_ARRAY_BEHAVED;
Expand All @@ -22,13 +22,14 @@ pub const NPY_ARRAY_FARRAY_RO: c_int = NPY_ARRAY_F_CONTIGUOUS | NPY_ARRAY_ALIGNE
pub const NPY_ARRAY_DEFAULT: c_int = NPY_ARRAY_CARRAY;
pub const NPY_ARRAY_IN_ARRAY: c_int = NPY_ARRAY_CARRAY_RO;
pub const NPY_ARRAY_OUT_ARRAY: c_int = NPY_ARRAY_CARRAY;
pub const NPY_ARRAY_INOUT_ARRAY: c_int = NPY_ARRAY_CARRAY | NPY_ARRAY_UPDATEIFCOPY;
pub const NPY_ARRAY_INOUT_ARRAY: c_int = NPY_ARRAY_CARRAY;
pub const NPY_ARRAY_INOUT_ARRAY2: c_int = NPY_ARRAY_CARRAY | NPY_ARRAY_WRITEBACKIFCOPY;
pub const NPY_ARRAY_IN_FARRAY: c_int = NPY_ARRAY_FARRAY_RO;
pub const NPY_ARRAY_OUT_FARRAY: c_int = NPY_ARRAY_FARRAY;
pub const NPY_ARRAY_INOUT_FARRAY: c_int = NPY_ARRAY_FARRAY | NPY_ARRAY_UPDATEIFCOPY;
pub const NPY_ARRAY_INOUT_FARRAY: c_int = NPY_ARRAY_FARRAY;
pub const NPY_ARRAY_INOUT_FARRAY2: c_int = NPY_ARRAY_FARRAY | NPY_ARRAY_WRITEBACKIFCOPY;
pub const NPY_ARRAY_UPDATE_ALL: c_int = NPY_ARRAY_C_CONTIGUOUS | NPY_ARRAY_F_CONTIGUOUS;
pub const NPY_ARRAY_UPDATE_ALL: c_int =
NPY_ARRAY_C_CONTIGUOUS | NPY_ARRAY_F_CONTIGUOUS | NPY_ARRAY_ALIGNED;

pub const NPY_ITER_C_INDEX: npy_uint32 = 0x00000001;
pub const NPY_ITER_F_INDEX: npy_uint32 = 0x00000002;
Expand Down Expand Up @@ -63,19 +64,18 @@ pub const NPY_ITER_OVERLAP_ASSUME_ELEMENTWISE: npy_uint32 = 0x40000000;
pub const NPY_ITER_GLOBAL_FLAGS: npy_uint32 = 0x0000ffff;
pub const NPY_ITER_PER_OP_FLAGS: npy_uint32 = 0xffff0000;

pub const NPY_ITEM_REFCOUNT: u64 = 0x01;
pub const NPY_ITEM_HASOBJECT: u64 = 0x01;
pub const NPY_LIST_PICKLE: u64 = 0x02;
pub const NPY_ITEM_IS_POINTER: u64 = 0x04;
pub const NPY_NEEDS_INIT: u64 = 0x08;
pub const NPY_NEEDS_PYAPI: u64 = 0x10;
pub const NPY_USE_GETITEM: u64 = 0x20;
pub const NPY_USE_SETITEM: u64 = 0x40;
#[allow(overflowing_literals)]
pub const NPY_ALIGNED_STRUCT: u64 = 0x80;
pub const NPY_FROM_FIELDS: u64 =
pub const NPY_ITEM_REFCOUNT: npy_uint64 = 0x01;
pub const NPY_ITEM_HASOBJECT: npy_uint64 = 0x01;
pub const NPY_LIST_PICKLE: npy_uint64 = 0x02;
pub const NPY_ITEM_IS_POINTER: npy_uint64 = 0x04;
pub const NPY_NEEDS_INIT: npy_uint64 = 0x08;
pub const NPY_NEEDS_PYAPI: npy_uint64 = 0x10;
pub const NPY_USE_GETITEM: npy_uint64 = 0x20;
pub const NPY_USE_SETITEM: npy_uint64 = 0x40;
pub const NPY_ALIGNED_STRUCT: npy_uint64 = 0x80;
pub const NPY_FROM_FIELDS: npy_uint64 =
NPY_NEEDS_INIT | NPY_LIST_PICKLE | NPY_ITEM_REFCOUNT | NPY_NEEDS_PYAPI;
pub const NPY_OBJECT_DTYPE_FLAGS: u64 = NPY_LIST_PICKLE
pub const NPY_OBJECT_DTYPE_FLAGS: npy_uint64 = NPY_LIST_PICKLE
| NPY_USE_GETITEM
| NPY_ITEM_IS_POINTER
| NPY_ITEM_REFCOUNT
Expand Down
114 changes: 74 additions & 40 deletions src/npyffi/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,44 +11,40 @@

use std::mem::forget;
use std::os::raw::{c_uint, c_void};
use std::ptr::NonNull;

use pyo3::{
ffi::PyTypeObject,
sync::PyOnceLock,
types::{PyAnyMethods, PyCapsule, PyCapsuleMethods, PyModule},
PyResult, Python,
};

pub const API_VERSION_2_0: c_uint = 0x00000012;

static API_VERSION: PyOnceLock<c_uint> = PyOnceLock::new();

fn get_numpy_api<'py>(
py: Python<'py>,
module: &str,
capsule: &str,
) -> PyResult<*const *const c_void> {
) -> PyResult<NonNull<*const c_void>> {
let module = PyModule::import(py, module)?;
let capsule = module.getattr(capsule)?.cast_into::<PyCapsule>()?;

let api = capsule
.pointer_checked(None)?
.cast::<*const c_void>()
.as_ptr()
.cast_const();
let api = capsule.pointer_checked(None)?;

// Intentionally leak a reference to the capsule
// so we can safely cache a pointer into its interior.
forget(capsule);

Ok(api)
Ok(api.cast())
}

/// Returns whether the runtime `numpy` version is 2.0 or greater.
pub fn is_numpy_2<'py>(py: Python<'py>) -> bool {
let api_version = *API_VERSION.get_or_init(py, || unsafe {
PY_ARRAY_API.PyArray_GetNDArrayCFeatureVersion(py)
});
api_version >= API_VERSION_2_0
api_version >= NPY_2_0_API_VERSION
}

// Implements wrappers for NumPy's Array and UFunc API
Expand All @@ -57,52 +53,90 @@ macro_rules! impl_api {
[$offset: expr; $fname: ident ($($arg: ident: $t: ty),* $(,)?) $(-> $ret: ty)?] => {
#[allow(non_snake_case)]
pub unsafe fn $fname<'py>(&self, py: Python<'py>, $($arg : $t), *) $(-> $ret)* {
let fptr = self.get(py, $offset) as *const extern "C" fn ($($arg : $t), *) $(-> $ret)*;
(*fptr)($($arg), *)
let f: extern "C" fn ($($arg : $t), *) $(-> $ret)* = self.get(py, $offset).cast().read();
f($($arg), *)
}
};
}

// API with version constraints, checked at runtime
[$offset: expr; NumPy1; $fname: ident ($($arg: ident: $t: ty),* $(,)?) $(-> $ret: ty)?] => {
#[allow(non_snake_case)]
pub unsafe fn $fname<'py>(&self, py: Python<'py>, $($arg : $t), *) $(-> $ret)* {
assert!(
!is_numpy_2(py),
"{} requires API < {:08X} (NumPy 1) but the runtime version is API {:08X}",
stringify!($fname),
API_VERSION_2_0,
*API_VERSION.get(py).expect("API_VERSION is initialized"),
);
let fptr = self.get(py, $offset) as *const extern "C" fn ($($arg: $t), *) $(-> $ret)*;
(*fptr)($($arg), *)
}
// Define type objects associated with the NumPy API
macro_rules! impl_array_type {
($(($api:ident [ $offset:expr ] , $tname:ident)),* $(,)?) => {
/// All type objects exported by the NumPy API.
#[allow(non_camel_case_types)]
pub enum NpyTypes { $($tname),* }

};
[$offset: expr; NumPy2; $fname: ident ($($arg: ident: $t: ty),* $(,)?) $(-> $ret: ty)?] => {
#[allow(non_snake_case)]
pub unsafe fn $fname<'py>(&self, py: Python<'py>, $($arg : $t), *) $(-> $ret)* {
assert!(
is_numpy_2(py),
"{} requires API {:08X} or greater (NumPy 2) but the runtime version is API {:08X}",
stringify!($fname),
API_VERSION_2_0,
*API_VERSION.get(py).expect("API_VERSION is initialized"),
);
let fptr = self.get(py, $offset) as *const extern "C" fn ($($arg: $t), *) $(-> $ret)*;
(*fptr)($($arg), *)
/// Get a pointer of the type object associated with `ty`.
pub unsafe fn get_type_object<'py>(py: Python<'py>, ty: NpyTypes) -> *mut PyTypeObject {
match ty {
$( NpyTypes::$tname => $api.get(py, $offset).read() as _ ),*
}
}
}
}

};
impl_array_type! {
// Multiarray API
// Slot 1 was never meaningfully used by NumPy
(PY_ARRAY_API[2], PyArray_Type),
(PY_ARRAY_API[3], PyArrayDescr_Type),
// Unused slot 4, was `PyArrayFlags_Type`
(PY_ARRAY_API[5], PyArrayIter_Type),
(PY_ARRAY_API[6], PyArrayMultiIter_Type),
// (PY_ARRAY_API[7], NPY_NUMUSERTYPES) -> c_int,
(PY_ARRAY_API[8], PyBoolArrType_Type),
// (PY_ARRAY_API[9], _PyArrayScalar_BoolValues) -> *mut PyBoolScalarObject,
(PY_ARRAY_API[10], PyGenericArrType_Type),
(PY_ARRAY_API[11], PyNumberArrType_Type),
(PY_ARRAY_API[12], PyIntegerArrType_Type),
(PY_ARRAY_API[13], PySignedIntegerArrType_Type),
(PY_ARRAY_API[14], PyUnsignedIntegerArrType_Type),
(PY_ARRAY_API[15], PyInexactArrType_Type),
(PY_ARRAY_API[16], PyFloatingArrType_Type),
(PY_ARRAY_API[17], PyComplexFloatingArrType_Type),
(PY_ARRAY_API[18], PyFlexibleArrType_Type),
(PY_ARRAY_API[19], PyCharacterArrType_Type),
(PY_ARRAY_API[20], PyByteArrType_Type),
(PY_ARRAY_API[21], PyShortArrType_Type),
(PY_ARRAY_API[22], PyIntArrType_Type),
(PY_ARRAY_API[23], PyLongArrType_Type),
(PY_ARRAY_API[24], PyLongLongArrType_Type),
(PY_ARRAY_API[25], PyUByteArrType_Type),
(PY_ARRAY_API[26], PyUShortArrType_Type),
(PY_ARRAY_API[27], PyUIntArrType_Type),
(PY_ARRAY_API[28], PyULongArrType_Type),
(PY_ARRAY_API[29], PyULongLongArrType_Type),
(PY_ARRAY_API[30], PyFloatArrType_Type),
(PY_ARRAY_API[31], PyDoubleArrType_Type),
(PY_ARRAY_API[32], PyLongDoubleArrType_Type),
(PY_ARRAY_API[33], PyCFloatArrType_Type),
(PY_ARRAY_API[34], PyCDoubleArrType_Type),
(PY_ARRAY_API[35], PyCLongDoubleArrType_Type),
(PY_ARRAY_API[36], PyObjectArrType_Type),
(PY_ARRAY_API[37], PyStringArrType_Type),
(PY_ARRAY_API[38], PyUnicodeArrType_Type),
(PY_ARRAY_API[39], PyVoidArrType_Type),
(PY_ARRAY_API[214], PyTimeIntegerArrType_Type),
(PY_ARRAY_API[215], PyDatetimeArrType_Type),
(PY_ARRAY_API[216], PyTimedeltaArrType_Type),
(PY_ARRAY_API[217], PyHalfArrType_Type),
(PY_ARRAY_API[218], NpyIter_Type),
// UFunc API
(PY_UFUNC_API[0], PyUFunc_Type),
}

pub mod array;
pub mod flags;
mod npy_common;
mod numpyconfig;
pub mod objects;
pub mod types;
pub mod ufunc;

pub use self::array::*;
pub use self::flags::*;
pub use self::npy_common::*;
pub use self::numpyconfig::*;
pub use self::objects::*;
pub use self::types::*;
pub use self::ufunc::*;
8 changes: 8 additions & 0 deletions src/npyffi/npy_common.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
use std::ffi::c_int;

/// Unknown CPU endianness.
pub const NPY_CPU_UNKNOWN_ENDIAN: c_int = 0;
/// CPU is little-endian.
pub const NPY_CPU_LITTLE: c_int = 1;
/// CPU is big-endian.
pub const NPY_CPU_BIG: c_int = 2;
18 changes: 18 additions & 0 deletions src/npyffi/numpyconfig.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
// This file matches the numpyconfig.h header.

use std::ffi::c_uint;

/// The current target ABI version
const NPY_ABI_VERSION: c_uint = 0x02000000;

/// The current target API version (v1.15)
const NPY_API_VERSION: c_uint = 0x0000000c;

pub(super) const NPY_2_0_API_VERSION: c_uint = 0x00000012;

/// The current version of the `ndarray` object (ABI version).
pub const NPY_VERSION: c_uint = NPY_ABI_VERSION;
/// The current version of C API.
pub const NPY_FEATURE_VERSION: c_uint = NPY_API_VERSION;
/// The string representation of current version C API.
pub const NPY_FEATURE_VERSION_STRING: &str = "1.15";
Loading
Loading