Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
177 changes: 127 additions & 50 deletions datajunction-server/datajunction_server/construction/build_v3/cte.py
Original file line number Diff line number Diff line change
Expand Up @@ -1108,40 +1108,79 @@ def _inject_filter_into_where(
inject_filter_into_select(cast(ast.Select, query_ast.select), filter_expr)


def _setop_arms(cte_query: ast.Query) -> list[ast.Select]:
"""Return every Select arm in a (possibly chained) set-operation CTE body.

A non-set-op CTE returns a single-element list containing its Select.
A UNION/INTERSECT/EXCEPT chain returns each arm in order — the first arm
is ``cte_query.select`` and subsequent arms hang off ``set_op.right``.
"""
arms: list[ast.Select] = []
cur: Optional[ast.Select] = (
cte_query.select if isinstance(cte_query.select, ast.Select) else None
)
while cur is not None:
arms.append(cur)
nxt = cur.set_op.right if cur.set_op else None
# ``set_op.right`` is typed as SelectExpression which may also be a
# parenthesised subquery; in practice DJ always emits a Select here.
cur = nxt if isinstance(nxt, ast.Select) else None
return arms


def _resolve_pushdown_filters_for_cte(
node: "Node",
cte_query: ast.Query,
pushdown_filters: list[str],
filter_column_aliases: dict[str, str],
) -> list[ast.Expression]:
"""Determine which user filters can be pushed into this CTE and return them.
) -> tuple[list[tuple[ast.Select, ast.Expression]], set[str]]:
"""Determine which user filters can be pushed into this CTE.

For each filter, extracts the dimension references, resolves them to bare
column names via filter_column_aliases, and checks whether this CTE outputs
those columns. If all referenced columns are present, the filter is rewritten
using the CTE's internal table-qualified column names and returned.

Returns a list of parsed filter AST expressions ready for injection.
using the CTE's internal table-qualified column names and returned, paired
with the Select node it should be injected into.

For set-operation CTEs (UNION / INTERSECT / EXCEPT), the rewrite is done
independently per arm — each arm's projection may resolve the same output
column differently, and pushing only into the first arm would produce
semantically wrong SQL. The push is atomic per filter: if any arm can't
accept the rewrite, the filter is skipped for the whole CTE.

Returns ``(injections, consumed)`` where ``injections`` is the list of
``(target_select, rewritten_filter)`` pairs to inject and ``consumed``
is the set of original filter strings that were successfully pushed
into the CTE (so the caller can drop them from outer-level WHERE).
"""
node_output_cols = (
{col.name for col in (node.current.columns or [])} if node.current else set()
)

if not node_output_cols: # pragma: no cover
return []
return [], set()

results: list[ast.Expression] = []
arms = _setop_arms(cte_query)
results: list[tuple[ast.Select, ast.Expression]] = []
consumed: set[str] = set()
for filter_str in pushdown_filters:
rewritten = _rewrite_filter_for_cte(
filter_str,
filter_column_aliases,
node_output_cols,
cte_query,
)
if rewritten is None:
continue
results.append(rewritten)
return results
per_arm: list[tuple[ast.Select, ast.Expression]] = []
all_ok = True
for arm in arms:
rewritten = _rewrite_filter_for_select(
filter_str,
filter_column_aliases,
node_output_cols,
arm,
)
if rewritten is None:
all_ok = False
break
per_arm.append((arm, rewritten))
if all_ok:
results.extend(per_arm)
consumed.add(filter_str)
return results, consumed


def _cte_has_set_operation(cte_query: ast.Query) -> bool:
Expand Down Expand Up @@ -1197,32 +1236,51 @@ def _rewrite_filter_for_cte(
) -> ast.Expression | None:
"""Rewrite a dimension filter for injection into a specific CTE.

For non-set-op CTEs, delegates to :func:`_rewrite_filter_for_select`
against the single Select arm. Set-op CTEs are not handled here —
callers that need per-arm rewrites use :func:`_setop_arms` and call
:func:`_rewrite_filter_for_select` per arm.
"""
if _cte_has_set_operation(cte_query):
return None
return _rewrite_filter_for_select(
filter_str,
filter_column_aliases,
cte_output_cols,
cast(ast.Select, cte_query.select),
)


def _rewrite_filter_for_select(
filter_str: str,
filter_column_aliases: dict[str, str],
cte_output_cols: set[str],
cte_select: ast.Select,
) -> ast.Expression | None:
"""Rewrite a dimension filter for injection into a specific Select.

Resolves each dimension reference (e.g., ``v3.product.category``) to the
form that's safe in the CTE's WHERE clause. Three projection cases:

1. CTE projects the column as a simple (possibly aliased) column: replace
with the underlying qualified form (e.g., ``p.category``). This is the
correctness-critical case — emitting a SELECT-list alias in WHERE is
rejected by Spark SQL and standard SQL.
2. CTE doesn't project the column at all (pruned): fall through to the
bare column name. Safe because the CTE's underlying source exposes
the column, even if it's not selected into the outer query.
3. CTE projects the column via a non-column expression (e.g.
form that's safe in the Select's WHERE clause. Three projection cases:

1. Select projects the column as a simple (possibly aliased) column:
replace with the underlying qualified form (e.g., ``p.category``).
This is the correctness-critical case — emitting a SELECT-list alias
in WHERE is rejected by Spark SQL and standard SQL.
2. Select doesn't project the column at all (pruned): fall through to
the bare column name. Safe because the Select's underlying source
still exposes the column, even if it's not selected.
3. Select projects the column via a non-column expression (e.g.
``SUM(x) AS y``): skip — inlining is unsafe.

Multi-predicate handling: a single filter may reference several dim refs
(``a.x = 1 OR b.y = 2``). All matching refs are rewritten, but if ANY
ref's column isn't exposed by this CTE, the whole filter is skipped —
pushing a partial OR-predicate into the wrong CTE produces invalid SQL.
ref's column isn't exposed by this Select, the whole filter is skipped —
pushing a partial OR-predicate into the wrong scope produces invalid SQL.

Returns the rewritten filter AST, or None when the filter can't be
safely pushed into this CTE.
safely pushed into this Select.
"""
# Set-operation CTEs can't be safely pushed into via the first arm alone.
if _cte_has_set_operation(cte_query):
return None

projection_map = _build_cte_projection_map(cte_query)
projection_map = _build_select_projection_map(cte_select)
filter_ast = parse_filter(filter_str)

# First pass: plan the rewrites by walking the AST. Role-qualified refs
Expand Down Expand Up @@ -1303,19 +1361,32 @@ def _rewrite_filter_for_cte(
def _build_cte_projection_map(cte_query: ast.Query) -> dict[str, str | None]:
"""Map a CTE's output column name to its underlying qualified reference.

Thin wrapper around :func:`_build_select_projection_map`; kept for
callers that already hold an ``ast.Query``.
"""
if not cte_query.select: # pragma: no cover
return {}
return _build_select_projection_map(cast(ast.Select, cte_query.select))


def _build_select_projection_map(
cte_select: ast.Select,
) -> dict[str, str | None]:
"""Map a Select's projection output names to underlying qualified refs.

Output name is the SELECT-list alias when present, else the bare column
name. Value is either:

- A string — the form that's safe to reference in a WHERE clause pushed
into this CTE (a table-qualified column when qualified in the
into this Select (a table-qualified column when qualified in the
projection, else the bare column name).
- ``None`` — the projection is a non-column expression under an alias
(e.g., ``SUM(x) AS y``); pushdown should skip this CTE to avoid
inlining an expression that would be semantically wrong in WHERE.
(e.g., ``SUM(x) AS y``); pushdown should skip to avoid inlining an
expression that would be semantically wrong in WHERE.

Columns that the CTE doesn't project at all are absent from the map;
callers should treat that as "fall through to the bare name" since the
CTE's underlying source still exposes them.
Columns that aren't projected at all are absent from the map; callers
should treat that as "fall through to the bare name" since the Select's
underlying source still exposes them.

Examples::

Expand All @@ -1329,9 +1400,7 @@ def _build_cte_projection_map(cte_query: ast.Query) -> dict[str, str | None]:
becomes {"total": None}
"""
result: dict[str, str | None] = {}
if not cte_query.select: # pragma: no cover
return result
for expr in cte_query.select.projection:
for expr in cte_select.projection:
inner = getattr(expr, "child", expr)
alias = getattr(expr, "alias", None)
# Unwrap CAST around a column — CAST(col AS T) is a transparent
Expand Down Expand Up @@ -1365,7 +1434,7 @@ def collect_node_ctes(
needed_columns_by_node: Optional[dict[str, set[str]]] = None,
injected_filters: Optional[dict[str, ast.Expression]] = None,
pushdown: Optional[PushdownFilters] = None,
) -> tuple[list[tuple[str, ast.Query]], list[str]]:
) -> tuple[list[tuple[str, ast.Query]], list[str], dict[str, set[str]]]:
"""
Collect CTEs for all non-source nodes, recursively expanding table references.

Expand All @@ -1389,14 +1458,19 @@ def collect_node_ctes(
them on the outer query after an expensive join.

Returns:
Tuple of (cte_list, scanned_sources):
Tuple of (cte_list, scanned_sources, consumed_by_node):
- cte_list: List of (cte_name, query_ast) tuples in dependency order
- scanned_sources: List of source node names encountered during traversal
- consumed_by_node: Map of node_name -> set of filter strings that were
successfully pushed into that node's CTE WHERE clause. Callers use
this to drop redundant copies from outer-level WHERE.
"""
# Collect all node names that need CTEs (including transitive dependencies)
all_node_names: set[str] = set()
# Track source nodes encountered during traversal
scanned_source_names: set[str] = set()
# Map of node_name -> filter strings successfully pushed into that CTE
consumed_by_node: dict[str, set[str]] = {}
mat_check_time = 0.0
parse_check_time = 0.0
ref_extract_time = 0.0
Expand Down Expand Up @@ -1508,17 +1582,20 @@ def collect_refs(node: Node, visited: set[str]) -> None:
_inject_filter_into_where(query_ast, injected_filters[node.name])

if pushdown:
for filter_ast in _resolve_pushdown_filters_for_cte(
injections, consumed = _resolve_pushdown_filters_for_cte(
node,
query_ast,
pushdown.filters,
pushdown.column_aliases,
):
_inject_filter_into_where(query_ast, filter_ast)
)
for target_select, filter_ast in injections:
inject_filter_into_select(target_select, filter_ast)
if consumed:
consumed_by_node[node.name] = consumed

ctes.append((cte_name, query_ast))

return ctes, list(scanned_source_names)
return ctes, list(scanned_source_names), consumed_by_node


def process_metric_combiner_expression(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -685,6 +685,10 @@ def _apply_outer_where_atoms(
The same predicate has already been pushed into the parent CTE's
WHERE, so the outer-level copy is redundant *and* unsafe (it
would defeat downstream RIGHT/FULL OUTER joins to dims).
``parent_pushdown_active`` should be set only when the caller
knows pushdown actually succeeded for the parent CTE — callers
that aren't sure should pre-strip un-pushed filter strings from
``where_clause`` instead of relying on this flag.

- **Parent-alias atoms** when parent pushdown wasn't applied (e.g.
parent is materialized or has a set-op body that blocks
Expand Down Expand Up @@ -1156,17 +1160,32 @@ def build_dimension_joins(
) -> tuple[dict[tuple[str, Optional[str]], str], list[ast.Join]]:
"""Build JOIN clauses for non-local dimensions.

Walks each resolved dimension's join path, deduplicating joins when two
dimensions share a common prefix (same dimension node + accumulated role).
Chains whose first link is a null-padding join (RIGHT / FULL OUTER) are
processed after chains whose first link is source-preserving (INNER /
LEFT / CROSS). This keeps each dim join's ON clause evaluating against
live source columns: an intervening RIGHT/FULL OUTER off the same source
won't run first and null-pad the columns a sibling chain depends on.
Chain locality is preserved — each chain still emits its links
contiguously.

Returns:
(dim_aliases, joins) where dim_aliases maps (node_name, accumulated_role)
to the table alias used in the JOIN.
"""
from datajunction_server.models.dimensionlink import JoinType

def _chain_bucket(rdim: ResolvedDimension) -> int:
if rdim.is_local or not rdim.join_path or not rdim.join_path.links:
return 0
jt = rdim.join_path.links[0].join_type
return 1 if jt in (JoinType.RIGHT, JoinType.FULL) else 0

ordered_dimensions = sorted(resolved_dimensions, key=_chain_bucket)

dim_aliases: dict[tuple[str, Optional[str]], str] = {}
joins: list[ast.Join] = []

for resolved_dim in resolved_dimensions:
for resolved_dim in ordered_dimensions:
if not resolved_dim.is_local and resolved_dim.join_path:
current_left_alias = main_alias
accumulated_role_parts: list[str] = []
Expand All @@ -1183,7 +1202,7 @@ def build_dimension_joins(

dim_key = (dim_node_name, accumulated_role)

if dim_key not in dim_aliases: # pragma: no branch
if dim_key not in dim_aliases:
if accumulated_role:
alias_base = accumulated_role.replace("->", "_")
else:
Expand Down Expand Up @@ -1404,7 +1423,7 @@ def build_select_ast(
# Build CTEs for all non-source nodes. Requested dimension filters and their
# resolution map are passed through so each CTE can independently decide
# whether to push a filter into its own WHERE clause.
ctes, scanned_sources = collect_node_ctes(
ctes, scanned_sources, consumed_by_node = collect_node_ctes(
ctx,
nodes_for_ctes,
needed_columns_by_node,
Expand All @@ -1416,6 +1435,10 @@ def build_select_ast(
if all_filters
else None,
)
# Surface all CTE-consumed filters so the metrics layer's outer WHERE
# can skip re-applying them on top of the aggregation CTEs.
for _consumed_set in consumed_by_node.values():
ctx.cte_consumed_filters.update(_consumed_set)

# Build SELECT.
# For non-decomposable metrics, skip GROUP BY to pass through raw rows.
Expand All @@ -1427,16 +1450,39 @@ def build_select_ast(
hints=spark_hints if spark_hints else None,
)

# Apply outer WHERE atoms. Parent-alias atoms are dropped from outer
# WHERE when the parent CTE pushdown is active (filter is already in
# the parent CTE; the outer copy is redundant and would defeat
# downstream OUTER joins). Dim-alias atoms continue to land in
# outer WHERE for standard dim-filter semantics. See
# ``_apply_outer_where_atoms`` for the full rule.
parent_pushdown_active = bool(all_filters) and not should_use_materialized_table(
# Determine which filters were actually consumed by the parent CTE's
# pushdown. Parent-alias atoms from those filters are dropped from the
# outer WHERE (the parent CTE already has them; double-application is
# redundant and can defeat downstream RIGHT/FULL OUTER joins).
# Filters that weren't consumed (couldn't be pushed — e.g. their column
# is projected as a non-column expression) still need to land at the
# outer WHERE so they aren't silently dropped.
parent_consumed = consumed_by_node.get(parent_node.name, set())
parent_pushdown_active = bool(
parent_consumed,
) and not should_use_materialized_table(
ctx,
parent_node,
)
# If pushdown was partial, strip the consumed filters from
# ``where_clause`` before applying the outer atoms — keeps unconsumed
# atoms while dropping the redundant copies of consumed ones.
if parent_consumed and parent_consumed != set(all_filters):
remaining = [f for f in all_filters if f not in parent_consumed]
where_clause = (
build_outer_where(
remaining,
filter_column_aliases,
resolved_dimensions,
main_alias,
dim_aliases,
parent_node,
nodes=ctx.nodes,
)
if remaining
else None
)
parent_pushdown_active = False
# Absorb LEFT/INNER-joined dims whose filter would defeat a
# downstream RIGHT/FULL OUTER JOIN into a filtered CTE on the
# parent side. Mutates ``select.from_`` (re-pointing the parent
Expand Down
Loading
Loading