Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
221 changes: 212 additions & 9 deletions datajunction-server/datajunction_server/construction/build_v3/cte.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from __future__ import annotations

from copy import deepcopy
from typing import Optional
from typing import Optional, cast

from datajunction_server.construction.build_v3.filters import (
extract_subscript_role,
Expand Down Expand Up @@ -904,18 +904,208 @@ def flatten_inner_ctes(
# ---------------------------------------------------------------------------


_OUTER_JOIN_KINDS_LEFT = {"LEFT", "LEFT OUTER"}
_OUTER_JOIN_KINDS_RIGHT = {"RIGHT", "RIGHT OUTER"}
_OUTER_JOIN_KINDS_FULL = {"FULL", "FULL OUTER", "OUTER"}


def _get_relation_side_id(expr: ast.Expression) -> Optional[str]:
"""Extract the alias or unqualified name used to qualify column refs to a side.

Returns ``None`` when the expression has no extractable identifier (e.g.
an unaliased subquery — defensive only; the v3 builder always emits
aliased relations).
"""
alias = getattr(expr, "alias", None)
if alias is not None:
return alias.name
if isinstance(expr, ast.Table):
return expr.name.name
return None # pragma: no cover


def _filter_namespaces(filter_ast: ast.Expression) -> set[str]:
"""Collect the namespace identifiers used to qualify columns in a filter."""
return {
col.name.namespace.name
for col in filter_ast.find_all(ast.Column)
if col.name.namespace is not None
}


def _classify_outer_join_target(
relation: ast.Relation,
filter_namespaces: set[str],
) -> Optional[ast.Expression]:
"""
Return the relation-side expression that should host a filter as a wrapped
subquery, or ``None`` when WHERE injection is safe.

This function walks the join chain and classifies each side, asking the
question: "Does this join introduce NULL-fill on either side?"
- LEFT OUTER: yes, on the right, so right side is non-preserved
- RIGHT OUTER: yes, on the left, so left side is non-preserved
- FULL OUTER: yes, on both sides, so both sides are non-preserved
- INNER JOIN: no, both sides are preserved

This walks the relation's joins, tracks which side identifiers
are non-preserved, and returns the target expression to wrap when all of
the filter's namespace references resolve to a single non-preserved side.
"""
if not relation.extensions or not filter_namespaces:
return None

# ``_get_relation_side_id`` returns None only for unaliased subqueries the
# v3 builder never emits — branches guarding ``is not None`` are
# defensive only, hence the no-branch pragmas on this loop.
preserved: dict[str, ast.Expression] = {}
non_preserved: dict[str, ast.Expression] = {}
primary_id = _get_relation_side_id(relation.primary)
if primary_id is not None: # pragma: no branch
preserved[primary_id] = relation.primary

for join in relation.extensions:
right_id = _get_relation_side_id(join.right)
kind = (join.join_type or "").upper().strip()

if kind in _OUTER_JOIN_KINDS_LEFT:
if right_id is not None: # pragma: no branch
non_preserved[right_id] = join.right
elif kind in _OUTER_JOIN_KINDS_RIGHT:
non_preserved.update(preserved)
preserved = {}
if right_id is not None: # pragma: no branch
preserved[right_id] = join.right
elif kind in _OUTER_JOIN_KINDS_FULL:
non_preserved.update(preserved)
if right_id is not None: # pragma: no branch
non_preserved[right_id] = join.right
elif right_id is not None: # pragma: no branch
preserved[right_id] = join.right

if not non_preserved or not filter_namespaces.issubset(non_preserved):
return None
targets = {non_preserved[ns] for ns in filter_namespaces}
return next(iter(targets)) if len(targets) == 1 else None


def _wrap_relation_side_with_filter(
target_expr: ast.Expression,
filter_ast: ast.Expression,
) -> ast.Query:
"""Wrap a relation-side expression as ``(SELECT * FROM <expr> WHERE filter) <alias>``.

The original alias is preserved on the outer wrapping query so column refs
qualified by the original side identifier continue to resolve.
"""
side_id = _get_relation_side_id(target_expr)
inner_select = ast.Select(
projection=[ast.Wildcard()],
from_=ast.From(relations=[ast.Relation(primary=deepcopy(target_expr))]),
where=deepcopy(filter_ast),
)
wrapped = ast.Query(select=inner_select)
wrapped.parenthesized = True
# Mark this Query as an outer-join-safety wrap so subsequent filters
# targeting the same side AND into its WHERE instead of nesting another
# layer. Inside the wrap, the original alias is preserved on the inner
# FROM, so re-qualifying isn't necessary.
wrapped._outer_join_filter_wrap = True # type: ignore[attr-defined]
if side_id is not None: # pragma: no branch
wrapped.alias = ast.Name(side_id)
wrapped.as_ = False
return wrapped


def _try_push_filter_into_outer_join_side(
select: ast.Select,
filter_expr: ast.Expression,
) -> bool:
"""Route a filter into the inner-side of an outer join when WHERE would be unsafe.

Returns True when the filter has been applied via inner-side wrapping
(caller must NOT also AND it into WHERE). Returns False when no outer-join
hazard applies, leaving the caller to fall back to standard WHERE injection.
"""
if select.from_ is None:
return False
namespaces = _filter_namespaces(filter_expr)
if not namespaces:
return False

for relation in select.from_.relations:
target = _classify_outer_join_target(relation, namespaces)
if target is None:
continue
# If the target is already an outer-join-safety wrap from a previous
# call, AND the new atom into its inner WHERE rather than nesting
# another wrap layer. The inner FROM preserves the original alias,
# so column refs qualified by that alias resolve correctly inside.
if isinstance(target, ast.Query) and getattr(
target,
"_outer_join_filter_wrap",
False,
):
inner_select = cast(ast.Select, target.select)
# The wrap was created with a non-None WHERE; the cast keeps mypy
# happy without a runtime branch.
existing_where = cast(ast.Expression, inner_select.where)
inner_select.where = ast.BinaryOp.And(
existing_where,
deepcopy(filter_expr),
)
return True
wrapped = _wrap_relation_side_with_filter(target, filter_expr)
if relation.primary is target:
relation.primary = wrapped
else:
for join in relation.extensions: # pragma: no branch
if join.right is target: # pragma: no branch
join.right = wrapped
break
return True
return False


def _split_and_atoms(expr: ast.Expression) -> list[ast.Expression]:
"""Flatten an AND-tree into its leaf predicates."""
if (
isinstance(expr, ast.BinaryOp) and expr.op == ast.BinaryOpKind.And # type: ignore[attr-defined]
):
return _split_and_atoms(expr.left) + _split_and_atoms(expr.right)
return [expr]


def inject_filter_into_select(
select: ast.Select,
filter_expr: ast.Expression,
) -> None:
"""AND a filter expression into a SELECT's WHERE clause, safely.

When the SELECT's FROM contains an OUTER JOIN whose non-preserved side
owns every column referenced by ``filter_expr``, the filter is instead
wrapped around that inner-side relation as a subquery — preventing the
WHERE from silently turning the OUTER JOIN into an INNER JOIN.

AND-trees are split so each atom is classified independently: a single
bundled WHERE may have some atoms safe for WHERE and others that need
inner-side wrapping.
"""
for atom in _split_and_atoms(filter_expr):
if _try_push_filter_into_outer_join_side(select, atom):
continue
if select.where:
select.where = ast.BinaryOp.And(select.where, atom)
else:
select.where = atom


def _inject_filter_into_where(
query_ast: ast.Query,
filter_expr: ast.Expression,
) -> None:
"""AND a filter expression into a query's WHERE clause."""
if query_ast.select.where:
query_ast.select.where = ast.BinaryOp.And(
query_ast.select.where,
filter_expr,
)
else:
query_ast.select.where = filter_expr
"""AND a filter into a Query's WHERE — wrapper around :func:`inject_filter_into_select`."""
inject_filter_into_select(cast(ast.Select, query_ast.select), filter_expr)


def _resolve_pushdown_filters_for_cte(
Expand Down Expand Up @@ -1133,6 +1323,8 @@ def _build_cte_projection_map(cte_query: ast.Query) -> dict[str, str | None]:
becomes {"order_date": "o.placed_on"}
SELECT T.test_id
becomes {"test_id": "T.test_id"}
SELECT CAST(x AS INT) AS x
becomes {"x": "x"} -- CAST is a transparent passthrough
SELECT SUM(x) AS total
becomes {"total": None}
"""
Expand All @@ -1142,6 +1334,17 @@ def _build_cte_projection_map(cte_query: ast.Query) -> dict[str, str | None]:
for expr in cte_query.select.projection:
inner = getattr(expr, "child", expr)
alias = getattr(expr, "alias", None)
# Unwrap CAST around a column — CAST(col AS T) is a transparent
# passthrough for filter pushdown. Equality, IN, and range
# comparisons against a literal of the cast target type are
# value-equivalent to the same comparison against the unwrapped
# column (modulo lossy casts, which are rare and generally
# intentional in the projection). Recognising this lets filters
# land in CTEs whose projection wraps a column in CAST for type
# normalization (e.g. ``CAST(x AS INT) AS x`` from a LATERAL
# VIEW EXPLODE'd column).
if isinstance(inner, ast.Cast) and isinstance(inner.expression, ast.Column):
inner = inner.expression
if isinstance(inner, ast.Column) and inner.name:
bare = inner.name.name
underlying = (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,7 @@ def resolve_dimensions(
role=dim_ref.role,
join_path=None,
is_local=True,
pre_skip_join_path=join_path,
),
)
elif can_skip and local_col and hops_skipped > 0:
Expand Down
Loading
Loading