Skip to content

Commit 81d35e6

Browse files
Merge remote-tracking branch 'github/main' into cf_name
2 parents a99f835 + 1f9ee37 commit 81d35e6

File tree

21 files changed

+466
-75
lines changed

21 files changed

+466
-75
lines changed

bigframes/bigquery/_operations/ai.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -58,14 +58,14 @@ def generate(
5858
>>> import bigframes.pandas as bpd
5959
>>> import bigframes.bigquery as bbq
6060
>>> country = bpd.Series(["Japan", "Canada"])
61-
>>> bbq.ai.generate(("What's the capital city of ", country, " one word only"))
62-
0 {'result': 'Tokyo\\n', 'full_response': '{"cand...
63-
1 {'result': 'Ottawa\\n', 'full_response': '{"can...
61+
>>> bbq.ai.generate(("What's the capital city of ", country, " one word only")) # doctest: +SKIP
62+
0 {'result': 'Tokyo', 'full_response': '{"cand...
63+
1 {'result': 'Ottawa', 'full_response': '{"can...
6464
dtype: struct<result: string, full_response: extension<dbjson<JSONArrowType>>, status: string>[pyarrow]
6565
66-
>>> bbq.ai.generate(("What's the capital city of ", country, " one word only")).struct.field("result")
67-
0 Tokyo\\n
68-
1 Ottawa\\n
66+
>>> bbq.ai.generate(("What's the capital city of ", country, " one word only")).struct.field("result") # doctest: +SKIP
67+
0 Tokyo
68+
1 Ottawa
6969
Name: result, dtype: string
7070
7171
You get structured output when the `output_schema` parameter is set:

bigframes/bigquery/_operations/ml.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -393,3 +393,41 @@ def global_explain(
393393
return bpd.read_gbq_query(sql)
394394
else:
395395
return session.read_gbq_query(sql)
396+
397+
398+
@log_adapter.method_logger(custom_base_name="bigquery_ml")
399+
def transform(
400+
model: Union[bigframes.ml.base.BaseEstimator, str, pd.Series],
401+
input_: Union[pd.DataFrame, dataframe.DataFrame, str],
402+
) -> dataframe.DataFrame:
403+
"""
404+
Transforms input data using a BigQuery ML model.
405+
406+
See the `BigQuery ML TRANSFORM function syntax
407+
<https://docs.cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-transform>`_
408+
for additional reference.
409+
410+
Args:
411+
model (bigframes.ml.base.BaseEstimator or str):
412+
The model to use for transformation.
413+
input_ (Union[bigframes.pandas.DataFrame, str]):
414+
The DataFrame or query to use for transformation.
415+
416+
Returns:
417+
bigframes.pandas.DataFrame:
418+
The transformed data.
419+
"""
420+
import bigframes.pandas as bpd
421+
422+
model_name, session = _get_model_name_and_session(model, input_)
423+
table_sql = _to_sql(input_)
424+
425+
sql = bigframes.core.sql.ml.transform(
426+
model_name=model_name,
427+
table=table_sql,
428+
)
429+
430+
if session is None:
431+
return bpd.read_gbq_query(sql)
432+
else:
433+
return session.read_gbq_query(sql)

bigframes/bigquery/ml.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
explain_predict,
2626
global_explain,
2727
predict,
28+
transform,
2829
)
2930

3031
__all__ = [
@@ -33,4 +34,5 @@
3334
"predict",
3435
"explain_predict",
3536
"global_explain",
37+
"transform",
3638
]

bigframes/core/compile/sqlglot/expressions/generic_ops.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -152,13 +152,17 @@ def _(left: TypedExpr, right: TypedExpr) -> sge.Expression:
152152
return sge.Coalesce(this=left.expr, expressions=[right.expr])
153153

154154

155-
@register_unary_op(ops.RemoteFunctionOp, pass_op=True)
156-
def _(expr: TypedExpr, op: ops.RemoteFunctionOp) -> sge.Expression:
155+
def _get_remote_function_name(op):
157156
routine_ref = op.function_def.routine_ref
158157
# Quote project, dataset, and routine IDs to avoid keyword clashes.
159-
func_name = (
158+
return (
160159
f"`{routine_ref.project}`.`{routine_ref.dataset_id}`.`{routine_ref.routine_id}`"
161160
)
161+
162+
163+
@register_unary_op(ops.RemoteFunctionOp, pass_op=True)
164+
def _(expr: TypedExpr, op: ops.RemoteFunctionOp) -> sge.Expression:
165+
func_name = _get_remote_function_name(op)
162166
func = sge.func(func_name, expr.expr)
163167

164168
if not op.apply_on_null:
@@ -175,15 +179,16 @@ def _(expr: TypedExpr, op: ops.RemoteFunctionOp) -> sge.Expression:
175179
def _(
176180
left: TypedExpr, right: TypedExpr, op: ops.BinaryRemoteFunctionOp
177181
) -> sge.Expression:
178-
routine_ref = op.function_def.routine_ref
179-
# Quote project, dataset, and routine IDs to avoid keyword clashes.
180-
func_name = (
181-
f"`{routine_ref.project}`.`{routine_ref.dataset_id}`.`{routine_ref.routine_id}`"
182-
)
183-
182+
func_name = _get_remote_function_name(op)
184183
return sge.func(func_name, left.expr, right.expr)
185184

186185

186+
@register_nary_op(ops.NaryRemoteFunctionOp, pass_op=True)
187+
def _(*operands: TypedExpr, op: ops.NaryRemoteFunctionOp) -> sge.Expression:
188+
func_name = _get_remote_function_name(op)
189+
return sge.func(func_name, *(operand.expr for operand in operands))
190+
191+
187192
@register_nary_op(ops.case_when_op)
188193
def _(*cases_and_outputs: TypedExpr) -> sge.Expression:
189194
# Need to upcast BOOL to INT if any output is numeric

bigframes/core/compile/sqlglot/expressions/numeric_ops.py

Lines changed: 57 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -93,12 +93,19 @@ def _(expr: TypedExpr) -> sge.Expression:
9393
def _(expr: TypedExpr) -> sge.Expression:
9494
return sge.Case(
9595
ifs=[
96+
# |x| < 1: The standard formula
97+
sge.If(
98+
this=sge.func("ABS", expr.expr) < sge.convert(1),
99+
true=sge.func("ATANH", expr.expr),
100+
),
101+
# |x| > 1: Returns NaN
96102
sge.If(
97103
this=sge.func("ABS", expr.expr) > sge.convert(1),
98104
true=constants._NAN,
99-
)
105+
),
100106
],
101-
default=sge.func("ATANH", expr.expr),
107+
# |x| = 1: Returns Infinity or -Infinity
108+
default=sge.Mul(this=constants._INF, expression=expr.expr),
102109
)
103110

104111

@@ -145,15 +152,11 @@ def _(expr: TypedExpr) -> sge.Expression:
145152

146153
@register_unary_op(ops.expm1_op)
147154
def _(expr: TypedExpr) -> sge.Expression:
148-
return sge.Case(
149-
ifs=[
150-
sge.If(
151-
this=expr.expr > constants._FLOAT64_EXP_BOUND,
152-
true=constants._INF,
153-
)
154-
],
155-
default=sge.func("EXP", expr.expr),
156-
) - sge.convert(1)
155+
return sge.If(
156+
this=expr.expr > constants._FLOAT64_EXP_BOUND,
157+
true=constants._INF,
158+
false=sge.func("EXP", expr.expr) - sge.convert(1),
159+
)
157160

158161

159162
@register_unary_op(ops.floor_op)
@@ -166,11 +169,22 @@ def _(expr: TypedExpr) -> sge.Expression:
166169
return sge.Case(
167170
ifs=[
168171
sge.If(
169-
this=expr.expr <= sge.convert(0),
172+
this=sge.Is(this=expr.expr, expression=sge.Null()),
173+
true=sge.null(),
174+
),
175+
# |x| > 0: The standard formula
176+
sge.If(
177+
this=expr.expr > sge.convert(0),
178+
true=sge.Ln(this=expr.expr),
179+
),
180+
# |x| < 0: Returns NaN
181+
sge.If(
182+
this=expr.expr < sge.convert(0),
170183
true=constants._NAN,
171-
)
184+
),
172185
],
173-
default=sge.Ln(this=expr.expr),
186+
# |x| == 0: Returns -Infinity
187+
default=constants._NEG_INF,
174188
)
175189

176190

@@ -179,11 +193,22 @@ def _(expr: TypedExpr) -> sge.Expression:
179193
return sge.Case(
180194
ifs=[
181195
sge.If(
182-
this=expr.expr <= sge.convert(0),
196+
this=sge.Is(this=expr.expr, expression=sge.Null()),
197+
true=sge.null(),
198+
),
199+
# |x| > 0: The standard formula
200+
sge.If(
201+
this=expr.expr > sge.convert(0),
202+
true=sge.Log(this=sge.convert(10), expression=expr.expr),
203+
),
204+
# |x| < 0: Returns NaN
205+
sge.If(
206+
this=expr.expr < sge.convert(0),
183207
true=constants._NAN,
184-
)
208+
),
185209
],
186-
default=sge.Log(this=expr.expr, expression=sge.convert(10)),
210+
# |x| == 0: Returns -Infinity
211+
default=constants._NEG_INF,
187212
)
188213

189214

@@ -192,11 +217,22 @@ def _(expr: TypedExpr) -> sge.Expression:
192217
return sge.Case(
193218
ifs=[
194219
sge.If(
195-
this=expr.expr <= sge.convert(-1),
220+
this=sge.Is(this=expr.expr, expression=sge.Null()),
221+
true=sge.null(),
222+
),
223+
# Domain: |x| > -1 (The standard formula)
224+
sge.If(
225+
this=expr.expr > sge.convert(-1),
226+
true=sge.Ln(this=sge.convert(1) + expr.expr),
227+
),
228+
# Out of Domain: |x| < -1 (Returns NaN)
229+
sge.If(
230+
this=expr.expr < sge.convert(-1),
196231
true=constants._NAN,
197-
)
232+
),
198233
],
199-
default=sge.Ln(this=sge.convert(1) + expr.expr),
234+
# Boundary: |x| == -1 (Returns -Infinity)
235+
default=constants._NEG_INF,
200236
)
201237

202238

@@ -608,7 +644,7 @@ def isfinite(arg: TypedExpr) -> sge.Expression:
608644
return sge.Not(
609645
this=sge.Or(
610646
this=sge.IsInf(this=arg.expr),
611-
right=sge.IsNan(this=arg.expr),
647+
expression=sge.IsNan(this=arg.expr),
612648
),
613649
)
614650

bigframes/core/compile/sqlglot/sqlglot_ir.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -674,7 +674,7 @@ def _literal(value: typing.Any, dtype: dtypes.Dtype) -> sge.Expression:
674674
expressions=[_literal(value=v, dtype=value_type) for v in value]
675675
)
676676
return values if len(value) > 0 else _cast(values, sqlglot_type)
677-
elif pd.isna(value):
677+
elif pd.isna(value) or (isinstance(value, pa.Scalar) and not value.is_valid):
678678
return _cast(sge.Null(), sqlglot_type)
679679
elif dtype == dtypes.JSON_DTYPE:
680680
return sge.ParseJSON(this=sge.convert(str(value)))

bigframes/core/logging/data_types.py

Lines changed: 103 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,115 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from __future__ import annotations
16+
17+
import functools
1518

1619
from bigframes import dtypes
20+
from bigframes.core import agg_expressions, bigframe_node, expression, nodes
21+
from bigframes.core.rewrite import schema_binding
22+
23+
IGNORED_NODES = (
24+
nodes.SelectionNode,
25+
nodes.ReadLocalNode,
26+
nodes.ReadTableNode,
27+
nodes.ConcatNode,
28+
nodes.RandomSampleNode,
29+
nodes.FromRangeNode,
30+
nodes.PromoteOffsetsNode,
31+
nodes.ReversedNode,
32+
nodes.SliceNode,
33+
nodes.ResultNode,
34+
)
35+
36+
37+
def encode_type_refs(root: bigframe_node.BigFrameNode) -> str:
38+
return f"{root.reduce_up(_encode_type_refs_from_node):x}"
39+
40+
41+
def _encode_type_refs_from_node(
42+
node: bigframe_node.BigFrameNode, child_results: tuple[int, ...]
43+
) -> int:
44+
child_result = functools.reduce(lambda x, y: x | y, child_results, 0)
45+
46+
curr_result = 0
47+
if isinstance(node, nodes.FilterNode):
48+
curr_result = _encode_type_refs_from_expr(node.predicate, node.child)
49+
elif isinstance(node, nodes.ProjectionNode):
50+
for assignment in node.assignments:
51+
expr = assignment[0]
52+
if isinstance(expr, (expression.DerefOp)):
53+
# Ignore direct assignments in projection nodes.
54+
continue
55+
curr_result = curr_result | _encode_type_refs_from_expr(
56+
assignment[0], node.child
57+
)
58+
elif isinstance(node, nodes.OrderByNode):
59+
for by in node.by:
60+
curr_result = curr_result | _encode_type_refs_from_expr(
61+
by.scalar_expression, node.child
62+
)
63+
elif isinstance(node, nodes.JoinNode):
64+
for left, right in node.conditions:
65+
curr_result = (
66+
curr_result
67+
| _encode_type_refs_from_expr(left, node.left_child)
68+
| _encode_type_refs_from_expr(right, node.right_child)
69+
)
70+
elif isinstance(node, nodes.InNode):
71+
curr_result = _encode_type_refs_from_expr(node.left_col, node.left_child)
72+
elif isinstance(node, nodes.AggregateNode):
73+
for agg, _ in node.aggregations:
74+
curr_result = curr_result | _encode_type_refs_from_expr(agg, node.child)
75+
elif isinstance(node, nodes.WindowOpNode):
76+
for grouping_key in node.window_spec.grouping_keys:
77+
curr_result = curr_result | _encode_type_refs_from_expr(
78+
grouping_key, node.child
79+
)
80+
for ordering_expr in node.window_spec.ordering:
81+
curr_result = curr_result | _encode_type_refs_from_expr(
82+
ordering_expr.scalar_expression, node.child
83+
)
84+
for col_def in node.agg_exprs:
85+
curr_result = curr_result | _encode_type_refs_from_expr(
86+
col_def.expression, node.child
87+
)
88+
elif isinstance(node, nodes.ExplodeNode):
89+
for col_id in node.column_ids:
90+
curr_result = curr_result | _encode_type_refs_from_expr(col_id, node.child)
91+
elif isinstance(node, IGNORED_NODES):
92+
# Do nothing
93+
pass
94+
else:
95+
# For unseen nodes, do not raise errors as this is the logging path, but
96+
# we should cover those nodes either in the branches above, or place them
97+
# in the IGNORED_NODES collection.
98+
pass
99+
100+
return child_result | curr_result
101+
102+
103+
def _encode_type_refs_from_expr(
104+
expr: expression.Expression, child_node: bigframe_node.BigFrameNode
105+
) -> int:
106+
# TODO(b/409387790): Remove this branch once SQLGlot compiler fully replaces Ibis compiler
107+
if not expr.is_resolved:
108+
if isinstance(expr, agg_expressions.Aggregation):
109+
expr = schema_binding._bind_schema_to_aggregation_expr(expr, child_node)
110+
else:
111+
expr = expression.bind_schema_fields(expr, child_node.field_by_id)
17112

113+
result = _get_dtype_mask(expr.output_type)
114+
for child_expr in expr.children:
115+
result = result | _encode_type_refs_from_expr(child_expr, child_node)
18116

19-
def _add_data_type(existing_types: int, curr_type: dtypes.Dtype) -> int:
20-
return existing_types | _get_dtype_mask(curr_type)
117+
return result
21118

22119

23-
def _get_dtype_mask(dtype: dtypes.Dtype) -> int:
120+
def _get_dtype_mask(dtype: dtypes.Dtype | None) -> int:
121+
if dtype is None:
122+
# If the dtype is not given, ignore
123+
return 0
24124
if dtype == dtypes.INT_DTYPE:
25125
return 1 << 1
26126
if dtype == dtypes.FLOAT_DTYPE:

bigframes/core/sql/ml.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,3 +213,14 @@ def global_explain(
213213
sql += _build_struct_sql(struct_options)
214214
sql += ")\n"
215215
return sql
216+
217+
218+
def transform(
219+
model_name: str,
220+
table: str,
221+
) -> str:
222+
"""Encode the ML.TRANSFORM statement.
223+
See https://cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-transform for reference.
224+
"""
225+
sql = f"SELECT * FROM ML.TRANSFORM(MODEL {googlesql.identifier(model_name)}, ({table}))\n"
226+
return sql

0 commit comments

Comments
 (0)