Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions native/core/src/execution/jni_api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ use datafusion_spark::function::math::hex::SparkHex;
use datafusion_spark::function::math::width_bucket::SparkWidthBucket;
use datafusion_spark::function::string::char::CharFunc;
use datafusion_spark::function::string::concat::SparkConcat;
use datafusion_spark::function::url::url_decode::UrlDecode;
use futures::poll;
use futures::stream::StreamExt;
use jni::objects::JByteBuffer;
Expand Down Expand Up @@ -400,6 +401,7 @@ fn register_datafusion_spark_function(session_ctx: &SessionContext) {
session_ctx.register_udf(ScalarUDF::new_from_impl(SparkWidthBucket::default()));
session_ctx.register_udf(ScalarUDF::new_from_impl(MapFromEntries::default()));
session_ctx.register_udf(ScalarUDF::new_from_impl(SparkCrc32::default()));
session_ctx.register_udf(ScalarUDF::new_from_impl(UrlDecode::default()));
}

/// Prepares arrow arrays for output.
Expand Down
25 changes: 23 additions & 2 deletions spark/src/main/scala/org/apache/comet/serde/statics.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,13 @@

package org.apache.comet.serde

import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.expressions.{Attribute, UrlCodec}
import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke
import org.apache.spark.sql.catalyst.util.CharVarcharCodegenUtils

import org.apache.comet.CometSparkSessionExtensions.withInfo
import org.apache.comet.serde.ExprOuterClass.Expr
import org.apache.comet.serde.QueryPlanSerde.{exprToProtoInternal, optExprWithInfo, scalarFunctionExprToProtoWithReturnType}

object CometStaticInvoke extends CometExpressionSerde[StaticInvoke] {

Expand All @@ -34,7 +36,8 @@ object CometStaticInvoke extends CometExpressionSerde[StaticInvoke] {
: Map[(String, Class[_]), CometExpressionSerde[StaticInvoke]] =
Map(
("readSidePadding", classOf[CharVarcharCodegenUtils]) -> CometScalarFunction(
"read_side_padding"))
"read_side_padding"),
("decode", UrlCodec.getClass) -> CometUrlDecode)

override def convert(
expr: StaticInvoke,
Expand All @@ -52,3 +55,21 @@ object CometStaticInvoke extends CometExpressionSerde[StaticInvoke] {
}
}
}

/**
* Handler for UrlCodec.decode StaticInvoke (Spark 3.4+). Maps to datafusion-spark's built-in
* url_decode function.
*/
private object CometUrlDecode extends CometExpressionSerde[StaticInvoke] {

override def convert(
expr: StaticInvoke,
inputs: Seq[Attribute],
binding: Boolean): Option[Expr] = {
// StaticInvoke args: [child, Literal("UTF-8")] — only serialize the first
val childExpr = exprToProtoInternal(expr.arguments.head, inputs, binding)
val optExpr =
scalarFunctionExprToProtoWithReturnType("url_decode", expr.dataType, false, childExpr)
optExprWithInfo(optExpr, expr, expr.arguments.head)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -478,4 +478,35 @@ class CometStringExpressionSuite extends CometTestBase {
}
}

test("url_decode") {
val data = Seq(
"https%3A%2F%2Fspark.apache.org", // percent-encoded URL
"hello+world", // plus as space
"%E4%B8%AD%E6%96%87", // multi-byte UTF-8 (Chinese)
"no+encoding+needed", // spaces only
"%e4%b8%ad", // lowercase hex digits
"abc%20def%21%40%23", // mixed encoded/unencoded
"%F0%9F%94%A5", // 4-byte UTF-8 (emoji)
"already+decoded+%2B+literal+plus", // encoded plus sign (%2B)
"").map(Tuple1(_))
withParquetTable(data, "tbl") {
checkSparkAnswerAndOperator("SELECT url_decode(_1) FROM tbl")
}
}

test("url_decode - null handling") {
withParquetTable(
Seq(Some("hello+world"), None, Some("%E4%B8%AD")).map(v => Tuple1(v.orNull)),
"tbl") {
checkSparkAnswerAndOperator("SELECT url_decode(_1) FROM tbl")
}
}

test("url_decode - literals") {
withParquetTable(Seq(Tuple1(1)), "tbl") {
checkSparkAnswerAndOperator("SELECT url_decode('hello%20world') FROM tbl")
checkSparkAnswerAndOperator("SELECT url_decode('%E4%B8%AD%E6%96%87') FROM tbl")
}
}

}