Skip to content

Commit d56df99

Browse files
authored
Merge branch 'googleapis:main' into output_schema
2 parents 5ec67a6 + 46994d7 commit d56df99

File tree

13 files changed

+355
-63
lines changed

13 files changed

+355
-63
lines changed

bigframes/core/compile/sqlglot/expressions/unary_compiler.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import sqlglot.expressions as sge
2323

2424
from bigframes import operations as ops
25+
from bigframes.core.compile.constants import UNIT_TO_US_CONVERSION_FACTORS
2526
from bigframes.core.compile.sqlglot.expressions.op_registration import OpRegistration
2627
from bigframes.core.compile.sqlglot.expressions.typed_expr import TypedExpr
2728

@@ -618,7 +619,11 @@ def _(op: ops.ToTimestampOp, expr: TypedExpr) -> sge.Expression:
618619

619620
@UNARY_OP_REGISTRATION.register(ops.ToTimedeltaOp)
620621
def _(op: ops.ToTimedeltaOp, expr: TypedExpr) -> sge.Expression:
621-
return sge.Interval(this=expr.expr, unit=sge.Identifier(this="SECOND"))
622+
value = expr.expr
623+
factor = UNIT_TO_US_CONVERSION_FACTORS[op.unit]
624+
if factor != 1:
625+
value = sge.Mul(this=value, expression=sge.convert(factor))
626+
return sge.Interval(this=value, unit=sge.Identifier(this="MICROSECOND"))
622627

623628

624629
@UNARY_OP_REGISTRATION.register(ops.UnixMicros)

bigframes/dataframe.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2797,10 +2797,17 @@ def where(self, cond, other=None):
27972797
)
27982798

27992799
# Execute it with the DataFrame when cond or/and other is callable.
2800+
# It can be either a plain python function or remote/managed function.
28002801
if callable(cond):
2801-
cond = cond(self)
2802+
if hasattr(cond, "bigframes_bigquery_function"):
2803+
cond = self.apply(cond, axis=1)
2804+
else:
2805+
cond = cond(self)
28022806
if callable(other):
2803-
other = other(self)
2807+
if hasattr(other, "bigframes_bigquery_function"):
2808+
other = self.apply(other, axis=1)
2809+
else:
2810+
other = other(self)
28042811

28052812
aligned_block, (_, _) = self._block.join(cond._block, how="left")
28062813
# No left join is needed when 'other' is None or constant.
@@ -2813,7 +2820,7 @@ def where(self, cond, other=None):
28132820
labels = aligned_block.column_labels[:self_len]
28142821
self_col = {x: ex.deref(y) for x, y in zip(labels, ids)}
28152822

2816-
if isinstance(cond, bigframes.series.Series) and cond.name in self_col:
2823+
if isinstance(cond, bigframes.series.Series):
28172824
# This is when 'cond' is a valid series.
28182825
y = aligned_block.value_columns[self_len]
28192826
cond_col = {x: ex.deref(y) for x in self_col.keys()}

bigframes/functions/_function_session.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -555,10 +555,6 @@ def wrapper(func):
555555
warnings.warn(msg, category=bfe.FunctionConflictTypeHintWarning)
556556
py_sig = py_sig.replace(return_annotation=output_type)
557557

558-
# Try to get input types via type annotations.
559-
560-
# The function will actually be receiving a pandas Series, but allow both
561-
# BigQuery DataFrames and pandas object types for compatibility.
562558
# The function will actually be receiving a pandas Series, but allow
563559
# both BigQuery DataFrames and pandas object types for compatibility.
564560
is_row_processor = False

tests/system/large/functions/test_managed_function.py

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -963,3 +963,115 @@ def float_parser(row):
963963
cleanup_function_assets(
964964
float_parser_mf, session.bqclient, ignore_failures=False
965965
)
966+
967+
968+
def test_managed_function_df_where(session, dataset_id, scalars_dfs):
969+
try:
970+
971+
# The return type has to be bool type for callable where condition.
972+
def is_sum_positive(a, b):
973+
return a + b > 0
974+
975+
is_sum_positive_mf = session.udf(
976+
input_types=[int, int],
977+
output_type=bool,
978+
dataset=dataset_id,
979+
name=prefixer.create_prefix(),
980+
)(is_sum_positive)
981+
982+
scalars_df, scalars_pandas_df = scalars_dfs
983+
int64_cols = ["int64_col", "int64_too"]
984+
985+
bf_int64_df = scalars_df[int64_cols]
986+
bf_int64_df_filtered = bf_int64_df.dropna()
987+
pd_int64_df = scalars_pandas_df[int64_cols]
988+
pd_int64_df_filtered = pd_int64_df.dropna()
989+
990+
# Use callable condition in dataframe.where method.
991+
bf_result = bf_int64_df_filtered.where(is_sum_positive_mf).to_pandas()
992+
# Pandas doesn't support such case, use following as workaround.
993+
pd_result = pd_int64_df_filtered.where(pd_int64_df_filtered.sum(axis=1) > 0)
994+
995+
# Ignore any dtype difference.
996+
pandas.testing.assert_frame_equal(bf_result, pd_result, check_dtype=False)
997+
998+
# Make sure the read_gbq_function path works for this function.
999+
is_sum_positive_ref = session.read_gbq_function(
1000+
function_name=is_sum_positive_mf.bigframes_bigquery_function
1001+
)
1002+
1003+
bf_result_gbq = bf_int64_df_filtered.where(
1004+
is_sum_positive_ref, -bf_int64_df_filtered
1005+
).to_pandas()
1006+
pd_result_gbq = pd_int64_df_filtered.where(
1007+
pd_int64_df_filtered.sum(axis=1) > 0, -pd_int64_df_filtered
1008+
)
1009+
1010+
# Ignore any dtype difference.
1011+
pandas.testing.assert_frame_equal(
1012+
bf_result_gbq, pd_result_gbq, check_dtype=False
1013+
)
1014+
1015+
finally:
1016+
# Clean up the gcp assets created for the managed function.
1017+
cleanup_function_assets(
1018+
is_sum_positive_mf, session.bqclient, ignore_failures=False
1019+
)
1020+
1021+
1022+
def test_managed_function_df_where_series(session, dataset_id, scalars_dfs):
1023+
try:
1024+
1025+
# The return type has to be bool type for callable where condition.
1026+
def is_sum_positive_series(s):
1027+
return s["int64_col"] + s["int64_too"] > 0
1028+
1029+
is_sum_positive_series_mf = session.udf(
1030+
input_types=bigframes.series.Series,
1031+
output_type=bool,
1032+
dataset=dataset_id,
1033+
name=prefixer.create_prefix(),
1034+
)(is_sum_positive_series)
1035+
1036+
scalars_df, scalars_pandas_df = scalars_dfs
1037+
int64_cols = ["int64_col", "int64_too"]
1038+
1039+
bf_int64_df = scalars_df[int64_cols]
1040+
bf_int64_df_filtered = bf_int64_df.dropna()
1041+
pd_int64_df = scalars_pandas_df[int64_cols]
1042+
pd_int64_df_filtered = pd_int64_df.dropna()
1043+
1044+
# Use callable condition in dataframe.where method.
1045+
bf_result = bf_int64_df_filtered.where(is_sum_positive_series).to_pandas()
1046+
pd_result = pd_int64_df_filtered.where(is_sum_positive_series)
1047+
1048+
# Ignore any dtype difference.
1049+
pandas.testing.assert_frame_equal(bf_result, pd_result, check_dtype=False)
1050+
1051+
# Make sure the read_gbq_function path works for this function.
1052+
is_sum_positive_series_ref = session.read_gbq_function(
1053+
function_name=is_sum_positive_series_mf.bigframes_bigquery_function,
1054+
is_row_processor=True,
1055+
)
1056+
1057+
# This is for callable `other` arg in dataframe.where method.
1058+
def func_for_other(x):
1059+
return -x
1060+
1061+
bf_result_gbq = bf_int64_df_filtered.where(
1062+
is_sum_positive_series_ref, func_for_other
1063+
).to_pandas()
1064+
pd_result_gbq = pd_int64_df_filtered.where(
1065+
is_sum_positive_series, func_for_other
1066+
)
1067+
1068+
# Ignore any dtype difference.
1069+
pandas.testing.assert_frame_equal(
1070+
bf_result_gbq, pd_result_gbq, check_dtype=False
1071+
)
1072+
1073+
finally:
1074+
# Clean up the gcp assets created for the managed function.
1075+
cleanup_function_assets(
1076+
is_sum_positive_series_mf, session.bqclient, ignore_failures=False
1077+
)

tests/system/large/functions/test_remote_function.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2847,3 +2847,86 @@ def foo(x: int) -> int:
28472847
finally:
28482848
# clean up the gcp assets created for the remote function
28492849
cleanup_function_assets(foo, session.bqclient, session.cloudfunctionsclient)
2850+
2851+
2852+
@pytest.mark.flaky(retries=2, delay=120)
2853+
def test_remote_function_df_where(session, dataset_id, scalars_dfs):
2854+
try:
2855+
2856+
# The return type has to be bool type for callable where condition.
2857+
def is_sum_positive(a, b):
2858+
return a + b > 0
2859+
2860+
is_sum_positive_mf = session.remote_function(
2861+
input_types=[int, int],
2862+
output_type=bool,
2863+
dataset=dataset_id,
2864+
reuse=False,
2865+
cloud_function_service_account="default",
2866+
)(is_sum_positive)
2867+
2868+
scalars_df, scalars_pandas_df = scalars_dfs
2869+
int64_cols = ["int64_col", "int64_too"]
2870+
2871+
bf_int64_df = scalars_df[int64_cols]
2872+
bf_int64_df_filtered = bf_int64_df.dropna()
2873+
pd_int64_df = scalars_pandas_df[int64_cols]
2874+
pd_int64_df_filtered = pd_int64_df.dropna()
2875+
2876+
# Use callable condition in dataframe.where method.
2877+
bf_result = bf_int64_df_filtered.where(is_sum_positive_mf, 0).to_pandas()
2878+
# Pandas doesn't support such case, use following as workaround.
2879+
pd_result = pd_int64_df_filtered.where(pd_int64_df_filtered.sum(axis=1) > 0, 0)
2880+
2881+
# Ignore any dtype difference.
2882+
pandas.testing.assert_frame_equal(bf_result, pd_result, check_dtype=False)
2883+
2884+
finally:
2885+
# Clean up the gcp assets created for the remote function.
2886+
cleanup_function_assets(
2887+
is_sum_positive_mf, session.bqclient, ignore_failures=False
2888+
)
2889+
2890+
2891+
@pytest.mark.flaky(retries=2, delay=120)
2892+
def test_remote_function_df_where_series(session, dataset_id, scalars_dfs):
2893+
try:
2894+
2895+
# The return type has to be bool type for callable where condition.
2896+
def is_sum_positive_series(s):
2897+
return s["int64_col"] + s["int64_too"] > 0
2898+
2899+
is_sum_positive_series_mf = session.remote_function(
2900+
input_types=bigframes.series.Series,
2901+
output_type=bool,
2902+
dataset=dataset_id,
2903+
reuse=False,
2904+
cloud_function_service_account="default",
2905+
)(is_sum_positive_series)
2906+
2907+
scalars_df, scalars_pandas_df = scalars_dfs
2908+
int64_cols = ["int64_col", "int64_too"]
2909+
2910+
bf_int64_df = scalars_df[int64_cols]
2911+
bf_int64_df_filtered = bf_int64_df.dropna()
2912+
pd_int64_df = scalars_pandas_df[int64_cols]
2913+
pd_int64_df_filtered = pd_int64_df.dropna()
2914+
2915+
# This is for callable `other` arg in dataframe.where method.
2916+
def func_for_other(x):
2917+
return -x
2918+
2919+
# Use callable condition in dataframe.where method.
2920+
bf_result = bf_int64_df_filtered.where(
2921+
is_sum_positive_series, func_for_other
2922+
).to_pandas()
2923+
pd_result = pd_int64_df_filtered.where(is_sum_positive_series, func_for_other)
2924+
2925+
# Ignore any dtype difference.
2926+
pandas.testing.assert_frame_equal(bf_result, pd_result, check_dtype=False)
2927+
2928+
finally:
2929+
# Clean up the gcp assets created for the remote function.
2930+
cleanup_function_assets(
2931+
is_sum_positive_series_mf, session.bqclient, ignore_failures=False
2932+
)

tests/unit/core/compile/sqlglot/conftest.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ def scalar_types_table_schema() -> typing.Sequence[bigquery.SchemaField]:
8989
bigquery.SchemaField("string_col", "STRING"),
9090
bigquery.SchemaField("time_col", "TIME"),
9191
bigquery.SchemaField("timestamp_col", "TIMESTAMP"),
92+
bigquery.SchemaField("duration_col", "INTEGER"),
9293
]
9394

9495

tests/unit/core/compile/sqlglot/expressions/snapshots/test_binary_compiler/test_mul_timedelta/out.sql

Lines changed: 26 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,30 +2,42 @@ WITH `bfcte_0` AS (
22
SELECT
33
`int64_col` AS `bfcol_0`,
44
`rowindex` AS `bfcol_1`,
5-
`timestamp_col` AS `bfcol_2`
5+
`timestamp_col` AS `bfcol_2`,
6+
`duration_col` AS `bfcol_3`
67
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
78
), `bfcte_1` AS (
89
SELECT
910
*,
10-
`bfcol_1` AS `bfcol_6`,
11-
`bfcol_2` AS `bfcol_7`,
12-
`bfcol_0` AS `bfcol_8`,
13-
CAST(FLOOR(86400000000 * `bfcol_0`) AS INT64) AS `bfcol_9`
11+
`bfcol_1` AS `bfcol_8`,
12+
`bfcol_2` AS `bfcol_9`,
13+
`bfcol_0` AS `bfcol_10`,
14+
INTERVAL `bfcol_3` MICROSECOND AS `bfcol_11`
1415
FROM `bfcte_0`
1516
), `bfcte_2` AS (
1617
SELECT
1718
*,
18-
`bfcol_6` AS `bfcol_14`,
19-
`bfcol_7` AS `bfcol_15`,
2019
`bfcol_8` AS `bfcol_16`,
2120
`bfcol_9` AS `bfcol_17`,
22-
CAST(FLOOR(`bfcol_8` * 86400000000) AS INT64) AS `bfcol_18`
21+
`bfcol_10` AS `bfcol_18`,
22+
`bfcol_11` AS `bfcol_19`,
23+
CAST(FLOOR(`bfcol_11` * `bfcol_10`) AS INT64) AS `bfcol_20`
2324
FROM `bfcte_1`
25+
), `bfcte_3` AS (
26+
SELECT
27+
*,
28+
`bfcol_16` AS `bfcol_26`,
29+
`bfcol_17` AS `bfcol_27`,
30+
`bfcol_18` AS `bfcol_28`,
31+
`bfcol_19` AS `bfcol_29`,
32+
`bfcol_20` AS `bfcol_30`,
33+
CAST(FLOOR(`bfcol_18` * `bfcol_19`) AS INT64) AS `bfcol_31`
34+
FROM `bfcte_2`
2435
)
2536
SELECT
26-
`bfcol_14` AS `rowindex`,
27-
`bfcol_15` AS `timestamp_col`,
28-
`bfcol_16` AS `int64_col`,
29-
`bfcol_17` AS `timedelta_mul_numeric`,
30-
`bfcol_18` AS `numeric_mul_timedelta`
31-
FROM `bfcte_2`
37+
`bfcol_26` AS `rowindex`,
38+
`bfcol_27` AS `timestamp_col`,
39+
`bfcol_28` AS `int64_col`,
40+
`bfcol_29` AS `duration_col`,
41+
`bfcol_30` AS `timedelta_mul_numeric`,
42+
`bfcol_31` AS `numeric_mul_timedelta`
43+
FROM `bfcte_3`

0 commit comments

Comments
 (0)