Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 67 additions & 22 deletions sqlglot/dialects/clickhouse.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations
import typing as t
import datetime
from collections import deque
from sqlglot import exp, generator, parser, tokens
from sqlglot._typing import E
from sqlglot.dialects.dialect import (
Expand Down Expand Up @@ -530,23 +531,29 @@ class Parser(parser.Parser):
"exponentialTimeDecayedAvg",
}

AGG_FUNCTIONS_SUFFIXES = [
"If",
"Array",
"ArrayIf",
"Map",
"SimpleState",
"State",
"Merge",
"MergeState",
"ForEach",
"Distinct",
"OrDefault",
"OrNull",
"Resample",
"ArgMin",
"ArgMax",
]
# Sorted longest-first so that compound suffixes (e.g. "SimpleState") are matched
# before their sub-suffixes (e.g. "State") when resolving multi-combinator functions.
AGG_FUNCTIONS_SUFFIXES = sorted(
[
"If",
"Array",
"ArrayIf",
"Map",
"SimpleState",
"State",
"Merge",
"MergeState",
"ForEach",
"Distinct",
"OrDefault",
"OrNull",
"Resample",
"ArgMin",
"ArgMax",
],
key=len,
reverse=True,
)

FUNC_TOKENS = {
*parser.Parser.FUNC_TOKENS,
Expand All @@ -562,12 +569,50 @@ class Parser(parser.Parser):
TokenType.LIKE,
}

AGG_FUNC_MAPPING = (
# memoized examples of all 0- and 1-suffix aggregate function names
AGG_FUNC_MAPPING: t.Mapping[str, t.Tuple[str, str | None]] = (
lambda functions, suffixes: {
f"{f}{sfx}": (f, sfx) for sfx in (suffixes + [""]) for f in functions
f"{f}{sfx}": (f, sfx) for sfx in suffixes for f in functions
}
| {
# some function names could (but should not) be interpreted as combined
# versions of other function names. For example, `minMap`. To avoid this,
# the 0-suffix function dict must be on the RHS of the | operator, above
f: (f, None)
for f in functions
}
)(AGG_FUNCTIONS, AGG_FUNCTIONS_SUFFIXES)

@classmethod
def _resolve_clickhouse_agg(cls, name: str) -> t.Optional[t.Tuple[str, t.Sequence[str]]]:
# ClickHouse allows chaining multiple combinators on aggregate functions.
# See https://clickhouse.com/docs/sql-reference/aggregate-functions/combinators
# N.B. this resolution allows any suffix stack, including ones that ClickHouse rejects
# syntactically such as sumMergeMerge (due to repeated adjacent suffixes)

# Until we are able to identify a 1- or 0-suffix aggregate function by name,
# repeatedly strip and queue suffixes (checking longer suffixes first, see comment on
# AGG_FUNCTIONS_SUFFIXES_SORTED). This loop only runs for 2 or more suffixes,
# as AGG_FUNC_MAPPING memoizes all 0- and 1-suffix
accumulated_suffixes: t.Deque[str] = deque()
while (parts := cls.AGG_FUNC_MAPPING.get(name)) is None:
for suffix in cls.AGG_FUNCTIONS_SUFFIXES:
if name.endswith(suffix) and len(name) != len(suffix):
accumulated_suffixes.appendleft(suffix)
name = name[: -len(suffix)]
break
else:
return None

# We now have a 0- or 1-suffix aggregate
agg_func_name, inner_suffix = parts
if inner_suffix:
# this is a 1-suffix aggregate (either naturally or via repeated suffix
# stripping). prepend the innermost suffix.
accumulated_suffixes.appendleft(inner_suffix)

return (agg_func_name, accumulated_suffixes)

FUNCTION_PARSERS = {
**parser.Parser.FUNCTION_PARSERS,
"ARRAYJOIN": lambda self: self.expression(exp.Explode, this=self._parse_expression()),
Expand Down Expand Up @@ -863,9 +908,9 @@ def _parse_function(

func = expr.this if isinstance(expr, exp.Window) else expr

# Aggregate functions can be split in 2 parts: <func_name><suffix>
# Aggregate functions can be split in 2 parts: <func_name><suffix[es]>
parts = (
self.AGG_FUNC_MAPPING.get(func.this) if isinstance(func, exp.Anonymous) else None
self._resolve_clickhouse_agg(func.this) if isinstance(func, exp.Anonymous) else None
)

if parts:
Expand All @@ -876,7 +921,7 @@ def _parse_function(
"this": anon_func.this,
"expressions": anon_func.expressions,
}
if parts[1]:
if len(parts[1]) > 0:
exp_class: t.Type[exp.Expression] = (
exp.CombinedParameterizedAgg if params else exp.CombinedAggFunc
)
Expand Down
66 changes: 66 additions & 0 deletions tests/dialects/test_clickhouse.py
Original file line number Diff line number Diff line change
Expand Up @@ -1356,6 +1356,72 @@ def extract_agg_func(query):
0
].assert_is(exp.ParameterizedAgg)

def test_agg_functions_multiple_suffixes(self):
# Regression test: single-suffix
self.validate_identity("SELECT uniqExactIf(x, y) FROM t").selects[0].assert_is(
exp.CombinedAggFunc
)

# Double suffix: If + Merge
self.validate_identity("SELECT countIfMerge(state) FROM t").selects[0].assert_is(
exp.CombinedAggFunc
)
self.validate_identity("SELECT uniqExactIfMerge(state) FROM t").selects[0].assert_is(
exp.CombinedAggFunc
)

# Triple suffix: ArgMin + If + State (#4814)
self.validate_identity("SELECT avgArgMinIfState(x, y) FROM t").selects[0].assert_is(
exp.CombinedAggFunc
)

# Double suffix + parameters: If + State with quantile parameter
self.validate_identity("SELECT quantileIfState(0.5)(col, cond) FROM t").selects[
0
].assert_is(exp.CombinedParameterizedAgg)

# Collision-prone bases: "Map" is both a valid suffix and part of the function name.
# These must parse as the base function (AnonymousAggFunc), not as sum/min/max + Map suffix.
self.validate_identity("SELECT sumMap(k, v) FROM t").selects[0].assert_is(
exp.AnonymousAggFunc
)
self.validate_identity("SELECT minMap(k, v) FROM t").selects[0].assert_is(
exp.AnonymousAggFunc
)
self.validate_identity("SELECT maxMap(k, v) FROM t").selects[0].assert_is(
exp.AnonymousAggFunc
)

# Single-suffix chains on collision-prone bases
self.validate_identity("SELECT sumMapIf(k, v, cond) FROM t").selects[0].assert_is(
exp.CombinedAggFunc
)
self.validate_identity("SELECT minMapIf(k, v, cond) FROM t").selects[0].assert_is(
exp.CombinedAggFunc
)
self.validate_identity("SELECT maxMapIf(k, v, cond) FROM t").selects[0].assert_is(
exp.CombinedAggFunc
)
self.validate_identity("SELECT sumMapState(k, v) FROM t").selects[0].assert_is(
exp.CombinedAggFunc
)

# Multi-suffix chain on a collision-prone base
self.validate_identity("SELECT sumMapIfState(k, v, cond) FROM t").selects[0].assert_is(
exp.CombinedAggFunc
)

# example of a nontrivial query:
sum_merge_if_merge = (
self.validate_identity(
"SELECT sumMergeIfMerge(s) FROM (SELECT sumMergeIfState(agg, 1 = 1) AS s "
"FROM (SELECT sumState(toFloat64(number)) AS agg FROM numbers(10)))"
)
.selects[0]
.assert_is(exp.CombinedAggFunc)
)
assert sum_merge_if_merge.name == "sumMergeIfMerge"

def test_drop_on_cluster(self):
for creatable in ("DATABASE", "TABLE", "VIEW", "DICTIONARY", "FUNCTION"):
with self.subTest(f"Test DROP {creatable} ON CLUSTER"):
Expand Down