Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import java.time.{ZoneId, ZoneOffset}
import java.util.Locale
import java.util.concurrent.TimeUnit._

import org.apache.spark.{QueryContext, SparkArithmeticException, SparkIllegalArgumentException}
import org.apache.spark.{QueryContext, SparkArithmeticException, SparkException, SparkIllegalArgumentException}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch
Expand Down Expand Up @@ -1988,16 +1988,28 @@ case class Cast(
from: DataType,
to: DataType): CastFunction = {
assert(ansiEnabled)
val fromDt = ctx.addReferenceObj("from", from, from.getClass.getName)
val toDt = ctx.addReferenceObj("to", to, to.getClass.getName)
(c, evPrim, _) =>
code"""
if ($c == ($integralType) $c) {
$evPrim = ($integralType) $c;
} else {
throw QueryExecutionErrors.castingCauseOverflowError($c, $fromDt, $toDt);
}
"""
if (integralType == "int") {
// Integral -> Int: call the existing *ExactNumeric.toInt directly. It already does the
// bounds check and throws castingCauseOverflowError -- same as the inline body.
// Only LongType reaches this branch today (`castToIntCode` gates on `case LongType`).
val numericObj = (from match {
case LongType => LongExactNumeric
case _ => throw SparkException.internalError(
s"Unexpected source type $from for castIntegralTypeToIntegralTypeExactCode int branch")
}).getClass.getCanonicalName.stripSuffix("$")
(c, evPrim, _) => code"$evPrim = $numericObj.toInt($c);"
} else {
val fromDt = ctx.addReferenceObj("from", from, from.getClass.getName)
val toDt = ctx.addReferenceObj("to", to, to.getClass.getName)
(c, evPrim, _) =>
code"""
if ($c == ($integralType) $c) {
$evPrim = ($integralType) $c;
} else {
throw QueryExecutionErrors.castingCauseOverflowError($c, $fromDt, $toDt);
}
"""
}
}


Expand All @@ -2017,23 +2029,37 @@ case class Cast(
from: DataType,
to: DataType): CastFunction = {
assert(ansiEnabled)
val (min, max) = lowerAndUpperBound(integralType)
val mathClass = classOf[Math].getName
val fromDt = ctx.addReferenceObj("from", from, from.getClass.getName)
val toDt = ctx.addReferenceObj("to", to, to.getClass.getName)
// When casting floating values to integral types, Spark uses the method `Numeric.toInt`
// Or `Numeric.toLong` directly. For positive floating values, it is equivalent to `Math.floor`;
// for negative floating values, it is equivalent to `Math.ceil`.
// So, we can use the condition `Math.floor(x) <= upperBound && Math.ceil(x) >= lowerBound`
// to check if the floating value x is in the range of an integral type after rounding.
(c, evPrim, _) =>
code"""
if ($mathClass.floor($c) <= $max && $mathClass.ceil($c) >= $min) {
$evPrim = ($integralType) $c;
} else {
throw QueryExecutionErrors.castingCauseOverflowError($c, $fromDt, $toDt);
}
"""
if (integralType == "int" || integralType == "long") {
// Float/Double -> Int/Long: call FloatExactNumeric/DoubleExactNumeric.toInt/toLong
// directly. Each already does the floor/ceil bounds check and throws
// castingCauseOverflowError -- same as the inline body.
val numericObj = (from match {
case FloatType => FloatExactNumeric
case DoubleType => DoubleExactNumeric
case _ => throw SparkException.internalError(
s"Unexpected source type $from for castFractionToIntegralTypeCode")
}).getClass.getCanonicalName.stripSuffix("$")
val method = s"to${integralType.capitalize}"
(c, evPrim, _) => code"$evPrim = $numericObj.$method($c);"
} else {
val (min, max) = lowerAndUpperBound(integralType)
val mathClass = classOf[Math].getName
val fromDt = ctx.addReferenceObj("from", from, from.getClass.getName)
val toDt = ctx.addReferenceObj("to", to, to.getClass.getName)
// When casting floating values to integral types, Spark uses the method `Numeric.toInt`
// Or `Numeric.toLong` directly. For positive floating values, it is equivalent to
// `Math.floor`; for negative floating values, it is equivalent to `Math.ceil`.
// So, we can use the condition `Math.floor(x) <= upperBound && Math.ceil(x) >= lowerBound`
// to check if the floating value x is in the range of an integral type after rounding.
(c, evPrim, _) =>
code"""
if ($mathClass.floor($c) <= $max && $mathClass.ceil($c) >= $min) {
$evPrim = ($integralType) $c;
} else {
throw QueryExecutionErrors.castingCauseOverflowError($c, $fromDt, $toDt);
}
"""
}
}

private[this] def castToByteCode(from: DataType, ctx: CodegenContext): CastFunction = from match {
Expand Down