Skip to content

Commit 9073fae

Browse files
committed
Merge remote-tracking branch 'origin/main' into issues/4576
Signed-off-by: Yuanchun Shen <yuanchu@amazon.com>
2 parents 4a8850a + bf37067 commit 9073fae

File tree

210 files changed

+5249
-3264
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

210 files changed

+5249
-3264
lines changed

DEVELOPER_GUIDE.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -318,6 +318,7 @@ For test cases, you can use the cases in the following checklist in case you mis
318318
- *Explain*
319319

320320
- DSL for simple query
321+
- Script for complex expressions, see details in `intro-scripts <./docs/dev/intro-scripts.md>`_.
321322
- Execution plan for complex query like JOIN
322323

323324
- *Response format*

benchmarks/src/jmh/java/org/opensearch/sql/expression/operator/predicate/ExpressionScriptSerdeBenchmark.java

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
import org.opensearch.sql.expression.function.PPLFuncImpTable;
3434
import org.opensearch.sql.opensearch.storage.serde.DefaultExpressionSerializer;
3535
import org.opensearch.sql.opensearch.storage.serde.RelJsonSerializer;
36+
import org.opensearch.sql.opensearch.storage.serde.ScriptParameterHelper;
3637

3738
@Warmup(iterations = 1)
3839
@Measurement(iterations = 10)
@@ -74,7 +75,9 @@ public void testRexNodeJsonSerde() {
7475
SqlStdOperatorTable.NOT_EQUALS, rexUpper, rexBuilder.makeLiteral("ABOUT"));
7576
Map<String, ExprType> fieldTypes = Map.of("Referer", ExprCoreType.STRING);
7677

77-
String serializedStr = relJsonSerializer.serialize(rexNotEquals, rowType, fieldTypes);
78+
String serializedStr =
79+
relJsonSerializer.serialize(
80+
rexNotEquals, new ScriptParameterHelper(rowType.getFieldList(), fieldTypes));
7881
relJsonSerializer.deserialize(serializedStr);
7982
}
8083
}

core/src/main/java/org/opensearch/sql/ast/tree/SPath.java

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55

66
package org.opensearch.sql.ast.tree;
77

8+
import static org.opensearch.sql.common.utils.StringUtils.unquoteText;
9+
810
import com.google.common.collect.ImmutableList;
911
import java.util.List;
1012
import lombok.AllArgsConstructor;
@@ -48,15 +50,16 @@ public <T, C> T accept(AbstractNodeVisitor<T, C> nodeVisitor, C context) {
4850

4951
public Eval rewriteAsEval() {
5052
String outField = this.outField;
53+
String unquotedPath = unquoteText(this.path);
5154
if (outField == null) {
52-
outField = this.path;
55+
outField = unquotedPath;
5356
}
5457

5558
return AstDSL.eval(
5659
this.child,
5760
AstDSL.let(
5861
AstDSL.field(outField),
5962
AstDSL.function(
60-
"json_extract", AstDSL.field(inField), AstDSL.stringLiteral(this.path))));
63+
"json_extract", AstDSL.field(inField), AstDSL.stringLiteral(unquotedPath))));
6164
}
6265
}

core/src/main/java/org/opensearch/sql/calcite/CalciteRelNodeVisitor.java

Lines changed: 43 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1703,7 +1703,10 @@ public RelNode visitStreamWindow(StreamWindow node, CalcitePlanContext context)
17031703
new String[] {ROW_NUMBER_COLUMN_FOR_STREAMSTATS});
17041704
}
17051705

1706-
// Default
1706+
// Default: first get rawExpr
1707+
List<RexNode> overExpressions =
1708+
node.getWindowFunctionList().stream().map(w -> rexVisitor.analyze(w, context)).toList();
1709+
17071710
if (hasGroup) {
17081711
// only build sequence when there is by condition
17091712
RexNode streamSeq =
@@ -1714,21 +1717,54 @@ public RelNode visitStreamWindow(StreamWindow node, CalcitePlanContext context)
17141717
.rowsTo(RexWindowBounds.CURRENT_ROW)
17151718
.as(ROW_NUMBER_COLUMN_FOR_STREAMSTATS);
17161719
context.relBuilder.projectPlus(streamSeq);
1717-
}
17181720

1719-
List<RexNode> overExpressions =
1720-
node.getWindowFunctionList().stream().map(w -> rexVisitor.analyze(w, context)).toList();
1721-
context.relBuilder.projectPlus(overExpressions);
1721+
// construct groupNotNull predicate
1722+
List<RexNode> groupByList =
1723+
groupList.stream().map(expr -> rexVisitor.analyze(expr, context)).toList();
1724+
List<RexNode> notNullList =
1725+
PlanUtils.getSelectColumns(groupByList).stream()
1726+
.map(context.relBuilder::field)
1727+
.map(context.relBuilder::isNotNull)
1728+
.toList();
1729+
RexNode groupNotNull = context.relBuilder.and(notNullList);
17221730

1723-
// resort when there is by condition
1724-
if (hasGroup) {
1731+
// wrap each expr: CASE WHEN groupNotNull THEN rawExpr ELSE CAST(NULL AS rawType) END
1732+
List<RexNode> wrappedOverExprs =
1733+
wrapWindowFunctionsWithGroupNotNull(overExpressions, groupNotNull, context);
1734+
context.relBuilder.projectPlus(wrappedOverExprs);
1735+
// resort when there is by condition
17251736
context.relBuilder.sort(context.relBuilder.field(ROW_NUMBER_COLUMN_FOR_STREAMSTATS));
17261737
context.relBuilder.projectExcept(context.relBuilder.field(ROW_NUMBER_COLUMN_FOR_STREAMSTATS));
1738+
} else {
1739+
context.relBuilder.projectPlus(overExpressions);
17271740
}
17281741

17291742
return context.relBuilder.peek();
17301743
}
17311744

1745+
private List<RexNode> wrapWindowFunctionsWithGroupNotNull(
1746+
List<RexNode> overExpressions, RexNode groupNotNull, CalcitePlanContext context) {
1747+
List<RexNode> wrappedOverExprs = new ArrayList<>(overExpressions.size());
1748+
for (RexNode overExpr : overExpressions) {
1749+
RexNode rawExpr = overExpr;
1750+
String aliasName = null;
1751+
if (overExpr instanceof RexCall rc && rc.getOperator() == SqlStdOperatorTable.AS) {
1752+
rawExpr = rc.getOperands().get(0);
1753+
if (rc.getOperands().size() >= 2 && rc.getOperands().get(1) instanceof RexLiteral lit) {
1754+
aliasName = lit.getValueAs(String.class);
1755+
}
1756+
}
1757+
RexNode nullLiteral = context.rexBuilder.makeNullLiteral(rawExpr.getType());
1758+
RexNode caseExpr =
1759+
context.rexBuilder.makeCall(SqlStdOperatorTable.CASE, groupNotNull, rawExpr, nullLiteral);
1760+
if (aliasName != null) {
1761+
caseExpr = context.relBuilder.alias(caseExpr, aliasName);
1762+
}
1763+
wrappedOverExprs.add(caseExpr);
1764+
}
1765+
return wrappedOverExprs;
1766+
}
1767+
17321768
private RelNode buildStreamWindowJoinPlan(
17331769
CalcitePlanContext context,
17341770
RelNode leftWithHelpers,

core/src/main/java/org/opensearch/sql/calcite/plan/LogicalSystemLimit.java

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -65,16 +65,15 @@ private LogicalSystemLimit(
6565
}
6666

6767
public static LogicalSystemLimit create(SystemLimitType type, RelNode input, RexNode fetch) {
68-
return create(type, input, input.getTraitSet().getCollation(), null, fetch);
68+
return create(type, input, null, fetch);
6969
}
7070

7171
public static LogicalSystemLimit create(
72-
SystemLimitType type,
73-
RelNode input,
74-
RelCollation collation,
75-
@Nullable RexNode offset,
76-
@Nullable RexNode fetch) {
72+
SystemLimitType type, RelNode input, @Nullable RexNode offset, @Nullable RexNode fetch) {
7773
RelOptCluster cluster = input.getCluster();
74+
List<RelCollation> collations = input.getTraitSet().getTraits(RelCollationTraitDef.INSTANCE);
75+
// When there exists multiple sets of equivalent collations, we randomly select one
76+
RelCollation collation = collations == null ? null : collations.get(0);
7877
collation = RelCollationTraitDef.INSTANCE.canonize(collation);
7978
RelTraitSet traitSet = input.getTraitSet().replace(Convention.NONE).replace(collation);
8079
return new LogicalSystemLimit(type, cluster, traitSet, input, collation, offset, fetch);

core/src/main/java/org/opensearch/sql/calcite/utils/PlanUtils.java

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -532,6 +532,23 @@ static boolean sortByFieldsOnly(Sort sort) {
532532
return !sort.getCollation().getFieldCollations().isEmpty() && sort.fetch == null;
533533
}
534534

535+
/**
536+
* Check if the sort collation points to non field project expression.
537+
*
538+
* @param sort the sort operator adding sort order over project
539+
* @param project project operation that may contain non field expressions
540+
* @return flag to indicate whether non field project expression will be sorted
541+
*/
542+
static boolean sortReferencesExpr(Sort sort, Project project) {
543+
if (sort.getCollation().getFieldCollations().isEmpty()) {
544+
return false;
545+
}
546+
return sort.getCollation().getFieldCollations().stream()
547+
.anyMatch(
548+
relFieldCollation ->
549+
project.getProjects().get(relFieldCollation.getFieldIndex()) instanceof RexCall);
550+
}
551+
535552
/**
536553
* Get a string representation of the argument types expressed in ExprType for error messages.
537554
*

core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,11 +68,13 @@ public enum BuiltinFunctionName {
6868
/** Collection functions */
6969
ARRAY(FunctionName.of("array")),
7070
ARRAY_LENGTH(FunctionName.of("array_length")),
71+
ARRAY_SLICE(FunctionName.of("array_slice"), true),
7172
MAP_APPEND(FunctionName.of("map_append"), true),
7273
MAP_CONCAT(FunctionName.of("map_concat"), true),
7374
MAP_REMOVE(FunctionName.of("map_remove"), true),
7475
MVAPPEND(FunctionName.of("mvappend")),
7576
MVJOIN(FunctionName.of("mvjoin")),
77+
MVINDEX(FunctionName.of("mvindex")),
7678
FORALL(FunctionName.of("forall")),
7779
EXISTS(FunctionName.of("exists")),
7880
FILTER(FunctionName.of("filter")),
Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.sql.expression.function.CollectionUDF;
7+
8+
import static org.opensearch.sql.expression.function.BuiltinFunctionName.ADDFUNCTION;
9+
import static org.opensearch.sql.expression.function.BuiltinFunctionName.ARRAY_LENGTH;
10+
import static org.opensearch.sql.expression.function.BuiltinFunctionName.ARRAY_SLICE;
11+
import static org.opensearch.sql.expression.function.BuiltinFunctionName.IF;
12+
import static org.opensearch.sql.expression.function.BuiltinFunctionName.INTERNAL_ITEM;
13+
import static org.opensearch.sql.expression.function.BuiltinFunctionName.LESS;
14+
import static org.opensearch.sql.expression.function.BuiltinFunctionName.SUBTRACT;
15+
16+
import java.math.BigDecimal;
17+
import org.apache.calcite.rex.RexBuilder;
18+
import org.apache.calcite.rex.RexNode;
19+
import org.opensearch.sql.expression.function.PPLFuncImpTable;
20+
21+
/**
22+
* MVINDEX function implementation that returns a subset of a multivalue array.
23+
*
24+
* <p>Usage:
25+
*
26+
* <ul>
27+
* <li>mvindex(array, start) - returns single element at index (0-based)
28+
* <li>mvindex(array, start, end) - returns array slice from start to end (inclusive, 0-based)
29+
* </ul>
30+
*
31+
* <p>Supports negative indexing where -1 refers to the last element.
32+
*
33+
* <p>Implementation notes:
34+
*
35+
* <ul>
36+
* <li>Single element access uses Calcite's ITEM operator (1-based indexing)
37+
* <li>Range access uses Calcite's ARRAY_SLICE operator (0-based indexing with length parameter)
38+
* <li>Index conversion handles the difference between PPL's 0-based indexing and Calcite's
39+
* conventions
40+
* </ul>
41+
*/
42+
public class MVIndexFunctionImp implements PPLFuncImpTable.FunctionImp {
43+
44+
@Override
45+
public RexNode resolve(RexBuilder builder, RexNode... args) {
46+
RexNode array = args[0];
47+
RexNode startIdx = args[1];
48+
49+
// Use resolve to get array length instead of direct makeCall
50+
RexNode arrayLen = PPLFuncImpTable.INSTANCE.resolve(builder, ARRAY_LENGTH, array);
51+
52+
if (args.length == 2) {
53+
// Single element access using ITEM (1-based indexing)
54+
return resolveSingleElement(builder, array, startIdx, arrayLen);
55+
} else {
56+
// Range access using ARRAY_SLICE (0-based indexing)
57+
RexNode endIdx = args[2];
58+
return resolveRange(builder, array, startIdx, endIdx, arrayLen);
59+
}
60+
}
61+
62+
/**
63+
* Resolves single element access: mvindex(array, index)
64+
*
65+
* <p>Uses Calcite's ITEM operator which uses 1-based indexing. Converts PPL's 0-based index to
66+
* 1-based by adding 1.
67+
*/
68+
private RexNode resolveSingleElement(
69+
RexBuilder builder, RexNode array, RexNode startIdx, RexNode arrayLen) {
70+
// Convert 0-based PPL index to 1-based Calcite ITEM index
71+
RexNode zero = builder.makeExactLiteral(BigDecimal.ZERO);
72+
RexNode one = builder.makeExactLiteral(BigDecimal.ONE);
73+
74+
RexNode isNegative = PPLFuncImpTable.INSTANCE.resolve(builder, LESS, startIdx, zero);
75+
RexNode sumArrayLenStart =
76+
PPLFuncImpTable.INSTANCE.resolve(builder, ADDFUNCTION, arrayLen, startIdx);
77+
RexNode negativeCase =
78+
PPLFuncImpTable.INSTANCE.resolve(builder, ADDFUNCTION, sumArrayLenStart, one);
79+
RexNode positiveCase = PPLFuncImpTable.INSTANCE.resolve(builder, ADDFUNCTION, startIdx, one);
80+
81+
RexNode normalizedStart =
82+
PPLFuncImpTable.INSTANCE.resolve(builder, IF, isNegative, negativeCase, positiveCase);
83+
84+
return PPLFuncImpTable.INSTANCE.resolve(builder, INTERNAL_ITEM, array, normalizedStart);
85+
}
86+
87+
/**
88+
* Resolves range access: mvindex(array, start, end)
89+
*
90+
* <p>Uses Calcite's ARRAY_SLICE operator which uses 0-based indexing and a length parameter.
91+
* PPL's end index is inclusive, so length = (end - start) + 1.
92+
*/
93+
private RexNode resolveRange(
94+
RexBuilder builder, RexNode array, RexNode startIdx, RexNode endIdx, RexNode arrayLen) {
95+
// Normalize negative indices for ARRAY_SLICE (0-based)
96+
RexNode zero = builder.makeExactLiteral(BigDecimal.ZERO);
97+
RexNode one = builder.makeExactLiteral(BigDecimal.ONE);
98+
99+
RexNode isStartNegative = PPLFuncImpTable.INSTANCE.resolve(builder, LESS, startIdx, zero);
100+
RexNode startNegativeCase =
101+
PPLFuncImpTable.INSTANCE.resolve(builder, ADDFUNCTION, arrayLen, startIdx);
102+
RexNode normalizedStart =
103+
PPLFuncImpTable.INSTANCE.resolve(builder, IF, isStartNegative, startNegativeCase, startIdx);
104+
105+
RexNode isEndNegative = PPLFuncImpTable.INSTANCE.resolve(builder, LESS, endIdx, zero);
106+
RexNode endNegativeCase =
107+
PPLFuncImpTable.INSTANCE.resolve(builder, ADDFUNCTION, arrayLen, endIdx);
108+
RexNode normalizedEnd =
109+
PPLFuncImpTable.INSTANCE.resolve(builder, IF, isEndNegative, endNegativeCase, endIdx);
110+
111+
// Calculate length: (normalizedEnd - normalizedStart) + 1
112+
RexNode diff =
113+
PPLFuncImpTable.INSTANCE.resolve(builder, SUBTRACT, normalizedEnd, normalizedStart);
114+
RexNode length = PPLFuncImpTable.INSTANCE.resolve(builder, ADDFUNCTION, diff, one);
115+
116+
// Call ARRAY_SLICE(array, normalizedStart, length)
117+
return PPLFuncImpTable.INSTANCE.resolve(builder, ARRAY_SLICE, array, normalizedStart, length);
118+
}
119+
}

core/src/main/java/org/opensearch/sql/expression/function/PPLBuiltinOperators.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -98,10 +98,10 @@
9898
import org.opensearch.sql.expression.function.udf.math.ConvFunction;
9999
import org.opensearch.sql.expression.function.udf.math.DivideFunction;
100100
import org.opensearch.sql.expression.function.udf.math.EulerFunction;
101-
import org.opensearch.sql.expression.function.udf.math.MaxFunction;
102-
import org.opensearch.sql.expression.function.udf.math.MinFunction;
103101
import org.opensearch.sql.expression.function.udf.math.ModFunction;
104102
import org.opensearch.sql.expression.function.udf.math.NumberToStringFunction;
103+
import org.opensearch.sql.expression.function.udf.math.ScalarMaxFunction;
104+
import org.opensearch.sql.expression.function.udf.math.ScalarMinFunction;
105105

106106
/** Defines functions and operators that are implemented only by PPL */
107107
public class PPLBuiltinOperators extends ReflectiveSqlOperatorTable {
@@ -131,8 +131,8 @@ public class PPLBuiltinOperators extends ReflectiveSqlOperatorTable {
131131
public static final SqlOperator DIVIDE = new DivideFunction().toUDF("DIVIDE");
132132
public static final SqlOperator SHA2 = CryptographicFunction.sha2().toUDF("SHA2");
133133
public static final SqlOperator CIDRMATCH = new CidrMatchFunction().toUDF("CIDRMATCH");
134-
public static final SqlOperator MAX = new MaxFunction().toUDF("MAX");
135-
public static final SqlOperator MIN = new MinFunction().toUDF("MIN");
134+
public static final SqlOperator SCALAR_MAX = new ScalarMaxFunction().toUDF("SCALAR_MAX");
135+
public static final SqlOperator SCALAR_MIN = new ScalarMinFunction().toUDF("SCALAR_MIN");
136136

137137
public static final SqlOperator COSH =
138138
adaptMathFunctionToUDF(

core/src/main/java/org/opensearch/sql/expression/function/PPLFuncImpTable.java

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import static org.opensearch.sql.expression.function.BuiltinFunctionName.AND;
1818
import static org.opensearch.sql.expression.function.BuiltinFunctionName.ARRAY;
1919
import static org.opensearch.sql.expression.function.BuiltinFunctionName.ARRAY_LENGTH;
20+
import static org.opensearch.sql.expression.function.BuiltinFunctionName.ARRAY_SLICE;
2021
import static org.opensearch.sql.expression.function.BuiltinFunctionName.ASCII;
2122
import static org.opensearch.sql.expression.function.BuiltinFunctionName.ASIN;
2223
import static org.opensearch.sql.expression.function.BuiltinFunctionName.ATAN;
@@ -149,6 +150,7 @@
149150
import static org.opensearch.sql.expression.function.BuiltinFunctionName.MULTIPLYFUNCTION;
150151
import static org.opensearch.sql.expression.function.BuiltinFunctionName.MULTI_MATCH;
151152
import static org.opensearch.sql.expression.function.BuiltinFunctionName.MVAPPEND;
153+
import static org.opensearch.sql.expression.function.BuiltinFunctionName.MVINDEX;
152154
import static org.opensearch.sql.expression.function.BuiltinFunctionName.MVJOIN;
153155
import static org.opensearch.sql.expression.function.BuiltinFunctionName.NOT;
154156
import static org.opensearch.sql.expression.function.BuiltinFunctionName.NOTEQUAL;
@@ -282,6 +284,7 @@
282284
import org.opensearch.sql.calcite.utils.UserDefinedFunctionUtils;
283285
import org.opensearch.sql.exception.ExpressionEvaluationException;
284286
import org.opensearch.sql.executor.QueryType;
287+
import org.opensearch.sql.expression.function.CollectionUDF.MVIndexFunctionImp;
285288

286289
public class PPLFuncImpTable {
287290
private static final Logger logger = LogManager.getLogger(PPLFuncImpTable.class);
@@ -847,8 +850,8 @@ void populate() {
847850
registerOperator(INTERNAL_TRANSLATE3, SqlLibraryOperators.TRANSLATE3);
848851

849852
// Register eval functions for PPL max() and min() calls
850-
registerOperator(MAX, PPLBuiltinOperators.MAX);
851-
registerOperator(MIN, PPLBuiltinOperators.MIN);
853+
registerOperator(MAX, PPLBuiltinOperators.SCALAR_MAX);
854+
registerOperator(MIN, PPLBuiltinOperators.SCALAR_MIN);
852855

853856
// Register PPL UDF operator
854857
registerOperator(COSH, PPLBuiltinOperators.COSH);
@@ -972,12 +975,25 @@ void populate() {
972975
builder.makeCall(SqlLibraryOperators.ARRAY_JOIN, array, delimiter),
973976
PPLTypeChecker.family(SqlTypeFamily.ARRAY, SqlTypeFamily.CHARACTER));
974977

978+
// Register MVINDEX to use Calcite's ITEM/ARRAY_SLICE with index normalization
979+
register(
980+
MVINDEX,
981+
new MVIndexFunctionImp(),
982+
PPLTypeChecker.wrapComposite(
983+
(CompositeOperandTypeChecker)
984+
OperandTypes.family(SqlTypeFamily.ARRAY, SqlTypeFamily.INTEGER)
985+
.or(
986+
OperandTypes.family(
987+
SqlTypeFamily.ARRAY, SqlTypeFamily.INTEGER, SqlTypeFamily.INTEGER)),
988+
false));
989+
975990
registerOperator(ARRAY, PPLBuiltinOperators.ARRAY);
976991
registerOperator(MVAPPEND, PPLBuiltinOperators.MVAPPEND);
977992
registerOperator(MAP_APPEND, PPLBuiltinOperators.MAP_APPEND);
978993
registerOperator(MAP_CONCAT, SqlLibraryOperators.MAP_CONCAT);
979994
registerOperator(MAP_REMOVE, PPLBuiltinOperators.MAP_REMOVE);
980995
registerOperator(ARRAY_LENGTH, SqlLibraryOperators.ARRAY_LENGTH);
996+
registerOperator(ARRAY_SLICE, SqlLibraryOperators.ARRAY_SLICE);
981997
registerOperator(FORALL, PPLBuiltinOperators.FORALL);
982998
registerOperator(EXISTS, PPLBuiltinOperators.EXISTS);
983999
registerOperator(FILTER, PPLBuiltinOperators.FILTER);

0 commit comments

Comments
 (0)