Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 28 additions & 1 deletion native/core/src/execution/expressions/temporal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ use datafusion::logical_expr::ScalarUDF;
use datafusion::physical_expr::{PhysicalExpr, ScalarFunctionExpr};
use datafusion_comet_proto::spark_expression::Expr;
use datafusion_comet_spark_expr::{
SparkHour, SparkMinute, SparkSecond, SparkUnixTimestamp, TimestampTruncExpr,
SparkHour, SparkHoursTransform, SparkMinute, SparkSecond, SparkUnixTimestamp,
TimestampTruncExpr,
};

use crate::execution::{
Expand Down Expand Up @@ -160,3 +161,29 @@ impl ExpressionBuilder for TruncTimestampBuilder {
Ok(Arc::new(TimestampTruncExpr::new(child, format, timezone)))
}
}

pub struct HoursTransformBuilder;

impl ExpressionBuilder for HoursTransformBuilder {
fn build(
&self,
spark_expr: &Expr,
input_schema: SchemaRef,
planner: &PhysicalPlanner,
) -> Result<Arc<dyn PhysicalExpr>, ExecutionError> {
let expr = extract_expr!(spark_expr, HoursTransform);
let child = planner.create_expr(expr.child.as_ref().unwrap(), Arc::clone(&input_schema))?;
let args = vec![child];
let comet_hours_transform = Arc::new(ScalarUDF::new_from_impl(SparkHoursTransform::new()));
let field_ref = Arc::new(Field::new("hours_transform", DataType::Int32, true));
let expr: ScalarFunctionExpr = ScalarFunctionExpr::new(
"hours_transform",
comet_hours_transform,
args,
field_ref,
Arc::new(ConfigOptions::default()),
);

Ok(Arc::new(expr))
}
}
6 changes: 6 additions & 0 deletions native/core/src/execution/planner/expression_registry.rs
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ pub enum ExpressionType {
Second,
TruncTimestamp,
UnixTimestamp,
HoursTransform,
}

/// Registry for expression builders
Expand Down Expand Up @@ -310,6 +311,10 @@ impl ExpressionRegistry {
ExpressionType::TruncTimestamp,
Box::new(TruncTimestampBuilder),
);
self.builders.insert(
ExpressionType::HoursTransform,
Box::new(HoursTransformBuilder),
);
}

/// Extract expression type from Spark protobuf expression
Expand Down Expand Up @@ -382,6 +387,7 @@ impl ExpressionRegistry {
Some(ExprStruct::Second(_)) => Ok(ExpressionType::Second),
Some(ExprStruct::TruncTimestamp(_)) => Ok(ExpressionType::TruncTimestamp),
Some(ExprStruct::UnixTimestamp(_)) => Ok(ExpressionType::UnixTimestamp),
Some(ExprStruct::HoursTransform(_)) => Ok(ExpressionType::HoursTransform),

Some(other) => Err(ExecutionError::GeneralError(format!(
"Unsupported expression type: {:?}",
Expand Down
5 changes: 5 additions & 0 deletions native/proto/src/proto/expr.proto
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ message Expr {
UnixTimestamp unix_timestamp = 65;
FromJson from_json = 66;
ToCsv to_csv = 67;
HoursTransform hours_transform = 68;
}

// Optional QueryContext for error reporting (contains SQL text and position)
Expand Down Expand Up @@ -356,6 +357,10 @@ message Hour {
string timezone = 2;
}

message HoursTransform {
Expr child = 1;
}

message Minute {
Expr child = 1;
string timezone = 2;
Expand Down
281 changes: 281 additions & 0 deletions native/spark-expr/src/datetime_funcs/hours.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,281 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

//! Spark-compatible `hours` V2 partition transform.
//!
//! Computes the number of hours since the Unix epoch (1970-01-01 00:00:00 UTC).
//!
//! Both `TimestampType` and `TimestampNTZType` are computationally identical. They
//! extract the absolute hours since the epoch by directly dividing the microsecond
//! value by the number of microseconds in an hour, ignoring session timezone offsets.

use arrow::array::cast::as_primitive_array;
use arrow::array::types::TimestampMicrosecondType;
use arrow::array::{Array, Int32Array};
use arrow::datatypes::{DataType, TimeUnit::Microsecond};
use datafusion::common::{internal_datafusion_err, DataFusionError};
use datafusion::logical_expr::{
ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility,
};
use num::integer::div_floor;
use std::{any::Any, fmt::Debug, sync::Arc};

const MICROS_PER_HOUR: i64 = 3_600_000_000;

#[derive(Debug, PartialEq, Eq, Hash)]
pub struct SparkHoursTransform {
signature: Signature,
}

impl SparkHoursTransform {
pub fn new() -> Self {
Self {
signature: Signature::user_defined(Volatility::Immutable),
}
}
}

impl Default for SparkHoursTransform {
fn default() -> Self {
Self::new()
}
}

impl ScalarUDFImpl for SparkHoursTransform {
fn as_any(&self) -> &dyn Any {
self
}

fn name(&self) -> &str {
"hours_transform"
}

fn signature(&self) -> &Signature {
&self.signature
}

fn return_type(&self, _arg_types: &[DataType]) -> datafusion::common::Result<DataType> {
Ok(DataType::Int32)
}

fn invoke_with_args(
&self,
args: ScalarFunctionArgs,
) -> datafusion::common::Result<ColumnarValue> {
let args: [ColumnarValue; 1] = args.args.try_into().map_err(|_| {
internal_datafusion_err!("hours_transform expects exactly one argument")
})?;

match args {
[ColumnarValue::Array(array)] => {
let result: Int32Array = match array.data_type() {
DataType::Timestamp(Microsecond, _) => {
let ts_array = as_primitive_array::<TimestampMicrosecondType>(&array);
arrow::compute::kernels::arity::unary(ts_array, |micros| {
div_floor(micros, MICROS_PER_HOUR) as i32
})
}
other => {
return Err(DataFusionError::Execution(format!(
"hours_transform does not support input type: {:?}",
other
)));
}
};
Ok(ColumnarValue::Array(Arc::new(result)))
}
_ => Err(DataFusionError::Execution(
"hours_transform(scalar) should be folded on Spark JVM side.".to_string(),
)),
}
}
}

#[cfg(test)]
mod tests {
use super::*;
use arrow::array::TimestampMicrosecondArray;
use arrow::datatypes::Field;
use datafusion::config::ConfigOptions;
use std::sync::Arc;

#[test]
fn test_hours_transform_utc() {
let udf = SparkHoursTransform::new();
// 2023-10-01 14:30:00 UTC = 1696171800 seconds = 1696171800000000 micros
// Expected hours since epoch = 1696171800000000 / 3600000000 = 471158
let ts = TimestampMicrosecondArray::from(vec![Some(1_696_171_800_000_000i64)])
.with_timezone("UTC");
let return_field = Arc::new(Field::new("hours_transform", DataType::Int32, true));
let args = ScalarFunctionArgs {
args: vec![ColumnarValue::Array(Arc::new(ts))],
number_rows: 1,
return_field,
config_options: Arc::new(ConfigOptions::default()),
arg_fields: vec![],
};
let result = udf.invoke_with_args(args).unwrap();
match result {
ColumnarValue::Array(arr) => {
let int_arr = arr.as_any().downcast_ref::<Int32Array>().unwrap();
assert_eq!(int_arr.value(0), 471158);
}
_ => panic!("Expected array"),
}
}

#[test]
fn test_hours_transform_ntz() {
let udf = SparkHoursTransform::new();
// Same timestamp but NTZ (no timezone on array)
let ts = TimestampMicrosecondArray::from(vec![Some(1_696_171_800_000_000i64)]);
let return_field = Arc::new(Field::new("hours_transform", DataType::Int32, true));
let args = ScalarFunctionArgs {
args: vec![ColumnarValue::Array(Arc::new(ts))],
number_rows: 1,
return_field,
config_options: Arc::new(ConfigOptions::default()),
arg_fields: vec![],
};
let result = udf.invoke_with_args(args).unwrap();
match result {
ColumnarValue::Array(arr) => {
let int_arr = arr.as_any().downcast_ref::<Int32Array>().unwrap();
assert_eq!(int_arr.value(0), 471158);
}
_ => panic!("Expected array"),
}
}

#[test]
fn test_hours_transform_negative_epoch() {
let udf = SparkHoursTransform::new();
// 1969-12-31 23:30:00 UTC = -1800 seconds = -1800000000 micros
// Expected: floor_div(-1800000000, 3600000000) = -1
let ts =
TimestampMicrosecondArray::from(vec![Some(-1_800_000_000i64)]).with_timezone("UTC");
let return_field = Arc::new(Field::new("hours_transform", DataType::Int32, true));
let args = ScalarFunctionArgs {
args: vec![ColumnarValue::Array(Arc::new(ts))],
number_rows: 1,
return_field,
config_options: Arc::new(ConfigOptions::default()),
arg_fields: vec![],
};
let result = udf.invoke_with_args(args).unwrap();
match result {
ColumnarValue::Array(arr) => {
let int_arr = arr.as_any().downcast_ref::<Int32Array>().unwrap();
assert_eq!(int_arr.value(0), -1);
}
_ => panic!("Expected array"),
}
}

#[test]
fn test_hours_transform_null() {
let udf = SparkHoursTransform::new();
let ts = TimestampMicrosecondArray::from(vec![None as Option<i64>]).with_timezone("UTC");
let return_field = Arc::new(Field::new("hours_transform", DataType::Int32, true));
let args = ScalarFunctionArgs {
args: vec![ColumnarValue::Array(Arc::new(ts))],
number_rows: 1,
return_field,
config_options: Arc::new(ConfigOptions::default()),
arg_fields: vec![],
};
let result = udf.invoke_with_args(args).unwrap();
match result {
ColumnarValue::Array(arr) => {
let int_arr = arr.as_any().downcast_ref::<Int32Array>().unwrap();
assert!(int_arr.is_null(0));
}
_ => panic!("Expected array"),
}
}

#[test]
fn test_hours_transform_epoch_zero() {
let udf = SparkHoursTransform::new();
let ts = TimestampMicrosecondArray::from(vec![Some(0i64)]).with_timezone("UTC");
let return_field = Arc::new(Field::new("hours_transform", DataType::Int32, true));
let args = ScalarFunctionArgs {
args: vec![ColumnarValue::Array(Arc::new(ts))],
number_rows: 1,
return_field,
config_options: Arc::new(ConfigOptions::default()),
arg_fields: vec![],
};
let result = udf.invoke_with_args(args).unwrap();
match result {
ColumnarValue::Array(arr) => {
let int_arr = arr.as_any().downcast_ref::<Int32Array>().unwrap();
assert_eq!(int_arr.value(0), 0);
}
_ => panic!("Expected array"),
}
}

#[test]
fn test_hours_transform_non_utc_timezone() {
// Spark's Hours partition transform evaluates absolute hours since epoch. Thus, a UTC
// timestamp of 1970-01-01 00:00:00 UTC (micros=0) maps to 0 hours, even if the
// timestamp array itself contains timezone metadata like Asia/Tokyo.
let udf = SparkHoursTransform::new();
let ts = TimestampMicrosecondArray::from(vec![Some(0i64)]).with_timezone("Asia/Tokyo");
let return_field = Arc::new(Field::new("hours_transform", DataType::Int32, true));
let args = ScalarFunctionArgs {
args: vec![ColumnarValue::Array(Arc::new(ts))],
number_rows: 1,
return_field,
config_options: Arc::new(ConfigOptions::default()),
arg_fields: vec![],
};
let result = udf.invoke_with_args(args).unwrap();
match result {
ColumnarValue::Array(arr) => {
let int_arr = arr.as_any().downcast_ref::<Int32Array>().unwrap();
assert_eq!(int_arr.value(0), 0);
}
_ => panic!("Expected array"),
}
}

#[test]
fn test_hours_transform_ntz_ignores_timezone() {
// NTZ with micros=0 always returns 0 because NTZ is pure wall-clock time.
// There is no timezone offset logic applied to either TimestampType or NTZ.
let udf = SparkHoursTransform::new();
let ts = TimestampMicrosecondArray::from(vec![Some(0i64)]); // No timezone on array
let return_field = Arc::new(Field::new("hours_transform", DataType::Int32, true));
let args = ScalarFunctionArgs {
args: vec![ColumnarValue::Array(Arc::new(ts))],
number_rows: 1,
return_field,
config_options: Arc::new(ConfigOptions::default()),
arg_fields: vec![],
};
let result = udf.invoke_with_args(args).unwrap();
match result {
ColumnarValue::Array(arr) => {
let int_arr = arr.as_any().downcast_ref::<Int32Array>().unwrap();
assert_eq!(int_arr.value(0), 0); // NOT 9, because NTZ ignores timezone
}
_ => panic!("Expected array"),
}
}
}
2 changes: 2 additions & 0 deletions native/spark-expr/src/datetime_funcs/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
mod date_diff;
mod date_trunc;
mod extract_date_part;
mod hours;
mod make_date;
mod timestamp_trunc;
mod unix_timestamp;
Expand All @@ -27,6 +28,7 @@ pub use date_trunc::SparkDateTrunc;
pub use extract_date_part::SparkHour;
pub use extract_date_part::SparkMinute;
pub use extract_date_part::SparkSecond;
pub use hours::SparkHoursTransform;
pub use make_date::SparkMakeDate;
pub use timestamp_trunc::TimestampTruncExpr;
pub use unix_timestamp::SparkUnixTimestamp;
4 changes: 2 additions & 2 deletions native/spark-expr/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,8 @@ pub use comet_scalar_funcs::{
};
pub use csv_funcs::*;
pub use datetime_funcs::{
SparkDateDiff, SparkDateTrunc, SparkHour, SparkMakeDate, SparkMinute, SparkSecond,
SparkUnixTimestamp, TimestampTruncExpr,
SparkDateDiff, SparkDateTrunc, SparkHour, SparkHoursTransform, SparkMakeDate, SparkMinute,
SparkSecond, SparkUnixTimestamp, TimestampTruncExpr,
};
pub use error::{decimal_overflow_error, SparkError, SparkErrorWithContext, SparkResult};
pub use hash_funcs::*;
Expand Down
Loading
Loading