Skip to content

Commit 523c80e

Browse files
committed
Merge remote-tracking branch 'origin/main' into tswast-geo
2 parents 9907c2b + 190f32e commit 523c80e

File tree

5 files changed

+21
-47
lines changed

5 files changed

+21
-47
lines changed

bigframes/core/compile/polars/compiler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -646,7 +646,7 @@ def _aggregate(
646646
def compile_explode(self, node: nodes.ExplodeNode):
647647
assert node.offsets_col is None
648648
df = self.compile_node(node.child)
649-
cols = [pl.col(col.id.sql) for col in node.column_ids]
649+
cols = [col.id.sql for col in node.column_ids]
650650
return df.explode(cols)
651651

652652
@compile_node.register

bigframes/core/indexes/base.py

Lines changed: 13 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -27,16 +27,12 @@
2727
import pandas
2828

2929
from bigframes import dtypes
30-
from bigframes.core.array_value import ArrayValue
3130
import bigframes.core.block_transforms as block_ops
3231
import bigframes.core.blocks as blocks
3332
import bigframes.core.expression as ex
34-
import bigframes.core.identifiers as ids
35-
import bigframes.core.nodes as nodes
3633
import bigframes.core.ordering as order
3734
import bigframes.core.utils as utils
3835
import bigframes.core.validations as validations
39-
import bigframes.core.window_spec as window_spec
4036
import bigframes.dtypes
4137
import bigframes.formatting_helpers as formatter
4238
import bigframes.operations as ops
@@ -272,37 +268,20 @@ def get_loc(self, key) -> typing.Union[int, slice, "bigframes.series.Series"]:
272268
# Get the index column from the block
273269
index_column = self._block.index_columns[0]
274270

275-
# Apply row numbering to the original data
276-
row_number_column_id = ids.ColumnId.unique()
277-
window_node = nodes.WindowOpNode(
278-
child=self._block._expr.node,
279-
expression=ex.NullaryAggregation(agg_ops.RowNumberOp()),
280-
window_spec=window_spec.unbound(),
281-
output_name=row_number_column_id,
282-
never_skip_nulls=True,
283-
)
284-
285-
windowed_array = ArrayValue(window_node)
286-
windowed_block = blocks.Block(
287-
windowed_array,
288-
index_columns=self._block.index_columns,
289-
column_labels=self._block.column_labels.insert(
290-
len(self._block.column_labels), None
291-
),
292-
index_labels=self._block._index_labels,
271+
# Use promote_offsets to get row numbers (similar to argmax/argmin implementation)
272+
block_with_offsets, offsets_id = self._block.promote_offsets(
273+
"temp_get_loc_offsets_"
293274
)
294275

295276
# Create expression to find matching positions
296277
match_expr = ops.eq_op.as_expr(ex.deref(index_column), ex.const(key))
297-
windowed_block, match_col_id = windowed_block.project_expr(match_expr)
278+
block_with_offsets, match_col_id = block_with_offsets.project_expr(match_expr)
298279

299280
# Filter to only rows where the key matches
300-
filtered_block = windowed_block.filter_by_id(match_col_id)
281+
filtered_block = block_with_offsets.filter_by_id(match_col_id)
301282

302-
# Check if key exists at all by counting on the filtered block
303-
count_agg = ex.UnaryAggregation(
304-
agg_ops.count_op, ex.deref(row_number_column_id.name)
305-
)
283+
# Check if key exists at all by counting
284+
count_agg = ex.UnaryAggregation(agg_ops.count_op, ex.deref(offsets_id))
306285
count_result = filtered_block._expr.aggregate([(count_agg, "count")])
307286
count_scalar = self._block.session._executor.execute(
308287
count_result
@@ -313,9 +292,7 @@ def get_loc(self, key) -> typing.Union[int, slice, "bigframes.series.Series"]:
313292

314293
# If only one match, return integer position
315294
if count_scalar == 1:
316-
min_agg = ex.UnaryAggregation(
317-
agg_ops.min_op, ex.deref(row_number_column_id.name)
318-
)
295+
min_agg = ex.UnaryAggregation(agg_ops.min_op, ex.deref(offsets_id))
319296
position_result = filtered_block._expr.aggregate([(min_agg, "position")])
320297
position_scalar = self._block.session._executor.execute(
321298
position_result
@@ -325,32 +302,24 @@ def get_loc(self, key) -> typing.Union[int, slice, "bigframes.series.Series"]:
325302
# Handle multiple matches based on index monotonicity
326303
is_monotonic = self.is_monotonic_increasing or self.is_monotonic_decreasing
327304
if is_monotonic:
328-
return self._get_monotonic_slice(filtered_block, row_number_column_id)
305+
return self._get_monotonic_slice(filtered_block, offsets_id)
329306
else:
330307
# Return boolean mask for non-monotonic duplicates
331-
mask_block = windowed_block.select_columns([match_col_id])
332-
# Reset the index to use positional integers instead of original index values
308+
mask_block = block_with_offsets.select_columns([match_col_id])
333309
mask_block = mask_block.reset_index(drop=True)
334-
# Ensure correct dtype and name to match pandas behavior
335310
result_series = bigframes.series.Series(mask_block)
336311
return result_series.astype("boolean")
337312

338-
def _get_monotonic_slice(
339-
self, filtered_block, row_number_column_id: "ids.ColumnId"
340-
) -> slice:
313+
def _get_monotonic_slice(self, filtered_block, offsets_id: str) -> slice:
341314
"""Helper method to get a slice for monotonic duplicates with an optimized query."""
342315
# Combine min and max aggregations into a single query for efficiency
343316
min_max_aggs = [
344317
(
345-
ex.UnaryAggregation(
346-
agg_ops.min_op, ex.deref(row_number_column_id.name)
347-
),
318+
ex.UnaryAggregation(agg_ops.min_op, ex.deref(offsets_id)),
348319
"min_pos",
349320
),
350321
(
351-
ex.UnaryAggregation(
352-
agg_ops.max_op, ex.deref(row_number_column_id.name)
353-
),
322+
ex.UnaryAggregation(agg_ops.max_op, ex.deref(offsets_id)),
354323
"max_pos",
355324
),
356325
]

tests/system/small/ml/test_llm.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,7 @@ def __eq__(self, other):
251251
return self.equals(other)
252252

253253

254+
@pytest.mark.skip("b/436340035 test failed")
254255
@pytest.mark.parametrize(
255256
(
256257
"model_class",
@@ -393,6 +394,7 @@ def test_text_generator_retry_success(
393394
)
394395

395396

397+
@pytest.mark.skip("b/436340035 test failed")
396398
@pytest.mark.parametrize(
397399
(
398400
"model_class",
@@ -509,6 +511,7 @@ def test_text_generator_retry_no_progress(session, model_class, options, bq_conn
509511
)
510512

511513

514+
@pytest.mark.skip("b/436340035 test failed")
512515
def test_text_embedding_generator_retry_success(session, bq_connection):
513516
# Requests.
514517
df0 = EqCmpAllDataFrame(
@@ -790,13 +793,14 @@ def test_gemini_preview_model_warnings(model_name):
790793
llm.GeminiTextGenerator(model_name=model_name)
791794

792795

796+
# b/436340035 temp disable the test to unblock presumbit
793797
@pytest.mark.parametrize(
794798
"model_class",
795799
[
796800
llm.TextEmbeddingGenerator,
797801
llm.MultimodalEmbeddingGenerator,
798802
llm.GeminiTextGenerator,
799-
llm.Claude3TextGenerator,
803+
# llm.Claude3TextGenerator,
800804
],
801805
)
802806
def test_text_embedding_generator_no_default_model_warning(model_class):

tests/unit/test_dataframe_polars.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1198,6 +1198,7 @@ def test_df_fillna(scalars_dfs, col, fill_value):
11981198
pd.testing.assert_frame_equal(bf_result, pd_result, check_dtype=False)
11991199

12001200

1201+
@pytest.mark.skip("b/436316698 unit test failed for python 3.12")
12011202
def test_df_ffill(scalars_dfs):
12021203
scalars_df, scalars_pandas_df = scalars_dfs
12031204
bf_result = scalars_df[["int64_col", "float64_col"]].ffill(limit=1).to_pandas()

third_party/bigframes_vendored/pandas/core/indexes/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -767,7 +767,7 @@ def get_loc(
767767
1 True
768768
2 False
769769
3 True
770-
Name: nan, dtype: boolean
770+
dtype: boolean
771771
772772
Args:
773773
key: Label to get the location for.

0 commit comments

Comments
 (0)