Skip to content

Commit 11727a4

Browse files
authored
Support mvmap eval function (opensearch-project#4856)
* Support mvmap eval function Signed-off-by: Kai Huang <ahkcs@amazon.com> # Conflicts: # docs/user/ppl/functions/collection.rst # integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalciteArrayFunctionIT.java # ppl/src/main/antlr/OpenSearchPPLLexer.g4 # ppl/src/main/antlr/OpenSearchPPLParser.g4 # ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLArrayFunctionTest.java # Conflicts: # core/src/main/java/org/opensearch/sql/expression/function/PPLFuncImpTable.java * update UT Signed-off-by: Kai Huang <ahkcs@amazon.com> * update error handling and test cases Signed-off-by: Kai Huang <ahkcs@amazon.com> * fixes Signed-off-by: Kai Huang <ahkcs@amazon.com> * update collection.md Signed-off-by: Kai Huang <ahkcs@amazon.com> * Update to support referencing single-value field Signed-off-by: Kai Huang <ahkcs@amazon.com> * update to use visitor pattern Signed-off-by: Kai Huang <ahkcs@amazon.com> * Addressing CodeRabbit Signed-off-by: Kai Huang <ahkcs@amazon.com> * Update Signed-off-by: Kai Huang <ahkcs@amazon.com> --------- Signed-off-by: Kai Huang <ahkcs@amazon.com>
1 parent 56569f3 commit 11727a4

File tree

15 files changed

+499
-11
lines changed

15 files changed

+499
-11
lines changed

core/src/main/java/org/opensearch/sql/calcite/CalcitePlanContext.java

Lines changed: 73 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import static org.opensearch.sql.calcite.utils.OpenSearchTypeFactory.TYPE_FACTORY;
99

1010
import java.sql.Connection;
11+
import java.util.ArrayList;
1112
import java.util.HashMap;
1213
import java.util.List;
1314
import java.util.Map;
@@ -61,6 +62,16 @@ public class CalcitePlanContext {
6162

6263
@Getter public Map<String, RexLambdaRef> rexLambdaRefMap;
6364

65+
/**
66+
* List of captured variables from outer scope for lambda functions. When a lambda body references
67+
* a field that is not a lambda parameter, it gets captured and stored here. The captured
68+
* variables are passed as additional arguments to the transform function.
69+
*/
70+
@Getter private List<RexNode> capturedVariables;
71+
72+
/** Whether we're currently inside a lambda context. */
73+
@Getter @Setter private boolean inLambdaContext = false;
74+
6475
private CalcitePlanContext(FrameworkConfig config, SysLimit sysLimit, QueryType queryType) {
6576
this.config = config;
6677
this.sysLimit = sysLimit;
@@ -70,6 +81,24 @@ private CalcitePlanContext(FrameworkConfig config, SysLimit sysLimit, QueryType
7081
this.rexBuilder = new ExtendedRexBuilder(relBuilder.getRexBuilder());
7182
this.functionProperties = new FunctionProperties(QueryType.PPL);
7283
this.rexLambdaRefMap = new HashMap<>();
84+
this.capturedVariables = new ArrayList<>();
85+
}
86+
87+
/**
88+
* Private constructor for creating a context that shares relBuilder with parent. Used by clone()
89+
* to create lambda contexts that can resolve fields from the parent context.
90+
*/
91+
private CalcitePlanContext(CalcitePlanContext parent) {
92+
this.config = parent.config;
93+
this.sysLimit = parent.sysLimit;
94+
this.queryType = parent.queryType;
95+
this.connection = parent.connection;
96+
this.relBuilder = parent.relBuilder; // Share the same relBuilder
97+
this.rexBuilder = parent.rexBuilder; // Share the same rexBuilder
98+
this.functionProperties = parent.functionProperties;
99+
this.rexLambdaRefMap = new HashMap<>(); // New map for lambda variables
100+
this.capturedVariables = new ArrayList<>(); // New list for captured variables
101+
this.inLambdaContext = true; // Mark that we're inside a lambda
73102
}
74103

75104
public RexNode resolveJoinCondition(
@@ -101,8 +130,13 @@ public Optional<RexCorrelVariable> peekCorrelVar() {
101130
}
102131
}
103132

133+
/**
134+
* Creates a clone of this context that shares the relBuilder with the parent. This allows lambda
135+
* expressions to reference fields from the current row while having their own lambda variable
136+
* mappings.
137+
*/
104138
public CalcitePlanContext clone() {
105-
return new CalcitePlanContext(config, sysLimit, queryType);
139+
return new CalcitePlanContext(this);
106140
}
107141

108142
public static CalcitePlanContext create(
@@ -134,4 +168,42 @@ public static boolean isLegacyPreferred() {
134168
public void putRexLambdaRefMap(Map<String, RexLambdaRef> candidateMap) {
135169
this.rexLambdaRefMap.putAll(candidateMap);
136170
}
171+
172+
/**
173+
* Captures an external variable for use inside a lambda. Returns a RexLambdaRef that references
174+
* the captured variable by its index in the captured variables list. The actual RexNode value is
175+
* stored in capturedVariables and will be passed as additional arguments to the transform
176+
* function.
177+
*
178+
* @param fieldRef The RexInputRef representing the external field
179+
* @param fieldName The name of the field being captured
180+
* @return A RexLambdaRef that can be used inside the lambda to reference the captured value
181+
*/
182+
public RexLambdaRef captureVariable(RexNode fieldRef, String fieldName) {
183+
// Check if this variable is already captured
184+
for (int i = 0; i < capturedVariables.size(); i++) {
185+
if (capturedVariables.get(i).equals(fieldRef)) {
186+
// Return existing reference - offset by number of lambda params (1 for array element)
187+
return rexLambdaRefMap.get("__captured_" + i);
188+
}
189+
}
190+
191+
// Add to captured variables list
192+
int captureIndex = capturedVariables.size();
193+
capturedVariables.add(fieldRef);
194+
195+
// Create a lambda ref for this captured variable
196+
// The index is offset by the number of lambda parameters (1 for single-param lambda)
197+
// Count only actual lambda parameters, not captured variables
198+
int lambdaParamCount =
199+
(int)
200+
rexLambdaRefMap.keySet().stream().filter(key -> !key.startsWith("__captured_")).count();
201+
RexLambdaRef lambdaRef =
202+
new RexLambdaRef(lambdaParamCount + captureIndex, fieldName, fieldRef.getType());
203+
204+
// Store it so we can find it again if the same field is referenced multiple times
205+
rexLambdaRefMap.put("__captured_" + captureIndex, lambdaRef);
206+
207+
return lambdaRef;
208+
}
137209
}

core/src/main/java/org/opensearch/sql/calcite/CalciteRexNodeVisitor.java

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -297,6 +297,20 @@ public RexNode visitLambdaFunction(LambdaFunction node, CalcitePlanContext conte
297297
TYPE_FACTORY.createSqlType(SqlTypeName.ANY))))
298298
.collect(Collectors.toList());
299299
RexNode body = node.getFunction().accept(this, context);
300+
301+
// Add captured variables as additional lambda parameters
302+
// They are stored with keys like "__captured_0", "__captured_1", etc.
303+
List<RexNode> capturedVars = context.getCapturedVariables();
304+
if (capturedVars != null && !capturedVars.isEmpty()) {
305+
args = new ArrayList<>(args);
306+
for (int i = 0; i < capturedVars.size(); i++) {
307+
RexLambdaRef capturedRef = context.getRexLambdaRefMap().get("__captured_" + i);
308+
if (capturedRef != null) {
309+
args.add(capturedRef);
310+
}
311+
}
312+
}
313+
300314
RexNode lambdaNode = context.rexBuilder.makeLambdaCall(body, args);
301315
return lambdaNode;
302316
} catch (Exception e) {
@@ -390,6 +404,7 @@ public RexNode visitFunction(Function node, CalcitePlanContext context) {
390404
context.setInCoalesceFunction(true);
391405
}
392406

407+
List<RexNode> capturedVars = null;
393408
try {
394409
for (UnresolvedExpression arg : args) {
395410
if (arg instanceof LambdaFunction) {
@@ -408,6 +423,8 @@ public RexNode visitFunction(Function node, CalcitePlanContext context) {
408423
lambdaNode = analyze(arg, lambdaContext);
409424
}
410425
arguments.add(lambdaNode);
426+
// Capture any external variables that were referenced in the lambda
427+
capturedVars = lambdaContext.getCapturedVariables();
411428
} else {
412429
arguments.add(analyze(arg, context));
413430
}
@@ -418,6 +435,15 @@ public RexNode visitFunction(Function node, CalcitePlanContext context) {
418435
}
419436
}
420437

438+
// For transform/mvmap functions with captured variables, add them as additional arguments
439+
if (capturedVars != null && !capturedVars.isEmpty()) {
440+
if (node.getFuncName().equalsIgnoreCase("mvmap")
441+
|| node.getFuncName().equalsIgnoreCase("transform")) {
442+
arguments = new ArrayList<>(arguments);
443+
arguments.addAll(capturedVars);
444+
}
445+
}
446+
421447
if ("LIKE".equalsIgnoreCase(node.getFuncName()) && arguments.size() == 2) {
422448
RexNode defaultCaseSensitive =
423449
CalcitePlanContext.isLegacyPreferred()

core/src/main/java/org/opensearch/sql/calcite/QualifiedNameResolver.java

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -64,12 +64,30 @@ private static RexNode resolveInNonJoinCondition(
6464
QualifiedName nameNode, CalcitePlanContext context) {
6565
log.debug("resolveInNonJoinCondition() called with nameNode={}", nameNode);
6666

67-
return resolveLambdaVariable(nameNode, context)
68-
.or(() -> resolveFieldDirectly(nameNode, context, 1))
69-
.or(() -> resolveFieldWithAlias(nameNode, context, 1))
70-
.or(() -> resolveFieldWithoutAlias(nameNode, context, 1))
71-
.or(() -> resolveRenamedField(nameNode, context))
72-
.or(() -> resolveCorrelationField(nameNode, context))
67+
// First try to resolve as lambda variable
68+
Optional<RexNode> lambdaVar = resolveLambdaVariable(nameNode, context);
69+
if (lambdaVar.isPresent()) {
70+
return lambdaVar.get();
71+
}
72+
73+
// Try to resolve as regular field
74+
Optional<RexNode> fieldRef =
75+
resolveFieldDirectly(nameNode, context, 1)
76+
.or(() -> resolveFieldWithAlias(nameNode, context, 1))
77+
.or(() -> resolveFieldWithoutAlias(nameNode, context, 1))
78+
.or(() -> resolveRenamedField(nameNode, context));
79+
80+
if (fieldRef.isPresent()) {
81+
// If we're in a lambda context and this is not a lambda variable,
82+
// we need to capture it as an external variable
83+
if (context.isInLambdaContext()) {
84+
log.debug("Capturing external field {} in lambda context", nameNode);
85+
return context.captureVariable(fieldRef.get(), nameNode.toString());
86+
}
87+
return fieldRef.get();
88+
}
89+
90+
return resolveCorrelationField(nameNode, context)
7391
.or(() -> replaceWithNullLiteralInCoalesce(context))
7492
.orElseThrow(() -> getNotFoundException(nameNode));
7593
}

core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ public enum BuiltinFunctionName {
7979
MVZIP(FunctionName.of("mvzip")),
8080
SPLIT(FunctionName.of("split")),
8181
MVDEDUP(FunctionName.of("mvdedup")),
82+
MVMAP(FunctionName.of("mvmap")),
8283
FORALL(FunctionName.of("forall")),
8384
EXISTS(FunctionName.of("exists")),
8485
FILTER(FunctionName.of("filter")),

core/src/main/java/org/opensearch/sql/expression/function/CollectionUDF/TransformFunctionImpl.java

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,12 @@ public static Object eval(Object... args) {
7373
List<Object> target = (List<Object>) args[0];
7474
List<Object> results = new ArrayList<>();
7575
SqlTypeName returnType = (SqlTypeName) args[args.length - 1];
76+
77+
// Check for captured variables: args structure is [array, lambda, captured1, captured2, ...,
78+
// returnType]
79+
// If there are more than 3 args (array, lambda, returnType), we have captured variables
80+
boolean hasCapturedVars = args.length > 3;
81+
7682
if (args[1] instanceof org.apache.calcite.linq4j.function.Function1) {
7783
org.apache.calcite.linq4j.function.Function1 lambdaFunction =
7884
(org.apache.calcite.linq4j.function.Function1) args[1];
@@ -90,9 +96,27 @@ public static Object eval(Object... args) {
9096
org.apache.calcite.linq4j.function.Function2 lambdaFunction =
9197
(org.apache.calcite.linq4j.function.Function2) args[1];
9298
try {
93-
for (int i = 0; i < target.size(); i++) {
94-
results.add(
95-
transferLambdaOutputToTargetType(lambdaFunction.apply(target.get(i), i), returnType));
99+
if (hasCapturedVars) {
100+
// Lambda has captured variables - pass the first captured variable as second arg
101+
// LIMITATION: Currently only the first captured variable (args[2]) is supported.
102+
// Supporting multiple captured variables would require either:
103+
// 1. Packing args[2..args.length-1] into an Object[] and modifying lambda generation
104+
// to accept a container as the second parameter, or
105+
// 2. Using higher-arity function interfaces (Function3, Function4, etc.)
106+
// For now, lambdas that capture multiple external variables may not work correctly.
107+
Object capturedVar = args[2];
108+
for (Object candidate : target) {
109+
results.add(
110+
transferLambdaOutputToTargetType(
111+
lambdaFunction.apply(candidate, capturedVar), returnType));
112+
}
113+
} else {
114+
// Original behavior: lambda with index (x, i) -> expr
115+
for (int i = 0; i < target.size(); i++) {
116+
results.add(
117+
transferLambdaOutputToTargetType(
118+
lambdaFunction.apply(target.get(i), i), returnType));
119+
}
96120
}
97121
} catch (Exception e) {
98122
throw new RuntimeException(e);

core/src/main/java/org/opensearch/sql/expression/function/PPLFuncImpTable.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,7 @@
155155
import static org.opensearch.sql.expression.function.BuiltinFunctionName.MVFIND;
156156
import static org.opensearch.sql.expression.function.BuiltinFunctionName.MVINDEX;
157157
import static org.opensearch.sql.expression.function.BuiltinFunctionName.MVJOIN;
158+
import static org.opensearch.sql.expression.function.BuiltinFunctionName.MVMAP;
158159
import static org.opensearch.sql.expression.function.BuiltinFunctionName.MVZIP;
159160
import static org.opensearch.sql.expression.function.BuiltinFunctionName.NOT;
160161
import static org.opensearch.sql.expression.function.BuiltinFunctionName.NOTEQUAL;
@@ -1039,6 +1040,7 @@ void populate() {
10391040
registerOperator(MVDEDUP, SqlLibraryOperators.ARRAY_DISTINCT);
10401041
registerOperator(MVFIND, PPLBuiltinOperators.MVFIND);
10411042
registerOperator(MVZIP, PPLBuiltinOperators.MVZIP);
1043+
registerOperator(MVMAP, PPLBuiltinOperators.TRANSFORM);
10421044
registerOperator(MAP_APPEND, PPLBuiltinOperators.MAP_APPEND);
10431045
registerOperator(MAP_CONCAT, SqlLibraryOperators.MAP_CONCAT);
10441046
registerOperator(MAP_REMOVE, PPLBuiltinOperators.MAP_REMOVE);

docs/user/ppl/functions/collection.md

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -806,6 +806,74 @@ fetched rows / total rows = 1/1
806806
+--------------------------+
807807
```
808808

809+
## MVMAP
810+
811+
### Description
812+
813+
Usage: mvmap(array, expression) iterates over each element of a multivalue array, applies the expression to each element, and returns a multivalue array with the transformed results. The field name in the expression is implicitly bound to each element value.
814+
Argument type: array: ARRAY, expression: EXPRESSION
815+
Return type: ARRAY
816+
Example
817+
818+
```ppl
819+
source=people
820+
| eval array = array(1, 2, 3), result = mvmap(array, array * 10)
821+
| fields result
822+
| head 1
823+
```
824+
825+
Expected output:
826+
827+
```text
828+
fetched rows / total rows = 1/1
829+
+------------+
830+
| result |
831+
|------------|
832+
| [10,20,30] |
833+
+------------+
834+
```
835+
836+
```ppl
837+
source=people
838+
| eval array = array(1, 2, 3), result = mvmap(array, array + 5)
839+
| fields result
840+
| head 1
841+
```
842+
843+
Expected output:
844+
845+
```text
846+
fetched rows / total rows = 1/1
847+
+---------+
848+
| result |
849+
|---------|
850+
| [6,7,8] |
851+
+---------+
852+
```
853+
854+
Note: For nested expressions like ``mvmap(mvindex(arr, 1, 3), arr * 2)``, the field name (``arr``) is extracted from the first argument and must match the field referenced in the expression.
855+
856+
The expression can also reference other single-value fields:
857+
858+
```ppl
859+
source=people
860+
| eval array = array(1, 2, 3), multiplier = 10, result = mvmap(array, array * multiplier)
861+
| fields result
862+
| head 1
863+
```
864+
865+
Expected output:
866+
867+
```text
868+
fetched rows / total rows = 1/1
869+
+------------+
870+
| result |
871+
|------------|
872+
| [10,20,30] |
873+
+------------+
874+
```
875+
876+
809877
## MVZIP
810878

811879
### Description

0 commit comments

Comments
 (0)