Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added

* Added support for buffer protocol objects as advanced index keys in `dpnp.ndarray` [#2889](https://github.com/IntelPython/dpnp/pull/2889)

### Changed

### Deprecated
Expand Down
33 changes: 19 additions & 14 deletions dpnp/dpnp_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,26 +52,31 @@ def _unwrap_index_element(x):
"""
Unwrap a single index element for the tensor indexing layer.

Converts dpnp arrays to usm_ndarray and array-like objects (range, list)
to numpy arrays with intp dtype for NumPy-compatible advanced indexing.
Converts dpnp arrays to usm_ndarray and array-like objects (range, list,
buffer protocol objects) to numpy arrays for NumPy-compatible advanced
indexing. Scalars and slices pass through to the tensor layer.

"""

if isinstance(x, dpt.usm_ndarray):
if (
x is None
or x is Ellipsis
or isinstance(x, (dpt.usm_ndarray, slice, numpy.ndarray))
):
return x
if isinstance(x, dpnp_array):
return x.get_array()
if isinstance(x, range):
return numpy.asarray(x, dtype=numpy.intp)
if isinstance(x, list):
# keep boolean lists as boolean
arr = numpy.asarray(x)
# cast empty lists (float64 in NumPy) to intp
# for correct tensor indexing
if arr.size == 0:
arr = arr.astype(numpy.intp)
return arr
return x
# scalars (int, bool, numpy scalars) pass through to the tensor layer
if isinstance(x, (int, numpy.generic)):
return x

# convert array-like objects (range, list, buffer protocol) to numpy
arr = numpy.asarray(x)
# cast empty arrays (float64 in NumPy) to intp
# for correct tensor indexing
if arr.size == 0 and arr.dtype.kind == "f":
arr = arr.astype(numpy.intp)
return arr


def _get_unwrapped_index_key(key):
Expand Down
60 changes: 60 additions & 0 deletions dpnp/tests/test_indexing.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import array
import functools

import dpctl
Expand Down Expand Up @@ -406,6 +407,65 @@ def test_array_like_single_index(self, idx):
dp_a = dpnp.arange(24).reshape(2, 3, 4)
assert_array_equal(dp_a[idx], np_a[idx])

def test_buffer_protocol_getitem(self):
inds = array.array("l")
inds.frombytes(numpy.arange(3).tobytes())
np_a = numpy.arange(12).reshape(3, 4)
dp_a = dpnp.arange(12).reshape(3, 4)
assert_array_equal(dp_a[inds], np_a[inds])

def test_buffer_protocol_paired_index(self):
inds = array.array("l")
inds.frombytes(numpy.arange(3).tobytes())
np_a = numpy.arange(12).reshape(3, 4)
dp_a = dpnp.arange(12).reshape(3, 4)
assert_array_equal(dp_a[inds, inds], np_a[inds, inds])

def test_buffer_protocol_setitem(self):
inds = array.array("l")
inds.frombytes(numpy.arange(3).tobytes())
np_a = numpy.arange(12).reshape(3, 4)
dp_a = dpnp.arange(12).reshape(3, 4)
np_a[inds, inds] = 0
dp_a[inds, inds] = 0
assert_array_equal(dp_a, np_a)

def test_memoryview_getitem(self):
inds = memoryview(array.array("l", [0, 1, 2]))
np_a = numpy.arange(12).reshape(3, 4)
dp_a = dpnp.arange(12).reshape(3, 4)
assert_array_equal(dp_a[inds], np_a[inds])

def test_bytearray_getitem(self):
inds = bytearray(b"\x00\x01\x02")
np_a = numpy.arange(10)
dp_a = dpnp.arange(10)
assert_array_equal(dp_a[inds], np_a[inds])

@pytest.mark.parametrize(
"idx",
[
1.0,
1 + 0j,
numpy.float64(1.0),
numpy.complex128(1.0),
"a",
[0.5, 1.5],
],
ids=[
"float",
"complex",
"np.float64",
"np.complex128",
"str",
"float_list",
],
)
def test_invalid_index(self, idx):
dp_a = dpnp.arange(12).reshape(3, 4)
with pytest.raises((IndexError, TypeError)):
dp_a[idx]


class TestIx:
@pytest.mark.parametrize(
Expand Down
Loading