Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2497,36 +2497,37 @@ public PlanFragment visitPhysicalRepeat(PhysicalRepeat<? extends Plan> repeat, P
PlanFragment inputPlanFragment = repeat.child(0).accept(this, context);
List<List<Expr>> distributeExprLists = getDistributeExprs(repeat.child(0));

ImmutableSet<Expression> flattenGroupingSetExprs = ImmutableSet.copyOf(
ExpressionUtils.flatExpressions(repeat.getGroupingSets()));
List<Expression> flattenGroupingExpressions = repeat.getGroupByExpressions();
Set<Slot> preRepeatExpressions = Sets.newLinkedHashSet();
// keep group by expression coming first
for (Expression groupByExpr : flattenGroupingExpressions) {
// NormalizeRepeat had converted group by expression to slot
preRepeatExpressions.add((Slot) groupByExpr);
}

List<Slot> aggregateFunctionUsedSlots = repeat.getOutputExpressions()
.stream()
.filter(output -> !flattenGroupingSetExprs.contains(output))
.filter(output -> !output.containsType(GroupingScalarFunction.class))
.distinct()
.map(NamedExpression::toSlot)
// add aggregate function used expressions
for (NamedExpression outputExpr : repeat.getOutputExpressions()) {
if (!outputExpr.containsType(GroupingScalarFunction.class)) {
preRepeatExpressions.add(outputExpr.toSlot());
}
}

List<Expr> preRepeatExprs = preRepeatExpressions.stream()
.map(expr -> ExpressionTranslator.translate(expr, context))
.collect(ImmutableList.toImmutableList());

// keep flattenGroupingSetExprs comes first
List<Expr> preRepeatExprs = Stream.concat(flattenGroupingSetExprs.stream(), aggregateFunctionUsedSlots.stream())
.map(expr -> ExpressionTranslator.translate(expr, context)).collect(ImmutableList.toImmutableList());

// outputSlots's order need same with preRepeatExprs
List<Slot> outputSlots = Stream.concat(Stream
.concat(repeat.getOutputExpressions().stream()
.filter(output -> flattenGroupingSetExprs.contains(output)),
repeat.getOutputExpressions().stream()
.filter(output -> !flattenGroupingSetExprs.contains(output))
.filter(output -> !output.containsType(GroupingScalarFunction.class))
.distinct()
),
Stream.concat(Stream.of(repeat.getGroupingId().toSlot()),
repeat.getOutputExpressions().stream()
.filter(output -> output.containsType(GroupingScalarFunction.class)))
)
.map(NamedExpression::toSlot).collect(ImmutableList.toImmutableList());
// outputSlots's order must match preRepeatExprs, then grouping id, then grouping function slots
ImmutableList.Builder<Slot> outputSlotsBuilder
= ImmutableList.builderWithExpectedSize(repeat.getOutputExpressions().size() + 1);
outputSlotsBuilder.addAll(preRepeatExpressions);
outputSlotsBuilder.add(repeat.getGroupingId().toSlot());
for (NamedExpression outputExpr : repeat.getOutputExpressions()) {
if (outputExpr.containsType(GroupingScalarFunction.class)) {
outputSlotsBuilder.add(outputExpr.toSlot());
}
}

List<Slot> outputSlots = outputSlotsBuilder.build();
// NOTE: we should first translate preRepeatExprs, then generate output tuple,
// or else the preRepeatExprs can not find the bottom slotRef and throw
// exception: invalid slot id
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,4 +107,8 @@ public String getNodeExplainString(String detailPrefix, TExplainLevel detailLeve
public boolean isSerialOperator() {
return children.get(0).isSerialOperator();
}

public GroupingInfo getGroupingInfo() {
return groupingInfo;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,11 @@

package org.apache.doris.nereids.glue.translator;

import org.apache.doris.analysis.Expr;
import org.apache.doris.analysis.GroupingInfo;
import org.apache.doris.analysis.SlotRef;
import org.apache.doris.analysis.TupleDescriptor;
import org.apache.doris.catalog.Column;
import org.apache.doris.catalog.KeysType;
import org.apache.doris.catalog.OlapTable;
import org.apache.doris.nereids.properties.DataTrait;
Expand All @@ -34,16 +39,19 @@
import org.apache.doris.nereids.trees.plans.physical.PhysicalOlapScan;
import org.apache.doris.nereids.trees.plans.physical.PhysicalProject;
import org.apache.doris.nereids.types.IntegerType;
import org.apache.doris.nereids.util.PlanChecker;
import org.apache.doris.nereids.util.PlanConstructor;
import org.apache.doris.planner.AggregationNode;
import org.apache.doris.planner.OlapScanNode;
import org.apache.doris.planner.PlanFragment;
import org.apache.doris.planner.PlanNode;
import org.apache.doris.planner.Planner;
import org.apache.doris.planner.RepeatNode;
import org.apache.doris.utframe.TestWithFeService;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Sets;
import mockit.Injectable;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
Expand All @@ -53,9 +61,18 @@
import java.util.Collections;
import java.util.List;
import java.util.Optional;
import java.util.Set;

public class PhysicalPlanTranslatorTest extends TestWithFeService {

@Override
protected void runBeforeAll() throws Exception {
createDatabase("test_db");
createTable("create table test_db.t(a int, b int) distributed by hash(a) buckets 3 "
+ "properties('replication_num' = '1');");
connectContext.getSessionVariable().setDisableNereidsRules("prune_empty_partition");
}

@Test
public void testOlapPrune(@Injectable LogicalProperties placeHolder) throws Exception {
OlapTable t1 = PlanConstructor.newOlapTable(0, "t1", 0, KeysType.AGG_KEYS);
Expand Down Expand Up @@ -93,10 +110,6 @@ public void testOlapPrune(@Injectable LogicalProperties placeHolder) throws Exce

@Test
public void testAggNeedsFinalize() throws Exception {
createDatabase("test_db");
createTable("create table test_db.t(a int, b int) distributed by hash(a) buckets 3 "
+ "properties('replication_num' = '1');");
connectContext.getSessionVariable().setDisableNereidsRules("prune_empty_partition");
String querySql = "select b from test_db.t group by b";
Planner planner = getSQLPlanner(querySql);
Assertions.assertNotNull(planner);
Expand Down Expand Up @@ -125,4 +138,31 @@ public void testAggNeedsFinalize() throws Exception {
Assertions.assertTrue(upperNeedsFinalize,
"upper AggregationNode needsFinalize should be true");
}

@Test
public void testRepeatInputOutputOrder() throws Exception {
String sql = "select grouping(a), grouping(b), grouping_id(a, b), sum(a + 2 * b), sum(a + 3 * b) + grouping_id(b, a, b), b, a, b, a"
+ " from test_db.t"
+ " group by grouping sets((a, b), (), (b), (a, b), (a + b), (a * b))";
PlanChecker.from(connectContext).checkPlannerResult(sql,
planner -> {
Set<RepeatNode> repeatNodes = Sets.newHashSet();
planner.getFragments().stream()
.map(PlanFragment::getPlanRoot)
.forEach(plan -> plan.collect(RepeatNode.class, repeatNodes));
Assertions.assertEquals(1, repeatNodes.size());
RepeatNode repeatNode = repeatNodes.iterator().next();
GroupingInfo groupingInfo = repeatNode.getGroupingInfo();
List<Expr> preRepeatExprs = groupingInfo.getPreRepeatExprs();
TupleDescriptor outputs = groupingInfo.getOutputTupleDesc();
for (int i = 0; i < preRepeatExprs.size(); i++) {
Expr inputExpr = preRepeatExprs.get(i);
Assertions.assertInstanceOf(SlotRef.class, inputExpr);
Column inputColumn = ((SlotRef) inputExpr).getColumn();
Column outputColumn = outputs.getSlots().get(i).getColumn();
Assertions.assertEquals(inputColumn, outputColumn);
}
}
);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
-- This file is automatically generated. You should know what you did if you want to edit this
-- !sql_1_shape --
PhysicalCteAnchor ( cteId=CTEId#0 )
--PhysicalCteProducer ( cteId=CTEId#0 )
----hashAgg[GLOBAL]
------PhysicalProject
--------PhysicalOlapScan[tbl_test_repeat_output_slot]
--PhysicalResultSink
----PhysicalProject
------PhysicalUnion
--------PhysicalProject
----------hashAgg[GLOBAL]
------------hashAgg[LOCAL]
--------------PhysicalRepeat
----------------PhysicalCteConsumer ( cteId=CTEId#0 )
--------PhysicalProject
----------PhysicalCteConsumer ( cteId=CTEId#0 )

-- !sql_1_result --
100000
100000
100000
100000
100000
100000
100000
100000
100000
100000
100000
100000
100000
100000
100000
100000
100000
100000
100000
100000
100000

-- !sql_2_shape --
PhysicalCteAnchor ( cteId=CTEId#0 )
--PhysicalCteProducer ( cteId=CTEId#0 )
----hashAgg[GLOBAL]
------hashAgg[LOCAL]
--------PhysicalProject
----------PhysicalOlapScan[tbl_test_repeat_output_slot]
--PhysicalResultSink
----PhysicalProject
------PhysicalUnion
--------PhysicalProject
----------filter((GROUPING_PREFIX_col_varchar_50__undef_signed__index_inverted_col_datetime_6__undef_signed_col_varchar_50__undef_signed > 0))
------------hashAgg[GLOBAL]
--------------hashAgg[LOCAL]
----------------PhysicalRepeat
------------------PhysicalCteConsumer ( cteId=CTEId#0 )
--------PhysicalEmptyRelation

-- !sql_2_result --
\N ALL 1 6 \N \N \N
\N ALL 1 6 \N \N \N
2020-01-02T00:00 ALL 1 6 \N 2020-01-02T00:00 \N
2020-01-02T00:00 ALL 1 6 \N 2020-01-02T00:00 \N
2020-01-03T00:00 ALL 1 6 \N 2020-01-03T00:00 \N
2020-01-03T00:00 ALL 1 6 \N 2020-01-03T00:00 \N
2020-01-04T00:00 ALL 1 6 \N 2020-01-04T00:00 \N
2020-01-04T00:00 ALL 1 6 \N 2020-01-04T00:00 \N
2020-01-04T00:00 ALL 1 7 \N \N \N

Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

suite("test_repeat_output_slot") {
sql """
SET enable_fallback_to_original_planner=false;
SET enable_nereids_planner=true;
SET ignore_shape_nodes='PhysicalDistribute';
SET disable_nereids_rules='PRUNE_EMPTY_PARTITION';
SET runtime_filter_mode=OFF;
SET disable_join_reorder=true;

DROP TABLE IF EXISTS tbl_test_repeat_output_slot FORCE;

"""

sql """
CREATE TABLE tbl_test_repeat_output_slot (
col_datetime_6__undef_signed datetime(6),
col_varchar_50__undef_signed varchar(50),
col_varchar_50__undef_signed__index_inverted varchar(50)
) engine=olap
distributed by hash(col_datetime_6__undef_signed) buckets 10
properties('replication_num' = '1');
"""

sql """
INSERT INTO tbl_test_repeat_output_slot VALUES
(null, null, null), (null, "a", "x"), (null, "a", "y"),
('2020-01-02', "b", "x"), ('2020-01-02', 'a', 'x'), ('2020-01-02', 'b', 'y'),
('2020-01-03', 'a', 'x'), ('2020-01-03', 'a', 'y'), ('2020-01-03', 'b', 'x'), ('2020-01-03', 'b', 'y'),
('2020-01-04', 'a', 'x'), ('2020-01-04', 'a', 'y'), ('2020-01-04', 'b', 'x'), ('2020-01-04', 'b', 'y');
"""

explainAndOrderResult 'sql_1', '''
SELECT 100000
FROM tbl_test_repeat_output_slot
GROUP BY GROUPING SETS (
(col_datetime_6__undef_signed, col_varchar_50__undef_signed)
, ()
, (col_varchar_50__undef_signed)
, (col_datetime_6__undef_signed, col_varchar_50__undef_signed)
);
'''

explainAndOrderResult 'sql_2', '''
SELECT MAX(col_datetime_6__undef_signed) AS total_col_datetime,
CASE WHEN GROUPING(col_varchar_50__undef_signed__index_inverted) = 1 THEN 'ALL'
ELSE CAST(col_varchar_50__undef_signed__index_inverted AS VARCHAR)
END AS pretty_val,
IF(GROUPING_ID(col_varchar_50__undef_signed__index_inverted,
col_datetime_6__undef_signed,
col_varchar_50__undef_signed) > 0, 1, 0) AS is_agg_row,
GROUPING_ID(col_varchar_50__undef_signed__index_inverted,
col_datetime_6__undef_signed, col_varchar_50__undef_signed) AS having_filter_col,
col_varchar_50__undef_signed__index_inverted,
col_datetime_6__undef_signed,
col_varchar_50__undef_signed
FROM tbl_test_repeat_output_slot
GROUP BY GROUPING SETS (
(col_varchar_50__undef_signed__index_inverted, col_datetime_6__undef_signed, col_varchar_50__undef_signed),
(),
(col_varchar_50__undef_signed),
(col_varchar_50__undef_signed__index_inverted, col_datetime_6__undef_signed, col_varchar_50__undef_signed),
(col_varchar_50__undef_signed))
HAVING having_filter_col > 0;
'''
}