Skip to content

Commit 9a774ac

Browse files
committed
update some unit tests
1 parent 99bb8ec commit 9a774ac

File tree

10 files changed

+158
-110
lines changed

10 files changed

+158
-110
lines changed

bigframes/bigquery/_operations/ai.py

Lines changed: 42 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from bigframes import series, session
2929
from bigframes.core import convert
3030
from bigframes.core.logging import log_adapter
31+
import bigframes.core.sql.literals
3132
from bigframes.ml import core as ml_core
3233
from bigframes.operations import ai_ops, output_schemas
3334

@@ -394,9 +395,11 @@ def generate_embedding(
394395
data: Union[dataframe.DataFrame, series.Series, pd.DataFrame, pd.Series],
395396
*,
396397
output_dimensionality: Optional[int] = None,
398+
task_type: Optional[str] = None,
397399
start_second: Optional[float] = None,
398400
end_second: Optional[float] = None,
399401
interval_seconds: Optional[float] = None,
402+
trial_id: Optional[int] = None,
400403
) -> dataframe.DataFrame:
401404
"""
402405
Creates embeddings that describe an entity—for example, a piece of text or an image.
@@ -414,32 +417,49 @@ def generate_embedding(
414417
Args:
415418
model_name (str):
416419
The name of a remote model over a Vertex AI multimodalembedding@001 model.
417-
data (DataFrame or Series):
418-
The data to generate embeddings for. If a Series is provided, it is treated as the 'content' column.
419-
If a DataFrame is provided, it must contain a 'content' column, or you must rename the column you wish to embed to 'content'.
420+
data (bigframes.pandas.DataFrame or bigframes.pandas.Series):
421+
The data to generate embeddings for. If a Series is provided, it is
422+
treated as the 'content' column. If a DataFrame is provided, it
423+
must contain a 'content' column, or you must rename the column you
424+
wish to embed to 'content'.
420425
output_dimensionality (int, optional):
421-
The number of dimensions to use when generating embeddings. Valid values are 128, 256, 512, and 1408. The default value is 1408.
426+
An INT64 value that specifies the number of dimensions to use when
427+
generating embeddings. For example, if you specify 256 AS
428+
output_dimensionality, then the embedding output column contains a
429+
256-dimensional embedding for each input value. To find the
430+
supported range of output dimensions, read about the available
431+
`Google text embedding models <https://docs.cloud.google.com/vertex-ai/generative-ai/docs/embeddings/get-text-embeddings#google-models>`_.
432+
task_type (str, optional):
433+
A STRING literal that specifies the intended downstream application to
434+
help the model produce better quality embeddings. For a list of
435+
supported task types and how to choose which one to use, see `Choose an
436+
embeddings task type <http://docs.cloud.google.com/vertex-ai/generative-ai/docs/embeddings/task-types>`_.
422437
start_second (float, optional):
423438
The second in the video at which to start the embedding. The default value is 0.
424439
end_second (float, optional):
425440
The second in the video at which to end the embedding. The default value is 120.
426441
interval_seconds (float, optional):
427442
The interval to use when creating embeddings. The default value is 16.
443+
trial_id (int, optional):
444+
An INT64 value that identifies the hyperparameter tuning trial that
445+
you want the function to evaluate. The function uses the optimal
446+
trial by default. Only specify this argument if you ran
447+
hyperparameter tuning when creating the model.
428448
429449
Returns:
430-
bigframes.dataframe.DataFrame:
431-
A new DataFrame with the generated embeddings. It contains the input table columns and the following columns:
432-
* "embedding": an ARRAY<FLOAT64> value that contains the generated embedding vector.
433-
* "status": a STRING value that contains the API response status for the corresponding row.
434-
* "video_start_sec": for video content, an INT64 value that contains the starting second.
435-
* "video_end_sec": for video content, an INT64 value that contains the ending second.
450+
bigframes.pandas.DataFrame:
451+
A new DataFrame with the generated embeddings. See the `SQL
452+
reference for AI.GENERATE_EMBEDDING
453+
<https://docs.cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-ai-generate-embedding#output>`_
454+
for details.
436455
"""
437456
if isinstance(data, (pd.DataFrame, pd.Series)):
438457
data = bpd.read_pandas(data)
439458

440459
if isinstance(data, series.Series):
441-
# Rename series to 'content' and convert to DataFrame
442-
data_df = data.rename("content").to_frame()
460+
data = data.copy()
461+
data.name = "content"
462+
data_df = data.to_frame()
443463
elif isinstance(data, dataframe.DataFrame):
444464
data_df = data
445465
else:
@@ -448,25 +468,27 @@ def generate_embedding(
448468
# We need to get the SQL for the input data to pass as a subquery to the TVF
449469
source_sql = data_df.sql
450470

451-
struct_fields = []
471+
struct_fields = {}
452472
if output_dimensionality is not None:
453-
struct_fields.append(f"{output_dimensionality} AS output_dimensionality")
473+
struct_fields["OUTPUT_DIMENSIONALITY"] = output_dimensionality
474+
if task_type is not None:
475+
struct_fields["TASK_TYPE"] = task_type
454476
if start_second is not None:
455-
struct_fields.append(f"{start_second} AS start_second")
477+
struct_fields["START_SECOND"] = start_second
456478
if end_second is not None:
457-
struct_fields.append(f"{end_second} AS end_second")
479+
struct_fields["END_SECOND"] = end_second
458480
if interval_seconds is not None:
459-
struct_fields.append(f"{interval_seconds} AS interval_seconds")
460-
461-
struct_args = ", ".join(struct_fields)
481+
struct_fields["INTERVAL_SECONDS"] = interval_seconds
482+
if trial_id is not None:
483+
struct_fields["TRIAL_ID"] = trial_id
462484

463485
# Construct the TVF query
464486
query = f"""
465487
SELECT *
466488
FROM AI.GENERATE_EMBEDDING(
467489
MODEL `{model_name}`,
468490
({source_sql}),
469-
STRUCT({struct_args})
491+
{bigframes.core.sql.literals.struct_literal(struct_fields)})
470492
)
471493
"""
472494

bigframes/core/pyformat.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828

2929
from bigframes.core import utils
3030
import bigframes.core.local_data
31+
import bigframes.core.sql.literals
3132
from bigframes.core.tools import bigquery_schema
3233
import bigframes.session
3334

@@ -120,7 +121,7 @@ def _validate_type(name: str, value: Any):
120121

121122
supported_types = (
122123
typing.get_args(_BQ_TABLE_TYPES)
123-
+ typing.get_args(bigframes.core.sql.SIMPLE_LITERAL_TYPES)
124+
+ typing.get_args(bigframes.core.sql.literals.SIMPLE_LITERAL_TYPES)
124125
+ (bigframes.dataframe.DataFrame,)
125126
+ (pandas.DataFrame,)
126127
)

bigframes/core/sql/__init__.py

Lines changed: 1 addition & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -17,91 +17,19 @@
1717
Utility functions for SQL construction.
1818
"""
1919

20-
import datetime
21-
import decimal
2220
import json
23-
import math
2421
from typing import cast, Collection, Iterable, Mapping, Optional, TYPE_CHECKING, Union
2522

26-
import shapely.geometry.base # type: ignore
27-
2823
import bigframes.core.compile.googlesql as googlesql
24+
from bigframes.core.sql.literals import simple_literal
2925

3026
if TYPE_CHECKING:
3127
import google.cloud.bigquery as bigquery
3228

3329
import bigframes.core.ordering
3430

3531

36-
# shapely.wkt.dumps was moved to shapely.io.to_wkt in 2.0.
37-
try:
38-
from shapely.io import to_wkt # type: ignore
39-
except ImportError:
40-
from shapely.wkt import dumps # type: ignore
41-
42-
to_wkt = dumps
43-
44-
45-
SIMPLE_LITERAL_TYPES = Union[
46-
bytes,
47-
str,
48-
int,
49-
bool,
50-
float,
51-
datetime.datetime,
52-
datetime.date,
53-
datetime.time,
54-
decimal.Decimal,
55-
list,
56-
]
57-
58-
5932
### Writing SQL Values (literals, column references, table references, etc.)
60-
def simple_literal(value: Union[SIMPLE_LITERAL_TYPES, None]) -> str:
61-
"""Return quoted input string."""
62-
63-
# https://cloud.google.com/bigquery/docs/reference/standard-sql/lexical#literals
64-
if value is None:
65-
return "NULL"
66-
elif isinstance(value, str):
67-
# Single quoting seems to work nicer with ibis than double quoting
68-
return f"'{googlesql._escape_chars(value)}'"
69-
elif isinstance(value, bytes):
70-
return repr(value)
71-
elif isinstance(value, (bool, int)):
72-
return str(value)
73-
elif isinstance(value, float):
74-
# https://cloud.google.com/bigquery/docs/reference/standard-sql/lexical#floating_point_literals
75-
if math.isnan(value):
76-
return 'CAST("nan" as FLOAT)'
77-
if value == math.inf:
78-
return 'CAST("+inf" as FLOAT)'
79-
if value == -math.inf:
80-
return 'CAST("-inf" as FLOAT)'
81-
return str(value)
82-
# Check datetime first as it is a subclass of date
83-
elif isinstance(value, datetime.datetime):
84-
if value.tzinfo is None:
85-
return f"DATETIME('{value.isoformat()}')"
86-
else:
87-
return f"TIMESTAMP('{value.isoformat()}')"
88-
elif isinstance(value, datetime.date):
89-
return f"DATE('{value.isoformat()}')"
90-
elif isinstance(value, datetime.time):
91-
return f"TIME(DATETIME('1970-01-01 {value.isoformat()}'))"
92-
elif isinstance(value, shapely.geometry.base.BaseGeometry):
93-
return f"ST_GEOGFROMTEXT({simple_literal(to_wkt(value))})"
94-
elif isinstance(value, decimal.Decimal):
95-
# TODO: disambiguate BIGNUMERIC based on scale and/or precision
96-
return f"CAST('{str(value)}' AS NUMERIC)"
97-
elif isinstance(value, list):
98-
simple_literals = [simple_literal(i) for i in value]
99-
return f"[{', '.join(simple_literals)}]"
100-
101-
else:
102-
raise ValueError(f"Cannot produce literal for {value}")
103-
104-
10533
def multi_literal(*values: str):
10634
literal_strings = [simple_literal(i) for i in values]
10735
return "(" + ", ".join(literal_strings) + ")"

bigframes/core/sql/literals.py

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
# Copyright 2026 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import annotations
16+
17+
import datetime
18+
import decimal
19+
import math
20+
from typing import Mapping, Union
21+
22+
import shapely.geometry.base # type: ignore
23+
24+
import bigframes.core.compile.googlesql as googlesql
25+
26+
# shapely.wkt.dumps was moved to shapely.io.to_wkt in 2.0.
27+
try:
28+
from shapely.io import to_wkt # type: ignore
29+
except ImportError:
30+
from shapely.wkt import dumps # type: ignore
31+
32+
to_wkt = dumps
33+
34+
35+
SIMPLE_LITERAL_TYPES = Union[
36+
bytes,
37+
str,
38+
int,
39+
bool,
40+
float,
41+
datetime.datetime,
42+
datetime.date,
43+
datetime.time,
44+
decimal.Decimal,
45+
list,
46+
]
47+
48+
49+
def simple_literal(value: Union[SIMPLE_LITERAL_TYPES, None]) -> str:
50+
"""Return quoted input string."""
51+
52+
# https://cloud.google.com/bigquery/docs/reference/standard-sql/lexical#literals
53+
if value is None:
54+
return "NULL"
55+
elif isinstance(value, str):
56+
# Single quoting seems to work nicer with ibis than double quoting
57+
return f"'{googlesql._escape_chars(value)}'"
58+
elif isinstance(value, bytes):
59+
return repr(value)
60+
elif isinstance(value, (bool, int)):
61+
return str(value)
62+
elif isinstance(value, float):
63+
# https://cloud.google.com/bigquery/docs/reference/standard-sql/lexical#floating_point_literals
64+
if math.isnan(value):
65+
return 'CAST("nan" as FLOAT)'
66+
if value == math.inf:
67+
return 'CAST("+inf" as FLOAT)'
68+
if value == -math.inf:
69+
return 'CAST("-inf" as FLOAT)'
70+
return str(value)
71+
# Check datetime first as it is a subclass of date
72+
elif isinstance(value, datetime.datetime):
73+
if value.tzinfo is None:
74+
return f"DATETIME('{value.isoformat()}')"
75+
else:
76+
return f"TIMESTAMP('{value.isoformat()}')"
77+
elif isinstance(value, datetime.date):
78+
return f"DATE('{value.isoformat()}')"
79+
elif isinstance(value, datetime.time):
80+
return f"TIME(DATETIME('1970-01-01 {value.isoformat()}'))"
81+
elif isinstance(value, shapely.geometry.base.BaseGeometry):
82+
return f"ST_GEOGFROMTEXT({simple_literal(to_wkt(value))})"
83+
elif isinstance(value, decimal.Decimal):
84+
# TODO: disambiguate BIGNUMERIC based on scale and/or precision
85+
return f"CAST('{str(value)}' AS NUMERIC)"
86+
elif isinstance(value, list):
87+
simple_literals = [simple_literal(i) for i in value]
88+
return f"[{', '.join(simple_literals)}]"
89+
90+
else:
91+
raise ValueError(f"Cannot produce literal for {value}")
92+
93+
94+
def struct_literal(struct_options: Mapping[str, SIMPLE_LITERAL_TYPES]) -> str:
95+
rendered_options = []
96+
for option_name, option_value in struct_options.items():
97+
rendered_val = simple_literal(option_value)
98+
rendered_options.append(f"{rendered_val} AS {option_name}")
99+
return f"STRUCT({', '.join(rendered_options)})"

bigframes/core/sql/ml.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
import bigframes.core.compile.googlesql as googlesql
2020
import bigframes.core.sql
21+
import bigframes.core.sql.literals
2122

2223

2324
def create_model_ddl(
@@ -105,11 +106,7 @@ def _build_struct_sql(
105106
if not struct_options:
106107
return ""
107108

108-
rendered_options = []
109-
for option_name, option_value in struct_options.items():
110-
rendered_val = bigframes.core.sql.simple_literal(option_value)
111-
rendered_options.append(f"{rendered_val} AS {option_name}")
112-
return f", STRUCT({', '.join(rendered_options)})"
109+
return f", {bigframes.core.sql.literals.struct_literal}"
113110

114111

115112
def evaluate(

tests/unit/bigquery/test_ai.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -78,11 +78,7 @@ def test_generate_embedding_with_series(mock_series, mock_session):
7878
model_name = "project.dataset.model"
7979

8080
ai_ops.generate_embedding(
81-
model_name,
82-
mock_series,
83-
start_second=0.0,
84-
end_second=10.0,
85-
interval_seconds=5.0
81+
model_name, mock_series, start_second=0.0, end_second=10.0, interval_seconds=5.0
8682
)
8783

8884
mock_series.rename.assert_called_with("content")
@@ -94,7 +90,10 @@ def test_generate_embedding_with_series(mock_series, mock_session):
9490

9591
assert f"MODEL `{model_name}`" in query
9692
assert "(SELECT my_col AS content FROM my_table)" in query
97-
assert "STRUCT(0.0 AS start_second, 10.0 AS end_second, 5.0 AS interval_seconds)" in query
93+
assert (
94+
"STRUCT(0.0 AS start_second, 10.0 AS end_second, 5.0 AS interval_seconds)"
95+
in query
96+
)
9897

9998

10099
def test_generate_embedding_defaults(mock_dataframe, mock_session):
@@ -114,7 +113,9 @@ def test_generate_embedding_defaults(mock_dataframe, mock_session):
114113

115114

116115
@mock.patch("bigframes.pandas.read_pandas")
117-
def test_generate_embedding_with_pandas_dataframe(read_pandas_mock, mock_dataframe, mock_session):
116+
def test_generate_embedding_with_pandas_dataframe(
117+
read_pandas_mock, mock_dataframe, mock_session
118+
):
118119
# This tests that pandas input path works and calls read_pandas
119120
model_name = "project.dataset.model"
120121

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
SELECT * FROM ML.EVALUATE(MODEL `my_model`, STRUCT(False AS perform_aggregation, 10 AS horizon, 0.95 AS confidence_level))
1+
SELECT * FROM ML.EVALUATE(MODEL `my_model`, <function struct_literal at 0x7fed1a1cc4a0>)
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
SELECT * FROM ML.EXPLAIN_PREDICT(MODEL `my_model`, (SELECT * FROM new_data), STRUCT(5 AS top_k_features))
1+
SELECT * FROM ML.EXPLAIN_PREDICT(MODEL `my_model`, (SELECT * FROM new_data), <function struct_literal at 0x7fed1a1cc4a0>)
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
SELECT * FROM ML.GLOBAL_EXPLAIN(MODEL `my_model`, STRUCT(True AS class_level_explain))
1+
SELECT * FROM ML.GLOBAL_EXPLAIN(MODEL `my_model`, <function struct_literal at 0x7fed1a1cc4a0>)
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
SELECT * FROM ML.PREDICT(MODEL `my_model`, (SELECT * FROM new_data), STRUCT(True AS keep_original_columns))
1+
SELECT * FROM ML.PREDICT(MODEL `my_model`, (SELECT * FROM new_data), <function struct_literal at 0x7fed1a1cc4a0>)

0 commit comments

Comments
 (0)