Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@
import org.opensearch.sql.calcite.plan.OpenSearchConstants;
import org.opensearch.sql.calcite.utils.BinUtils;
import org.opensearch.sql.calcite.utils.JoinAndLookupUtils;
import org.opensearch.sql.calcite.utils.PPLHintUtils;
import org.opensearch.sql.calcite.utils.PlanUtils;
import org.opensearch.sql.calcite.utils.UserDefinedFunctionUtils;
import org.opensearch.sql.calcite.utils.WildcardUtils;
Expand Down Expand Up @@ -949,13 +950,14 @@ private boolean isCountField(RexCall call) {
* @param groupExprList group by expression list
* @param aggExprList aggregate expression list
* @param context CalcitePlanContext
* @param hintIgnoreNullBucket true if bucket_nullable=false
* @return Pair of (group-by list, field list, aggregate list)
*/
private Pair<List<RexNode>, List<AggCall>> aggregateWithTrimming(
List<UnresolvedExpression> groupExprList,
List<UnresolvedExpression> aggExprList,
CalcitePlanContext context,
boolean hintBucketNonNull) {
boolean hintIgnoreNullBucket) {
Pair<List<RexNode>, List<AggCall>> resolved =
resolveAttributesForAggregation(groupExprList, aggExprList, context);
List<RexNode> resolvedGroupByList = resolved.getLeft();
Expand Down Expand Up @@ -1047,7 +1049,9 @@ private Pair<List<RexNode>, List<AggCall>> aggregateWithTrimming(
// \- Scan t
List<RexInputRef> trimmedRefs = new ArrayList<>();
trimmedRefs.addAll(PlanUtils.getInputRefs(resolvedGroupByList)); // group-by keys first
trimmedRefs.addAll(PlanUtils.getInputRefsFromAggCall(resolvedAggCallList));
List<RexInputRef> aggCallRefs = PlanUtils.getInputRefsFromAggCall(resolvedAggCallList);
boolean hintNestedAgg = containsNestedAggregator(context.relBuilder, aggCallRefs);
trimmedRefs.addAll(aggCallRefs);
context.relBuilder.project(trimmedRefs);

// Re-resolve all attributes based on adding trimmed Project.
Expand All @@ -1059,7 +1063,8 @@ private Pair<List<RexNode>, List<AggCall>> aggregateWithTrimming(
List<String> intendedGroupKeyAliases = getGroupKeyNamesAfterAggregation(reResolved.getLeft());
context.relBuilder.aggregate(
context.relBuilder.groupKey(reResolved.getLeft()), reResolved.getRight());
if (hintBucketNonNull) PlanUtils.addIgnoreNullBucketHintToAggregate(context.relBuilder);
if (hintIgnoreNullBucket) PPLHintUtils.addIgnoreNullBucketHintToAggregate(context.relBuilder);
if (hintNestedAgg) PPLHintUtils.addNestedAggCallHintToAggregate(context.relBuilder);
// During aggregation, Calcite projects both input dependencies and output group-by fields.
// When names conflict, Calcite adds numeric suffixes (e.g., "value0").
// Apply explicit renaming to restore the intended aliases.
Expand All @@ -1068,6 +1073,17 @@ private Pair<List<RexNode>, List<AggCall>> aggregateWithTrimming(
return Pair.of(reResolved.getLeft(), reResolved.getRight());
}

/**
* Return true if the aggCalls contains a nested field. For example: aggCalls: [count(),
* count(a.b)] returns true.
*/
private boolean containsNestedAggregator(RelBuilder relBuilder, List<RexInputRef> aggCallRefs) {
return aggCallRefs.stream()
.map(r -> relBuilder.peek().getRowType().getFieldNames().get(r.getIndex()))
.map(name -> org.apache.commons.lang3.StringUtils.substringBefore(name, "."))
.anyMatch(root -> relBuilder.field(root).getType().getSqlTypeName() == SqlTypeName.ARRAY);
}

/**
* Imitates {@code Registrar.registerExpression} of {@link RelBuilder} to derive the output order
* of group-by keys after aggregation.
Expand Down Expand Up @@ -1173,8 +1189,8 @@ private void visitAggregation(
}
groupExprList.addAll(node.getGroupExprList());

// Add stats hint to LogicalAggregation.
boolean toAddHintsOnAggregate =
// Add a hint to LogicalAggregation when bucket_nullable=false.
boolean hintIgnoreNullBucket =
!groupExprList.isEmpty()
// This checks if all group-bys should be nonnull
&& nonNullGroupMask.nextClearBit(0) >= groupExprList.size();
Expand All @@ -1194,14 +1210,16 @@ private void visitAggregation(
.filter(nonNullGroupMask::get)
.mapToObj(nonNullCandidates::get)
.toList();
context.relBuilder.filter(
PlanUtils.getSelectColumns(nonNullFields).stream()
.map(context.relBuilder::field)
.map(context.relBuilder::isNotNull)
.toList());
if (!nonNullFields.isEmpty()) {
context.relBuilder.filter(
PlanUtils.getSelectColumns(nonNullFields).stream()
.map(context.relBuilder::field)
.map(context.relBuilder::isNotNull)
.toList());
}

Pair<List<RexNode>, List<AggCall>> aggregationAttributes =
aggregateWithTrimming(groupExprList, aggExprList, context, toAddHintsOnAggregate);
aggregateWithTrimming(groupExprList, aggExprList, context, hintIgnoreNullBucket);

// schema reordering
List<RexNode> outputFields = context.relBuilder.fields();
Expand Down Expand Up @@ -2329,9 +2347,9 @@ public RelNode visitRareTopN(RareTopN node, CalcitePlanContext context) {

// if usenull=false, add a isNotNull before Aggregate and the hint to this Aggregate
Boolean bucketNullable = (Boolean) argumentMap.get(RareTopN.Option.useNull.name()).getValue();
boolean toAddHintsOnAggregate = false;
boolean hintIgnoreNullBucket = false;
if (!bucketNullable && !groupExprList.isEmpty()) {
toAddHintsOnAggregate = true;
hintIgnoreNullBucket = true;
// add isNotNull filter before aggregation to filter out null bucket
List<RexNode> groupByList =
groupExprList.stream().map(expr -> rexVisitor.analyze(expr, context)).toList();
Expand All @@ -2341,7 +2359,7 @@ public RelNode visitRareTopN(RareTopN node, CalcitePlanContext context) {
.map(context.relBuilder::isNotNull)
.toList());
}
aggregateWithTrimming(groupExprList, aggExprList, context, toAddHintsOnAggregate);
aggregateWithTrimming(groupExprList, aggExprList, context, hintIgnoreNullBucket);

// 2. add count() column with sort direction
List<RexNode> partitionKeys = rexVisitor.analyze(node.getGroupExprList(), context);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@
import org.apache.calcite.jdbc.CalcitePrepare;
import org.apache.calcite.jdbc.CalciteSchema;
import org.apache.calcite.jdbc.Driver;
import org.apache.calcite.linq4j.function.Function0;
import org.apache.calcite.plan.Context;
import org.apache.calcite.plan.Contexts;
import org.apache.calcite.plan.Convention;
Expand Down Expand Up @@ -175,8 +174,11 @@ public Connection connect(
}

@Override
protected Function0<CalcitePrepare> createPrepareFactory() {
return OpenSearchPrepareImpl::new;
Comment on lines -178 to -179
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After upgraded to 1.41. this method was not called any more, change it to createPrepare()

public CalcitePrepare createPrepare() {
if (prepareFactory != null) {
return prepareFactory.get();
}
return new OpenSearchPrepareImpl();
}
}

Expand Down Expand Up @@ -298,10 +300,10 @@ public OpenSearchCalcitePreparingStmt(

@Override
protected PreparedResult implement(RelRoot root) {
Hook.PLAN_BEFORE_IMPLEMENTATION.run(root);
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This hook was called twice for non-full-scannable plan

RelDataType resultType = root.rel.getRowType();
boolean isDml = root.kind.belongsTo(SqlKind.DML);
if (root.rel instanceof Scannable scannable) {
Hook.PLAN_BEFORE_IMPLEMENTATION.run(root);
RelDataType resultType = root.rel.getRowType();
boolean isDml = root.kind.belongsTo(SqlKind.DML);
final Bindable bindable = dataContext -> scannable.scan();

return new PreparedResultImpl(
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.sql.calcite.utils;

import com.google.common.base.Suppliers;
import java.util.function.Supplier;
import lombok.experimental.UtilityClass;
import org.apache.calcite.rel.core.Aggregate;
import org.apache.calcite.rel.hint.HintStrategyTable;
import org.apache.calcite.rel.hint.RelHint;
import org.apache.calcite.rel.logical.LogicalAggregate;
import org.apache.calcite.tools.RelBuilder;

@UtilityClass
public class PPLHintUtils {
private static final String HINT_AGG_ARGUMENTS = "AGG_ARGS";
private static final String KEY_IGNORE_NULL_BUCKET = "ignoreNullBucket";
private static final String KEY_HAS_NESTED_AGG_CALL = "hasNestedAggCall";

private static final Supplier<HintStrategyTable> HINT_STRATEGY_TABLE =
Suppliers.memoize(
() ->
HintStrategyTable.builder()
.hintStrategy(
HINT_AGG_ARGUMENTS,
(hint, rel) -> {
return rel instanceof LogicalAggregate;
})
// add more here
.build());

/**
* Add hint to aggregate to indicate that the aggregate will ignore null value bucket. Notice, the
* current peek of relBuilder is expected to be LogicalAggregate.
*/
public static void addIgnoreNullBucketHintToAggregate(RelBuilder relBuilder) {
assert relBuilder.peek() instanceof LogicalAggregate
: "Hint HINT_AGG_ARGUMENTS can be added to LogicalAggregate only";
final RelHint statHint =
RelHint.builder(HINT_AGG_ARGUMENTS).hintOption(KEY_IGNORE_NULL_BUCKET, "true").build();
relBuilder.hints(statHint);
if (relBuilder.getCluster().getHintStrategies() == HintStrategyTable.EMPTY) {
relBuilder.getCluster().setHintStrategies(HINT_STRATEGY_TABLE.get());
}
}

/**
* Add hint to aggregate to indicate that the aggregate has nested agg call. Notice, the current
* peek of relBuilder is expected to be LogicalAggregate.
*/
public static void addNestedAggCallHintToAggregate(RelBuilder relBuilder) {
assert relBuilder.peek() instanceof LogicalAggregate
: "Hint HINT_AGG_ARGUMENTS can be added to LogicalAggregate only";
final RelHint statHint =
RelHint.builder(HINT_AGG_ARGUMENTS).hintOption(KEY_HAS_NESTED_AGG_CALL, "true").build();
relBuilder.hints(statHint);
if (relBuilder.getCluster().getHintStrategies() == HintStrategyTable.EMPTY) {
relBuilder.getCluster().setHintStrategies(HINT_STRATEGY_TABLE.get());
}
}

/** Return true if the aggregate will ignore null value bucket. */
public static boolean ignoreNullBucket(Aggregate aggregate) {
return aggregate.getHints().stream()
.anyMatch(
hint ->
hint.hintName.equals(PPLHintUtils.HINT_AGG_ARGUMENTS)
&& hint.kvOptions.getOrDefault(KEY_IGNORE_NULL_BUCKET, "false").equals("true"));
}

/** Return true if the aggregate has any nested agg call. */
public static boolean hasNestedAggCall(Aggregate aggregate) {
return aggregate.getHints().stream()
.anyMatch(
hint ->
hint.hintName.equals(PPLHintUtils.HINT_AGG_ARGUMENTS)
&& hint.kvOptions
.getOrDefault(KEY_HAS_NESTED_AGG_CALL, "false")
.equals("true"));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,6 @@
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rel.core.Sort;
import org.apache.calcite.rel.core.TableScan;
import org.apache.calcite.rel.hint.RelHint;
import org.apache.calcite.rel.logical.LogicalAggregate;
import org.apache.calcite.rel.logical.LogicalFilter;
import org.apache.calcite.rel.logical.LogicalProject;
import org.apache.calcite.rel.logical.LogicalSort;
Expand All @@ -62,7 +60,6 @@
import org.apache.calcite.util.mapping.Mappings;
import org.opensearch.sql.ast.AbstractNodeVisitor;
import org.opensearch.sql.ast.Node;
import org.opensearch.sql.ast.expression.Argument;
import org.opensearch.sql.ast.expression.IntervalUnit;
import org.opensearch.sql.ast.expression.SpanUnit;
import org.opensearch.sql.ast.expression.WindowBound;
Expand Down Expand Up @@ -610,15 +607,6 @@ static void replaceTop(RelBuilder relBuilder, RelNode relNode) {
}
}

static void addIgnoreNullBucketHintToAggregate(RelBuilder relBuilder) {
final RelHint statHits =
RelHint.builder("stats_args").hintOption(Argument.BUCKET_NULLABLE, "false").build();
assert relBuilder.peek() instanceof LogicalAggregate
: "Stats hits should be added to LogicalAggregate";
relBuilder.hints(statHits);
relBuilder.getCluster().setHintStrategies(PPLHintStrategyTable.getHintStrategyTable());
}

/** Extract the RexLiteral from the aggregate call if the aggregate call is a LITERAL_AGG. */
static @Nullable RexLiteral getObjectFromLiteralAgg(AggregateCall aggCall) {
if (aggCall.getAggregation().kind == SqlKind.LITERAL_AGG) {
Expand Down Expand Up @@ -655,13 +643,7 @@ private static boolean isNotNullOnRef(RexNode rex) {
&& rexCall.getOperands().get(0) instanceof RexInputRef;
}

Predicate<Aggregate> aggIgnoreNullBucket =
agg ->
agg.getHints().stream()
.anyMatch(
hint ->
hint.hintName.equals("stats_args")
&& hint.kvOptions.get(Argument.BUCKET_NULLABLE).equals("false"));
Predicate<Aggregate> aggIgnoreNullBucket = PPLHintUtils::ignoreNullBucket;

Predicate<Aggregate> maybeTimeSpanAgg =
agg ->
Expand Down
59 changes: 59 additions & 0 deletions core/src/main/java/org/opensearch/sql/utils/Utils.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.sql.utils;

import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import javax.annotation.Nullable;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.tuple.Pair;
import org.opensearch.sql.data.type.ExprCoreType;
import org.opensearch.sql.data.type.ExprType;

public interface Utils {
static <I> List<Pair<I, Integer>> zipWithIndex(List<I> input) {
LinkedList<Pair<I, Integer>> result = new LinkedList<>();
Iterator<I> iter = input.iterator();
int index = 0;
while (iter.hasNext()) {
result.add(Pair.of(iter.next(), index++));
}
return result;
}

/**
* Resolve the nested path from the field name.
*
* @param path the field name
* @param fieldTypes the field types
* @return the nested path if exists, otherwise null
*/
static @Nullable String resolveNestedPath(String path, Map<String, ExprType> fieldTypes) {
if (path == null || fieldTypes == null || fieldTypes.isEmpty()) {
return null;
}
boolean found = false;
String current = path;
String parent = StringUtils.substringBeforeLast(current, ".");
while (parent != null && !parent.equals(current)) {
ExprType pathType = fieldTypes.get(parent);
// Nested is mapped to ExprCoreType.ARRAY
if (pathType == ExprCoreType.ARRAY) {
found = true;
break;
}
current = parent;
parent = StringUtils.substringBeforeLast(current, ".");
}
if (found) {
return parent;
} else {
return null;
}
}
}
Loading
Loading