Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
271 changes: 195 additions & 76 deletions sql_redis/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,30 @@ class ScoringSpec:
scorer: str = "BM25" # Scorer algorithm (BM25, TFIDF, DISMAX, etc.)


@dataclass
class BoolLeaf:
"""A leaf in the WHERE-clause boolean tree wrapping a single Condition."""

condition: Condition


@dataclass
class BoolGroup:
"""An internal node in the WHERE-clause boolean tree.

Preserves the SQL operator precedence and parenthesization so that
mixed expressions like ``A AND (B OR C)`` are not flattened into a
single boolean operator.
"""

operator: str # "AND" or "OR"
children: list = dataclasses.field(default_factory=list)


# Type alias for a boolean tree node — Union[BoolLeaf, BoolGroup].
BoolNode = "BoolLeaf | BoolGroup"


@dataclass
class ParsedQuery:
"""Result of parsing a SQL query."""
Expand All @@ -222,6 +246,12 @@ class ParsedQuery:
default_factory=list
)
boolean_operator: str = "AND"
condition_tree: object | None = None # BoolLeaf | BoolGroup | None
# True iff the WHERE clause contains an OR anywhere in the original SQL.
# Set during parsing — independent of the boolean tree, which may collapse
# an OR group when one side is a side-channel predicate (geo_distance)
# that produces no tree leaf.
has_or_in_where: bool = False
aggregations: list[AggregationSpec] = dataclasses.field(default_factory=list)
computed_fields: list[ComputedField] = dataclasses.field(default_factory=list)
date_functions: list[DateFunctionSpec] = dataclasses.field(default_factory=list)
Expand Down Expand Up @@ -267,7 +297,12 @@ def parse(self, sql: str) -> ParsedQuery:
# Extract WHERE clause conditions
where = ast.find(exp.Where)
if where:
self._process_where_clause(where.this, result)
tree = self._process_where_clause(where.this, result)
result.condition_tree = tree
# Set legacy boolean_operator from the tree root for backward
# compatibility with callers that still consult this field.
if isinstance(tree, BoolGroup):
result.boolean_operator = tree.operator

# Extract GROUP BY clause
group = ast.find(exp.Group)
Expand Down Expand Up @@ -352,11 +387,35 @@ def _process_select_expression_inner(
func_name = redis_func_map.get(func_name, func_name)
field_name = None
# Get the field being aggregated (if any)
if expression.this:
if isinstance(expression.this, exp.Column):
field_name = expression.this.name
elif isinstance(expression.this, exp.Star):
field_name = None # COUNT(*)
inner = expression.this
if isinstance(inner, exp.Distinct):
# AGG(DISTINCT col) — only COUNT has a native RediSearch
# equivalent (COUNT_DISTINCT). Other aggregates can't be
# silently translated to a non-distinct form, so raise.
distinct_cols = inner.expressions or (
[inner.this] if inner.this is not None else []
)
if len(distinct_cols) != 1 or not isinstance(
distinct_cols[0], exp.Column
):
raise ValueError(
f"{func_name}(DISTINCT ...) expects a single column "
"reference; multi-column or expression DISTINCT is "
"not supported by RediSearch."
)
if func_name != "COUNT":
raise ValueError(
f"{func_name}(DISTINCT ...) is not supported by "
"RediSearch. Only COUNT(DISTINCT x) maps to a native "
"reducer (COUNT_DISTINCT); pre-deduplicate the data "
"or use COUNT_DISTINCT for cardinality."
)
func_name = "COUNT_DISTINCT"
field_name = distinct_cols[0].name
elif isinstance(inner, exp.Column):
field_name = inner.name
elif isinstance(inner, exp.Star):
field_name = None # COUNT(*)
result.aggregations.append(
AggregationSpec(function=func_name, field=field_name, alias=alias)
)
Expand Down Expand Up @@ -688,54 +747,63 @@ def _process_date_expression(

def _process_where_clause(
self, expression, result: ParsedQuery, negated: bool = False
) -> None:
"""Process WHERE clause expression recursively."""
):
"""Process WHERE clause expression recursively.

Returns a boolean tree (BoolLeaf or BoolGroup) preserving the original
SQL operator precedence and grouping, or None when the expression
contributes no boolean clause to the RediSearch query string (e.g.,
geo_distance comparisons stored separately on result.geo_conditions).
"""
if isinstance(expression, exp.EQ):
self._add_condition(expression, "=", result, negated)
return self._leaf(self._add_condition(expression, "=", result, negated))
elif isinstance(expression, exp.GT):
self._add_condition(expression, ">", result, negated)
return self._leaf(self._add_condition(expression, ">", result, negated))
elif isinstance(expression, exp.GTE):
self._add_condition(expression, ">=", result, negated)
return self._leaf(self._add_condition(expression, ">=", result, negated))
elif isinstance(expression, exp.LT):
self._add_condition(expression, "<", result, negated)
return self._leaf(self._add_condition(expression, "<", result, negated))
elif isinstance(expression, exp.LTE):
self._add_condition(expression, "<=", result, negated)
return self._leaf(self._add_condition(expression, "<=", result, negated))
elif isinstance(expression, exp.NEQ):
self._add_condition(expression, "!=", result, negated)
return self._leaf(self._add_condition(expression, "!=", result, negated))
elif isinstance(expression, exp.Between):
self._add_between_condition(expression, result, negated)
return self._leaf(self._add_between_condition(expression, result, negated))
elif isinstance(expression, exp.In):
self._add_in_condition(expression, result, negated)
return self._leaf(self._add_in_condition(expression, result, negated))
elif isinstance(expression, exp.Like):
# LIKE 'pattern%' / '%pattern' / '%pattern%'
self._add_condition(expression, "LIKE", result, negated)
return self._leaf(self._add_condition(expression, "LIKE", result, negated))
elif isinstance(expression, exp.And):
result.boolean_operator = "AND"
self._process_where_clause(expression.this, result, negated)
self._process_where_clause(expression.expression, result, negated)
left = self._process_where_clause(expression.this, result, negated)
right = self._process_where_clause(expression.expression, result, negated)
return self._combine("AND", left, right)
elif isinstance(expression, exp.Or):
result.boolean_operator = "OR"
self._process_where_clause(expression.this, result, negated)
self._process_where_clause(expression.expression, result, negated)
result.has_or_in_where = True
left = self._process_where_clause(expression.this, result, negated)
right = self._process_where_clause(expression.expression, result, negated)
return self._combine("OR", left, right)
elif isinstance(expression, exp.Not):
self._process_where_clause(expression.this, result, negated=not negated)
return self._process_where_clause(
expression.this, result, negated=not negated
)
elif isinstance(expression, exp.Paren):
self._process_where_clause(expression.this, result, negated=negated)
return self._process_where_clause(expression.this, result, negated=negated)
elif isinstance(expression, exp.Is):
# IS NULL: exp.Is(this=Column, expression=Null())
# IS NOT NULL arrives here with negated=True via the exp.Not handler above
if isinstance(expression.this, exp.Column) and isinstance(
expression.expression, exp.Null
):
operator = "IS_NOT_NULL" if negated else "IS_NULL"
result.conditions.append(
Condition(
field=expression.this.name,
operator=operator,
value=None,
negated=False,
)
cond = Condition(
field=expression.this.name,
operator=operator,
value=None,
negated=False,
)
result.conditions.append(cond)
return BoolLeaf(cond)
else:
raise ValueError(
"Unsupported IS expression in WHERE clause; only "
Expand All @@ -752,9 +820,39 @@ def _process_where_clause(
"for post-aggregate filtering."
)
# EXISTS (SELECT ...) — SQL subquery, silently ignored (not supported)
return None
elif isinstance(expression, exp.Anonymous):
# Custom function like MATCH(field, value)
self._add_function_condition(expression, result, negated)
return self._leaf(self._add_function_condition(expression, result, negated))
return None

@staticmethod
def _leaf(condition: Condition | None):
"""Wrap a Condition in a BoolLeaf, or return None for non-leaf adds."""
if condition is None:
return None
return BoolLeaf(condition)

@staticmethod
def _combine(operator: str, left, right):
"""Combine two child nodes under a boolean operator.

Drops None children, flattens same-operator subtrees so that
``A AND B AND C`` produces a single AND group with three children.
"""
children: list = []
for child in (left, right):
if child is None:
continue
if isinstance(child, BoolGroup) and child.operator == operator:
children.extend(child.children)
else:
children.append(child)
if not children:
return None
if len(children) == 1:
return children[0]
return BoolGroup(operator=operator, children=children)

def _process_having_clause(self, expression, result: ParsedQuery) -> None:
"""Process HAVING clause — routes exists() to filters."""
Expand All @@ -780,8 +878,12 @@ def _process_having_clause(self, expression, result: ParsedQuery) -> None:

def _add_condition(
self, expression, operator: str, result: ParsedQuery, negated: bool
) -> None:
"""Add a condition from a comparison expression."""
) -> Condition | None:
"""Add a condition from a comparison expression.

Returns the appended Condition for inclusion in the boolean tree, or
None when the expression was routed to result.geo_conditions instead.
"""
field_name = None
value = None
is_geo_distance = False
Expand Down Expand Up @@ -875,20 +977,26 @@ def _add_condition(
unit=geo_unit,
)
)
return None
else:
result.conditions.append(
Condition(
field=field_name,
operator=operator,
value=value,
negated=negated,
)
cond = Condition(
field=field_name,
operator=operator,
value=value,
negated=negated,
)
result.conditions.append(cond)
return cond
return None

def _add_between_condition(
self, expression, result: ParsedQuery, negated: bool
) -> None:
"""Add a BETWEEN condition."""
) -> Condition | None:
"""Add a BETWEEN condition.

Returns the appended Condition for inclusion in the boolean tree, or
None when the expression was routed to result.geo_conditions instead.
"""
field_name = None
is_geo_distance = False
geo_lon = None
Expand Down Expand Up @@ -963,35 +1071,44 @@ def _add_between_condition(
unit=geo_unit,
)
)
return None
else:
result.conditions.append(
Condition(
field=field_name,
operator="BETWEEN",
value=(low_val, high_val),
negated=negated,
)
cond = Condition(
field=field_name,
operator="BETWEEN",
value=(low_val, high_val),
negated=negated,
)
result.conditions.append(cond)
return cond
return None

def _add_in_condition(self, expression, result: ParsedQuery, negated: bool) -> None:
"""Add an IN condition."""
def _add_in_condition(
self, expression, result: ParsedQuery, negated: bool
) -> Condition | None:
"""Add an IN condition. Returns the appended Condition or None."""
field_name = None
if isinstance(expression.this, exp.Column):
field_name = expression.this.name

values = [self._extract_literal_value(e) for e in expression.expressions]

if field_name is not None:
result.conditions.append(
Condition(
field=field_name, operator="IN", value=values, negated=negated
)
cond = Condition(
field=field_name, operator="IN", value=values, negated=negated
)
result.conditions.append(cond)
return cond
return None

def _add_function_condition(
self, expression, result: ParsedQuery, negated: bool
) -> None:
"""Add a condition from a function call like fulltext(field, value) or fuzzy(field, value, level)."""
) -> Condition | None:
"""Add a condition from a function call like fulltext(field, value) or fuzzy(field, value, level).

Returns the appended Condition for inclusion in the boolean tree, or
None when no condition was added.
"""
func_name = expression.name.upper()
args = expression.expressions

Expand Down Expand Up @@ -1071,16 +1188,16 @@ def _add_function_condition(
"fulltext() first argument must be a column name, "
f"got {args[0]}. Usage: fulltext(field, 'search terms')"
)
result.conditions.append(
Condition(
field=field_name,
operator="FULLTEXT",
value=value,
negated=negated,
slop=slop,
inorder=inorder,
)
cond = Condition(
field=field_name,
operator="FULLTEXT",
value=value,
negated=negated,
slop=slop,
inorder=inorder,
)
result.conditions.append(cond)
return cond

elif func_name == "FUZZY" and len(args) >= 2:
field_name = args[0].name if isinstance(args[0], exp.Column) else None
Expand Down Expand Up @@ -1121,15 +1238,17 @@ def _add_function_condition(
"fuzzy() first argument must be a column name, "
f"got {args[0]}. Usage: fuzzy(field, 'search term')"
)
result.conditions.append(
Condition(
field=field_name,
operator="FUZZY",
value=value,
negated=negated,
fuzzy_level=fuzzy_level,
)
cond = Condition(
field=field_name,
operator="FUZZY",
value=value,
negated=negated,
fuzzy_level=fuzzy_level,
)
result.conditions.append(cond)
return cond

return None

def _extract_literal_value(self, expression, convert_dates: bool = False):
"""Extract a Python value from a sqlglot Literal or Neg expression.
Expand Down
Loading
Loading