Skip to content

Commit 83fa865

Browse files
committed
feat: substrait builder extra apifeat: substrait builder extra api
Signed-off-by: MBWhite <whitemat@uk.ibm.com>
1 parent af52a69 commit 83fa865

4 files changed

Lines changed: 474 additions & 13 deletions

File tree

.bob/notes/pending-notes.txt

Whitespace-only changes.

build-logic/.kotlin/sessions/kotlin-compiler-17666991750707048222.salive

Whitespace-only changes.

core/src/main/java/io/substrait/dsl/SubstraitBuilder.java

Lines changed: 209 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import io.substrait.expression.Expression.SwitchClause;
1313
import io.substrait.expression.FieldReference;
1414
import io.substrait.expression.FunctionArg;
15+
import io.substrait.expression.FunctionOption;
1516
import io.substrait.expression.WindowBound;
1617
import io.substrait.extension.DefaultExtensionCatalog;
1718
import io.substrait.extension.SimpleExtension;
@@ -39,6 +40,7 @@
3940
import io.substrait.type.NamedStruct;
4041
import io.substrait.type.Type;
4142
import io.substrait.type.TypeCreator;
43+
import java.util.ArrayList;
4244
import java.util.Arrays;
4345
import java.util.Collections;
4446
import java.util.LinkedList;
@@ -73,6 +75,17 @@ public class SubstraitBuilder {
7375

7476
private final SimpleExtension.ExtensionCollection extensions;
7577

78+
/**
79+
* Constructs a new SubstraitBuilder with the default extension collection.
80+
*
81+
* <p>The builder is initialized with {@link DefaultExtensionCatalog#DEFAULT_COLLECTION}, which
82+
* includes standard Substrait functions for strings, arithmetic, comparison, datetime, and other
83+
* operations.
84+
*/
85+
public SubstraitBuilder() {
86+
this(DefaultExtensionCatalog.DEFAULT_COLLECTION);
87+
}
88+
7689
/**
7790
* Constructs a new SubstraitBuilder with the specified extension collection.
7891
*
@@ -83,6 +96,18 @@ public SubstraitBuilder(SimpleExtension.ExtensionCollection extensions) {
8396
this.extensions = extensions;
8497
}
8598

99+
/**
100+
* Gets the default extension collection used by this builder.
101+
*
102+
* <p>This collection includes standard Substrait functions for strings, arithmetic, comparison,
103+
* datetime, and other operations from {@link DefaultExtensionCatalog#DEFAULT_COLLECTION}.
104+
*
105+
* @return the ExtensionCollection containing standard Substrait functions
106+
*/
107+
public SimpleExtension.ExtensionCollection getExtensions() {
108+
return extensions;
109+
}
110+
86111
// Relations
87112

88113
/**
@@ -142,13 +167,32 @@ public Aggregate aggregate(
142167
return aggregate(groupingsFn, measuresFn, Optional.of(remap), input);
143168
}
144169

145-
private Aggregate aggregate(
146-
Function<Rel, List<Aggregate.Grouping>> groupingsFn,
147-
Function<Rel, List<Aggregate.Measure>> measuresFn,
148-
Optional<Rel.Remap> remap,
149-
Rel input) {
150-
List<Aggregate.Grouping> groupings = groupingsFn.apply(input);
151-
List<Aggregate.Measure> measures = measuresFn.apply(input);
170+
/**
171+
* Creates an aggregate relation that groups and aggregates data from an input relation.
172+
*
173+
* <p>This method constructs a Substrait aggregate operation by applying grouping and measure
174+
* functions to the input relation. The grouping function defines how rows are grouped together,
175+
* while the measure function defines the aggregate computations (e.g., SUM, COUNT, AVG) to
176+
* perform on each group.
177+
*
178+
* <p>The optional remap parameter allows reordering or filtering of output columns, which is
179+
* useful for controlling the final schema of the aggregate result.
180+
*
181+
* @param groupingsFn a function that takes the input relation and returns a list of grouping
182+
* expressions defining how to partition the data
183+
* @param measuresFn a function that takes the input relation and returns a list of aggregate
184+
* measures to compute for each group
185+
* @param remap an optional remapping specification to reorder or filter output columns
186+
* @param input the input relation to aggregate
187+
* @return an Aggregate relation representing the grouping and aggregation operation
188+
*/
189+
public Aggregate aggregate(
190+
final Function<Rel, List<Aggregate.Grouping>> groupingsFn,
191+
final Function<Rel, List<Aggregate.Measure>> measuresFn,
192+
final Optional<Rel.Remap> remap,
193+
final Rel input) {
194+
final List<Aggregate.Grouping> groupings = groupingsFn.apply(input);
195+
final List<Aggregate.Measure> measures = measuresFn.apply(input);
152196
return Aggregate.builder()
153197
.groupings(groupings)
154198
.measures(measures)
@@ -853,24 +897,64 @@ public Expression.BoolLiteral bool(boolean v) {
853897
return Expression.BoolLiteral.builder().value(v).build();
854898
}
855899

900+
/**
901+
* Create i16 literal.
902+
*
903+
* @param value value to create
904+
* @return i16 instance
905+
*/
906+
public Expression.I8Literal i8(final int value) {
907+
return Expression.I8Literal.builder().value(value).build();
908+
}
909+
910+
/**
911+
* Create i16 literal.
912+
*
913+
* @param value value to create
914+
* @return i16 instance
915+
*/
916+
public Expression.I16Literal i16(final int value) {
917+
return Expression.I16Literal.builder().value(value).build();
918+
}
919+
856920
/**
857921
* Creates a 32-bit integer literal expression.
858922
*
859-
* @param v the integer value
923+
* @param value the integer value
860924
* @return a new {@link Expression.I32Literal}
861925
*/
862-
public Expression.I32Literal i32(int v) {
863-
return Expression.I32Literal.builder().value(v).build();
926+
public Expression.I32Literal i32(final int value) {
927+
return Expression.I32Literal.builder().value(value).build();
928+
}
929+
930+
/**
931+
* Creates a 64-bit integer literal expression.
932+
*
933+
* @param value value to create
934+
* @return i64 instance
935+
*/
936+
public Expression.I64Literal i64(final long value) {
937+
return Expression.I64Literal.builder().value(value).build();
938+
}
939+
940+
/**
941+
* Creates a 32-bit floating point literal expression.
942+
*
943+
* @param value the float value
944+
* @return a new {@link Expression.FP32Literal}
945+
*/
946+
public Expression.FP32Literal fp32(final float value) {
947+
return Expression.FP32Literal.builder().value(value).build();
864948
}
865949

866950
/**
867951
* Creates a 64-bit floating point literal expression.
868952
*
869-
* @param v the double value
953+
* @param value the double value
870954
* @return a new {@link Expression.FP64Literal}
871955
*/
872-
public Expression.FP64Literal fp64(double v) {
873-
return Expression.FP64Literal.builder().value(v).build();
956+
public Expression.FP64Literal fp64(final double value) {
957+
return Expression.FP64Literal.builder().value(value).build();
874958
}
875959

876960
/**
@@ -1439,6 +1523,79 @@ public Expression.ScalarFunctionInvocation or(Expression... args) {
14391523
return scalarFn(DefaultExtensionCatalog.FUNCTIONS_BOOLEAN, "or:bool", outputType, args);
14401524
}
14411525

1526+
/**
1527+
* Creates a logical NOT expression that negates a boolean expression.
1528+
*
1529+
* <p>This is a convenience method that wraps the boolean NOT function from the Substrait standard
1530+
* library. The result is nullable to handle NULL input values according to three-valued logic.
1531+
*
1532+
* @param expression the boolean expression to negate
1533+
* @return a scalar function invocation representing the logical NOT of the input expression
1534+
*/
1535+
public Expression not(final Expression expression) {
1536+
return this.scalarFn(
1537+
DefaultExtensionCatalog.FUNCTIONS_BOOLEAN,
1538+
"not:bool",
1539+
TypeCreator.NULLABLE.BOOLEAN,
1540+
expression);
1541+
}
1542+
1543+
/**
1544+
* Creates a null-check expression that tests whether an expression is null.
1545+
*
1546+
* <p>This is a convenience method that wraps the is_null function from the Substrait comparison
1547+
* function library. The function evaluates the input expression and returns true if it is null,
1548+
* false otherwise. This is commonly used in conditional logic and filtering operations.
1549+
*
1550+
* <p>The return type is always a required (non-nullable) boolean, as the null check itself always
1551+
* produces a definite true/false result.
1552+
*
1553+
* @param expression the expression to test for null
1554+
* @return a scalar function invocation that returns true if the expression is null, false
1555+
* otherwise
1556+
*/
1557+
public Expression isNull(final Expression expression) {
1558+
1559+
final List<Expression> args = new ArrayList<>();
1560+
args.add(expression);
1561+
1562+
return this.scalarFn(
1563+
DefaultExtensionCatalog.FUNCTIONS_COMPARISON,
1564+
"is_null:any",
1565+
TypeCreator.REQUIRED.BOOLEAN,
1566+
args,
1567+
new ArrayList<FunctionOption>());
1568+
}
1569+
1570+
/**
1571+
* Creates a scalar function invocation with function options.
1572+
*
1573+
* <p>This method extends the base builder's functionality by supporting function options, which
1574+
* control function behavior (e.g., rounding modes, overflow handling).
1575+
*
1576+
* @param urn the extension URI (e.g., {@link DefaultExtensionCatalog#FUNCTIONS_STRING})
1577+
* @param key the function signature (e.g., "substring:str_i32_i32")
1578+
* @param returnType the return type of the function
1579+
* @param args the function arguments
1580+
* @param optionsList the function options controlling behavior
1581+
* @return a scalar function invocation expression
1582+
*/
1583+
public Expression scalarFn(
1584+
final String urn,
1585+
final String key,
1586+
final Type returnType,
1587+
final List<? extends FunctionArg> args,
1588+
final List<FunctionOption> optionsList) {
1589+
final SimpleExtension.ScalarFunctionVariant declaration =
1590+
extensions.getScalarFunction(SimpleExtension.FunctionAnchor.of(urn, key));
1591+
return Expression.ScalarFunctionInvocation.builder()
1592+
.declaration(declaration)
1593+
.options(optionsList)
1594+
.outputType(returnType)
1595+
.arguments(args)
1596+
.build();
1597+
}
1598+
14421599
/**
14431600
* Creates a scalar function invocation with specified arguments.
14441601
*
@@ -1459,6 +1616,29 @@ public Expression.ScalarFunctionInvocation scalarFn(
14591616
.build();
14601617
}
14611618

1619+
/**
1620+
* Creates a scalar function invocation with function options.
1621+
*
1622+
* @param urn the extension URI (e.g., {@link DefaultExtensionCatalog#FUNCTIONS_STRING})
1623+
* @param key the function signature (e.g., "substring:str_i32_i32")
1624+
* @param returnType the return type of the function
1625+
* @param args the function arguments
1626+
* @return a scalar function invocation expression
1627+
*/
1628+
public Expression scalarFn(
1629+
final String urn,
1630+
final String key,
1631+
final Type returnType,
1632+
final List<? extends FunctionArg> args) {
1633+
final SimpleExtension.ScalarFunctionVariant declaration =
1634+
extensions.getScalarFunction(SimpleExtension.FunctionAnchor.of(urn, key));
1635+
return Expression.ScalarFunctionInvocation.builder()
1636+
.declaration(declaration)
1637+
.outputType(returnType)
1638+
.arguments(args)
1639+
.build();
1640+
}
1641+
14621642
/**
14631643
* Creates a window function invocation with specified arguments and window bounds.
14641644
*
@@ -1532,6 +1712,22 @@ public Plan plan(Plan.Root root) {
15321712
return Plan.builder().addRoots(root).build();
15331713
}
15341714

1715+
/**
1716+
* Creates a Plan.Root, which is the top-level container for a Substrait query plan.
1717+
*
1718+
* <p>The {@link Plan} wraps a relational expression tree and associates output column names with
1719+
* the plan. This is the final step in building a complete Substrait plan that can be serialized
1720+
* and executed by a Substrait consumer.
1721+
*
1722+
* @param input the root relational expression of the query plan
1723+
* @param names the ordered list of output column names corresponding to the input relation's
1724+
* output schema
1725+
* @return a new {@link Plan}
1726+
*/
1727+
public Plan.Root root(final Rel input, final List<String> names) {
1728+
return Plan.Root.builder().input(input).names(names).build();
1729+
}
1730+
15351731
/**
15361732
* Creates a field remapping specification from field indexes.
15371733
*

0 commit comments

Comments
 (0)