Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from copy import deepcopy
from typing import Optional

from datajunction_server.construction.build_v3.filters import extract_subscript_role
from datajunction_server.construction.build_v3.materialization import (
get_table_reference_parts_with_materialization,
should_use_materialized_table,
Expand Down Expand Up @@ -261,15 +262,7 @@ def replace_dimension_refs_in_ast(
if not base_col_name: # pragma: no cover
continue

# Get the role from the index (e.g., "order")
role = None
if isinstance(subscript.index, ast.Column):
role = subscript.index.name.name if subscript.index.name else None
elif isinstance(subscript.index, ast.Name): # pragma: no cover
role = subscript.index.name # pragma: no cover
elif hasattr(subscript.index, "name"): # pragma: no cover
role = str(subscript.index.name) # type: ignore

role = extract_subscript_role(subscript)
if not role: # pragma: no cover
continue

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,29 @@ def parse_filter(filter_str: str) -> ast.Expression:
return query.select.where


def extract_subscript_role(subscript: ast.Subscript) -> str | None:
"""
Extract the role string from a subscript index node.

Handles the three forms that can appear as a subscript index:
- ast.Column: simple role like "order" (e.g., "v3.date.year[order]")
- ast.Name: simple role like "order" (fallback if parser produces Name instead of Column)
- ast.Lambda: multi-hop role (e.g., "v3.user[customer->home]")

Returns the role string, or None if the index is not a recognised form.
"""
# simple role like "dim.attr[order]"
if isinstance(subscript.index, ast.Column):
return subscript.index.name.name if subscript.index.name else None
# simple role like "dim.attr[order]"
if isinstance(subscript.index, ast.Name): # pragma: no cover
return subscript.index.name
# multi-hop role like "dim.attr[customer->home]"
if isinstance(subscript.index, ast.Lambda):
return str(subscript.index)
return None # pragma: no cover


def resolve_filter_references(
filter_ast: ast.Expression,
column_aliases: dict[str, str],
Expand Down Expand Up @@ -81,17 +104,7 @@ def resolve_filter_references(
if not base_col_ref:
continue # pragma: no cover

# Extract the role from the subscript index
role = None
if isinstance(subscript.index, ast.Column):
role = subscript.index.name.name if subscript.index.name else None
elif isinstance(subscript.index, ast.Name): # pragma: no cover
role = subscript.index.name
elif isinstance(subscript.index, ast.Lambda):
# Multi-hop role notation like "customer->home" is parsed as a Lambda node.
# Lambda.__str__ returns the canonical role string (e.g., "customer->home").
role = str(subscript.index)

role = extract_subscript_role(subscript)
if not role:
continue # pragma: no cover

Expand Down
146 changes: 136 additions & 10 deletions datajunction-server/datajunction_server/construction/build_v3/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,19 +195,26 @@ def get_comp_aggregability(comp_name: str) -> Aggregability:
return gg.component_aggregabilities.get(comp_name, Aggregability.FULL)
return decomposed.aggregability

# Handle LIMITED aggregability (COUNT DISTINCT) specially
# This can't be pre-aggregated, so we need COUNT(DISTINCT grain_col)
# Handle LIMITED aggregability (COUNT DISTINCT).
# If the grain group was pre-aggregated (is_pre_aggregated=True), the wrapper CTE
# already computed COUNT(DISTINCT grain_key) and stored it as a named column.
# Emit SUM(pre_agg_col) — a no-op re-aggregation since the wrapper produces
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have to aggregate to the right grain before we combine these metrics to remove chances of double counting.

# exactly 1 row per dimension combination.
# Otherwise fall through to COUNT(DISTINCT grain_col) against the raw CTE.
if len(decomposed.components) == 1:
comp = decomposed.components[0]
orig_agg = get_comp_aggregability(comp.name)

if orig_agg == Aggregability.LIMITED:
_, col_name = comp_mappings[comp.name]
distinct_col = make_column_ref(col_name, cte_alias)
col_ref = make_column_ref(col_name, cte_alias)
if gg.is_pre_aggregated:
# Wrapper CTE already has COUNT(DISTINCT) as a named column; just SUM it.
return ast.Function(ast.Name("SUM"), args=[col_ref])
agg_name = comp.aggregation or "COUNT"
return ast.Function(
ast.Name(agg_name),
args=[distinct_col],
args=[col_ref],
quantifier=ast.SetQuantifier.Distinct,
)

Expand All @@ -220,11 +227,92 @@ def get_comp_aggregability(comp_name: str) -> Aggregability:
return expr_ast


def _build_pre_agg_wrapper_cte(
alias: str,
gg: GrainGroupSQL,
) -> tuple[ast.Query, str]:
"""
Build a pre-aggregation wrapper CTE for a LIMITED grain group.

A LIMITED grain group CTE outputs N rows per dimension combination (one per distinct
grain key, e.g. customer_id). When FULL OUTER JOINed with other CTEs that have 1
row per dimension combination, those rows fan out 1:N, causing SUM() to overcount.

This wrapper collapses the N rows into 1 by applying COUNT(DISTINCT grain_key) inside
the CTE instead of in the outer SELECT. The outer SELECT can then use SUM() on the
already-computed count, which is a no-op when there's exactly 1 row per group.

Args:
alias: The raw grain group CTE alias (e.g., "page_views_enriched_0")
gg: The LIMITED grain group

Returns:
(wrapper_cte_ast, wrapper_alias) where wrapper_alias is e.g.
"page_views_enriched_0_agg"
"""
wrapper_alias = f"{alias}_agg"

# Dimension columns for the GROUP BY: all grain columns except the LIMITED grain keys.
# gg.grain = [dim_col_aliases..., grain_key, ...]
# We want only the user-requested dimension columns (e.g., "category"), not the
# extra grain keys (e.g., "customer_id") that are being collapsed by COUNT DISTINCT.
limited_grain_keys = {
comp.rule.level[0]
for comp in gg.components
if comp.rule and comp.rule.type == Aggregability.LIMITED and comp.rule.level
}
dim_col_names = [col for col in gg.grain if col not in limited_grain_keys]

# Build SELECT projection: dim cols + COUNT(DISTINCT grain_key) per component
projection: list[Any] = [
ast.Column(name=ast.Name(col_name)) for col_name in dim_col_names
]
for comp in gg.components:
if comp.rule and comp.rule.type == Aggregability.LIMITED:
grain_col = comp.rule.level[0] if comp.rule.level else None
if not grain_col:
continue # pragma: no cover
grain_col_ref = ast.Column(name=ast.Name(grain_col))
count_expr = ast.Function(
ast.Name("COUNT"),
args=[grain_col_ref],
quantifier=ast.SetQuantifier.Distinct,
)
projection.append(ast.Alias(child=count_expr, alias=ast.Name(comp.name)))

# GROUP BY the dimension columns only (not the grain key)
group_by: list[ast.Expression] = [
ast.Column(name=ast.Name(col_name)) for col_name in dim_col_names
]

from_clause = ast.From(
relations=[
ast.Relation(primary=ast.Table(name=ast.Name(alias))),
],
)

wrapper_query = ast.Query(
select=ast.Select(
projection=projection,
from_=from_clause,
group_by=group_by if group_by else [],
),
)
wrapper_query.to_cte(ast.Name(wrapper_alias), None)
return wrapper_query, wrapper_alias


def collect_and_build_ctes(
grain_groups: list[GrainGroupSQL],
) -> tuple[list[ast.Query], list[str]]:
"""
Collect shared CTEs and convert grain groups to CTEs.

For LIMITED grain groups (COUNT DISTINCT), also emits a pre-aggregation wrapper
CTE that collapses the N-rows-per-dimension output to 1 row per dimension by
computing COUNT(DISTINCT grain_key) inside the CTE. This prevents fan-out when
FULL OUTER JOINing with FULL grain groups.

Returns (all_cte_asts, cte_aliases).
"""
# Collect all inner CTEs, dedupe by original name
Expand Down Expand Up @@ -258,7 +346,6 @@ def collect_and_build_ctes(
idx = parent_index_counter.get(parent_short, 0)
parent_index_counter[parent_short] = idx + 1
alias = f"{parent_short}_{idx}"
cte_aliases.append(alias)

# gg.query is already an AST - no need to parse!
gg_query = gg.query
Expand All @@ -272,6 +359,29 @@ def collect_and_build_ctes(
gg_main.to_cte(ast.Name(alias), None)
all_cte_asts.append(gg_main)

# For non-merged LIMITED grain groups, add a pre-aggregation wrapper CTE.
# This collapses N rows per dimension (one per distinct grain key) into 1 row
# by computing COUNT(DISTINCT grain_key) inside the CTE, preventing fan-out
# in the FULL OUTER JOIN step.
needs_pre_agg = (
not gg.is_merged
and gg.aggregability == Aggregability.LIMITED
and gg.components
)
if needs_pre_agg:
wrapper_cte, wrapper_alias = _build_pre_agg_wrapper_cte(alias, gg)
all_cte_asts.append(wrapper_cte)
# Record the pre-aggregated column name for each LIMITED component so that
# _build_metric_aggregation() can reference it by name instead of re-applying
# COUNT(DISTINCT).
for comp in gg.components:
if comp.rule and comp.rule.type == Aggregability.LIMITED:
gg.component_aliases[comp.name] = comp.name
gg.is_pre_aggregated = True
cte_aliases.append(wrapper_alias)
else:
cte_aliases.append(alias)

return all_cte_asts, cte_aliases


Expand Down Expand Up @@ -1915,13 +2025,29 @@ def collect_derived_dependencies(metric_name: str, visited: set[str]) -> None:
applicable_dimension_filters = []
for f in dimension_filters_raw:
filter_ast = parse_filter(f)
# Check if any column ref in this filter is a filter-only dimension
# Check if any column ref in this filter is a filter-only dimension.
# Must handle both plain column refs and role-qualified subscript refs
# (e.g., "v3.location.country[customer->home]"), because find_all(ast.Column)
# returns only the base Column inside the Subscript, not the full role string.
refs_filter_only = False
for col in filter_ast.find_all(ast.Column):
full_name = get_column_full_name(col)
if full_name and full_name in ctx.filter_dimensions:
refs_filter_only = True
for subscript in filter_ast.find_all(ast.Subscript):
if not isinstance(subscript.expr, ast.Column):
continue # pragma: no cover
base_ref = get_column_full_name(subscript.expr)
if base_ref:
for fd in ctx.filter_dimensions:
fd_base = fd.split("[")[0] if "[" in fd else fd
if fd_base == base_ref:
refs_filter_only = True
break
if refs_filter_only:
break
if not refs_filter_only:
for col in filter_ast.find_all(ast.Column):
full_name = get_column_full_name(col)
if full_name and full_name in ctx.filter_dimensions:
refs_filter_only = True
break
if not refs_filter_only:
applicable_dimension_filters.append(f)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,12 @@ class GrainGroupSQL:
# instead of individual grain group CTEs.
is_cross_fact_window: bool = False

# Pre-aggregation: True when collect_and_build_ctes() added a wrapper CTE that
# applies COUNT(DISTINCT grain_key) per requested dimension combination.
# When True, _build_metric_aggregation() should emit SUM(pre_agg_col) instead of
# COUNT(DISTINCT raw_grain_col), since the wrapper CTE already did the DISTINCT work.
is_pre_aggregated: bool = False

# Scan estimation: source tables accessed during SQL generation
# Populated by collect_node_ctes during CTE building
scanned_sources: list[str] = field(default_factory=list)
Expand Down
Loading
Loading