Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/spark_expressions_support.md
Original file line number Diff line number Diff line change
Expand Up @@ -400,7 +400,7 @@
### string_funcs

- [x] ascii
- [ ] base64
- [x] base64
- [x] bit_length
- [x] btrim
- [x] char
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ object QueryPlanSerde extends Logging with CometExprShim {

private val stringExpressions: Map[Class[_ <: Expression], CometExpressionSerde[_]] = Map(
classOf[Ascii] -> CometScalarFunction("ascii"),
classOf[Base64] -> CometBase64,
classOf[BitLength] -> CometScalarFunction("bit_length"),
classOf[Chr] -> CometScalarFunction("char"),
classOf[ConcatWs] -> CometConcatWs,
Expand Down
31 changes: 29 additions & 2 deletions spark/src/main/scala/org/apache/comet/serde/statics.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,13 @@

package org.apache.comet.serde

import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.expressions.{Attribute, Base64, Literal}
import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke
import org.apache.spark.sql.catalyst.util.CharVarcharCodegenUtils
import org.apache.spark.sql.types.BooleanType

import org.apache.comet.CometSparkSessionExtensions.withInfo
import org.apache.comet.serde.QueryPlanSerde.{exprToProtoInternal, optExprWithInfo, scalarFunctionExprToProto}

object CometStaticInvoke extends CometExpressionSerde[StaticInvoke] {

Expand All @@ -34,7 +36,8 @@ object CometStaticInvoke extends CometExpressionSerde[StaticInvoke] {
: Map[(String, Class[_]), CometExpressionSerde[StaticInvoke]] =
Map(
("readSidePadding", classOf[CharVarcharCodegenUtils]) -> CometScalarFunction(
"read_side_padding"))
"read_side_padding"),
("encode", classOf[Base64]) -> CometBase64Encode)

override def convert(
expr: StaticInvoke,
Expand All @@ -52,3 +55,27 @@ object CometStaticInvoke extends CometExpressionSerde[StaticInvoke] {
}
}
}

/**
* Handler for Base64.encode StaticInvoke (Spark 3.5+, where Base64 is RuntimeReplaceable). Maps
* to DataFusion's built-in encode(input, 'base64') function.
*/
private object CometBase64Encode extends CometExpressionSerde[StaticInvoke] {

override def convert(
expr: StaticInvoke,
inputs: Seq[Attribute],
binding: Boolean): Option[ExprOuterClass.Expr] = {
// Check if chunked mode is requested (2nd argument, Spark 3.5+)
expr.arguments match {
case Seq(_, Literal(true, BooleanType)) =>
withInfo(expr, "base64 with chunk encoding is not supported")
return None
case _ => // OK: either no chunkBase64 param (Spark 3.4) or chunkBase64=false
}
val inputExpr = exprToProtoInternal(expr.arguments.head, inputs, binding)
val encodingExpr = exprToProtoInternal(Literal("base64"), inputs, binding)
val optExpr = scalarFunctionExprToProto("encode", inputExpr, encodingExpr)
optExprWithInfo(optExpr, expr, expr.arguments.head)
}
}
19 changes: 18 additions & 1 deletion spark/src/main/scala/org/apache/comet/serde/strings.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ package org.apache.comet.serde

import java.util.Locale

import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Concat, ConcatWs, Expression, If, InitCap, IsNull, Left, Length, Like, Literal, Lower, RegExpReplace, Right, RLike, StringLPad, StringRepeat, StringRPad, StringSplit, Substring, Upper}
import org.apache.spark.sql.catalyst.expressions.{Attribute, Base64, Cast, Concat, ConcatWs, Expression, If, InitCap, IsNull, Left, Length, Like, Literal, Lower, RegExpReplace, Right, RLike, StringLPad, StringRepeat, StringRPad, StringSplit, Substring, Upper}
import org.apache.spark.sql.types.{BinaryType, DataTypes, LongType, StringType}
import org.apache.spark.unsafe.types.UTF8String

Expand All @@ -31,6 +31,23 @@ import org.apache.comet.expressions.{CometCast, CometEvalMode, RegExp}
import org.apache.comet.serde.ExprOuterClass.Expr
import org.apache.comet.serde.QueryPlanSerde.{createBinaryExpr, exprToProtoInternal, optExprWithInfo, scalarFunctionExprToProto, scalarFunctionExprToProtoWithReturnType}

/**
* Handler for Base64 as a direct expression (Spark 3.4 where it is not RuntimeReplaceable). In
* Spark 3.5+, Base64 is RuntimeReplaceable and handled via CometBase64Encode in statics.scala.
*/
object CometBase64 extends CometExpressionSerde[Base64] {

override def convert(
expr: Base64,
inputs: Seq[Attribute],
binding: Boolean): Option[ExprOuterClass.Expr] = {
val inputExpr = exprToProtoInternal(expr.child, inputs, binding)
val encodingExpr = exprToProtoInternal(Literal("base64"), inputs, binding)
val optExpr = scalarFunctionExprToProto("encode", inputExpr, encodingExpr)
optExprWithInfo(optExpr, expr, expr.child)
}
}

object CometStringRepeat extends CometExpressionSerde[StringRepeat] {

override def convert(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,31 @@ class CometStringExpressionSuite extends CometTestBase {
}
}

test("base64") {
withSQLConf("spark.sql.chunkBase64String.enabled" -> "false") {
val data = Seq(
Array[Byte](72, 101, 108, 108, 111), // "Hello"
Array[Byte](83, 112, 97, 114, 107, 32, 83, 81, 76), // "Spark SQL"
Array[Byte](), // empty
null).map(Tuple1(_))
withParquetTable(data, "tbl") {
checkSparkAnswerAndOperator("SELECT base64(_1) FROM tbl")
checkSparkAnswerAndOperator("SELECT base64(NULL) FROM tbl")
}
}
}

test("base64 with chunk encoding falls back") {
withSQLConf("spark.sql.chunkBase64String.enabled" -> "true") {
val data = Seq(Array[Byte](72, 101, 108, 108, 111)).map(Tuple1(_))
withParquetTable(data, "tbl") {
checkSparkAnswerAndFallbackReason(
"SELECT base64(_1) FROM tbl",
"base64 with chunk encoding is not supported")
}
}
}

test("split string basic") {
withSQLConf("spark.comet.expression.StringSplit.allowIncompatible" -> "true") {
withParquetTable((0 until 5).map(i => (s"value$i,test$i", i)), "tbl") {
Expand Down