Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,17 @@

import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.plan.RelRule;
import org.apache.calcite.rel.RelCollation;
import org.apache.calcite.rel.RelCollations;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.Filter;
import org.apache.calcite.rel.core.Intersect;
import org.apache.calcite.rel.core.Minus;
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rel.core.SetOp;
import org.apache.calcite.rel.core.Sort;
import org.apache.calcite.rel.core.Union;
import org.apache.calcite.rel.metadata.RelMetadataQuery;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexUtil;
import org.apache.calcite.tools.RelBuilder;
Expand Down Expand Up @@ -117,13 +122,18 @@ protected SetOpToFilterRule(Config config) {

private static void match(RelOptRuleCall call) {
final SetOp setOp = call.rel(0);
final RelMetadataQuery mq = call.getMetadataQuery();
final List<RelNode> inputs = setOp.getInputs();
if (setOp.all || inputs.size() < 2) {
return;
}

final RelBuilder builder = call.builder();
Pair<RelNode, @Nullable RexNode> first = extractSourceAndCond(inputs.get(0).stripped());
final RelNode firstClause = inputs.get(0).stripped();
final List<RelCollation> firstCollations = mq.collations(firstClause);
Pair<RelNode, @Nullable RexNode> first =
extractSourceAndCond(firstClause, firstCollations != null
&& firstCollations.stream().anyMatch(c -> c != RelCollations.EMPTY));

// Groups conditions by their source relational node and input position.
// - Key: Pair of (sourceRelNode, inputPosition)
Expand All @@ -143,7 +153,14 @@ private static void match(RelOptRuleCall call) {

for (int i = 1; i < inputs.size(); i++) {
final RelNode input = inputs.get(i).stripped();
final Pair<RelNode, @Nullable RexNode> pair = extractSourceAndCond(input);
boolean isSorted = false;
final List<RelCollation> inputCollations = mq.collations(input);
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@asolimando I use mq to check if the input is sorted. If the input is sorted, a sub-plan that has both limit/offset and sort is safe to be rewritten. Please check if this logic matches your suggestion.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As suggested let's bring the discussion in Jira, I will try to summarize there my idea, let's see if it makes sense for everyone, thanks for the code draft @xiedeyantu!

if (inputCollations != null
&& inputCollations.stream().anyMatch(c -> c != RelCollations.EMPTY)
&& inputCollations.equals(firstCollations)) {
isSorted = true;
}
final Pair<RelNode, @Nullable RexNode> pair = extractSourceAndCond(input, isSorted);
sourceToConds.computeIfAbsent(Pair.of(pair.left, pair.right != null ? null : i),
k -> new ArrayList<>()).add(pair.right);
}
Expand Down Expand Up @@ -197,21 +214,46 @@ private static RelBuilder buildSetOp(RelBuilder builder, int count, RelNode setO
throw new IllegalStateException("unreachable code");
}

private static Pair<RelNode, @Nullable RexNode> extractSourceAndCond(RelNode input) {
private static Pair<RelNode, @Nullable RexNode> extractSourceAndCond(RelNode input,
boolean isSorted) {
if (input instanceof Filter) {
Filter filter = (Filter) input;
if (!RexUtil.isDeterministic(filter.getCondition())
|| RexUtil.SubQueryFinder.containsSubQuery(filter)) {
// Skip non-deterministic conditions or those containing subqueries
return Pair.of(input, null);
}
return Pair.of(filter.getInput().stripped(), filter.getCondition());
final RelNode source = filter.getInput().stripped();
if (containsBlockingSortInProjectFilterChain(source, isSorted)) {
return Pair.of(input, null);
}
return Pair.of(source, filter.getCondition());
}
if (containsBlockingSortInProjectFilterChain(input, isSorted)) {
return Pair.of(input, null);
}
// For non-filter inputs, use TRUE literal as default condition.
return Pair.of(input.stripped(),
input.getCluster().getRexBuilder().makeLiteral(true));
}

private static boolean containsBlockingSortInProjectFilterChain(RelNode input,
boolean isSorted) {
RelNode current = input.stripped();
while (true) {
if (current instanceof Sort) {
Sort sort = (Sort) current;
return !isSorted && (sort.fetch != null || sort.offset != null);
}
if (current instanceof Project
|| current instanceof Filter) {
current = current.getInput(0).stripped();
continue;
}
return false;
}
}

/**
* Creates a combined condition where the first condition
* is kept as-is and all subsequent conditions are negated,
Expand Down
113 changes: 113 additions & 0 deletions core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -11435,6 +11435,119 @@ private void checkLoptOptimizeJoinRule(LoptOptimizeJoinRule rule) {
.check();
}

/** Test case of
* <a href="https://issues.apache.org/jira/browse/CALCITE-7463">[CALCITE-7463]
* UnionToFilterRule incorrectly rewrites UNION with LIMIT</a>. */
@Test void testUnionToFilterRuleWithLimit() {
Comment thread
asolimando marked this conversation as resolved.
final String sql = "(SELECT mgr, comm FROM emp LIMIT 2)\n"
+ "UNION\n"
+ "(SELECT mgr, comm FROM emp LIMIT 2)\n";
sql(sql)
.withRule(CoreRules.UNION_FILTER_TO_FILTER)
.checkUnchanged();
}

/** Test case of
* <a href="https://issues.apache.org/jira/browse/CALCITE-7463">[CALCITE-7463]
* UnionToFilterRule incorrectly rewrites UNION with LIMIT</a>. */
@Test void testUnionAllToFilterRuleWithLimit() {
final String sql = "(SELECT mgr, comm FROM emp LIMIT 2)\n"
+ "UNION ALL\n"
+ "(SELECT mgr, comm FROM emp LIMIT 2)\n";
sql(sql)
.withRule(CoreRules.UNION_FILTER_TO_FILTER)
.checkUnchanged();
}

/** Test case of
* <a href="https://issues.apache.org/jira/browse/CALCITE-7463">[CALCITE-7463]
* UnionToFilterRule incorrectly rewrites UNION with LIMIT</a>. */
@Test void testUnionToFilterRuleWithNestedLimit() {
final String sql = "SELECT comm FROM (SELECT mgr, comm FROM emp LIMIT 2) t\n"
+ "WHERE comm > 5\n"
+ "UNION\n"
+ "SELECT comm FROM (SELECT mgr, comm FROM emp LIMIT 2) t\n"
+ "WHERE comm > 10\n";
sql(sql)
.withPreRule(CoreRules.PROJECT_FILTER_TRANSPOSE)
.withRule(CoreRules.UNION_FILTER_TO_FILTER)
.checkUnchanged();
}

/** Test case of
* <a href="https://issues.apache.org/jira/browse/CALCITE-7463">[CALCITE-7463]
* UnionToFilterRule incorrectly rewrites UNION with LIMIT</a>. */
@Test void testUnionToFilterRuleWithSortOnly() {
final Function<RelBuilder, RelNode> relFn = b -> {
final RelNode left = b.scan("EMP")
.project(b.field("MGR"), b.field("COMM"))
.sort(b.field(0))
.filter(b.call(SqlStdOperatorTable.GREATER_THAN, b.field(1), b.literal(5)))
.build();
final RelNode right = b.scan("EMP")
.project(b.field("MGR"), b.field("COMM"))
.sort(b.field(0))
.filter(b.call(SqlStdOperatorTable.GREATER_THAN, b.field(1), b.literal(10)))
.build();
return b.push(left)
.push(right)
.union(false, 2)
.build();
};
relFn(relFn)
.withRule(CoreRules.UNION_FILTER_TO_FILTER)
.check();
}

/** Test case of
* <a href="https://issues.apache.org/jira/browse/CALCITE-7463">[CALCITE-7463]
* UnionToFilterRule incorrectly rewrites UNION with LIMIT</a>. */
@Test void testUnionToFilterRuleWithSortLimit() {
final Function<RelBuilder, RelNode> relFn = b -> {
final RelNode left = b.scan("EMP")
.project(b.field("MGR"), b.field("COMM"))
.sortLimit(10, 2, b.field(0))
.filter(b.call(SqlStdOperatorTable.GREATER_THAN, b.field(1), b.literal(5)))
.build();
final RelNode right = b.scan("EMP")
.project(b.field("MGR"), b.field("COMM"))
.sortLimit(10, 2, b.field(0))
.filter(b.call(SqlStdOperatorTable.GREATER_THAN, b.field(1), b.literal(10)))
.build();
return b.push(left)
.push(right)
.union(false, 2)
.build();
};
relFn(relFn)
.withRule(CoreRules.UNION_FILTER_TO_FILTER)
.check();
}

/** Test case of
* <a href="https://issues.apache.org/jira/browse/CALCITE-7463">[CALCITE-7463]
* UnionToFilterRule incorrectly rewrites UNION with LIMIT</a>. */
@Test void testUnionToFilterRuleWithUnmergeableFirstInput() {
final Function<RelBuilder, RelNode> relFn = b -> {
final RelNode left = b.scan("EMP")
.project(b.field("MGR"), b.field("COMM"))
.sortLimit(10, 2, b.field(0))
.filter(b.call(SqlStdOperatorTable.GREATER_THAN, b.field(1), b.literal(5)))
.build();
final RelNode right = b.scan("EMP")
.project(b.field("MGR"), b.field("COMM"))
.filter(b.call(SqlStdOperatorTable.GREATER_THAN, b.field(1), b.literal(10)))
.build();
return b.push(left)
.push(right)
.union(false, 2)
.build();
};
relFn(relFn)
.withRule(CoreRules.UNION_FILTER_TO_FILTER)
.checkUnchanged();
}

/** Test case of
* <a href="https://issues.apache.org/jira/browse/CALCITE-7002">[CALCITE-7002]
* Create an optimization rule to eliminate UNION
Expand Down
125 changes: 125 additions & 0 deletions core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml
Original file line number Diff line number Diff line change
Expand Up @@ -21336,6 +21336,25 @@ LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4], SAL=[$
LogicalFilter(condition=[AND(<(+($0, 50), 20), >=($cor0.DEPTNO, $9))])
LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4], SAL=[$5], COMM=[$6], DEPTNO=[$7], SLACKER=[$8], $f9=[+(30, $7)])
LogicalTableScan(table=[[CATALOG, SALES, EMP]])
]]>
</Resource>
</TestCase>
<TestCase name="testUnionAllToFilterRuleWithLimit">
<Resource name="sql">
<![CDATA[(SELECT mgr, comm FROM emp LIMIT 2)
UNION ALL
(SELECT mgr, comm FROM emp LIMIT 2)
]]>
</Resource>
<Resource name="planBefore">
<![CDATA[
LogicalUnion(all=[true])
LogicalSort(fetch=[2])
LogicalProject(MGR=[$3], COMM=[$6])
LogicalTableScan(table=[[CATALOG, SALES, EMP]])
LogicalSort(fetch=[2])
LogicalProject(MGR=[$3], COMM=[$6])
LogicalTableScan(table=[[CATALOG, SALES, EMP]])
]]>
</Resource>
</TestCase>
Expand Down Expand Up @@ -21495,6 +21514,50 @@ LogicalUnion(all=[false])
LogicalFilter(condition=[SEARCH($0, Sarg[5, 10])])
LogicalProject(DEPTNO=[$0])
LogicalTableScan(table=[[CATALOG, SALES, DEPT]])
]]>
</Resource>
</TestCase>
<TestCase name="testUnionToFilterRuleWithLimit">
<Resource name="sql">
<![CDATA[(SELECT mgr, comm FROM emp LIMIT 2)
UNION
(SELECT mgr, comm FROM emp LIMIT 2)
]]>
</Resource>
<Resource name="planBefore">
<![CDATA[
LogicalUnion(all=[false])
LogicalSort(fetch=[2])
LogicalProject(MGR=[$3], COMM=[$6])
LogicalTableScan(table=[[CATALOG, SALES, EMP]])
LogicalSort(fetch=[2])
LogicalProject(MGR=[$3], COMM=[$6])
LogicalTableScan(table=[[CATALOG, SALES, EMP]])
]]>
</Resource>
</TestCase>
<TestCase name="testUnionToFilterRuleWithNestedLimit">
<Resource name="sql">
<![CDATA[SELECT comm FROM (SELECT mgr, comm FROM emp LIMIT 2) t
WHERE comm > 5
UNION
SELECT comm FROM (SELECT mgr, comm FROM emp LIMIT 2) t
WHERE comm > 10
]]>
</Resource>
<Resource name="planBefore">
<![CDATA[
LogicalUnion(all=[false])
LogicalFilter(condition=[>($0, 5)])
LogicalProject(COMM=[$1])
LogicalSort(fetch=[2])
LogicalProject(MGR=[$3], COMM=[$6])
LogicalTableScan(table=[[CATALOG, SALES, EMP]])
LogicalFilter(condition=[>($0, 10)])
LogicalProject(COMM=[$1])
LogicalSort(fetch=[2])
LogicalProject(MGR=[$3], COMM=[$6])
LogicalTableScan(table=[[CATALOG, SALES, EMP]])
]]>
</Resource>
</TestCase>
Expand Down Expand Up @@ -21541,6 +21604,54 @@ LogicalUnion(all=[false])
LogicalAggregate(group=[{0, 1}])
LogicalProject(MGR=[$3], COMM=[$6])
LogicalTableScan(table=[[CATALOG, SALES, EMP]])
]]>
</Resource>
</TestCase>
<TestCase name="testUnionToFilterRuleWithSortLimit">
<Resource name="planBefore">
<![CDATA[
LogicalUnion(all=[false])
LogicalFilter(condition=[>($1, 5)])
LogicalSort(sort0=[$0], dir0=[ASC], offset=[10], fetch=[2])
LogicalProject(MGR=[$3], COMM=[$6])
LogicalTableScan(table=[[scott, EMP]])
LogicalFilter(condition=[>($1, 10)])
LogicalSort(sort0=[$0], dir0=[ASC], offset=[10], fetch=[2])
LogicalProject(MGR=[$3], COMM=[$6])
LogicalTableScan(table=[[scott, EMP]])
]]>
</Resource>
<Resource name="planAfter">
<![CDATA[
LogicalAggregate(group=[{0, 1}])
LogicalFilter(condition=[>($1, 5)])
LogicalSort(sort0=[$0], dir0=[ASC], offset=[10], fetch=[2])
LogicalProject(MGR=[$3], COMM=[$6])
LogicalTableScan(table=[[scott, EMP]])
]]>
</Resource>
</TestCase>
<TestCase name="testUnionToFilterRuleWithSortOnly">
<Resource name="planBefore">
<![CDATA[
LogicalUnion(all=[false])
LogicalFilter(condition=[>($1, 5)])
LogicalSort(sort0=[$0], dir0=[ASC])
LogicalProject(MGR=[$3], COMM=[$6])
LogicalTableScan(table=[[scott, EMP]])
LogicalFilter(condition=[>($1, 10)])
LogicalSort(sort0=[$0], dir0=[ASC])
LogicalProject(MGR=[$3], COMM=[$6])
LogicalTableScan(table=[[scott, EMP]])
]]>
</Resource>
<Resource name="planAfter">
<![CDATA[
LogicalAggregate(group=[{0, 1}])
LogicalFilter(condition=[>($1, 5)])
LogicalSort(sort0=[$0], dir0=[ASC])
LogicalProject(MGR=[$3], COMM=[$6])
LogicalTableScan(table=[[scott, EMP]])
]]>
</Resource>
</TestCase>
Expand Down Expand Up @@ -21599,6 +21710,20 @@ LogicalAggregate(group=[{0, 1}])
LogicalFilter(condition=[OR(=($0, 12), =($1, 5))])
LogicalProject(MGR=[$3], COMM=[$6])
LogicalTableScan(table=[[CATALOG, SALES, EMP]])
]]>
</Resource>
</TestCase>
<TestCase name="testUnionToFilterRuleWithUnmergeableFirstInput">
<Resource name="planBefore">
<![CDATA[
LogicalUnion(all=[false])
LogicalFilter(condition=[>($1, 5)])
LogicalSort(sort0=[$0], dir0=[ASC], offset=[10], fetch=[2])
LogicalProject(MGR=[$3], COMM=[$6])
LogicalTableScan(table=[[scott, EMP]])
LogicalFilter(condition=[>($1, 10)])
LogicalProject(MGR=[$3], COMM=[$6])
LogicalTableScan(table=[[scott, EMP]])
]]>
</Resource>
</TestCase>
Expand Down
Loading