Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
321 changes: 319 additions & 2 deletions datafusion/common/src/scalar/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2000,6 +2000,204 @@ impl ScalarValue {
}
}

#[inline]
fn can_use_direct_add(lhs: &ScalarValue, rhs: &ScalarValue) -> bool {
matches!(
(lhs, rhs),
(ScalarValue::Int8(_), ScalarValue::Int8(_))
| (ScalarValue::Int16(_), ScalarValue::Int16(_))
| (ScalarValue::Int32(_), ScalarValue::Int32(_))
| (ScalarValue::Int64(_), ScalarValue::Int64(_))
| (ScalarValue::UInt8(_), ScalarValue::UInt8(_))
| (ScalarValue::UInt16(_), ScalarValue::UInt16(_))
| (ScalarValue::UInt32(_), ScalarValue::UInt32(_))
| (ScalarValue::UInt64(_), ScalarValue::UInt64(_))
| (ScalarValue::Float16(_), ScalarValue::Float16(_))
| (ScalarValue::Float32(_), ScalarValue::Float32(_))
| (ScalarValue::Float64(_), ScalarValue::Float64(_))
| (
ScalarValue::Decimal32(_, _, _),
ScalarValue::Decimal32(_, _, _)
)
| (
ScalarValue::Decimal64(_, _, _),
ScalarValue::Decimal64(_, _, _)
)
| (
ScalarValue::Decimal128(_, _, _),
ScalarValue::Decimal128(_, _, _),
)
| (
ScalarValue::Decimal256(_, _, _),
ScalarValue::Decimal256(_, _, _),
)
)
}

#[inline]
fn add_optional<T: ArrowNativeTypeOp>(
lhs: &mut Option<T>,
rhs: Option<T>,
checked: bool,
) -> Result<()> {
match rhs {
Some(rhs) => {
if let Some(lhs) = lhs.as_mut() {
*lhs = if checked {
lhs.add_checked(rhs).map_err(|e| arrow_datafusion_err!(e))?
} else {
lhs.add_wrapping(rhs)
};
}
}
None => *lhs = None,
}
Ok(())
}

#[inline]
fn add_decimal_values<T: DecimalType>(
lhs_value: &mut Option<T::Native>,
lhs_precision: &mut u8,
lhs_scale: &mut i8,
rhs_value: Option<T::Native>,
rhs_precision: u8,
rhs_scale: i8,
) -> Result<()>
where
T::Native: ArrowNativeTypeOp,
{
let result_scale = (*lhs_scale).max(rhs_scale);
let result_precision = (result_scale.saturating_add(
(*lhs_precision as i8 - *lhs_scale).max(rhs_precision as i8 - rhs_scale),
) as u8)
.saturating_add(1)
.min(T::MAX_PRECISION);

Self::validate_decimal_or_internal_err::<T>(result_precision, result_scale)?;

let lhs_mul = T::Native::usize_as(10)
.pow_checked((result_scale - *lhs_scale) as u32)
.map_err(|e| arrow_datafusion_err!(e))?;
let rhs_mul = T::Native::usize_as(10)
.pow_checked((result_scale - rhs_scale) as u32)
.map_err(|e| arrow_datafusion_err!(e))?;

let result_value = match (*lhs_value, rhs_value) {
(Some(lhs_value), Some(rhs_value)) => Some(
lhs_value
.mul_checked(lhs_mul)
.and_then(|lhs| {
rhs_value
.mul_checked(rhs_mul)
.and_then(|rhs| lhs.add_checked(rhs))
})
.map_err(|e| arrow_datafusion_err!(e))?,
),
_ => None,
};

*lhs_value = result_value;
*lhs_precision = result_precision;
*lhs_scale = result_scale;

Ok(())
}

#[inline]
fn try_add_in_place_impl(
&mut self,
other: &ScalarValue,
checked: bool,
) -> Result<bool> {
match (self, other) {
(ScalarValue::Int8(lhs), ScalarValue::Int8(rhs)) => {
Self::add_optional(lhs, *rhs, checked)?;
}
(ScalarValue::Int16(lhs), ScalarValue::Int16(rhs)) => {
Self::add_optional(lhs, *rhs, checked)?;
}
(ScalarValue::Int32(lhs), ScalarValue::Int32(rhs)) => {
Self::add_optional(lhs, *rhs, checked)?;
}
(ScalarValue::Int64(lhs), ScalarValue::Int64(rhs)) => {
Self::add_optional(lhs, *rhs, checked)?;
}
(ScalarValue::UInt8(lhs), ScalarValue::UInt8(rhs)) => {
Self::add_optional(lhs, *rhs, checked)?;
}
(ScalarValue::UInt16(lhs), ScalarValue::UInt16(rhs)) => {
Self::add_optional(lhs, *rhs, checked)?;
}
(ScalarValue::UInt32(lhs), ScalarValue::UInt32(rhs)) => {
Self::add_optional(lhs, *rhs, checked)?;
}
(ScalarValue::UInt64(lhs), ScalarValue::UInt64(rhs)) => {
Self::add_optional(lhs, *rhs, checked)?;
}
(ScalarValue::Float16(lhs), ScalarValue::Float16(rhs)) => {
Self::add_optional(lhs, *rhs, checked)?;
}
(ScalarValue::Float32(lhs), ScalarValue::Float32(rhs)) => {
Self::add_optional(lhs, *rhs, checked)?;
}
(ScalarValue::Float64(lhs), ScalarValue::Float64(rhs)) => {
Self::add_optional(lhs, *rhs, checked)?;
}
(
ScalarValue::Decimal32(lhs, p, s),
ScalarValue::Decimal32(rhs, rhs_p, rhs_s),
) => {
Self::add_decimal_values::<Decimal32Type>(
lhs, p, s, *rhs, *rhs_p, *rhs_s,
)?;
}
(
ScalarValue::Decimal64(lhs, p, s),
ScalarValue::Decimal64(rhs, rhs_p, rhs_s),
) => {
Self::add_decimal_values::<Decimal64Type>(
lhs, p, s, *rhs, *rhs_p, *rhs_s,
)?;
}
(
ScalarValue::Decimal128(lhs, p, s),
ScalarValue::Decimal128(rhs, rhs_p, rhs_s),
) => {
Self::add_decimal_values::<Decimal128Type>(
lhs, p, s, *rhs, *rhs_p, *rhs_s,
)?;
}
(
ScalarValue::Decimal256(lhs, p, s),
ScalarValue::Decimal256(rhs, rhs_p, rhs_s),
) => {
Self::add_decimal_values::<Decimal256Type>(
lhs, p, s, *rhs, *rhs_p, *rhs_s,
)?;
}
_ => return Ok(false),
}

Ok(true)
}

#[inline]
pub(crate) fn try_add_wrapping_in_place(
&mut self,
other: &ScalarValue,
) -> Result<bool> {
self.try_add_in_place_impl(other, false)
}

#[inline]
pub(crate) fn try_add_checked_in_place(
&mut self,
other: &ScalarValue,
) -> Result<bool> {
self.try_add_in_place_impl(other, true)
}

/// Calculate arithmetic negation for a scalar value
pub fn arithmetic_negate(&self) -> Result<Self> {
fn neg_checked_with_ctx<T: ArrowNativeTypeOp>(
Expand Down Expand Up @@ -2135,7 +2333,16 @@ impl ScalarValue {
/// NB: operating on `ScalarValue` directly is not efficient, performance sensitive code
/// should operate on Arrays directly, using vectorized array kernels
pub fn add<T: Borrow<ScalarValue>>(&self, other: T) -> Result<ScalarValue> {
let r = add_wrapping(&self.to_scalar()?, &other.borrow().to_scalar()?)?;
let other = other.borrow();
if Self::can_use_direct_add(self, other) {
let mut result = self.clone();
if result.try_add_wrapping_in_place(other)? {
return Ok(result);
}
debug_assert!(false, "fast-path eligibility drifted from implementation");
}

let r = add_wrapping(&self.to_scalar()?, &other.to_scalar()?)?;
Self::try_from_array(r.as_ref(), 0)
}

Expand All @@ -2144,7 +2351,16 @@ impl ScalarValue {
/// NB: operating on `ScalarValue` directly is not efficient, performance sensitive code
/// should operate on Arrays directly, using vectorized array kernels
pub fn add_checked<T: Borrow<ScalarValue>>(&self, other: T) -> Result<ScalarValue> {
let r = add(&self.to_scalar()?, &other.borrow().to_scalar()?)?;
let other = other.borrow();
if Self::can_use_direct_add(self, other) {
let mut result = self.clone();
if result.try_add_checked_in_place(other)? {
return Ok(result);
}
debug_assert!(false, "fast-path eligibility drifted from implementation");
}

let r = add(&self.to_scalar()?, &other.to_scalar()?)?;
Self::try_from_array(r.as_ref(), 0)
}

Expand Down Expand Up @@ -5945,6 +6161,54 @@ mod tests {
Ok(())
}

#[test]
fn scalar_add_trait_null_test() -> Result<()> {
let int_value = ScalarValue::Int32(Some(42));

assert_eq!(
int_value.add(ScalarValue::Int32(None))?,
ScalarValue::Int32(None)
);

Ok(())
}

#[test]
fn scalar_add_trait_wrapping_overflow_test() -> Result<()> {
let int_value = ScalarValue::Int32(Some(i32::MAX));
let one = ScalarValue::Int32(Some(1));

assert_eq!(int_value.add(one)?, ScalarValue::Int32(Some(i32::MIN)));

Ok(())
}

#[test]
fn scalar_add_trait_decimal_scale_test() -> Result<()> {
let decimal = ScalarValue::Decimal128(Some(123), 10, 2);
let decimal_2 = ScalarValue::Decimal128(Some(4), 9, 1);

assert_eq!(
decimal.add(decimal_2)?,
ScalarValue::Decimal128(Some(163), 11, 2)
);

Ok(())
}

#[test]
fn scalar_add_trait_decimal256_scale_test() -> Result<()> {
let decimal = ScalarValue::Decimal256(Some(i256::from(123)), 10, 2);
let decimal_2 = ScalarValue::Decimal256(Some(i256::from(4)), 9, 1);

assert_eq!(
decimal.add(decimal_2)?,
ScalarValue::Decimal256(Some(i256::from(163)), 11, 2)
);

Ok(())
}

#[test]
fn scalar_sub_trait_test() -> Result<()> {
let float_value = ScalarValue::Float64(Some(123.));
Expand Down Expand Up @@ -6044,6 +6308,43 @@ mod tests {
Ok(())
}

#[test]
fn scalar_decimal_add_overflow_test() {
check_scalar_decimal_add_overflow::<Decimal128Type>(
ScalarValue::Decimal128(Some(i128::MAX), DECIMAL128_MAX_PRECISION, 0),
ScalarValue::Decimal128(Some(1), DECIMAL128_MAX_PRECISION, 0),
);
check_scalar_decimal_add_overflow::<Decimal256Type>(
ScalarValue::Decimal256(Some(i256::MAX), DECIMAL256_MAX_PRECISION, 0),
ScalarValue::Decimal256(Some(i256::ONE), DECIMAL256_MAX_PRECISION, 0),
);
}

#[test]
fn scalar_decimal_in_place_add_error_preserves_lhs() {
let mut lhs =
ScalarValue::Decimal128(Some(i128::MAX), DECIMAL128_MAX_PRECISION, 0);
let original = lhs.clone();

let err = lhs
.try_add_checked_in_place(&ScalarValue::Decimal128(
Some(1),
DECIMAL128_MAX_PRECISION,
0,
))
.unwrap_err()
.strip_backtrace();

assert_eq!(
err,
format!(
"Arrow error: Arithmetic overflow: Overflow happened on: {} + 1",
i128::MAX
)
);
assert_eq!(lhs, original);
}

// Verifies that ScalarValue has the same behavior with compute kernel when it overflows.
fn check_scalar_add_overflow<T>(left: ScalarValue, right: ScalarValue)
where
Expand All @@ -6060,6 +6361,22 @@ mod tests {
assert_eq!(scalar_result.is_ok(), arrow_result.is_ok());
}

// Verifies the decimal fast path preserves the same overflow behavior as Arrow kernels.
fn check_scalar_decimal_add_overflow<T>(left: ScalarValue, right: ScalarValue)
where
T: ArrowPrimitiveType,
{
let scalar_result = left.add(&right);

let left_array = left.to_array().expect("Failed to convert to array");
let right_array = right.to_array().expect("Failed to convert to array");
let arrow_left_array = left_array.as_primitive::<T>();
let arrow_right_array = right_array.as_primitive::<T>();
let arrow_result = add_wrapping(arrow_left_array, arrow_right_array);

assert_eq!(scalar_result.is_ok(), arrow_result.is_ok());
}

#[test]
fn test_interval_add_timestamp() -> Result<()> {
let interval = ScalarValue::IntervalMonthDayNano(Some(IntervalMonthDayNano {
Expand Down
9 changes: 4 additions & 5 deletions datafusion/common/src/stats.rs
Original file line number Diff line number Diff line change
Expand Up @@ -671,17 +671,16 @@ impl Statistics {
.collect();

// Accumulate all statistics in a single pass.
// Uses precision_add for sum (avoids the expensive
// ScalarValue::add round-trip through Arrow arrays), and
// Precision::min/max which use cheap PartialOrd comparison.
// Uses precision_add for sum (reuses the lhs accumulator for
// direct numeric addition), and Precision::min/max which use
// cheap PartialOrd comparison.
for stat in items.iter().skip(1) {
for (col_idx, col_stats) in column_statistics.iter_mut().enumerate() {
let item_cs = &stat.column_statistics[col_idx];

col_stats.null_count = col_stats.null_count.add(&item_cs.null_count);
col_stats.byte_size = col_stats.byte_size.add(&item_cs.byte_size);
col_stats.sum_value =
precision_add(&col_stats.sum_value, &item_cs.sum_value);
precision_add(&mut col_stats.sum_value, &item_cs.sum_value);
col_stats.min_value = col_stats.min_value.min(&item_cs.min_value);
col_stats.max_value = col_stats.max_value.max(&item_cs.max_value);
}
Expand Down
Loading
Loading