Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
185 changes: 185 additions & 0 deletions core/src/main/java/io/substrait/dsl/SubstraitBuilder.java
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import io.substrait.expression.Expression.SwitchClause;
import io.substrait.expression.FieldReference;
import io.substrait.expression.FunctionArg;
import io.substrait.expression.FunctionOption;
import io.substrait.expression.WindowBound;
import io.substrait.extension.DefaultExtensionCatalog;
import io.substrait.extension.SimpleExtension;
Expand Down Expand Up @@ -1300,6 +1301,190 @@ public Aggregate.Measure sum0(Expression expr) {
R.I64);
}

/**
* Creates a population standard deviation aggregate measure for a specific field.
*
* <p>Computes the standard deviation using the population formula (n denominator), which
* considers all values in the dataset as the entire population. This is equivalent to SQL's
* STDDEV_POP function.
*
* @param input the input relation containing the field
* @param field the zero-based index of the field to aggregate
* @return an aggregate measure computing population standard deviation with
* distribution=POPULATION option
*/
public Aggregate.Measure stddevPopulation(Rel input, int field) {
return stddevPopulation(fieldReference(input, field));
}

/**
* Creates a population standard deviation aggregate measure for an expression.
*
* <p>Computes the standard deviation using the population formula (n denominator), which
* considers all values in the dataset as the entire population. This is equivalent to SQL's
* STDDEV_POP function.
*
* <p>The measure is created with:
*
* <ul>
* <li>Function: Substrait's "std_dev" from the arithmetic extension
* <li>Option: distribution=POPULATION
* <li>Output type: nullable version of the input expression type
* <li>Aggregation phase: INITIAL_TO_RESULT
* <li>Invocation: ALL (processes all rows)
* </ul>
*
* @param expr the expression to aggregate (typically a numeric field reference)
* @return an aggregate measure computing population standard deviation
*/
public Aggregate.Measure stddevPopulation(Expression expr) {
return statisticalAggregate(expr, "std_dev", "POPULATION");
}

/**
* Creates a sample standard deviation aggregate measure for a specific field.
*
* <p>Computes the standard deviation using the sample formula (n-1 denominator), which applies
* Bessel's correction for sample data. This is equivalent to SQL's STDDEV_SAMP or STDDEV
* function.
*
* @param input the input relation containing the field
* @param field the zero-based index of the field to aggregate
* @return an aggregate measure computing sample standard deviation with distribution=SAMPLE
* option
*/
public Aggregate.Measure stddevSample(Rel input, int field) {
return stddevSample(fieldReference(input, field));
}

/**
* Creates a sample standard deviation aggregate measure for an expression.
*
* <p>Computes the standard deviation using the sample formula (n-1 denominator), which applies
* Bessel's correction for sample data. This is equivalent to SQL's STDDEV_SAMP or STDDEV
* function.
*
* <p>The measure is created with:
*
* <ul>
* <li>Function: Substrait's "std_dev" from the arithmetic extension
* <li>Option: distribution=SAMPLE
* <li>Output type: nullable version of the input expression type
* <li>Aggregation phase: INITIAL_TO_RESULT
* <li>Invocation: ALL (processes all rows)
* </ul>
*
* @param expr the expression to aggregate (typically a numeric field reference)
* @return an aggregate measure computing sample standard deviation
*/
public Aggregate.Measure stddevSample(Expression expr) {
return statisticalAggregate(expr, "std_dev", "SAMPLE");
}

/**
* Creates a population variance aggregate measure for a specific field.
*
* <p>Computes the variance using the population formula (n denominator), which considers all
* values in the dataset as the entire population. This is equivalent to SQL's VAR_POP function.
*
* @param input the input relation containing the field
* @param field the zero-based index of the field to aggregate
* @return an aggregate measure computing population variance with distribution=POPULATION option
*/
public Aggregate.Measure variancePopulation(Rel input, int field) {
return variancePopulation(fieldReference(input, field));
}

/**
* Creates a population variance aggregate measure for an expression.
*
* <p>Computes the variance using the population formula (n denominator), which considers all
* values in the dataset as the entire population. This is equivalent to SQL's VAR_POP function.
*
* <p>The measure is created with:
*
* <ul>
* <li>Function: Substrait's "variance" from the arithmetic extension
* <li>Option: distribution=POPULATION
* <li>Output type: nullable version of the input expression type
* <li>Aggregation phase: INITIAL_TO_RESULT
* <li>Invocation: ALL (processes all rows)
* </ul>
*
* @param expr the expression to aggregate (typically a numeric field reference)
* @return an aggregate measure computing population variance
*/
public Aggregate.Measure variancePopulation(Expression expr) {
return statisticalAggregate(expr, "variance", "POPULATION");
}

/**
* Creates a sample variance aggregate measure for a specific field.
*
* <p>Computes the variance using the sample formula (n-1 denominator), which applies Bessel's
* correction for sample data. This is equivalent to SQL's VAR_SAMP or VARIANCE function.
*
* @param input the input relation containing the field
* @param field the zero-based index of the field to aggregate
* @return an aggregate measure computing sample variance with distribution=SAMPLE option
*/
public Aggregate.Measure varianceSample(Rel input, int field) {
return varianceSample(fieldReference(input, field));
}

/**
* Creates a sample variance aggregate measure for an expression.
*
* <p>Computes the variance using the sample formula (n-1 denominator), which applies Bessel's
* correction for sample data. This is equivalent to SQL's VAR_SAMP or VARIANCE function.
*
* <p>The measure is created with:
*
* <ul>
* <li>Function: Substrait's "variance" from the arithmetic extension
* <li>Option: distribution=SAMPLE
* <li>Output type: nullable version of the input expression type
* <li>Aggregation phase: INITIAL_TO_RESULT
* <li>Invocation: ALL (processes all rows)
* </ul>
*
* @param expr the expression to aggregate (typically a numeric field reference)
* @return an aggregate measure computing sample variance
*/
public Aggregate.Measure varianceSample(Expression expr) {
return statisticalAggregate(expr, "variance", "SAMPLE");
}

/**
* Helper method to create statistical aggregate measures (std_dev, variance) with distribution
* option.
*
* @param expr the expression to aggregate
* @param functionName the Substrait function name ("std_dev" or "variance")
* @param distribution the distribution type ("SAMPLE" or "POPULATION")
* @return an aggregate measure with the specified distribution option
*/
private Aggregate.Measure statisticalAggregate(
Expression expr, String functionName, String distribution) {
String typeString = ToTypeString.apply(expr.getType());
SimpleExtension.AggregateFunctionVariant declaration =
extensions.getAggregateFunction(
SimpleExtension.FunctionAnchor.of(
DefaultExtensionCatalog.FUNCTIONS_ARITHMETIC,
String.format("%s:%s", functionName, typeString)));
FunctionOption distributionOption =
FunctionOption.builder().name("distribution").addValues(distribution).build();
return measure(
AggregateFunctionInvocation.builder()
.arguments(Arrays.asList(expr))
.outputType(TypeCreator.asNullable(expr.getType()))
.declaration(declaration)
.addOptions(distributionOption)
.aggregationPhase(Expression.AggregationPhase.INITIAL_TO_RESULT)
.invocation(Expression.AggregationInvocation.ALL)
.build());
}

private Aggregate.Measure singleArgumentArithmeticAggregate(
Expression expr, String functionName, Type outputType) {
String typeString = ToTypeString.apply(expr.getType());
Expand Down
64 changes: 52 additions & 12 deletions isthmus/src/main/java/io/substrait/isthmus/AggregateFunctions.java
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,30 @@ public class AggregateFunctions {
/** Substrait-specific AVG aggregate function (nullable return type). */
public static SqlAggFunction AVG = new SubstraitAvgAggFunction(SqlKind.AVG);

/**
* Standard deviation (population) aggregate function. Maps to Substrait's std_dev function with
* distribution=POPULATION option.
*/
public static SqlAggFunction STDDEV_POP = new SubstraitAvgAggFunction(SqlKind.STDDEV_POP);

/**
* Standard deviation (sample) aggregate function. Maps to Substrait's std_dev function with
* distribution=SAMPLE option.
*/
public static SqlAggFunction STDDEV_SAMP = new SubstraitAvgAggFunction(SqlKind.STDDEV_SAMP);

/**
* Variance (population) aggregate function. Maps to Substrait's variance function with
* distribution=POPULATION option.
*/
public static SqlAggFunction VAR_POP = new SubstraitAvgAggFunction(SqlKind.VAR_POP);

/**
* Variance (sample) aggregate function. Maps to Substrait's variance function with
* distribution=SAMPLE option.
*/
public static SqlAggFunction VAR_SAMP = new SubstraitAvgAggFunction(SqlKind.VAR_SAMP);

/** Substrait-specific SUM aggregate function (nullable return type). */
public static SqlAggFunction SUM = new SubstraitSumAggFunction();

Expand All @@ -42,18 +66,34 @@ public class AggregateFunctions {
* @return optional containing Substrait equivalent if conversion applies
*/
public static Optional<SqlAggFunction> toSubstraitAggVariant(SqlAggFunction aggFunction) {
if (aggFunction instanceof SqlMinMaxAggFunction) {
SqlMinMaxAggFunction fun = (SqlMinMaxAggFunction) aggFunction;
return Optional.of(
fun.getKind() == SqlKind.MIN ? AggregateFunctions.MIN : AggregateFunctions.MAX);
} else if (aggFunction instanceof SqlAvgAggFunction) {
return Optional.of(AggregateFunctions.AVG);
} else if (aggFunction instanceof SqlSumAggFunction) {
return Optional.of(AggregateFunctions.SUM);
} else if (aggFunction instanceof SqlSumEmptyIsZeroAggFunction) {
return Optional.of(AggregateFunctions.SUM0);
} else {
return Optional.empty();
// First check by SqlKind to handle all statistical functions
SqlKind kind = aggFunction.getKind();
switch (kind) {
case MIN:
return Optional.of(AggregateFunctions.MIN);
case MAX:
return Optional.of(AggregateFunctions.MAX);
case AVG:
return Optional.of(AggregateFunctions.AVG);
case STDDEV_POP:
return Optional.of(AggregateFunctions.STDDEV_POP);
case STDDEV_SAMP:
return Optional.of(AggregateFunctions.STDDEV_SAMP);
case VAR_POP:
return Optional.of(AggregateFunctions.VAR_POP);
case VAR_SAMP:
return Optional.of(AggregateFunctions.VAR_SAMP);
case SUM:
case SUM0:
// Check instance type for SUM variants
if (aggFunction instanceof SqlSumEmptyIsZeroAggFunction) {
return Optional.of(AggregateFunctions.SUM0);
} else if (aggFunction instanceof SqlSumAggFunction) {
return Optional.of(AggregateFunctions.SUM);
}
return Optional.empty();
default:
return Optional.empty();
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -398,7 +398,9 @@ private AggregateCall fromMeasure(Aggregate.Measure measure, Context context) {
.collect(java.util.stream.Collectors.toList());
Optional<SqlOperator> operator =
aggregateFunctionConverter.getSqlOperatorFromSubstraitFunc(
measure.getFunction().declaration().key(), measure.getFunction().outputType());
measure.getFunction().declaration().key(),
measure.getFunction().outputType(),
measure.getFunction().options());
if (!operator.isPresent()) {
throw new IllegalArgumentException(
String.format(
Expand Down
Loading
Loading