Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion datajunction-server/datajunction_server/sql/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ class Abs(Function):
Returns the absolute value of the numeric or interval value.
"""

is_aggregation = True
is_aggregation = False
dialects = [Dialect.SPARK, Dialect.DRUID]


Expand Down Expand Up @@ -1398,6 +1398,8 @@ class CountMinSketch(Function):
count_min_sketch(col, eps, confidence, seed) - Creates a Count-Min sketch of col.
"""

is_aggregation = True


@CountMinSketch.register # type: ignore
def infer_type(
Expand Down Expand Up @@ -2503,6 +2505,8 @@ class HistogramNumeric(Function):
defined by equally spaced width intervals.
"""

is_aggregation = True


@HistogramNumeric.register # type: ignore
def infer_type(arg1: ct.ColumnType, arg2: ct.IntegerType) -> ct.ColumnType:
Expand Down Expand Up @@ -2754,6 +2758,8 @@ class Kurtosis(Function):
kurtosis(expr) - Returns the kurtosis of the values in a group.
"""

is_aggregation = True


@Kurtosis.register # type: ignore
def infer_type(arg: ct.NumberType) -> ct.DoubleType:
Expand Down Expand Up @@ -3463,6 +3469,8 @@ class Mean(Function):
mean(expr) - Returns the average of the values in the group.
"""

is_aggregation = True


@Mean.register # type: ignore
def infer_type(arg: ct.ColumnType) -> ct.ColumnType:
Expand Down Expand Up @@ -3577,6 +3585,8 @@ class Mode(Function):
mode(col) - Returns the most frequent value for the values within col.
"""

is_aggregation = True


@Mode.register # type: ignore
def infer_type(arg: ct.ColumnType) -> ct.ColumnType:
Expand Down
38 changes: 38 additions & 0 deletions datajunction-server/tests/sql/decompose_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
MetricComponent,
)
from datajunction_server.models.node_type import NodeType
from datajunction_server.sql import functions as dj_functions
from datajunction_server.sql.decompose import (
MetricComponentExtractor,
safe_denominator,
Expand Down Expand Up @@ -2232,3 +2233,40 @@ async def test_user_authored_division_not_double_wrapped(
str(derived_sql),
"SELECT SUM(x_sum_b5c12ce5) / NULLIF(SUM(y_sum_898a9389), 0) FROM parent_node",
)


class TestFunctionAggregationClassification:
"""Lock in is_aggregation for previously-misclassified functions."""

def test_scalar_functions_not_marked_as_aggregation(self):
assert dj_functions.Abs.is_aggregation is False

def test_real_aggregations_marked_as_aggregation(self):
for cls in (
dj_functions.Mean,
dj_functions.Median,
dj_functions.Mode,
dj_functions.Kurtosis,
dj_functions.HistogramNumeric,
dj_functions.CountMinSketch,
):
assert cls.is_aggregation is True, (
f"{cls.__name__} is a real aggregation but is_aggregation is False"
)


@pytest.mark.asyncio
async def test_sum_abs_decomposes(session: AsyncSession, create_metric):
"""SUM(ABS(x)) decomposes into one SUM component over ABS(x)."""
metric_rev = await create_metric(
"SELECT SUM(ABS(amount - 100)) FROM parent_node",
)
extractor = MetricComponentExtractor(metric_rev.id)
measures, derived_sql = await extractor.extract(session)
assert len(measures) == 1
comp = measures[0]
assert comp.aggregation == "SUM"
assert comp.merge == "SUM"
assert comp.rule.type == Aggregability.FULL
assert "ABS" in comp.expression
assert_sql_equal(str(derived_sql), f"SELECT SUM({comp.name}) FROM parent_node")
Loading