Skip to content

Commit 5bb514b

Browse files
committed
[SPARK-56912][SQL] Refactor Cast to boolean codegen under ANSI mode
### What changes were proposed in this pull request? Extend `CastUtils.java` with `stringToBooleanExact(UTF8String, QueryContext)` and use it from `Cast.scala` for the ANSI `String -> Boolean` cast path (both eval and codegen). The non-ANSI path keeps the inline `if/else if/else evNull = true` form because it has no error to throw. ### Why are the changes needed? Part of SPARK-56908 (umbrella). The ANSI String->Boolean cast emits an 8-line `if (isTrueString) … else if (isFalseString) … else throw` block in codegen. This PR collapses it to a one-line `CastUtils .stringToBooleanExact(...)` call. ### Does this PR introduce _any_ user-facing change? No. The compiled behavior is identical; only the emitted Java source text changes. ### How was this patch tested? ``` build/sbt "catalyst/testOnly *CastSuite *CastWithAnsiOnSuite \ *AnsiCastSuite *TryCastSuite" ``` 204/204 pass. ### Was this patch authored or co-authored using generative AI tooling? Generated-by: Cursor 1.x
1 parent 0beac87 commit 5bb514b

2 files changed

Lines changed: 18 additions & 12 deletions

File tree

  • sql/catalyst/src/main

sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/CastUtils.java

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,12 @@
1818
package org.apache.spark.sql.catalyst.expressions;
1919

2020
import org.apache.spark.QueryContext;
21+
import org.apache.spark.sql.catalyst.util.StringUtils;
2122
import org.apache.spark.sql.errors.QueryExecutionErrors;
2223
import org.apache.spark.sql.types.DataType;
2324
import org.apache.spark.sql.types.DataTypes;
2425
import org.apache.spark.sql.types.Decimal;
26+
import org.apache.spark.unsafe.types.UTF8String;
2527

2628
/**
2729
* Static helpers invoked from {@code Cast.doGenCode} so the generated Java
@@ -137,4 +139,12 @@ public static Decimal changePrecisionExact(
137139
public static Decimal changePrecisionOrNull(Decimal d, int precision, int scale) {
138140
return d.changePrecision(precision, scale) ? d : null;
139141
}
142+
143+
// ----- string -> boolean (ANSI: throw on invalid syntax) -----
144+
145+
public static boolean stringToBooleanExact(UTF8String s, QueryContext context) {
146+
if (StringUtils.isTrueString(s)) return true;
147+
if (StringUtils.isFalseString(s)) return false;
148+
throw QueryExecutionErrors.invalidInputSyntaxForBooleanError(s, context);
149+
}
140150
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -695,18 +695,16 @@ case class Cast(
695695

696696
// UDFToBoolean
697697
private[this] def castToBoolean(from: DataType): Any => Any = from match {
698+
case _: StringType if ansiEnabled =>
699+
buildCast[UTF8String](_, s => CastUtils.stringToBooleanExact(s, getContextOrNull()))
698700
case _: StringType =>
699701
buildCast[UTF8String](_, s => {
700702
if (StringUtils.isTrueString(s)) {
701703
true
702704
} else if (StringUtils.isFalseString(s)) {
703705
false
704706
} else {
705-
if (ansiEnabled) {
706-
throw QueryExecutionErrors.invalidInputSyntaxForBooleanError(s, getContextOrNull())
707-
} else {
708-
null
709-
}
707+
null
710708
}
711709
})
712710
case TimestampType =>
@@ -1891,22 +1889,20 @@ case class Cast(
18911889
private[this] def castToBooleanCode(
18921890
from: DataType,
18931891
ctx: CodegenContext): CastFunction = from match {
1892+
case _: StringType if ansiEnabled =>
1893+
val castUtils = classOf[CastUtils].getName
1894+
val errorContext = getContextOrNullCode(ctx)
1895+
(c, evPrim, _) => code"$evPrim = $castUtils.stringToBooleanExact($c, $errorContext);"
18941896
case _: StringType =>
18951897
val stringUtils = inline"${StringUtils.getClass.getName.stripSuffix("$")}"
18961898
(c, evPrim, evNull) =>
1897-
val castFailureCode = if (ansiEnabled) {
1898-
val errorContext = getContextOrNullCode(ctx)
1899-
s"throw QueryExecutionErrors.invalidInputSyntaxForBooleanError($c, $errorContext);"
1900-
} else {
1901-
s"$evNull = true;"
1902-
}
19031899
code"""
19041900
if ($stringUtils.isTrueString($c)) {
19051901
$evPrim = true;
19061902
} else if ($stringUtils.isFalseString($c)) {
19071903
$evPrim = false;
19081904
} else {
1909-
$castFailureCode
1905+
$evNull = true;
19101906
}
19111907
"""
19121908
case TimestampType =>

0 commit comments

Comments
 (0)