Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions core/src/main/java/io/substrait/expression/EnumArg.java
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,14 @@
*/
@Value.Immutable
public interface EnumArg extends FunctionArg {
/** Constant representing an unspecified enum argument with no value. */
EnumArg UNSPECIFIED_ENUM_ARG = builder().value(Optional.empty()).build();

/**
* Returns the enum option value.
*
* @return the option value, if present
*/
Optional<String> value();

@Override
Expand All @@ -25,6 +31,15 @@ default <R, C extends VisitationContext, E extends Throwable> R accept(
return fnArgVisitor.visitEnumArg(fnDef, argIdx, this, context);
}

/**
* Creates an EnumArg with the specified option value, validating it against the enum argument
* definition.
*
* @param enumArg the enum argument definition
* @param option the option value to use
* @return a new EnumArg instance
* @throws IllegalArgumentException if the option is not valid for the enum argument
*/
static EnumArg of(SimpleExtension.EnumArgument enumArg, String option) {
if (!enumArg.options().contains(option)) {
throw new IllegalArgumentException(
Expand All @@ -33,10 +48,21 @@ static EnumArg of(SimpleExtension.EnumArgument enumArg, String option) {
return builder().value(Optional.of(option)).build();
}

/**
* Creates an EnumArg with the specified value without validation.
*
* @param value the enum value
* @return a new EnumArg instance
*/
static EnumArg of(String value) {
return builder().value(Optional.of(value)).build();
}

/**
* Creates a new builder for constructing an EnumArg.
*
* @return a new builder instance
*/
static ImmutableEnumArg.Builder builder() {
return ImmutableEnumArg.builder();
}
Expand Down
13 changes: 13 additions & 0 deletions core/src/main/java/io/substrait/expression/FunctionArg.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,19 @@
*/
public interface FunctionArg {

/**
* Accepts a visitor for this function argument.
*
* @param <R> the return type
* @param <C> the visitation context type
* @param <E> the exception type that may be thrown
* @param fnDef the function definition
* @param argIdx the argument index
* @param fnArgVisitor the visitor
* @param context the visitation context
* @return the result of the visit
* @throws E if the visit fails
*/
<R, C extends VisitationContext, E extends Throwable> R accept(
SimpleExtension.Function fnDef, int argIdx, FuncArgVisitor<R, C, E> fnArgVisitor, C context)
throws E;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,23 @@

import java.util.Map;

/**
* Abstract base class for {@link ExtensionLookup} implementations that use maps to resolve
* extension references to their corresponding function and type anchors.
*/
public abstract class AbstractExtensionLookup implements ExtensionLookup {
/** Map of function reference IDs to their corresponding function anchors. */
protected final Map<Integer, SimpleExtension.FunctionAnchor> functionAnchorMap;

/** Map of type reference IDs to their corresponding type anchors. */
protected final Map<Integer, SimpleExtension.TypeAnchor> typeAnchorMap;

/**
* Constructs an AbstractExtensionLookup with the provided anchor maps.
*
* @param functionAnchorMap map of function reference IDs to function anchors
* @param typeAnchorMap map of type reference IDs to type anchors
*/
public AbstractExtensionLookup(
Map<Integer, SimpleExtension.FunctionAnchor> functionAnchorMap,
Map<Integer, SimpleExtension.TypeAnchor> typeAnchorMap) {
Expand Down
28 changes: 28 additions & 0 deletions core/src/main/java/io/substrait/extension/ExtensionLookup.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,42 @@
* functions or types.
*/
public interface ExtensionLookup {
/**
* Resolves a scalar function reference to its corresponding function variant.
*
* @param reference the function reference ID
* @param extensions the extension collection to search
* @return the scalar function variant
*/
SimpleExtension.ScalarFunctionVariant getScalarFunction(
int reference, SimpleExtension.ExtensionCollection extensions);

/**
* Resolves a window function reference to its corresponding function variant.
*
* @param reference the function reference ID
* @param extensions the extension collection to search
* @return the window function variant
*/
SimpleExtension.WindowFunctionVariant getWindowFunction(
int reference, SimpleExtension.ExtensionCollection extensions);

/**
* Resolves an aggregate function reference to its corresponding function variant.
*
* @param reference the function reference ID
* @param extensions the extension collection to search
* @return the aggregate function variant
*/
SimpleExtension.AggregateFunctionVariant getAggregateFunction(
int reference, SimpleExtension.ExtensionCollection extensions);

/**
* Resolves a type reference to its corresponding type.
*
* @param reference the type reference ID
* @param extensions the extension collection to search
* @return the type
*/
SimpleExtension.Type getType(int reference, SimpleExtension.ExtensionCollection extensions);
}
20 changes: 20 additions & 0 deletions core/src/main/java/io/substrait/relation/AbstractReadRel.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,32 @@
import io.substrait.type.Type;
import java.util.Optional;

/**
* Abstract base class for read relations that scan data from various sources. Provides common
* functionality for schema definition and filtering.
*/
public abstract class AbstractReadRel extends ZeroInputRel implements HasExtension {

/**
* Returns the initial schema of the data being read.
*
* @return the named struct defining the schema
*/
public abstract NamedStruct getInitialSchema();

/**
* Returns an optional filter expression that must be applied during the read.
*
* @return the filter expression, if present
*/
public abstract Optional<Expression> getFilter();

/**
* Returns an optional best-effort filter to apply during the read. If the source doesn't support
* all operations, this filter may not be applied.
*
* @return the best-effort filter expression, if present
*/
public abstract Optional<Expression> getBestEffortFilter();

// TODO:
Expand Down
9 changes: 9 additions & 0 deletions core/src/main/java/io/substrait/relation/AbstractRel.java
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@
import io.substrait.util.Util;
import java.util.function.Supplier;

/**
* Abstract base class for relations that provides common functionality for deriving and caching
* record types with optional remapping.
*/
public abstract class AbstractRel implements Rel {

private Supplier<Type.Struct> recordType =
Expand All @@ -13,6 +17,11 @@ public abstract class AbstractRel implements Rel {
return getRemap().map(r -> r.remap(s)).orElse(s);
});

/**
* Derives the record type for this relation before any remapping is applied.
*
* @return the derived record type
*/
protected abstract Type.Struct deriveRecordType();

@Override
Expand Down
15 changes: 15 additions & 0 deletions core/src/main/java/io/substrait/relation/AbstractRelVisitor.java
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,23 @@
import io.substrait.relation.physical.SingleBucketExchange;
import io.substrait.util.VisitationContext;

/**
* Abstract base class for relation visitors that provides default implementations delegating all
* visit methods to a fallback method.
*
* @param <O> the return type of visit methods
* @param <C> the visitation context type
* @param <E> the exception type that may be thrown
*/
public abstract class AbstractRelVisitor<O, C extends VisitationContext, E extends Exception>
implements RelVisitor<O, C, E> {
/**
* Fallback method called by default implementations of all visit methods.
*
* @param rel the relation to visit
* @param context the visitation context
* @return the result of the visit
*/
public abstract O visitFallback(Rel rel, C context);

@Override
Expand Down
46 changes: 46 additions & 0 deletions core/src/main/java/io/substrait/relation/Aggregate.java
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,25 @@
import java.util.stream.Stream;
import org.immutables.value.Value;

/**
* Represents an aggregate relation that groups input rows and computes aggregate functions.
* Supports multiple grouping sets and measures.
*/
@Value.Immutable
public abstract class Aggregate extends SingleInputRel implements HasExtension {

/**
* Returns the list of grouping sets for this aggregate.
*
* @return list of grouping sets
*/
public abstract List<Grouping> getGroupings();

/**
* Returns the list of aggregate measures to compute.
*
* @return list of measures
*/
public abstract List<Measure> getMeasures();

@Override
Expand Down Expand Up @@ -71,26 +85,58 @@ public <O, C extends VisitationContext, E extends Exception> O accept(
return visitor.visit(this, context);
}

/** Represents a grouping set - a set of expressions to group by. */
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

... by which to group :)

@Value.Immutable
public abstract static class Grouping {
/**
* Returns the list of expressions in this grouping set.
*
* @return list of grouping expressions
*/
public abstract List<Expression> getExpressions();

/**
* Creates a new builder for constructing a Grouping.
*
* @return a new builder instance
*/
public static ImmutableGrouping.Builder builder() {
return ImmutableGrouping.builder();
}
}

/** Represents an aggregate measure - an aggregate function to compute. */
@Value.Immutable
public abstract static class Measure {
/**
* Returns the aggregate function invocation for this measure.
*
* @return the aggregate function
*/
public abstract AggregateFunctionInvocation getFunction();

/**
* Returns an optional filter to apply before computing the aggregate.
*
* @return the pre-measure filter, if present
*/
public abstract Optional<Expression> getPreMeasureFilter();

/**
* Creates a new builder for constructing a Measure.
*
* @return a new builder instance
*/
public static ImmutableMeasure.Builder builder() {
return ImmutableMeasure.builder();
}
}

/**
* Creates a new builder for constructing an Aggregate relation.
*
* @return a new builder instance
*/
public static ImmutableAggregate.Builder builder() {
return ImmutableAggregate.builder();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,23 @@ public class AggregateFunctionProtoConverter {
private final TypeProtoConverter typeProtoConverter;
private final ExtensionCollector functionCollector;

/**
* Constructs a converter with the specified extension collector.
*
* @param functionCollector the extension collector for tracking function references
*/
public AggregateFunctionProtoConverter(ExtensionCollector functionCollector) {
this.functionCollector = functionCollector;
this.exprProtoConverter = new ExpressionProtoConverter(functionCollector, null);
this.typeProtoConverter = new TypeProtoConverter(functionCollector);
}

/**
* Converts an aggregate measure to its protobuf representation.
*
* @param measure the aggregate measure to convert
* @return the protobuf aggregate function
*/
public AggregateFunction toProto(Aggregate.Measure measure) {
FunctionArg.FuncArgVisitor<FunctionArgument, EmptyVisitationContext, RuntimeException>
argVisitor = FunctionArg.toProto(typeProtoConverter, exprProtoConverter);
Expand Down
11 changes: 11 additions & 0 deletions core/src/main/java/io/substrait/relation/BiRel.java
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,21 @@
import java.util.Arrays;
import java.util.List;

/** Abstract base class for binary relations that have exactly two input relations. */
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you have a binary relation without two inputs :-)

public abstract class BiRel extends AbstractRel {

/**
* Returns the left input relation.
*
* @return the left input
*/
public abstract Rel getLeft();

/**
* Returns the right input relation.
*
* @return the right input
*/
public abstract Rel getRight();

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,13 @@ public static <T> Optional<T> or(Optional<T> left, Supplier<? extends Optional<T
}
}

/**
* Functional interface for transforming values during copy-on-write operations.
*
* @param <T> the type of value to transform
* @param <C> the visitation context type
* @param <E> the exception type that may be thrown
*/
@FunctionalInterface
public interface TransformFunction<T, C extends VisitationContext, E extends Exception> {

Expand Down
9 changes: 9 additions & 0 deletions core/src/main/java/io/substrait/relation/Cross.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@
import java.util.stream.Stream;
import org.immutables.value.Value;

/**
* Represents a cross product (Cartesian product) relation that combines all rows from the left
* input with all rows from the right input.
*/
@Value.Immutable
public abstract class Cross extends BiRel implements HasExtension {

Expand All @@ -23,6 +27,11 @@ public <O, C extends VisitationContext, E extends Exception> O accept(
return visitor.visit(this, context);
}

/**
* Creates a new builder for constructing a Cross relation.
*
* @return a new builder instance
*/
public static ImmutableCross.Builder builder() {
return ImmutableCross.builder();
}
Expand Down
Loading
Loading