Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 121 additions & 7 deletions python/src/indexing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,13 @@ bool is_none_slice(const nb::slice& in_slice) {

int get_slice_int(nb::object obj, int default_val) {
if (!obj.is_none()) {
if (!nb::isinstance<nb::int_>(obj)) {
// Try to cast to int - this handles Python int, numpy scalars, and other
// int-like types
try {
return nb::cast<int>(obj);
} catch (...) {
throw std::invalid_argument("Slice indices must be integers or None.");
}
return nb::cast<int>(nb::cast<nb::int_>(obj));
}
return default_val;
}
Expand Down Expand Up @@ -50,10 +53,58 @@ mx::array get_int_index(nb::object idx, int axis_size) {
return mx::array(idx_, mx::uint32);
}

// Convert boolean mask to integer indices
// Returns a packed array of indices where mask is True
// Uses a simple sort-based algorithm
std::pair<mx::array, int> boolean_mask_to_indices_and_count(
const mx::array& mask) {
// Flatten the boolean mask if it's multi-dimensional
auto flat_mask = (mask.ndim() > 1) ? flatten(mask) : mask;

auto size = flat_mask.size();

// Count total True values using sum
auto mask_int = astype(flat_mask, mx::int32);
auto num_true_arr = sum(mask_int);
num_true_arr.eval(); // Force evaluation to get the count
int num_true = num_true_arr.item<int>();

if (num_true == 0) {
// Return empty array
return {mx::array({}, mx::uint32), 0};
}

// Create array of all indices [0, 1, 2, ..., size-1]
auto all_indices = arange(0, size, 1, mx::int32);

// Use where to assign indices or large sentinel value, then sort
auto large_value = size; // Use size as sentinel for False positions
auto indexed =
where(flat_mask, all_indices, mx::array(large_value, mx::int32));
auto sorted_result = sort(indexed);

// Slice to get only valid indices (first num_true elements after sorting)
auto result = slice(sorted_result, {0}, {num_true}, {1});

return {astype(result, mx::uint32), num_true};
}

bool is_valid_index_type(const nb::object& obj) {
return nb::isinstance<nb::slice>(obj) || nb::isinstance<nb::int_>(obj) ||
// Fast path: check common types first
if (nb::isinstance<nb::slice>(obj) || nb::isinstance<nb::int_>(obj) ||
nb::isinstance<mx::array>(obj) || obj.is_none() ||
nb::ellipsis().is(obj) || nb::isinstance<nb::list>(obj);
nb::ellipsis().is(obj) || nb::isinstance<nb::list>(obj)) {
return true;
}

// Fallback: try to cast to int (handles numpy scalars and other int-like
// types)
try {
nb::cast<int>(obj);
return true;
} catch (...) {
return false;
}
}

mx::array mlx_get_item_slice(const mx::array& src, const nb::slice& in_slice) {
Expand Down Expand Up @@ -84,8 +135,33 @@ mx::array mlx_get_item_array(const mx::array& src, const mx::array& indices) {
"too many indices for array: array is 0-dimensional");
}

// Handle boolean indexing
if (indices.dtype() == mx::bool_) {
throw std::invalid_argument("boolean indices are not yet supported");
// Boolean indexing: convert boolean mask to integer indices
auto [int_indices, count] = boolean_mask_to_indices_and_count(indices);

if (count == 0) {
// Empty selection - return empty array with appropriate shape
mx::Shape out_shape = {0};
out_shape.insert(
out_shape.end(), src.shape().begin() + 1, src.shape().end());
return zeros(out_shape, src.dtype());
}

// Flatten source if mask is multi-dimensional or doesn't match first dim
if (indices.size() == src.size()) {
// Mask covers entire array - flatten both
auto flat_src = flatten(src);
return take(flat_src, int_indices, 0);
} else if (indices.size() == src.shape(0)) {
// Mask is for first dimension only
return take(src, int_indices, 0);
} else {
throw std::invalid_argument(
"boolean index did not match indexed array; size is " +
std::to_string(indices.size()) + " but corresponding dimension is " +
std::to_string(src.shape(0)));
}
}

// If only one input array is mentioned, we set axis=0 in take
Expand Down Expand Up @@ -442,7 +518,16 @@ mx::array mlx_get_item(const mx::array& src, const nb::object& obj) {
return mlx_get_item_array(
src, array_from_list(nb::cast<nb::list>(obj), {}));
}
throw std::invalid_argument("Cannot index mlx array using the given type.");

// Fallback: try to treat as integer index (handles numpy scalars and other
// int-like types)
try {
// Convert to Python int first to handle numpy scalars
nb::int_ idx = nb::int_(obj);
return mlx_get_item_int(src, idx);
} catch (...) {
throw std::invalid_argument("Cannot index mlx array using the given type.");
}
}

std::tuple<std::vector<mx::array>, mx::array, std::vector<int>>
Expand Down Expand Up @@ -489,6 +574,27 @@ mlx_scatter_args_array(
"too many indices for array: array is 0-dimensional");
}

// Handle boolean indexing for scatter
if (indices.dtype() == mx::bool_) {
auto [int_indices, count] = boolean_mask_to_indices_and_count(indices);

if (count == 0) {
// No elements to update - return empty scatter args
return {{}, src, {}};
}

auto up = squeeze_leading_singletons(update);

// The update shape must broadcast with int_indices.shape + src.shape[1:]
auto up_shape = int_indices.shape();
up_shape.insert(up_shape.end(), src.shape().begin() + 1, src.shape().end());
up = broadcast_to(up, up_shape);
up_shape.insert(up_shape.begin() + int_indices.ndim(), 1);
up = reshape(up, up_shape);

return {{int_indices}, up, {0}};
}

auto up = squeeze_leading_singletons(update);

// The update shape must broadcast with indices.shape + [1] + src.shape[1:]
Expand Down Expand Up @@ -757,7 +863,15 @@ mlx_compute_scatter_args(
src, array_from_list(nb::cast<nb::list>(obj), {}), vals);
}

throw std::invalid_argument("Cannot index mlx array using the given type.");
// Fallback: try to treat as integer index (handles numpy scalars and other
// int-like types)
try {
// Convert to Python int first to handle numpy scalars
nb::int_ idx = nb::int_(obj);
return mlx_scatter_args_int(src, idx, vals);
} catch (...) {
throw std::invalid_argument("Cannot index mlx array using the given type.");
}
}

auto mlx_slice_update(
Expand Down
102 changes: 102 additions & 0 deletions python/tests/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -1105,6 +1105,108 @@ def index_fn(x, ind):
self.assertTrue(mx.array_equal(grad_x, expected))
self.assertTrue(mx.array_equal(grad_ind, mx.zeros(ind.shape)))

def test_numpy_scalar_indexing(self):
"""Test indexing with numpy scalar types"""
# Basic numpy scalar indexing
x = mx.array([1, 2, 3, 4, 5])
result = x[np.int64(1)]
self.assertEqual(result.item(), 2)

# Numpy scalar in slice start
result = x[np.int64(1) :]
self.assertTrue(np.array_equal(np.array(result), np.array([2, 3, 4, 5])))

# Numpy scalar in slice stop
result = x[: np.int64(3)]
self.assertTrue(np.array_equal(np.array(result), np.array([1, 2, 3])))

# Other numpy scalar types
result = x[np.int32(2)]
self.assertEqual(result.item(), 3)

# Negative numpy scalar indexing
result = x[np.int64(-1)]
self.assertEqual(result.item(), 5)

# Numpy scalar assignment
x_copy = mx.array([1, 2, 3, 4, 5])
x_copy[np.int64(2)] = 99
self.assertTrue(np.array_equal(np.array(x_copy), np.array([1, 2, 99, 4, 5])))

# Numpy scalar in both slice start and stop
x = mx.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
result = x[np.int64(2) : np.int64(8)]
expected = np.arange(1, 11)[np.int64(2) : np.int64(8)]
self.assertTrue(np.array_equal(np.array(result), expected))

# Test with 2D array
x_2d = mx.array([[1, 2], [3, 4], [5, 6]])
result = x_2d[np.int32(1)]
self.assertTrue(np.array_equal(np.array(result), np.array([3, 4])))

def test_boolean_mask_indexing(self):
"""Test boolean mask indexing"""
# Basic boolean indexing
x = mx.array([1, 2, 3, 4, 5])
mask = x > 2
result = x[mask]
self.assertTrue(np.array_equal(np.array(result), np.array([3, 4, 5])))

# Boolean indexing with all True
x = mx.array([1, 2, 3])
mask = mx.array([True, True, True])
result = x[mask]
self.assertTrue(np.array_equal(np.array(result), np.array([1, 2, 3])))

# Boolean indexing with all False
x = mx.array([1, 2, 3])
mask = mx.array([False, False, False])
result = x[mask]
self.assertEqual(result.size, 0)

# Boolean indexing with alternating pattern
x = mx.array([10, 20, 30, 40, 50])
mask = mx.array([True, False, True, False, True])
result = x[mask]
self.assertTrue(np.array_equal(np.array(result), np.array([10, 30, 50])))

# Boolean assignment
x = mx.array([1, 2, 3, 4, 5])
mask = x > 2
x[mask] = 99
self.assertTrue(np.array_equal(np.array(x), np.array([1, 2, 99, 99, 99])))

# Boolean indexing with 2D array (flatten behavior)
x = mx.array([[1, 2], [3, 4], [5, 6]])
mask = x > 3
result = x[mask]
expected = np.array([4, 5, 6])
self.assertTrue(np.array_equal(np.array(result), expected))

# Boolean indexing with negative values
x = mx.array([-3, -1, 0, 2, 4])
mask = x < 0
result = x[mask]
self.assertTrue(np.array_equal(np.array(result), np.array([-3, -1])))

# Complex boolean condition
x = mx.array([0, 1, 2, 3, 4, 5, 6])
mask = (x > 1) & (x < 5)
result = x[mask]
self.assertTrue(np.array_equal(np.array(result), np.array([2, 3, 4])))

# Empty result from boolean indexing
x = mx.array([1, 2, 3])
mask = x > 10
result = x[mask]
self.assertEqual(result.size, 0)

# Single element from boolean indexing
x = mx.array([1, 2, 3, 4, 5])
mask = x == 3
result = x[mask]
self.assertTrue(np.array_equal(np.array(result), np.array([3])))

def test_setitem(self):
a = mx.array(0)
a[None] = 1
Expand Down