Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import io.substrait.function.ParameterizedTypeVisitor;
import io.substrait.function.TypeExpression;
import io.substrait.type.Type;
import io.substrait.type.TypeExpressionEvaluator;
import java.util.ArrayList;
import java.util.List;
import java.util.stream.Collectors;
Expand All @@ -18,7 +17,9 @@
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.SqlOperator;
import org.apache.calcite.sql.SqlOperatorBinding;
import org.apache.calcite.sql.fun.SqlBasicAggFunction;
import org.apache.calcite.sql.type.OperandTypes;
import org.apache.calcite.sql.type.SqlOperandTypeChecker;
import org.apache.calcite.sql.type.SqlReturnTypeInference;
import org.apache.calcite.sql.type.SqlTypeFamily;
import org.apache.calcite.sql.type.SqlTypeName;
Expand Down Expand Up @@ -139,6 +140,67 @@ private static SqlFunction toSqlFunction(
RelDataTypeFactory typeFactory,
TypeConverter typeConverter) {

if (function instanceof SimpleExtension.AggregateFunctionVariant) {
return toAggregateSqlFunction(function, typeFactory, typeConverter);
}

if (function instanceof SimpleExtension.WindowFunctionVariant) {
return toWindowSqlFunction(function, typeFactory, typeConverter);
}

return toScalarSqlFunction(function, typeFactory, typeConverter);
}

private static SqlFunction toAggregateSqlFunction(
SimpleExtension.Function function,
RelDataTypeFactory typeFactory,
TypeConverter typeConverter) {

SqlReturnTypeInference returnTypeInference =
new AggregateReturnTypeInference(function, typeFactory, typeConverter);

return SqlBasicAggFunction.create(
function.name(),
SqlKind.OTHER_FUNCTION,
returnTypeInference,
createOperandTypeChecker(function));
}

private static SqlFunction toWindowSqlFunction(
SimpleExtension.Function function,
RelDataTypeFactory typeFactory,
TypeConverter typeConverter) {

SqlReturnTypeInference returnTypeInference =
new WindowReturnTypeInference(function, typeFactory, typeConverter);

return new SqlFunction(
function.name(),
SqlKind.OTHER_FUNCTION,
returnTypeInference,
null,
createOperandTypeChecker(function),
SqlFunctionCategory.USER_DEFINED_FUNCTION);
}

private static SqlFunction toScalarSqlFunction(
SimpleExtension.Function function,
RelDataTypeFactory typeFactory,
TypeConverter typeConverter) {

SqlReturnTypeInference returnTypeInference =
new ScalarReturnTypeInference(function, typeFactory, typeConverter);

return new SqlFunction(
function.name(),
SqlKind.OTHER_FUNCTION,
returnTypeInference,
null,
createOperandTypeChecker(function),
SqlFunctionCategory.USER_DEFINED_FUNCTION);
}

private static SqlOperandTypeChecker createOperandTypeChecker(SimpleExtension.Function function) {
List<SqlTypeFamily> argFamilies = new ArrayList<>();

for (SimpleExtension.Argument arg : function.requiredArguments()) {
Expand All @@ -152,25 +214,16 @@ private static SqlFunction toSqlFunction(
}
}

SqlReturnTypeInference returnTypeInference =
new SubstraitReturnTypeInference(function, typeFactory, typeConverter);

return new SqlFunction(
function.name(),
SqlKind.OTHER_FUNCTION,
returnTypeInference,
null,
OperandTypes.family(argFamilies),
SqlFunctionCategory.USER_DEFINED_FUNCTION);
return OperandTypes.family(argFamilies);
}

private static class SubstraitReturnTypeInference implements SqlReturnTypeInference {
/** Base class for return type inference with common logic for handling concrete types. */
private abstract static class BaseReturnTypeInference implements SqlReturnTypeInference {
protected final SimpleExtension.Function function;
protected final RelDataTypeFactory typeFactory;
protected final TypeConverter typeConverter;

private final SimpleExtension.Function function;
private final RelDataTypeFactory typeFactory;
private final TypeConverter typeConverter;

private SubstraitReturnTypeInference(
protected BaseReturnTypeInference(
SimpleExtension.Function function,
RelDataTypeFactory typeFactory,
TypeConverter typeConverter) {
Expand All @@ -181,34 +234,177 @@ private SubstraitReturnTypeInference(

@Override
public RelDataType inferReturnType(SqlOperatorBinding opBinding) {
List<Type> substraitArgTypes =
opBinding.collectOperandTypes().stream()
.map(typeConverter::toSubstrait)
.collect(Collectors.toList());

TypeExpression returnExpression = function.returnType();
Type resolvedSubstraitType =
TypeExpressionEvaluator.evaluateExpression(
returnExpression, function.args(), substraitArgTypes);

boolean finalIsNullable;
// If return type is a concrete Type, use it directly
if (returnExpression instanceof Type) {
return inferConcreteReturnType((Type) returnExpression, opBinding);
}

// For parameterized types, delegate to subclass
return inferParameterizedReturnType(opBinding);
}

/**
* Infers return type for concrete (non-parameterized) type expressions.
*
* @param resolvedSubstraitType the concrete Substrait type
* @param opBinding the operator binding with operand information
* @return the inferred Calcite return type
*/
private RelDataType inferConcreteReturnType(
Type resolvedSubstraitType, SqlOperatorBinding opBinding) {
boolean finalIsNullable = determineNullability(resolvedSubstraitType, opBinding);
RelDataType baseCalciteType = typeConverter.toCalcite(typeFactory, resolvedSubstraitType);
return typeFactory.createTypeWithNullability(baseCalciteType, finalIsNullable);
}

/**
* Determines the nullability of the return type based on function nullability rules.
*
* @param resolvedSubstraitType the resolved Substrait type
* @param opBinding the operator binding with operand information
* @return true if the return type should be nullable
*/
private boolean determineNullability(Type resolvedSubstraitType, SqlOperatorBinding opBinding) {
switch (function.nullability()) {
case MIRROR:
// If any input is nullable, the output is nullable.
finalIsNullable =
opBinding.collectOperandTypes().stream().anyMatch(RelDataType::isNullable);
break;
// If any input is nullable, the output is nullable
return opBinding.collectOperandTypes().stream().anyMatch(RelDataType::isNullable);
case DISCRETE:
case DECLARED_OUTPUT:
// Use the nullability declared on the resolved Substrait type
return resolvedSubstraitType.nullable();
default:
// Use the nullability declared on the resolved Substrait type.
finalIsNullable = resolvedSubstraitType.nullable();
break;
return resolvedSubstraitType.nullable();
}
}

RelDataType baseCalciteType = typeConverter.toCalcite(typeFactory, resolvedSubstraitType);
/**
* Infer return type for parameterized type expressions (e.g., any1, T). Each function type
* implements its own logic.
*/
protected abstract RelDataType inferParameterizedReturnType(SqlOperatorBinding opBinding);
}

return typeFactory.createTypeWithNullability(baseCalciteType, finalIsNullable);
/**
* Return type inference for scalar functions. Scalar functions typically return the same type as
* their first argument.
*/
private static final class ScalarReturnTypeInference extends BaseReturnTypeInference {
private ScalarReturnTypeInference(
SimpleExtension.Function function,
RelDataTypeFactory typeFactory,
TypeConverter typeConverter) {
super(function, typeFactory, typeConverter);
}

@Override
protected RelDataType inferParameterizedReturnType(SqlOperatorBinding opBinding) {
List<RelDataType> operandTypes = opBinding.collectOperandTypes();
if (operandTypes.isEmpty()) {
throw new IllegalStateException(
String.format(
"Scalar function '%s' has parameterized return type but no arguments to infer from",
function.name()));
}

RelDataType firstArgType = operandTypes.get(0);
return applyNullabilityRules(firstArgType);
}

private RelDataType applyNullabilityRules(RelDataType baseType) {
if (function.nullability() == SimpleExtension.Nullability.DECLARED_OUTPUT) {
return typeFactory.createTypeWithNullability(baseType, true);
}
// MIRROR and other cases: keep original nullability
return baseType;
}
}

/**
* Return type inference for aggregate functions. Aggregate functions often return nullable types
* and may differ from input type.
*/
private static final class AggregateReturnTypeInference extends BaseReturnTypeInference {
private AggregateReturnTypeInference(
SimpleExtension.Function function,
RelDataTypeFactory typeFactory,
TypeConverter typeConverter) {
super(function, typeFactory, typeConverter);
}

@Override
protected RelDataType inferParameterizedReturnType(SqlOperatorBinding opBinding) {
List<RelDataType> operandTypes = opBinding.collectOperandTypes();
if (operandTypes.isEmpty()) {
// Fallback for aggregates without arguments (e.g., COUNT(*))
return createNullableBigInt();
}

RelDataType firstArgType = operandTypes.get(0);
return applyAggregateNullabilityRules(firstArgType);
}

private RelDataType applyAggregateNullabilityRules(RelDataType baseType) {
// Aggregates typically return nullable types
if (function.nullability() == SimpleExtension.Nullability.MIRROR) {
return baseType; // Keep original nullability
}
// DECLARED_OUTPUT and other cases: always nullable
return typeFactory.createTypeWithNullability(baseType, true);
}

private RelDataType createNullableBigInt() {
return typeFactory.createTypeWithNullability(
typeFactory.createSqlType(SqlTypeName.BIGINT), true);
}
}

/**
* Return type inference for window functions. Window functions have diverse return types
* depending on their category.
*/
private static final class WindowReturnTypeInference extends BaseReturnTypeInference {
private WindowReturnTypeInference(
SimpleExtension.Function function,
RelDataTypeFactory typeFactory,
TypeConverter typeConverter) {
super(function, typeFactory, typeConverter);
}

@Override
protected RelDataType inferParameterizedReturnType(SqlOperatorBinding opBinding) {
if (isRankingFunction()) {
return typeFactory.createSqlType(SqlTypeName.BIGINT);
}

List<RelDataType> operandTypes = opBinding.collectOperandTypes();
if (operandTypes.isEmpty()) {
// Fallback for window functions without arguments
return createNullableBigInt();
}

RelDataType firstArgType = operandTypes.get(0);
return applyWindowNullabilityRules(firstArgType);
}

private boolean isRankingFunction() {
String funcName = function.name().toLowerCase();
return funcName.contains("rank") || "row_number".equals(funcName) || "ntile".equals(funcName);
}

private RelDataType applyWindowNullabilityRules(RelDataType baseType) {
if (function.nullability() == SimpleExtension.Nullability.MIRROR) {
return baseType; // Keep original nullability
}
// DECLARED_OUTPUT and other cases: always nullable
return typeFactory.createTypeWithNullability(baseType, true);
}

private RelDataType createNullableBigInt() {
return typeFactory.createTypeWithNullability(
typeFactory.createSqlType(SqlTypeName.BIGINT), true);
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
package io.substrait.isthmus;

import org.junit.jupiter.api.Test;

class AnyValueFunctionTest extends PlanTestBase {

AnyValueFunctionTest() {
super(new AutomaticDynamicFunctionMappingConverterProvider());
}

@Test
void simpleAnyValue() throws Exception {
String query = "SELECT any_value(l_orderkey) FROM lineitem";
assertFullRoundTrip(query);
}

@Test
void windowFunctionRowNumber() throws Exception {
String query =
"SELECT l_orderkey, ROW_NUMBER() OVER (PARTITION BY l_suppkey ORDER BY l_orderkey) as rn FROM lineitem";
assertFullRoundTrip(query);
}

@Test
void windowFunctionLag() throws Exception {
String query =
"SELECT l_orderkey, LAG(l_quantity) OVER (PARTITION BY l_suppkey ORDER BY l_orderkey) as prev_qty FROM lineitem";
assertFullRoundTrip(query);
}

@Test
void windowFunctionFirstValue() throws Exception {
String query =
"SELECT l_orderkey, FIRST_VALUE(l_quantity) OVER (PARTITION BY l_suppkey ORDER BY l_orderkey) as first_qty FROM lineitem";
assertFullRoundTrip(query);
}
}
4 changes: 3 additions & 1 deletion isthmus/src/test/java/io/substrait/isthmus/PlanTestBase.java
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,9 @@ protected void assertFullRoundTrip(String sqlQuery, Prepare.CatalogReader catalo
ExtensionCollector extensionCollector = new ExtensionCollector();

// SQL -> Calcite 1
RelRoot calcite1 = SubstraitSqlToCalcite.convertQuery(sqlQuery, catalogReader);
RelRoot calcite1 =
SubstraitSqlToCalcite.convertQuery(
sqlQuery, catalogReader, converterProvider.getSqlOperatorTable());

// Calcite 1 -> Substrait POJO 1
Plan.Root root1 = SubstraitRelVisitor.convert(calcite1, converterProvider);
Expand Down
Loading