Skip to content

Commit ec63444

Browse files
authored
BUG: Allow is_bool_indexer to recognize NumpyExtensionArray (pandas-dev#63402)
1 parent 080b6b5 commit ec63444

File tree

3 files changed

+27
-1
lines changed

3 files changed

+27
-1
lines changed

pandas/core/common.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
ABCExtensionArray,
4646
ABCIndex,
4747
ABCMultiIndex,
48+
ABCNumpyExtensionArray,
4849
ABCSeries,
4950
)
5051
from pandas.core.dtypes.inference import iterable_not_string
@@ -128,7 +129,8 @@ def is_bool_indexer(key: Any) -> bool:
128129
and convert to an ndarray.
129130
"""
130131
if isinstance(
131-
key, (ABCSeries, np.ndarray, ABCIndex, ABCExtensionArray)
132+
key,
133+
(ABCSeries, np.ndarray, ABCIndex, ABCExtensionArray, ABCNumpyExtensionArray),
132134
) and not isinstance(key, ABCMultiIndex):
133135
if key.dtype == np.object_:
134136
key_array = np.asarray(key)

pandas/tests/indexes/ranges/test_range.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -710,6 +710,24 @@ def test_take_return_rangeindex():
710710
tm.assert_index_equal(result, expected, exact=True)
711711

712712

713+
def test__getitem__boolean_numpyextensionarray():
714+
ri = RangeIndex(1)
715+
result = ri[pd.arrays.NumpyExtensionArray(np.array([True]))]
716+
tm.assert_index_equal(ri, result)
717+
718+
719+
@pytest.mark.parametrize(
720+
"container",
721+
[np.array, pd.Series, lambda x: pd.arrays.NumpyExtensionArray(np.array(x))],
722+
ids=["numpy-array", "series", "numpy-extension-array"],
723+
)
724+
def test__getitem__boolean_arraylike(container):
725+
ri = RangeIndex(5)
726+
result = ri[container([True, True, False, False, True])]
727+
expected = Index([0, 1, 4], dtype="int64")
728+
tm.assert_index_equal(result, expected)
729+
730+
713731
@pytest.mark.parametrize(
714732
"rng, exp_rng",
715733
[

pandas/tests/test_common.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,12 @@ def test_frozenlist(self):
236236
expected = df[[]]
237237
tm.assert_frame_equal(result, expected)
238238

239+
@pytest.mark.parametrize("scalar", [1, True])
240+
def test_numpyextensionarray(self, scalar):
241+
# GH 63391
242+
arr = pd.arrays.NumpyExtensionArray(np.array([scalar]))
243+
assert com.is_bool_indexer(arr) is isinstance(scalar, bool)
244+
239245

240246
@pytest.mark.parametrize("with_exception", [True, False])
241247
def test_temp_setattr(with_exception):

0 commit comments

Comments
 (0)