Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions native/core/src/execution/jni_api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ use datafusion_spark::function::math::hex::SparkHex;
use datafusion_spark::function::math::width_bucket::SparkWidthBucket;
use datafusion_spark::function::string::char::CharFunc;
use datafusion_spark::function::string::concat::SparkConcat;
use datafusion_spark::function::string::luhn_check::SparkLuhnCheck;
use futures::poll;
use futures::stream::StreamExt;
use jni::objects::JByteBuffer;
Expand Down Expand Up @@ -400,6 +401,7 @@ fn register_datafusion_spark_function(session_ctx: &SessionContext) {
session_ctx.register_udf(ScalarUDF::new_from_impl(SparkWidthBucket::default()));
session_ctx.register_udf(ScalarUDF::new_from_impl(MapFromEntries::default()));
session_ctx.register_udf(ScalarUDF::new_from_impl(SparkCrc32::default()));
session_ctx.register_udf(ScalarUDF::new_from_impl(SparkLuhnCheck::default()));
}

/// Prepares arrow arrays for output.
Expand Down
24 changes: 22 additions & 2 deletions spark/src/main/scala/org/apache/comet/serde/statics.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,13 @@

package org.apache.comet.serde

import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.expressions.{Attribute, ExpressionImplUtils}
import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke
import org.apache.spark.sql.catalyst.util.CharVarcharCodegenUtils
import org.apache.spark.sql.types.BooleanType

import org.apache.comet.CometSparkSessionExtensions.withInfo
import org.apache.comet.serde.QueryPlanSerde.{exprToProtoInternal, optExprWithInfo, scalarFunctionExprToProtoWithReturnType}

object CometStaticInvoke extends CometExpressionSerde[StaticInvoke] {

Expand All @@ -34,7 +36,8 @@ object CometStaticInvoke extends CometExpressionSerde[StaticInvoke] {
: Map[(String, Class[_]), CometExpressionSerde[StaticInvoke]] =
Map(
("readSidePadding", classOf[CharVarcharCodegenUtils]) -> CometScalarFunction(
"read_side_padding"))
"read_side_padding"),
("isLuhnNumber", classOf[ExpressionImplUtils]) -> CometLuhnCheck)

override def convert(
expr: StaticInvoke,
Expand All @@ -52,3 +55,20 @@ object CometStaticInvoke extends CometExpressionSerde[StaticInvoke] {
}
}
}

/**
* Handler for ExpressionImplUtils.isLuhnNumber StaticInvoke (Spark 3.5+). Maps to
* datafusion-spark's built-in luhn_check function.
*/
private object CometLuhnCheck extends CometExpressionSerde[StaticInvoke] {

override def convert(
expr: StaticInvoke,
inputs: Seq[Attribute],
binding: Boolean): Option[ExprOuterClass.Expr] = {
val childExpr = exprToProtoInternal(expr.arguments.head, inputs, binding)
val optExpr =
scalarFunctionExprToProtoWithReturnType("luhn_check", BooleanType, false, childExpr)
optExprWithInfo(optExpr, expr, expr.arguments.head)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,24 @@ class CometStringExpressionSuite extends CometTestBase {
}
}

test("luhn_check") {
val data = Seq(
"79927398710", // invalid (fails Luhn)
"79927398713", // valid Luhn number
"1234567812345670", // valid credit card-like
"0", // valid single digit
"", // empty string
"abc", // non-numeric
null).map(Tuple1(_))
withParquetTable(data, "tbl") {
checkSparkAnswerAndOperator("SELECT luhn_check(_1) FROM tbl")
// literal values
checkSparkAnswerAndOperator("SELECT luhn_check('79927398713') FROM tbl")
// null handling
checkSparkAnswerAndOperator("SELECT luhn_check(NULL) FROM tbl")
}
}

test("split string basic") {
withSQLConf("spark.comet.expression.StringSplit.allowIncompatible" -> "true") {
withParquetTable((0 until 5).map(i => (s"value$i,test$i", i)), "tbl") {
Expand Down