Skip to content

Commit 73f30dc

Browse files
fix multiindex eq, mypy
1 parent 4196dd4 commit 73f30dc

File tree

3 files changed

+23
-7
lines changed

3 files changed

+23
-7
lines changed

bigframes/core/indexes/base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -754,7 +754,7 @@ def item(self):
754754
# Docstring is in third_party/bigframes_vendored/pandas/core/indexes/base.py
755755
return self.to_series().peek(2).item()
756756

757-
def __eq__(self, other) -> Index:
757+
def __eq__(self, other) -> Index: # type: ignore
758758
return self._apply_binop(other, ops.eq_op)
759759

760760
def _apply_binop(self, other, op: ops.BinaryOp) -> Index:
@@ -802,7 +802,7 @@ def _apply_binop(self, other, op: ops.BinaryOp) -> Index:
802802
labels=[None] * self.nlevels,
803803
drop=True,
804804
)
805-
return Index(block)
805+
return Index(block.set_index(block.value_columns))
806806
else:
807807
return NotImplemented
808808

bigframes/core/indexes/multi.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
import bigframes_vendored.pandas.core.indexes.multi as vendored_pandas_multindex
2020
import pandas
2121

22+
from bigframes.core import blocks
23+
from bigframes.core import expression as ex
2224
from bigframes.core.indexes.base import Index
2325

2426

@@ -47,13 +49,25 @@ def from_arrays(
4749
# Index.__new__ should detect multiple levels and properly create a multiindex
4850
return cast(MultiIndex, Index(pd_index))
4951

50-
def __eg__(self, other) -> Index:
52+
def __eq__(self, other) -> Index: # type: ignore
5153
import bigframes.operations as ops
5254
import bigframes.operations.aggregations as agg_ops
5355

54-
eq_result = self._apply_binop(other, ops.eq_op)
56+
eq_result = self._apply_binop(other, ops.eq_op)._block.expr
57+
58+
as_array = ops.ToArrayOp().as_expr(
59+
*(
60+
ops.fillna_op.as_expr(col, ex.const(False))
61+
for col in eq_result.column_ids
62+
)
63+
)
64+
reduced = ops.ArrayReduceOp(agg_ops.all_op).as_expr(as_array)
65+
result_expr, result_ids = eq_result.compute_values([reduced])
5566
return Index(
56-
eq_result._block.aggregate_all_and_stack(
57-
agg_ops.all_op, axis=1, dropna=False
67+
blocks.Block(
68+
result_expr.select_columns(result_ids),
69+
index_columns=result_ids,
70+
column_labels=(),
71+
index_labels=[None],
5872
)
5973
)

tests/system/small/test_multiindex.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1481,4 +1481,6 @@ def test_multiindex_eq_const(scalars_df_index, scalars_pandas_df_index):
14811481
bf_result = scalars_df_index.set_index(col_name).index == (2, False)
14821482
pd_result = scalars_pandas_df_index.set_index(col_name).index == (2, False)
14831483

1484-
assert bf_result == pd_result
1484+
pandas.testing.assert_index_equal(
1485+
pandas.Index(pd_result, dtype="boolean"), bf_result.to_pandas()
1486+
)

0 commit comments

Comments
 (0)