Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
201 changes: 186 additions & 15 deletions core/src/main/java/io/substrait/dsl/SubstraitBuilder.java
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import io.substrait.expression.Expression.SwitchClause;
import io.substrait.expression.FieldReference;
import io.substrait.expression.FunctionArg;
import io.substrait.expression.FunctionOption;
import io.substrait.expression.WindowBound;
import io.substrait.extension.DefaultExtensionCatalog;
import io.substrait.extension.SimpleExtension;
Expand Down Expand Up @@ -39,6 +40,7 @@
import io.substrait.type.NamedStruct;
import io.substrait.type.Type;
import io.substrait.type.TypeCreator;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.LinkedList;
Expand Down Expand Up @@ -73,6 +75,17 @@ public class SubstraitBuilder {

private final SimpleExtension.ExtensionCollection extensions;

/**
* Constructs a new SubstraitBuilder with the default extension collection.
*
* <p>The builder is initialized with {@link DefaultExtensionCatalog#DEFAULT_COLLECTION}, which
* includes standard Substrait functions for strings, arithmetic, comparison, datetime, and other
* operations.
*/
public SubstraitBuilder() {
this(DefaultExtensionCatalog.DEFAULT_COLLECTION);
}

/**
* Constructs a new SubstraitBuilder with the specified extension collection.
*
Expand All @@ -83,6 +96,15 @@ public SubstraitBuilder(SimpleExtension.ExtensionCollection extensions) {
this.extensions = extensions;
}

/**
* Gets the extension collection used by this builder.
*
* @return the ExtensionCollection used with this builder
*/
public SimpleExtension.ExtensionCollection getExtensions() {
return extensions;
}

// Relations

/**
Expand Down Expand Up @@ -124,25 +146,25 @@ public Aggregate aggregate(
}

/**
* Creates an aggregate relation with a single grouping and output field remapping.
* Creates an aggregate relation that groups and aggregates data from an input relation.
*
* @param groupingFn function to derive the grouping from the input relation
* @param measuresFn function to derive the measures from the input relation
* @param remap the output field remapping specification
* <p>This method constructs a Substrait aggregate operation by applying grouping and measure
* functions to the input relation. The grouping function defines how rows are grouped together,
* while the measure function defines the aggregate computations (e.g., SUM, COUNT, AVG) to
* perform on each group.
*
* <p>The optional remap parameter allows reordering or filtering of output columns, which is
* useful for controlling the final schema of the aggregate result.
*
* @param groupingsFn a function that takes the input relation and returns a list of grouping
* expressions defining how to partition the data
* @param measuresFn a function that takes the input relation and returns a list of aggregate
* measures to compute for each group
* @param remap an optional remapping specification to reorder or filter output columns
* @param input the input relation to aggregate
* @return a new {@link Aggregate} relation
* @return an Aggregate relation representing the grouping and aggregation operation
*/
public Aggregate aggregate(
Function<Rel, Aggregate.Grouping> groupingFn,
Function<Rel, List<Aggregate.Measure>> measuresFn,
Rel.Remap remap,
Rel input) {
Function<Rel, List<Aggregate.Grouping>> groupingsFn =
groupingFn.andThen(g -> Stream.of(g).collect(Collectors.toList()));
return aggregate(groupingsFn, measuresFn, Optional.of(remap), input);
}

private Aggregate aggregate(
Function<Rel, List<Aggregate.Grouping>> groupingsFn,
Function<Rel, List<Aggregate.Measure>> measuresFn,
Optional<Rel.Remap> remap,
Expand Down Expand Up @@ -853,6 +875,26 @@ public Expression.BoolLiteral bool(boolean v) {
return Expression.BoolLiteral.builder().value(v).build();
}

/**
* Create i16 literal.
*
* @param value value to create
* @return i16 instance
*/
public Expression.I8Literal i8(int value) {
return Expression.I8Literal.builder().value(value).build();
}

/**
* Create i16 literal.
*
* @param value value to create
* @return i16 instance
*/
public Expression.I16Literal i16(int value) {
return Expression.I16Literal.builder().value(value).build();
}

/**
* Creates a 32-bit integer literal expression.
*
Expand All @@ -863,6 +905,26 @@ public Expression.I32Literal i32(int v) {
return Expression.I32Literal.builder().value(v).build();
}

/**
* Creates a 64-bit integer literal expression.
*
* @param value value to create
* @return i64 instance
*/
public Expression.I64Literal i64(long value) {
return Expression.I64Literal.builder().value(value).build();
}

/**
* Creates a 32-bit floating point literal expression.
*
* @param value the float value
* @return a new {@link Expression.FP32Literal}
*/
public Expression.FP32Literal fp32(float value) {
return Expression.FP32Literal.builder().value(value).build();
}

/**
* Creates a 64-bit floating point literal expression.
*
Expand Down Expand Up @@ -1439,6 +1501,79 @@ public Expression.ScalarFunctionInvocation or(Expression... args) {
return scalarFn(DefaultExtensionCatalog.FUNCTIONS_BOOLEAN, "or:bool", outputType, args);
}

/**
* Creates a logical NOT expression that negates a boolean expression.
*
* <p>This is a convenience method that wraps the boolean NOT function from the Substrait standard
* library. The result is nullable to handle NULL input values according to three-valued logic.
*
* @param expression the boolean expression to negate
* @return a scalar function invocation representing the logical NOT of the input expression
*/
public Expression not(Expression expression) {
return this.scalarFn(
DefaultExtensionCatalog.FUNCTIONS_BOOLEAN,
"not:bool",
TypeCreator.NULLABLE.BOOLEAN,
expression);
}

/**
* Creates a null-check expression that tests whether an expression is null.
*
* <p>This is a convenience method that wraps the is_null function from the Substrait comparison
* function library. The function evaluates the input expression and returns true if it is null,
* false otherwise. This is commonly used in conditional logic and filtering operations.
*
* <p>The return type is always a required (non-nullable) boolean, as the null check itself always
* produces a definite true/false result.
*
* @param expression the expression to test for null
* @return a scalar function invocation that returns true if the expression is null, false
* otherwise
*/
public Expression isNull(Expression expression) {

final List<Expression> args = new ArrayList<>();
args.add(expression);

return this.scalarFn(
DefaultExtensionCatalog.FUNCTIONS_COMPARISON,
"is_null:any",
TypeCreator.REQUIRED.BOOLEAN,
args,
new ArrayList<FunctionOption>());
}

/**
* Creates a scalar function invocation with function options.
*
* <p>This method extends the base builder's functionality by supporting function options, which
* control function behavior (e.g., rounding modes, overflow handling).
*
* @param urn the extension URI (e.g., {@link DefaultExtensionCatalog#FUNCTIONS_STRING})
* @param key the function signature (e.g., "substring:str_i32_i32")
* @param returnType the return type of the function
* @param args the function arguments
* @param optionsList the function options controlling behavior
* @return a scalar function invocation expression
*/
public Expression scalarFn(
String urn,
String key,
Type returnType,
List<? extends FunctionArg> args,
List<FunctionOption> optionsList) {
SimpleExtension.ScalarFunctionVariant declaration =
extensions.getScalarFunction(SimpleExtension.FunctionAnchor.of(urn, key));
return Expression.ScalarFunctionInvocation.builder()
.declaration(declaration)
.options(optionsList)
.outputType(returnType)
.arguments(args)
.build();
}

/**
* Creates a scalar function invocation with specified arguments.
*
Expand All @@ -1459,6 +1594,26 @@ public Expression.ScalarFunctionInvocation scalarFn(
.build();
}

/**
* Creates a scalar function invocation with function options.
*
* @param urn the extension URI (e.g., {@link DefaultExtensionCatalog#FUNCTIONS_STRING})
* @param key the function signature (e.g., "substring:str_i32_i32")
* @param returnType the return type of the function
* @param args the function arguments
* @return a scalar function invocation expression
*/
public Expression scalarFn(
String urn, String key, Type returnType, List<? extends FunctionArg> args) {
SimpleExtension.ScalarFunctionVariant declaration =
extensions.getScalarFunction(SimpleExtension.FunctionAnchor.of(urn, key));
return Expression.ScalarFunctionInvocation.builder()
.declaration(declaration)
.outputType(returnType)
.arguments(args)
.build();
}

/**
* Creates a window function invocation with specified arguments and window bounds.
*
Expand Down Expand Up @@ -1532,6 +1687,22 @@ public Plan plan(Plan.Root root) {
return Plan.builder().addRoots(root).build();
}

/**
* Creates a Plan.Root, which is the top-level container for a Substrait query plan.
*
* <p>The {@link Plan} wraps a relational expression tree and associates output column names with
* the plan. This is the final step in building a complete Substrait plan that can be serialized
* and executed by a Substrait consumer.
*
* @param input the root relational expression of the query plan
* @param names the ordered list of output column names corresponding to the input relation's
* output schema
* @return a new {@link Plan}
*/
public Plan.Root root(final Rel input, final List<String> names) {
return Plan.Root.builder().input(input).names(names).build();
}

/**
* Creates a field remapping specification from field indexes.
*
Expand Down
Loading
Loading