Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions sqlglot/expressions/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,24 @@ class AIClassify(Expression, Func):
_sql_names = ["AI_CLASSIFY"]


class AIEmbed(Expression, Func):
arg_types = {"expressions": False}
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this needs to be True, as at least one argument is expected. The same holds for AISimilarity and AIGenerate.

is_var_len_args = True
_sql_names = ["EMBED"]
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why did you add _sql_names here and in the other two AST nodes? I don't think we want these. You should instead override the nodes' generators in BigQuery and map them to different names using rename_func.



class AISimilarity(Expression, Func):
arg_types = {"expressions": False}
is_var_len_args = True
_sql_names = ["SIMILARITY"]


class AIGenerate(Expression, Func):
arg_types = {"expressions": False}
is_var_len_args = True
_sql_names = ["GENERATE"]


class FeaturesAtTime(Expression, Func):
arg_types = {"this": True, "time": False, "num_rows": False, "ignore_feature_nulls": False}

Expand Down
10 changes: 10 additions & 0 deletions sqlglot/parsers/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -730,6 +730,16 @@ def _parse_column_ops(self, this: exp.Expr | None) -> exp.Expr | None:
self._retreat(func_index)
parsed = self._parse_function(any_token=True)
if parsed:
if prefix == "AI" and isinstance(parsed, exp.Anonymous):
ai_scalars: dict[str, type[exp.Func]] = {
"EMBED": exp.AIEmbed,
"SIMILARITY": exp.AISimilarity,
"GENERATE": exp.AIGenerate,
}
Comment on lines +734 to +738
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this necessary if you have proper function parsers (e.g., using FUNCTIONS) for EMBED et al? Check if there's an overlap between these and non-AI-prefixed functions in BigQuery and if that's the case, then try my suggestion to see if we can simplify this.

As a side-note, we generally don't define constants like this mapping inline, but instead "bubble them up" in the parser class.

expr_type = ai_scalars.get(parsed.name.upper())
if expr_type:
parsed = expr_type(expressions=parsed.expressions)

this = self.expression(exp.Dot(this=this.this, expression=parsed))

return this
12 changes: 12 additions & 0 deletions tests/dialects/test_bigquery.py
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should increase the test coverage a bit by including more validate_identity tests that cover more args. For example, the EMBED function has the following syntax:

AI.EMBED(
  [ content => ] 'content',
  { endpoint => 'endpoint' | model => 'model' }
  [, task_type => 'task_type']
  [, title => 'title']
  [, model_params => model_params]
  [, connection_id => 'connection']
)

ideally, we want some representative tests with more arguments.

Original file line number Diff line number Diff line change
Expand Up @@ -2445,6 +2445,18 @@ def test_ml_functions(self):
"SELECT AI.GENERATE_BOOL(MODEL `mydataset.gemini_model`, 'Is sky blue?')"
)

ast = self.validate_identity("SELECT AI.EMBED('hello')")
assert isinstance(ast.expressions[0], exp.Dot)
assert isinstance(ast.expressions[0].expression, exp.AIEmbed)

ast = self.validate_identity("SELECT AI.SIMILARITY('a', 'b')")
assert isinstance(ast.expressions[0], exp.Dot)
assert isinstance(ast.expressions[0].expression, exp.AISimilarity)

ast = self.validate_identity("SELECT AI.GENERATE('Write a haiku')")
assert isinstance(ast.expressions[0], exp.Dot)
assert isinstance(ast.expressions[0].expression, exp.AIGenerate)

def test_merge(self):
self.validate_all(
"""
Expand Down
Loading