Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 13 additions & 4 deletions native/spark-expr/src/comet_scalar_funcs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,11 @@ use crate::math_funcs::abs::abs;
use crate::math_funcs::checked_arithmetic::{checked_add, checked_div, checked_mul, checked_sub};
use crate::math_funcs::modulo_expr::spark_modulo;
use crate::{
spark_ceil, spark_decimal_div, spark_decimal_integral_div, spark_floor, spark_isnan,
spark_lpad, spark_make_decimal, spark_read_side_padding, spark_round, spark_rpad, spark_unhex,
spark_unscaled_value, EvalMode, SparkBitwiseCount, SparkContains, SparkDateDiff,
SparkDateTrunc, SparkMakeDate, SparkSizeFunc, SparkStringSpace,
spark_ceil, spark_char_type_write_side_check, spark_decimal_div,
spark_decimal_integral_div, spark_floor, spark_isnan, spark_lpad, spark_make_decimal,
spark_read_side_padding, spark_round, spark_rpad, spark_unhex, spark_unscaled_value,
spark_varchar_type_write_side_check, EvalMode, SparkBitwiseCount, SparkContains,
SparkDateDiff, SparkDateTrunc, SparkMakeDate, SparkSizeFunc, SparkStringSpace,
};
use arrow::datatypes::DataType;
use datafusion::common::{DataFusionError, Result as DataFusionResult};
Expand Down Expand Up @@ -112,6 +113,14 @@ pub fn create_comet_physical_fun_with_eval_mode(
let func = Arc::new(spark_read_side_padding);
make_comet_scalar_udf!("read_side_padding", func, without data_type)
}
"char_type_write_side_check" => {
let func = Arc::new(spark_char_type_write_side_check);
make_comet_scalar_udf!("char_type_write_side_check", func, without data_type)
}
"varchar_type_write_side_check" => {
let func = Arc::new(spark_varchar_type_write_side_check);
make_comet_scalar_udf!("varchar_type_write_side_check", func, without data_type)
}
"rpad" => {
let func = Arc::new(spark_rpad);
make_comet_scalar_udf!("rpad", func, without data_type)
Expand Down
2 changes: 2 additions & 0 deletions native/spark-expr/src/static_invoke/char_varchar_utils/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,7 @@
// under the License.

mod read_side_padding;
mod write_side_check;

pub use read_side_padding::{spark_lpad, spark_read_side_padding, spark_rpad};
pub use write_side_check::{spark_char_type_write_side_check, spark_varchar_type_write_side_check};
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

use arrow::array::builder::GenericStringBuilder;
use arrow::array::cast::as_dictionary_array;
use arrow::array::types::Int32Type;
use arrow::array::{make_array, Array, DictionaryArray};
use arrow::array::{ArrayRef, OffsetSizeTrait};
use arrow::datatypes::DataType;
use datafusion::common::{cast::as_generic_string_array, DataFusionError, ScalarValue};
use datafusion::physical_plan::ColumnarValue;
use std::sync::Arc;

/// Spark's charTypeWriteSideCheck: pad if shorter, trim trailing spaces if longer.
/// Throws if string exceeds limit after trimming.
pub fn spark_char_type_write_side_check(
args: &[ColumnarValue],
) -> Result<ColumnarValue, DataFusionError> {
write_side_check_impl(args, true)
}

/// Spark's varcharTypeWriteSideCheck: return as-is if within limit, trim trailing spaces if longer.
/// Throws if string exceeds limit after trimming.
pub fn spark_varchar_type_write_side_check(
args: &[ColumnarValue],
) -> Result<ColumnarValue, DataFusionError> {
write_side_check_impl(args, false)
}

fn write_side_check_impl(
args: &[ColumnarValue],
pad_if_shorter: bool,
) -> Result<ColumnarValue, DataFusionError> {
match args {
[ColumnarValue::Array(array), ColumnarValue::Scalar(ScalarValue::Int32(Some(limit)))] => {
let limit = *limit as usize;
match array.data_type() {
DataType::Utf8 => {
write_side_check_internal::<i32>(array, limit, pad_if_shorter)
}
DataType::LargeUtf8 => {
write_side_check_internal::<i64>(array, limit, pad_if_shorter)
}
DataType::Dictionary(_, value_type) => {
let dict = as_dictionary_array::<Int32Type>(array);
let col = if value_type.as_ref() == &DataType::Utf8 {
write_side_check_internal::<i32>(dict.values(), limit, pad_if_shorter)?
} else {
write_side_check_internal::<i64>(dict.values(), limit, pad_if_shorter)?
};
let values = col.to_array(0)?;
let result = DictionaryArray::try_new(dict.keys().clone(), values)?;
Ok(ColumnarValue::Array(make_array(result.into())))
}
other => Err(DataFusionError::Internal(format!(
"Unsupported data type {other:?} for write_side_check",
))),
}
}
other => Err(DataFusionError::Internal(format!(
"Unsupported arguments {other:?} for write_side_check",
))),
}
}

fn write_side_check_internal<T: OffsetSizeTrait>(
array: &ArrayRef,
limit: usize,
pad_if_shorter: bool,
) -> Result<ColumnarValue, DataFusionError> {
let string_array = as_generic_string_array::<T>(array)?;

let mut builder =
GenericStringBuilder::<T>::with_capacity(string_array.len(), string_array.len() * limit);
let mut buffer = String::with_capacity(limit);

for string in string_array.iter() {
match string {
Some(s) => {
let char_len = s.chars().count();
if char_len <= limit {
if pad_if_shorter && char_len < limit {
// Pad with spaces to reach limit
buffer.clear();
buffer.push_str(s);
for _ in 0..(limit - char_len) {
buffer.push(' ');
}
builder.append_value(&buffer);
} else {
builder.append_value(s);
}
} else {
// Trim trailing spaces
let trimmed = s.trim_end_matches(' ');
let trimmed_char_len = trimmed.chars().count();
if trimmed_char_len > limit {
return Err(DataFusionError::Execution(format!(
"Exceeds char/varchar type length limitation: {limit}"
)));
}
if pad_if_shorter && trimmed_char_len < limit {
// For CHAR type: pad back to limit after trimming
buffer.clear();
buffer.push_str(trimmed);
for _ in 0..(limit - trimmed_char_len) {
buffer.push(' ');
}
builder.append_value(&buffer);
} else {
builder.append_value(trimmed);
}
}
}
None => builder.append_null(),
}
}
Ok(ColumnarValue::Array(Arc::new(builder.finish())))
}
5 changes: 4 additions & 1 deletion native/spark-expr/src/static_invoke/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,7 @@

mod char_varchar_utils;

pub use char_varchar_utils::{spark_lpad, spark_read_side_padding, spark_rpad};
pub use char_varchar_utils::{
spark_char_type_write_side_check, spark_lpad, spark_read_side_padding, spark_rpad,
spark_varchar_type_write_side_check,
};
6 changes: 5 additions & 1 deletion spark/src/main/scala/org/apache/comet/serde/statics.scala
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,11 @@ object CometStaticInvoke extends CometExpressionSerde[StaticInvoke] {
: Map[(String, Class[_]), CometExpressionSerde[StaticInvoke]] =
Map(
("readSidePadding", classOf[CharVarcharCodegenUtils]) -> CometScalarFunction(
"read_side_padding"))
"read_side_padding"),
("charTypeWriteSideCheck", classOf[CharVarcharCodegenUtils]) -> CometScalarFunction(
"char_type_write_side_check"),
("varcharTypeWriteSideCheck", classOf[CharVarcharCodegenUtils]) -> CometScalarFunction(
"varchar_type_write_side_check"))

override def convert(
expr: StaticInvoke,
Expand Down
26 changes: 26 additions & 0 deletions spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2252,6 +2252,32 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
}
}

test("charTypeWriteSideCheck") {
val table = "test"
withTable(table) {
sql(s"create table $table(col CHAR(5)) using parquet")
sql(s"insert into $table values('ab')")
sql(s"insert into $table values('abcde')")
sql(s"insert into $table values('abc ')") // trailing spaces, equals limit after trim+pad
sql(s"insert into $table values('')")
// Read back — CHAR(5) should pad to 5 characters
checkSparkAnswerAndOperator(s"SELECT col FROM $table")
}
}

test("varcharTypeWriteSideCheck") {
val table = "test"
withTable(table) {
sql(s"create table $table(col VARCHAR(5)) using parquet")
sql(s"insert into $table values('ab')")
sql(s"insert into $table values('abcde')")
sql(s"insert into $table values('abc ')") // trailing spaces within limit
sql(s"insert into $table values('')")
// Read back — VARCHAR(5) should NOT pad
checkSparkAnswerAndOperator(s"SELECT col FROM $table")
}
}

test("isnan") {
Seq("true", "false").foreach { dictionary =>
withSQLConf("parquet.enable.dictionary" -> dictionary) {
Expand Down