Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 13 additions & 4 deletions native/spark-expr/src/comet_scalar_funcs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,11 @@ use crate::math_funcs::abs::abs;
use crate::math_funcs::checked_arithmetic::{checked_add, checked_div, checked_mul, checked_sub};
use crate::math_funcs::modulo_expr::spark_modulo;
use crate::{
spark_ceil, spark_decimal_div, spark_decimal_integral_div, spark_floor, spark_isnan,
spark_lpad, spark_make_decimal, spark_read_side_padding, spark_round, spark_rpad, spark_unhex,
spark_unscaled_value, EvalMode, SparkBitwiseCount, SparkContains, SparkDateDiff,
SparkDateTrunc, SparkMakeDate, SparkSizeFunc, SparkStringSpace,
spark_binary_lpad, spark_binary_rpad, spark_ceil, spark_decimal_div,
spark_decimal_integral_div, spark_floor, spark_isnan, spark_lpad, spark_make_decimal,
spark_read_side_padding, spark_round, spark_rpad, spark_unhex, spark_unscaled_value, EvalMode,
SparkBitwiseCount, SparkContains, SparkDateDiff, SparkDateTrunc, SparkMakeDate,
SparkSizeFunc, SparkStringSpace,
};
use arrow::datatypes::DataType;
use datafusion::common::{DataFusionError, Result as DataFusionResult};
Expand Down Expand Up @@ -112,6 +113,14 @@ pub fn create_comet_physical_fun_with_eval_mode(
let func = Arc::new(spark_read_side_padding);
make_comet_scalar_udf!("read_side_padding", func, without data_type)
}
"binary_lpad" => {
let func = Arc::new(spark_binary_lpad);
make_comet_scalar_udf!("binary_lpad", func, without data_type)
}
"binary_rpad" => {
let func = Arc::new(spark_binary_rpad);
make_comet_scalar_udf!("binary_rpad", func, without data_type)
}
"rpad" => {
let func = Arc::new(spark_rpad);
make_comet_scalar_udf!("rpad", func, without data_type)
Expand Down
114 changes: 114 additions & 0 deletions native/spark-expr/src/static_invoke/binary_pad.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

use arrow::array::builder::BinaryBuilder;
use arrow::array::{Array, ArrayRef, AsArray};
use arrow::datatypes::DataType;
use datafusion::common::{DataFusionError, ScalarValue};
use datafusion::physical_plan::ColumnarValue;
use std::sync::Arc;

/// Spark's ByteArray.lpad: left-pad binary array with cyclic pattern.
pub fn spark_binary_lpad(args: &[ColumnarValue]) -> Result<ColumnarValue, DataFusionError> {
binary_pad_impl(args, true)
}

/// Spark's ByteArray.rpad: right-pad binary array with cyclic pattern.
pub fn spark_binary_rpad(args: &[ColumnarValue]) -> Result<ColumnarValue, DataFusionError> {
binary_pad_impl(args, false)
}

fn binary_pad_impl(
args: &[ColumnarValue],
is_left_pad: bool,
) -> Result<ColumnarValue, DataFusionError> {
match args {
[ColumnarValue::Array(array), ColumnarValue::Scalar(ScalarValue::Int32(Some(len))), ColumnarValue::Scalar(ScalarValue::Binary(Some(pad)))] =>
{
let len = *len;
match array.data_type() {
DataType::Binary => {
let binary_array = array.as_binary::<i32>();
let mut builder = BinaryBuilder::with_capacity(binary_array.len(), 0);

for i in 0..binary_array.len() {
if binary_array.is_null(i) {
builder.append_null();
} else {
let bytes = binary_array.value(i);
let result = pad_bytes(bytes, len as usize, pad, is_left_pad);
builder.append_value(&result);
}
}
Ok(ColumnarValue::Array(Arc::new(builder.finish()) as ArrayRef))
}
other => Err(DataFusionError::Internal(format!(
"Unsupported data type {other:?} for binary_pad",
))),
}
}
other => Err(DataFusionError::Internal(format!(
"Unsupported arguments {other:?} for binary_pad",
))),
}
}

/// Pad bytes to target length using cyclic pad pattern.
/// Matches Spark's ByteArray.lpad/rpad behavior.
fn pad_bytes(bytes: &[u8], len: usize, pad: &[u8], is_left_pad: bool) -> Vec<u8> {
if len == 0 {
return Vec::new();
}

if pad.is_empty() {
// Empty pattern: return first `len` bytes or copy of input
let take = bytes.len().min(len);
return bytes[..take].to_vec();
}

let mut result = vec![0u8; len];
let min_len = bytes.len().min(len);

if is_left_pad {
// Copy input bytes to the right side of result
result[len - min_len..].copy_from_slice(&bytes[..min_len]);
// Fill remaining left side with pad pattern
if bytes.len() < len {
fill_with_pattern(&mut result, 0, len - bytes.len(), pad);
}
} else {
// Copy input bytes to the left side of result
result[..min_len].copy_from_slice(&bytes[..min_len]);
// Fill remaining right side with pad pattern
if bytes.len() < len {
fill_with_pattern(&mut result, bytes.len(), len, pad);
}
}

result
}

/// Fill result[first_pos..beyond_pos] with cyclic pad pattern.
fn fill_with_pattern(result: &mut [u8], first_pos: usize, beyond_pos: usize, pad: &[u8]) {
let mut pos = first_pos;
while pos < beyond_pos {
let remaining = beyond_pos - pos;
let take = pad.len().min(remaining);
result[pos..pos + take].copy_from_slice(&pad[..take]);
pos += take;
}
}
2 changes: 2 additions & 0 deletions native/spark-expr/src/static_invoke/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
// specific language governing permissions and limitations
// under the License.

mod binary_pad;
mod char_varchar_utils;

pub use binary_pad::{spark_binary_lpad, spark_binary_rpad};
pub use char_varchar_utils::{spark_lpad, spark_read_side_padding, spark_rpad};
46 changes: 45 additions & 1 deletion spark/src/main/scala/org/apache/comet/serde/statics.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,11 @@ package org.apache.comet.serde
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke
import org.apache.spark.sql.catalyst.util.CharVarcharCodegenUtils
import org.apache.spark.unsafe.types.ByteArray

import org.apache.comet.CometSparkSessionExtensions.withInfo
import org.apache.comet.serde.ExprOuterClass.Expr
import org.apache.comet.serde.QueryPlanSerde.{exprToProtoInternal, optExprWithInfo, scalarFunctionExprToProtoWithReturnType}

object CometStaticInvoke extends CometExpressionSerde[StaticInvoke] {

Expand All @@ -34,7 +37,9 @@ object CometStaticInvoke extends CometExpressionSerde[StaticInvoke] {
: Map[(String, Class[_]), CometExpressionSerde[StaticInvoke]] =
Map(
("readSidePadding", classOf[CharVarcharCodegenUtils]) -> CometScalarFunction(
"read_side_padding"))
"read_side_padding"),
("lpad", classOf[ByteArray]) -> CometBinaryPad("binary_lpad"),
("rpad", classOf[ByteArray]) -> CometBinaryPad("binary_rpad"))

override def convert(
expr: StaticInvoke,
Expand All @@ -52,3 +57,42 @@ object CometStaticInvoke extends CometExpressionSerde[StaticInvoke] {
}
}
}

/**
* Handler for ByteArray.lpad/rpad StaticInvoke (Spark 3.2+, via BinaryPad). Maps to Comet's
* binary_lpad/binary_rpad UDFs.
*/
private case class CometBinaryPad(funcName: String) extends CometExpressionSerde[StaticInvoke] {

override def convert(
expr: StaticInvoke,
inputs: Seq[Attribute],
binding: Boolean): Option[Expr] = {
val str = expr.arguments(0)
val len = expr.arguments(1)
val pad = expr.arguments(2)
if (str.foldable) {
withInfo(expr, "Scalar values are not supported for the str argument", str)
return None
}
if (!len.foldable) {
withInfo(expr, "Only scalar values are supported for the len argument", len)
return None
}
if (!pad.foldable) {
withInfo(expr, "Only scalar values are supported for the pad argument", pad)
return None
}
val strExpr = exprToProtoInternal(str, inputs, binding)
val lenExpr = exprToProtoInternal(len, inputs, binding)
val padExpr = exprToProtoInternal(pad, inputs, binding)
val optExpr = scalarFunctionExprToProtoWithReturnType(
funcName,
expr.dataType,
false,
strExpr,
lenExpr,
padExpr)
optExprWithInfo(optExpr, expr, expr.arguments: _*)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -136,12 +136,20 @@ class CometStringExpressionSuite extends CometTestBase {
// all arguments are literal, so Spark constant folding will kick in
// and pad function will not be evaluated by Comet
checkSparkAnswerAndOperator(sql)
} else {
// Comet will fall back to Spark because the plan contains a staticinvoke instruction
// which is not supported
} else if (isLiteralStr) {
checkSparkAnswerAndFallbackReason(
sql,
"Scalar values are not supported for the str argument")
} else if (!isLiteralLen) {
checkSparkAnswerAndFallbackReason(
sql,
s"Static invoke expression: $expr is not supported")
"Only scalar values are supported for the len argument")
} else if (!isLiteralPad) {
checkSparkAnswerAndFallbackReason(
sql,
"Only scalar values are supported for the pad argument")
} else {
checkSparkAnswerAndOperator(sql)
}
}
}
Expand Down