Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions core/src/main/java/io/substrait/expression/Expression.java
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,13 @@ interface Literal extends Expression {
default boolean nullable() {
return false;
}

/**
* Returns a copy of this literal with the specified nullability.
*
* <p>This method is implemented by all concrete Literal classes via Immutables code generation.
*/
Literal withNullable(boolean nullable);
}

interface Nested extends Expression {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ public SubstraitRelNodeConverter(
this.expressionRexConverter.setRelNodeConverter(this);
}

private static ScalarFunctionConverter createScalarFunctionConverter(
static ScalarFunctionConverter createScalarFunctionConverter(
SimpleExtension.ExtensionCollection extensions,
RelDataTypeFactory typeFactory,
boolean allowDynamicUdfs) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@

import io.substrait.extension.SimpleExtension;
import io.substrait.isthmus.SubstraitRelNodeConverter.Context;
import io.substrait.isthmus.expression.AggregateFunctionConverter;
import io.substrait.isthmus.expression.ScalarFunctionConverter;
import io.substrait.isthmus.expression.WindowFunctionConverter;
import io.substrait.plan.Plan;
import io.substrait.relation.Rel;
import io.substrait.util.EmptyVisitationContext;
Expand Down Expand Up @@ -104,7 +107,20 @@ protected RelBuilder createRelBuilder(CalciteSchema schema) {
* <p>Override this method to customize the {@link SubstraitRelNodeConverter}.
*/
protected SubstraitRelNodeConverter createSubstraitRelNodeConverter(RelBuilder relBuilder) {
return new SubstraitRelNodeConverter(extensions, typeFactory, relBuilder, featureBoard);
ScalarFunctionConverter scalarFunctionConverter =
SubstraitRelNodeConverter.createScalarFunctionConverter(
extensions, typeFactory, featureBoard.allowDynamicUdfs());
AggregateFunctionConverter aggregateFunctionConverter =
new AggregateFunctionConverter(extensions.aggregateFunctions(), typeFactory);
WindowFunctionConverter windowFunctionConverter =
new WindowFunctionConverter(extensions.windowFunctions(), typeFactory);
return new SubstraitRelNodeConverter(
typeFactory,
relBuilder,
scalarFunctionConverter,
aggregateFunctionConverter,
windowFunctionConverter,
typeConverter);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,16 +44,18 @@ public class CallConverters {
* {@link SqlKind#REINTERPRET} is utilized by Isthmus to represent and store {@link
* Expression.UserDefinedLiteral}s within Calcite.
*
* <p>When converting from Substrait to Calcite, the {@link
* Expression.UserDefinedAnyLiteral#value()} is stored within a {@link
* org.apache.calcite.sql.type.SqlTypeName#BINARY} {@link org.apache.calcite.rex.RexLiteral} and
* then re-interpreted to have the correct type.
* <p>When converting from Substrait to Calcite, the user-defined literal value is stored either
* as a {@link org.apache.calcite.sql.type.SqlTypeName#BINARY} {@link
* org.apache.calcite.rex.RexLiteral} (for ANY-encoded values) or a {@link SqlKind#ROW} (for
* struct-encoded values) and then re-interpreted to have the correct user-defined type.
*
* <p>See {@link ExpressionRexConverter#visit(Expression.UserDefinedAnyLiteral,
* SubstraitRelNodeConverter.Context)} and {@link
* ExpressionRexConverter#visit(Expression.UserDefinedStructLiteral,
* SubstraitRelNodeConverter.Context)} for this conversion.
*
* <p>When converting from Calcite to Substrait, this call converter extracts the {@link
* Expression.UserDefinedAnyLiteral} that was stored.
* <p>When converting from Calcite to Substrait, this call converter extracts the stored {@link
* Expression.UserDefinedLiteral}.
*/
public static Function<TypeConverter, SimpleCallConverter> REINTERPRET =
typeConverter ->
Expand Down Expand Up @@ -86,8 +88,24 @@ public class CallConverters {
} catch (com.google.protobuf.InvalidProtocolBufferException e) {
throw new IllegalStateException("Failed to parse UserDefinedAnyLiteral value", e);
}
} else if (operand instanceof Expression.StructLiteral
&& type instanceof Type.UserDefined) {
Expression.StructLiteral structLiteral = (Expression.StructLiteral) operand;
Type.UserDefined t = (Type.UserDefined) type;

return Expression.UserDefinedStructLiteral.builder()
.nullable(t.nullable())
.urn(t.urn())
.name(t.name())
.addAllTypeParameters(t.typeParameters())
.addAllFields(structLiteral.fields())
.build();
}
return null;
throw new IllegalStateException(
"Unexpected REINTERPRET operand type: "
+ operand.getClass().getSimpleName()
+ " with target type: "
+ type.getClass().getSimpleName());
};

// public static SimpleCallConverter OrAnd(FunctionConverter c) {
Expand All @@ -100,6 +118,51 @@ public class CallConverters {
// return null;
// };
// }
/**
* Converts Calcite ROW constructors into Substrait struct literals.
*
* <p>ROW values are always concrete (never null themselves) - if a value is actually null, use
* NullLiteral instead of StructLiteral. Therefore, the resulting StructLiteral always has
* nullable=false. The ROW's type may be nullable (for regular structs) or non-nullable (for UDT
* struct encoding), but the value itself is always concrete.
*
* <p>Field nullability comes from individual field types in the ROW's type definition. When a
* field's type is nullable but the literal operand is not, we update the literal's nullability to
* match.
*/
public static SimpleCallConverter ROW =
(call, visitor) -> {
if (call.getKind() != SqlKind.ROW) {
return null;
}

List<Expression> operands =
call.getOperands().stream().map(visitor).collect(java.util.stream.Collectors.toList());
if (!operands.stream().allMatch(expr -> expr instanceof Expression.Literal)) {
throw new IllegalArgumentException("ROW operands must be literals.");
}

java.util.List<org.apache.calcite.rel.type.RelDataTypeField> fieldTypes =
call.getType().getFieldList();
List<Expression.Literal> literals = new java.util.ArrayList<>();

for (int i = 0; i < operands.size(); i++) {
Expression.Literal lit = (Expression.Literal) operands.get(i);
boolean fieldIsNullable = fieldTypes.get(i).getType().isNullable();

// ROW types are never nullable (struct literals are always concrete values).
// Field nullability comes from individual field types.
if (fieldIsNullable && !lit.nullable()) {
lit = lit.withNullable(true);
}
literals.add(lit);
}

// Struct literals are always concrete values (never null).
// For UDT struct literals, struct-level nullability is in the REINTERPRET target type.
return ExpressionCreator.struct(false, literals);
};

/** */
public static SimpleCallConverter CASE =
(call, visitor) -> {
Expand Down Expand Up @@ -150,6 +213,7 @@ public static List<CallConverter> defaults(TypeConverter typeConverter) {
return ImmutableList.of(
new FieldSelectionConverter(typeConverter),
CallConverters.CASE,
CallConverters.ROW,
CallConverters.CAST.apply(typeConverter),
CallConverters.REINTERPRET.apply(typeConverter),
new SqlArrayValueConstructorCallConverter(typeConverter),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,12 @@ public RexNode visit(Expression.UserDefinedAnyLiteral expr, Context context)
@Override
public RexNode visit(Expression.UserDefinedStructLiteral expr, Context context)
throws RuntimeException {
throw new UnsupportedOperationException(
"UserDefinedStructLiteral representation is not yet supported in Isthmus");
// UserDefinedStructLiteral: Struct is just the ENCODING/REPRESENTATION of a UDT value.
// The ROW is never nullable (it's just encoding). UDT nullability is carried by the
// REINTERPRET target type: REINTERPRET(ROW(...), udt{nullable=true/false}).
RelDataType type = typeConverter.toCalcite(typeFactory, expr.getType());
RexNode structValue = toStructEncoding(expr.fields(), context);
return rexBuilder.makeReinterpretCast(type, structValue, rexBuilder.makeLiteral(false));
}

@Override
Expand Down Expand Up @@ -320,6 +324,14 @@ public RexNode visit(Expression.DecimalLiteral expr, Context context) throws Run
return rexBuilder.makeLiteral(decimal, typeConverter.toCalcite(typeFactory, expr.getType()));
}

@Override
public RexNode visit(Expression.StructLiteral expr, Context context) throws RuntimeException {
List<RexNode> fieldNodes =
expr.fields().stream().map(f -> f.accept(this, context)).collect(Collectors.toList());
RelDataType structType = typeConverter.toCalcite(typeFactory, expr.getType());
return rexBuilder.makeCall(structType, SqlStdOperatorTable.ROW, fieldNodes);
}

@Override
public RexNode visit(Expression.ListLiteral expr, Context context) throws RuntimeException {
List<RexNode> args =
Expand Down Expand Up @@ -723,4 +735,35 @@ public RexNode visit(SetPredicate expr, Context context) throws RuntimeException
"Cannot handle SetPredicate when PredicateOp is %s.", expr.predicateOp().name()));
}
}

/**
* Helper method to create a Calcite ROW expression for encoding UDT struct literals.
*
* <p>Used specifically for {@link Expression.UserDefinedStructLiteral} where the struct is just
* the encoding representation of the UDT value. The ROW is never nullable because it's just the
* encoding - nullability is carried by the REINTERPRET target UDT type.
*
* <p>For regular {@link Expression.StructLiteral}, use the struct's own type via {@code
* expr.getType()} instead.
*/
private RexNode toStructEncoding(List<? extends Expression.Literal> fields, Context context) {
List<RexNode> fieldNodes =
fields.stream().map(f -> f.accept(this, context)).collect(Collectors.toList());

// Note: Field names ("field0", "field1", etc.) are dummy values required by Calcite's ROW
// type. These names are discarded during roundtrip conversion back to Substrait, as Substrait
// struct literals are position-based and only the field values are preserved.
//
// The ROW type is never nullable because it's just encoding for the UDT. Field nullability
// comes from individual field types.
RelDataTypeFactory.Builder rowBuilder = typeFactory.builder();
IntStream.range(0, fields.size())
.forEach(
i -> {
RelDataType fieldType = typeConverter.toCalcite(typeFactory, fields.get(i).getType());
rowBuilder.add("field" + i, fieldType);
});

return rexBuilder.makeCall(rowBuilder.build(), SqlStdOperatorTable.ROW, fieldNodes);
}
}
72 changes: 72 additions & 0 deletions isthmus/src/test/java/io/substrait/isthmus/CalciteLiteralTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import java.nio.charset.StandardCharsets;
import java.time.LocalDate;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.TimeUnit;
import org.apache.calcite.rex.RexLiteral;
Expand Down Expand Up @@ -388,6 +389,77 @@ void tStruct() {
false));
}

@Test
void tStructRoundtripNullableFields() {
// Test regular struct with nullable fields roundtrips correctly
Expression.StructLiteral struct =
ExpressionCreator.struct(
false, ExpressionCreator.i32(true, 4), ExpressionCreator.i32(true, -1));

RexNode rex = struct.accept(expressionRexConverter, Context.newContext());
Expression roundtrip = rex.accept(rexExpressionConverter);

assertEquals(struct, roundtrip);
}

@Test
void tStructRoundtripMixedFieldNullability() {
// Test regular struct with mixed field nullability roundtrips correctly
Expression.StructLiteral struct =
ExpressionCreator.struct(
false, ExpressionCreator.i32(true, 4), ExpressionCreator.i32(false, -1));

RexNode rex = struct.accept(expressionRexConverter, Context.newContext());
Expression roundtrip = rex.accept(rexExpressionConverter);

assertEquals(struct, roundtrip);
}

@Test
void tStructRoundtripWithNullFieldValues() {
// Test struct with actual NULL field values roundtrips correctly
Expression.NullLiteral nullField =
Expression.NullLiteral.builder()
.nullable(true)
.type(io.substrait.type.Type.I32.builder().nullable(true).build())
.build();

Expression.StructLiteral struct =
ExpressionCreator.struct(false, nullField, ExpressionCreator.i32(false, 100));

RexNode rex = struct.accept(expressionRexConverter, Context.newContext());
Expression roundtrip = rex.accept(rexExpressionConverter);

assertEquals(struct, roundtrip);
}

@Test
void tStructRoundtripNested() {
// Test nested regular structs roundtrip correctly
Expression.StructLiteral innerStruct =
ExpressionCreator.struct(
false, ExpressionCreator.i32(false, 1), ExpressionCreator.i32(false, 2));

Expression.StructLiteral outerStruct =
ExpressionCreator.struct(false, innerStruct, ExpressionCreator.i32(false, 3));

RexNode rex = outerStruct.accept(expressionRexConverter, Context.newContext());
Expression roundtrip = rex.accept(rexExpressionConverter);

assertEquals(outerStruct, roundtrip);
}

@Test
void tStructRoundtripEmpty() {
// Test empty struct roundtrips correctly
Expression.StructLiteral struct = ExpressionCreator.struct(false, Collections.emptyList());

RexNode rex = struct.accept(expressionRexConverter, Context.newContext());
Expression roundtrip = rex.accept(rexExpressionConverter);

assertEquals(struct, roundtrip);
}

@Test
void tFixedBinary() {
byte[] val = "my test".getBytes(StandardCharsets.UTF_8);
Expand Down
Loading