Skip to content

Commit e17ea37

Browse files
committed
refactor: fix test_concat_dataframe by de-deuplicating merged ctes
1 parent a80ac3f commit e17ea37

File tree

2 files changed

+22
-17
lines changed

2 files changed

+22
-17
lines changed

bigframes/core/compile/sqlglot/compiler.py

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,6 @@
4141
def compile_sql(request: configs.CompileRequest) -> configs.CompileResult:
4242
"""Compiles a BigFrameNode according to the request into SQL using SQLGlot."""
4343

44-
# Generator for unique identifiers.
45-
uid_gen = guid.SequentialUIDGenerator()
4644
output_names = tuple((expression.DerefOp(id), id.sql) for id in request.node.ids)
4745
result_node = nodes.ResultNode(
4846
request.node,
@@ -61,22 +59,16 @@ def compile_sql(request: configs.CompileRequest) -> configs.CompileResult:
6159
)
6260
if request.sort_rows:
6361
result_node = typing.cast(nodes.ResultNode, rewrite.column_pruning(result_node))
64-
result_node = _remap_variables(result_node, uid_gen)
65-
result_node = typing.cast(
66-
nodes.ResultNode, rewrite.defer_selection(result_node)
67-
)
68-
sql = _compile_result_node(result_node, uid_gen)
62+
sql = _compile_result_node(result_node)
6963
return configs.CompileResult(
7064
sql, result_node.schema.to_bigquery(), result_node.order_by
7165
)
7266

7367
ordering: typing.Optional[bf_ordering.RowOrdering] = result_node.order_by
7468
result_node = dataclasses.replace(result_node, order_by=None)
7569
result_node = typing.cast(nodes.ResultNode, rewrite.column_pruning(result_node))
70+
sql = _compile_result_node(result_node)
7671

77-
result_node = _remap_variables(result_node, uid_gen)
78-
result_node = typing.cast(nodes.ResultNode, rewrite.defer_selection(result_node))
79-
sql = _compile_result_node(result_node, uid_gen)
8072
# Return the ordering iff no extra columns are needed to define the row order
8173
if ordering is not None:
8274
output_order = (
@@ -97,11 +89,16 @@ def _remap_variables(
9789
return typing.cast(nodes.ResultNode, result_node)
9890

9991

100-
def _compile_result_node(
101-
root: nodes.ResultNode, uid_gen: guid.SequentialUIDGenerator
102-
) -> str:
92+
def _compile_result_node(root: nodes.ResultNode) -> str:
93+
# Create UIDs to standardize variable names and ensure consistent compilation
94+
# of nodes using the same generator.
95+
uid_gen = guid.SequentialUIDGenerator()
96+
root = _remap_variables(root, uid_gen)
97+
root = typing.cast(nodes.ResultNode, rewrite.defer_selection(root))
98+
10399
# Have to bind schema as the final step before compilation.
104100
root = typing.cast(nodes.ResultNode, schema_binding.bind_schema_to_tree(root))
101+
105102
selected_cols: tuple[tuple[str, sge.Expression], ...] = tuple(
106103
(name, scalar_compiler.scalar_op_compiler.compile_expression(ref))
107104
for ref, name in root.output_cols
@@ -127,7 +124,6 @@ def _compile_result_node(
127124
return sqlglot_ir.sql
128125

129126

130-
@functools.lru_cache(maxsize=5000)
131127
def compile_node(
132128
node: nodes.BigFrameNode, uid_gen: guid.SequentialUIDGenerator
133129
) -> ir.SQLGlotIR:

bigframes/core/compile/sqlglot/sqlglot_ir.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,7 @@ def from_union(
206206

207207
select_expr = select.copy()
208208
select_expr, select_ctes = _pop_query_ctes(select_expr)
209-
existing_ctes = [*existing_ctes, *select_ctes]
209+
existing_ctes = _merge_ctes(existing_ctes, select_ctes)
210210
union_selects.append(select_expr)
211211

212212
union_expr: sge.Query = union_selects[0].subquery()
@@ -337,7 +337,7 @@ def join(
337337

338338
left_select, left_ctes = _pop_query_ctes(left_select)
339339
right_select, right_ctes = _pop_query_ctes(right_select)
340-
merged_ctes = [*left_ctes, *right_ctes]
340+
merged_ctes = _merge_ctes(left_ctes, right_ctes)
341341

342342
join_on = _and(
343343
tuple(
@@ -374,7 +374,7 @@ def isin_join(
374374

375375
left_select, left_ctes = _pop_query_ctes(left_select)
376376
right_select, right_ctes = _pop_query_ctes(right_select)
377-
merged_ctes = [*left_ctes, *right_ctes]
377+
merged_ctes = _merge_ctes(left_ctes, right_ctes)
378378

379379
left_condition = typed_expr.TypedExpr(
380380
sge.Column(this=conditions[0].expr, table=left_cte_name),
@@ -827,6 +827,15 @@ def _set_query_ctes(
827827
return new_expr
828828

829829

830+
def _merge_ctes(ctes1: list[sge.CTE], ctes2: list[sge.CTE]) -> list[sge.CTE]:
831+
"""Merges two lists of CTEs, de-duplicating by alias name."""
832+
seen = {cte.alias: cte for cte in ctes1}
833+
for cte in ctes2:
834+
if cte.alias not in seen:
835+
seen[cte.alias] = cte
836+
return list(seen.values())
837+
838+
830839
def _pop_query_ctes(
831840
expr: sge.Select,
832841
) -> tuple[sge.Select, list[sge.CTE]]:

0 commit comments

Comments
 (0)