Skip to content

Commit 757c2e2

Browse files
committed
fix: Resolve the validation issue for other arg in dataframe where method
1 parent 209d0d4 commit 757c2e2

File tree

4 files changed

+78
-3
lines changed

4 files changed

+78
-3
lines changed

bigframes/dataframe.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2876,9 +2876,6 @@ def _apply_callable(self, condition):
28762876
return condition
28772877

28782878
def where(self, cond, other=None):
2879-
if isinstance(other, bigframes.series.Series):
2880-
raise ValueError("Seires is not a supported replacement type!")
2881-
28822879
if self.columns.nlevels > 1:
28832880
raise NotImplementedError(
28842881
"The dataframe.where() method does not support multi-column."
@@ -2889,6 +2886,9 @@ def where(self, cond, other=None):
28892886
cond = self._apply_callable(cond)
28902887
other = self._apply_callable(other)
28912888

2889+
if isinstance(other, bigframes.series.Series):
2890+
raise ValueError("Seires is not a supported replacement type!")
2891+
28922892
aligned_block, (_, _) = self._block.join(cond._block, how="left")
28932893
# No left join is needed when 'other' is None or constant.
28942894
if isinstance(other, bigframes.dataframe.DataFrame):

tests/system/large/functions/test_managed_function.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1103,6 +1103,37 @@ def func_for_other(x):
11031103
)
11041104

11051105

1106+
def test_managed_function_df_where_other_issue(session, dataset_id, scalars_df_index):
1107+
try:
1108+
1109+
def the_sum(s):
1110+
return s["int64_col"] + s["int64_too"]
1111+
1112+
the_sum_mf = session.udf(
1113+
input_types=bigframes.series.Series,
1114+
output_type=int,
1115+
dataset=dataset_id,
1116+
name=prefixer.create_prefix(),
1117+
)(the_sum)
1118+
1119+
int64_cols = ["int64_col", "int64_too"]
1120+
1121+
bf_int64_df = scalars_df_index[int64_cols]
1122+
bf_int64_df_filtered = bf_int64_df.dropna()
1123+
1124+
with pytest.raises(
1125+
ValueError,
1126+
match="Seires is not a supported replacement type!",
1127+
):
1128+
# The execution of the callable other=the_sum_mf will return a
1129+
# Series, which is not a supported replacement type.
1130+
bf_int64_df_filtered.where(cond=bf_int64_df_filtered, other=the_sum_mf)
1131+
1132+
finally:
1133+
# Clean up the gcp assets created for the managed function.
1134+
cleanup_function_assets(the_sum_mf, session.bqclient, ignore_failures=False)
1135+
1136+
11061137
def test_managed_function_series_where_mask(session, dataset_id, scalars_dfs):
11071138
try:
11081139

tests/system/large/functions/test_remote_function.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2896,6 +2896,38 @@ def is_sum_positive(a, b):
28962896
)
28972897

28982898

2899+
@pytest.mark.flaky(retries=2, delay=120)
2900+
def test_remote_function_df_where_other_issue(session, dataset_id, scalars_df_index):
2901+
try:
2902+
2903+
def the_sum(a, b):
2904+
return a + b
2905+
2906+
the_sum_mf = session.remote_function(
2907+
input_types=[int, float],
2908+
output_type=float,
2909+
dataset=dataset_id,
2910+
reuse=False,
2911+
cloud_function_service_account="default",
2912+
)(the_sum)
2913+
2914+
int64_cols = ["int64_col", "float64_col"]
2915+
bf_int64_df = scalars_df_index[int64_cols]
2916+
bf_int64_df_filtered = bf_int64_df.dropna()
2917+
2918+
with pytest.raises(
2919+
ValueError,
2920+
match="Seires is not a supported replacement type!",
2921+
):
2922+
# The execution of the callable other=the_sum_mf will return a
2923+
# Series, which is not a supported replacement type.
2924+
bf_int64_df_filtered.where(cond=bf_int64_df > 100, other=the_sum_mf)
2925+
2926+
finally:
2927+
# Clean up the gcp assets created for the remote function.
2928+
cleanup_function_assets(the_sum_mf, session.bqclient, ignore_failures=False)
2929+
2930+
28992931
@pytest.mark.flaky(retries=2, delay=120)
29002932
def test_remote_function_df_where_mask_series(session, dataset_id, scalars_dfs):
29012933
try:

tests/system/small/test_dataframe.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -570,6 +570,18 @@ def func(x):
570570
pandas.testing.assert_frame_equal(bf_result, pd_result)
571571

572572

573+
def test_where_series_other(scalars_df_index):
574+
# When other is a series, throw an error.
575+
columns = ["int64_col", "float64_col"]
576+
dataframe_bf = scalars_df_index[columns]
577+
578+
with pytest.raises(
579+
ValueError,
580+
match="Seires is not a supported replacement type!",
581+
):
582+
dataframe_bf.where(dataframe_bf > 0, dataframe_bf["int64_col"])
583+
584+
573585
def test_drop_column(scalars_dfs):
574586
scalars_df, scalars_pandas_df = scalars_dfs
575587
col_name = "int64_col"

0 commit comments

Comments
 (0)