Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,13 @@ object FlinkStreamRuleSets {
// expand grouping sets
DecomposeGroupingSetsRule.INSTANCE,

// rank rules
FlinkLogicalRankRule.CONSTANT_RANGE_ALL_FUNCTIONS_INSTANCE,
// transpose calc past rank to reduce rank input fields
CalcRankTransposeRule.INSTANCE,
// remove output of rank number when it is a constant
ConstantRankNumberColumnRemoveRule.INSTANCE,

// calc rules
FlinkFilterCalcMergeRule.INSTANCE,
FlinkProjectCalcMergeRule.INSTANCE,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,70 @@ class FlinkLogicalRankRuleForConstantRange extends FlinkLogicalRankRuleBase {
}
}

/**
* This rule handles all [[SqlRankFunction]] types (ROW_NUMBER, RANK, DENSE_RANK) with both constant
* and variable rank ranges. Unlike [[FlinkLogicalRankRuleForRangeEnd]], it does not throw an
* exception for [[ConstantRankRangeWithoutEnd]], making it safe for use in the Volcano (cost-based)
* optimizer phase where exceptions in `matches()` are wrapped and cause planning failures.
*
* This rule silently rejects [[ConstantRankRangeWithoutEnd]] (rank end not specified) rather than
* throwing, deferring that error to [[FlinkLogicalRankRuleForRangeEnd]] in a later HEP phase where
* exceptions are properly surfaced to the user.
*/
class FlinkLogicalRankRuleForConstantRangeAllFunctions extends FlinkLogicalRankRuleBase {
override def matches(call: RelOptRuleCall): Boolean = {
val calc: FlinkLogicalCalc = call.rel(0)
val window: FlinkLogicalOverAggregate = call.rel(1)

if (window.groups.size > 1) {
// only accept one window
return false
}

val group = window.groups.get(0)
if (group.aggCalls.size > 1) {
// only accept one agg call
return false
}

val agg = group.aggCalls.get(0)
if (!agg.getOperator.isInstanceOf[SqlRankFunction]) {
// only accept SqlRankFunction for Rank
return false
}

if (group.lowerBound.isUnbounded && group.upperBound.isCurrentRow) {
val condition = calc.getProgram.getCondition
if (condition != null) {
val predicate = calc.getProgram.expandLocalRef(condition)
// the rank function is the last field of FlinkLogicalOverAggregate
val rankFieldIndex = window.getRowType.getFieldCount - 1
val tableConfig = unwrapTableConfig(calc)
val (rankRange, remainingPreds) = RankUtil.extractRankRange(
predicate,
rankFieldIndex,
calc.getCluster.getRexBuilder,
tableConfig,
unwrapClassLoader(calc))

// remaining predicate must not access rank field attributes
val remainingPredsAccessRank = remainingPreds.isDefined &&
RankUtil.accessesRankField(remainingPreds.get, rankFieldIndex)

// accept any rank range except ConstantRankRangeWithoutEnd (rank end not specified)
rankRange.exists(!_.isInstanceOf[ConstantRankRangeWithoutEnd]) &&
!remainingPredsAccessRank
} else {
false
}
} else {
false
}
}
}

object FlinkLogicalRankRule {
val INSTANCE = new FlinkLogicalRankRuleForRangeEnd
val CONSTANT_RANGE_INSTANCE = new FlinkLogicalRankRuleForConstantRange
val CONSTANT_RANGE_ALL_FUNCTIONS_INSTANCE = new FlinkLogicalRankRuleForConstantRangeAllFunctions
}
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,8 @@ LogicalProject(a=[$0], b=[$1], rk1=[$2], rk2=[$3])
</Resource>
<Resource name="optimized rel plan">
<![CDATA[
FlinkLogicalCalc(select=[a, b, rk2 AS rk1, rk2])
+- FlinkLogicalRank(rankType=[RANK], rankRange=[rankStart=1, rankEnd=9], partitionBy=[b], orderBy=[a ASC], select=[a, b, rk2])
FlinkLogicalCalc(select=[a, b, w0$o0 AS rk1, w0$o0 AS rk2])
+- FlinkLogicalRank(rankType=[RANK], rankRange=[rankStart=1, rankEnd=9], partitionBy=[b], orderBy=[a ASC], select=[a, b, w0$o0])
+- FlinkLogicalCalc(select=[a, b])
+- FlinkLogicalTableSourceScan(table=[[default_catalog, default_database, MyTable]], fields=[a, b, c])
]]>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -219,8 +219,8 @@ LogicalProject(a=[$0], rk=[$1], b=[$2], c=[$3])
</Resource>
<Resource name="optimized exec plan">
<![CDATA[
Calc(select=[a, rk, b, c])
+- Rank(strategy=[AppendFastStrategy], rankType=[RANK], rankRange=[rankStart=1, rankEnd=9], partitionBy=[a], orderBy=[a ASC], select=[a, b, c, rk])
Calc(select=[a, w0$o0 AS rk, b, c])
+- Rank(strategy=[AppendFastStrategy], rankType=[RANK], rankRange=[rankStart=1, rankEnd=9], partitionBy=[a], orderBy=[a ASC], select=[a, b, c, w0$o0])
+- Exchange(distribution=[hash[a]])
+- Calc(select=[a, b, c])
+- DataStreamScan(table=[[default_catalog, default_database, MyTable]], fields=[a, b, c, proctime, rowtime])
Expand Down Expand Up @@ -464,6 +464,42 @@ Calc(select=[CONCAT('http://txmov2.a.yximgs.com', uri) AS url, reqcount AS downl
+- Rank(strategy=[AppendFastStrategy], rankType=[ROW_NUMBER], rankRange=[rankStart=1, rankEnd=100000], partitionBy=[start_time, bucket_id], orderBy=[reqcount DESC], select=[uri, reqcount, start_time, bucket_id])
+- Exchange(distribution=[hash[start_time, bucket_id]])
+- DataStreamScan(table=[[default_catalog, default_database, MyTable1]], fields=[uri, reqcount, start_time, bucket_id])
]]>
</Resource>
</TestCase>
<TestCase name="testRowNumberWithCaseWhenAndWhereClause">
<Resource name="sql">
<![CDATA[
SELECT a, b, category FROM (
SELECT a, b,
CASE WHEN c > 10 THEN 'big' ELSE 'small' END as category,
row_num
FROM (
SELECT a, b, c,
ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC) as row_num
FROM MyTable)
WHERE row_num <= 2
)
WHERE b <> 'z'
]]>
</Resource>
<Resource name="ast">
<![CDATA[
LogicalProject(a=[$0], b=[$1], category=[$2])
+- LogicalFilter(condition=[<>($1, _UTF-16LE'z')])
+- LogicalProject(a=[$0], b=[$1], category=[CASE(>($2, 10), _UTF-16LE'big':VARCHAR(5) CHARACTER SET "UTF-16LE", _UTF-16LE'small':VARCHAR(5) CHARACTER SET "UTF-16LE")], row_num=[$3])
+- LogicalFilter(condition=[<=($3, 2)])
+- LogicalProject(a=[$0], b=[$1], c=[$2], row_num=[ROW_NUMBER() OVER (PARTITION BY $0 ORDER BY $2 DESC NULLS LAST)])
+- LogicalTableScan(table=[[default_catalog, default_database, MyTable]])
]]>
</Resource>
<Resource name="optimized exec plan">
<![CDATA[
Calc(select=[a, b, CASE((c > 10), 'big', 'small') AS category], where=[(b <> 'z')])
+- Rank(strategy=[AppendFastStrategy], rankType=[ROW_NUMBER], rankRange=[rankStart=1, rankEnd=2], partitionBy=[a], orderBy=[c DESC], select=[a, b, c])
+- Exchange(distribution=[hash[a]])
+- Calc(select=[a, b, c])
+- DataStreamScan(table=[[default_catalog, default_database, MyTable]], fields=[a, b, c, proctime, rowtime])
]]>
</Resource>
</TestCase>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1014,5 +1014,24 @@ class RankTest extends TableTestBase {
util.verifyExplainInsert(sql, ExplainDetail.CHANGELOG_MODE)
}

@Test
def testRowNumberWithCaseWhenAndWhereClause(): Unit = {
val sql =
"""
|SELECT a, b, category FROM (
| SELECT a, b,
| CASE WHEN c > 10 THEN 'big' ELSE 'small' END as category,
| row_num
| FROM (
| SELECT a, b, c,
| ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC) as row_num
| FROM MyTable)
| WHERE row_num <= 2
|)
|WHERE b <> 'z'
""".stripMargin
util.verifyExecPlan(sql)
}

// TODO add tests about multi-sinks and udf
}