Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,53 @@
import java.util.Arrays;
import java.util.Comparator;

import org.apache.spark.QueryContext;
import org.apache.spark.sql.catalyst.util.SQLOrderingUtil;
import org.apache.spark.sql.errors.QueryExecutionErrors;

public class ArrayExpressionUtils {

// ANSI index helpers used by ArrayType expression codegen and eval paths.

/**
* Resolves the user-supplied 1-based {@code element_at} index to a
* 0-based array position. Throws when the absolute index exceeds the
* array length (ANSI out-of-bounds) or when {@code index} is zero
* (always invalid).
*
* @param length the array length
* @param index the 1-based index supplied by the user (positive or negative)
* @param context the query context attached to the error
* @return the resolved 0-based position
*/
public static int resolveArrayIndex(int length, int index, QueryContext context) {
if (length < Math.abs(index)) {
throw QueryExecutionErrors.invalidElementAtIndexError(index, length, context);
}
if (index == 0) {
throw QueryExecutionErrors.invalidIndexOfZeroError(context);
}
return index > 0 ? index - 1 : length + index;
}

/**
* Validates a 0-based {@code arr[idx]} index against the array length
* under ANSI mode. Throws when {@code index} is negative or
* {@code >= length}; otherwise returns {@code index} unchanged so the
* caller can chain into {@code arr.get(idx, dataType)}.
*
* @param length the array length
* @param index the 0-based index supplied by the user
* @param context the query context attached to the error
* @return the validated 0-based position (== {@code index})
*/
public static int checkArrayIndex(int length, int index, QueryContext context) {
if (index < 0 || index >= length) {
throw QueryExecutionErrors.invalidArrayIndexError(index, length, context);
}
return index;
}

// comparator
// Boolean ascending nullable comparator
private static final Comparator<Boolean> booleanComp = (o1, o2) -> {
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -2743,7 +2743,7 @@ case class ElementAt(
case _: ArrayType if failOnError =>
(value, ordinal) => {
val array = value.asInstanceOf[ArrayData]
val idx = ElementAtUtils.resolveArrayIndex(
val idx = ArrayExpressionUtils.resolveArrayIndex(
array.numElements(), ordinal.asInstanceOf[Int], getContextOrNull())
if (arrayElementNullable && array.isNullAt(idx)) null else array.get(idx, dataType)
}
Expand Down Expand Up @@ -2783,7 +2783,7 @@ case class ElementAt(
nullSafeCodeGen(ctx, ev, (eval1, eval2) => {
val index = ctx.freshName("elementAtIndex")
val errorContext = getContextOrNullCode(ctx)
val utils = classOf[ElementAtUtils].getName
val utils = classOf[ArrayExpressionUtils].getName
val assignment = s"${ev.value} = ${CodeGenerator.getValue(eval1, dataType, index)};"
val body = if (arrayElementNullable) {
s"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode}
import org.apache.spark.sql.catalyst.trees.TreePattern.{EXTRACT_VALUE, TreePattern}
import org.apache.spark.sql.catalyst.util.{quoteIdentifier, ArrayData, GenericArrayData, MapData, TypeUtils}
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryErrorsBase, QueryExecutionErrors}
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryErrorsBase}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._

Expand Down Expand Up @@ -356,13 +356,11 @@ case class GetArrayItem(
protected override def nullSafeEval(value: Any, ordinal: Any): Any = {
val baseValue = value.asInstanceOf[ArrayData]
val index = ordinal.asInstanceOf[Number].intValue()
if (index >= baseValue.numElements() || index < 0) {
if (failOnError) {
throw QueryExecutionErrors.invalidArrayIndexError(
index, baseValue.numElements(), getContextOrNull())
} else {
null
}
if (failOnError) {
ArrayExpressionUtils.checkArrayIndex(baseValue.numElements(), index, getContextOrNull())
if (baseValue.isNullAt(index)) null else baseValue.get(index, dataType)
} else if (index >= baseValue.numElements() || index < 0) {
null
} else if (baseValue.isNullAt(index)) {
null
} else {
Expand All @@ -371,36 +369,52 @@ case class GetArrayItem(
}

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
nullSafeCodeGen(ctx, ev, (eval1, eval2) => {
val index = ctx.freshName("index")
val childArrayElementNullable = child.dataType.asInstanceOf[ArrayType].containsNull
val nullCheck = if (childArrayElementNullable) {
s"""else if ($eval1.isNullAt($index)) {
${ev.isNull} = true;
}
"""
} else {
""
}

val indexOutOfBoundBranch = if (failOnError) {
// ANSI (failOnError) and non-ANSI paths generate different codegen.
if (failOnError) {
nullSafeCodeGen(ctx, ev, (eval1, eval2) => {
val index = ctx.freshName("index")
val errorContext = getContextOrNullCode(ctx)
// scalastyle:off line.size.limit
s"throw QueryExecutionErrors.invalidArrayIndexError($index, $eval1.numElements(), $errorContext);"
// scalastyle:on line.size.limit
} else {
s"${ev.isNull} = true;"
}

s"""
final int $index = (int) $eval2;
if ($index >= $eval1.numElements() || $index < 0) {
$indexOutOfBoundBranch
} $nullCheck else {
${ev.value} = ${CodeGenerator.getValue(eval1, dataType, index)};
val utils = classOf[ArrayExpressionUtils].getName
val childArrayElementNullable = child.dataType.asInstanceOf[ArrayType].containsNull
val assignment = s"${ev.value} = ${CodeGenerator.getValue(eval1, dataType, index)};"
val body = if (childArrayElementNullable) {
s"""
|if ($eval1.isNullAt($index)) {
| ${ev.isNull} = true;
|} else {
| $assignment
|}
""".stripMargin
} else {
assignment
}
"""
})
s"""
|int $index = $utils.checkArrayIndex($eval1.numElements(), (int) $eval2, $errorContext);
|$body
""".stripMargin
})
} else {
nullSafeCodeGen(ctx, ev, (eval1, eval2) => {
val index = ctx.freshName("index")
val childArrayElementNullable = child.dataType.asInstanceOf[ArrayType].containsNull
val nullCheck = if (childArrayElementNullable) {
s"""else if ($eval1.isNullAt($index)) {
${ev.isNull} = true;
}
"""
} else {
""
}
s"""
final int $index = (int) $eval2;
if ($index >= $eval1.numElements() || $index < 0) {
${ev.isNull} = true;
} $nullCheck else {
${ev.value} = ${CodeGenerator.getValue(eval1, dataType, index)};
}
"""
})
}
}

override protected def withNewChildrenInternal(
Expand Down