Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 43 additions & 16 deletions arkouda/pandas/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,13 +72,12 @@

from arkouda.numpy.dtypes import bool_ as akbool
from arkouda.numpy.dtypes import bool_scalars
from arkouda.numpy.dtypes import int64 as akint64
from arkouda.numpy.manipulation_functions import flip as ak_flip
from arkouda.numpy.pdarrayclass import RegistrationError, pdarray
from arkouda.numpy.pdarraysetops import argsort, in1d
from arkouda.numpy.sorting import coargsort
from arkouda.numpy.util import convert_if_categorical, generic_concat, get_callback
from arkouda.pandas.groupbyclass import GroupBy, unique
from arkouda.pandas.groupbyclass import GroupBy, groupable, unique


__all__ = [
Expand Down Expand Up @@ -1202,12 +1201,13 @@ def lookup(self, key):
Returns
-------
pdarray
A boolean array indicating which elements of `key` are present in the Index.
A boolean array of length ``len(self)``, indicating which entries of
the Index are present in `key`.

Raises
------
TypeError
If `key` is not a scalar or a pdarray.
If `key` cannot be converted to an arkouda array.

"""
from arkouda.numpy.pdarrayclass import pdarray
Expand Down Expand Up @@ -2139,15 +2139,22 @@ def concat(self, other):
idx = [generic_concat([ix1, ix2], ordered=True) for ix1, ix2 in zip(self.index, other.index)]
return MultiIndex(idx)

def lookup(self, key):
def lookup(self, key: Union[List, Tuple]) -> groupable:
"""
Perform element-wise lookup on the MultiIndex.

Parameters
----------
key : list or tuple
A sequence of values, one for each level of the MultiIndex. Values may be scalars
or pdarrays. If scalars, they are cast to the appropriate Arkouda array type.
A sequence of values, one for each level of the MultiIndex.

- If the elements are scalars (e.g., ``(1, "red")``), they are
treated as a single row key: the result is a boolean mask over
rows where all levels match the corresponding scalar.
- If the elements are arkouda arrays (e.g., list of pdarrays /
Strings), they must align one-to-one with the levels, and the
lookup is delegated to ``in1d(self.index, key)`` for multi-column
membership.

Returns
-------
Expand All @@ -2157,21 +2164,41 @@ def lookup(self, key):
Raises
------
TypeError
If `key` is not a list or tuple, or if its elements cannot be converted to pdarrays.
If `key` is not a list or tuple.
ValueError
If the length of `key` does not match the number of levels.

"""
from arkouda.numpy import cast as akcast
from arkouda.numpy.pdarrayclass import pdarray
from arkouda.numpy.pdarraycreation import array
from arkouda.numpy.strings import Strings

if not isinstance(key, (list, tuple)):
raise TypeError("MultiIndex.lookup expects a list or tuple of keys, one per level")

if len(key) != self.nlevels:
raise ValueError(
f"MultiIndex.lookup key length {len(key)} must match number of levels {self.nlevels}"
)

# Case 1: user passed per-level arkouda arrays.
# We assume they are already the correct types and lengths.
if isinstance(key[0], (pdarray, Strings)):
return in1d(self.index, key)

# Case 2: user passed scalars (e.g., (1, "red")).
# Convert each scalar to a length-1 arkouda array, preserving per-level dtypes.
scalar_key_arrays = []
for i, v in enumerate(key):
lvl = self.levels[i]

# Determine the dtype for this level
dt = lvl.dtype

if not isinstance(key, list) and not isinstance(key, tuple):
raise TypeError("MultiIndex lookup failure")
# if individual vals convert to pdarrays
if not isinstance(key[0], pdarray):
dt = self.levels[0].dtype if isinstance(self.levels[0], pdarray) else akint64
key = [akcast(array([x]), dt) for x in key]
a = array([v], dtype=dt) # make length-1 array
scalar_key_arrays.append(a)

return in1d(self.index, key)
return in1d(self.index, scalar_key_arrays)

def to_hdf(
self,
Expand Down
14 changes: 14 additions & 0 deletions tests/pandas/index_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -846,3 +846,17 @@ def test_round_trip_conversion_categorical(self, size):
i1 = ak.Index(ak.Categorical(strings1))
i2 = ak.Index(i1.to_pandas())
assert_index_equal(i1, i2)

def test_multiindex_lookup_tuple_mixed_dtypes(self):
# Level 0: int
lvl0 = ak.array([1, 1, 2, 3])
# Level 1: strings
lvl1 = ak.array(["red", "blue", "red", "red"])

midx = ak.MultiIndex([lvl0, lvl1], names=["num", "color"])

# Tuple key mixes int + str and should NOT trigger castStringsTo<int64> on "red"
mask = midx.lookup((1, "red"))

# Expect exactly the first row to match
assert mask.to_ndarray().tolist() == [True, False, False, False]