Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 58 additions & 0 deletions vortex-array/benches/cast_primitive.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright the Vortex contributors

use std::sync::Arc;

use arrow_array::UInt32Array;
use arrow_buffer::NullBuffer;
use arrow_cast::CastOptions;
use arrow_schema::DataType as ArrowDataType;
use divan::Bencher;
use rand::prelude::*;
use vortex_array::Canonical;
Expand All @@ -13,6 +19,9 @@ use vortex_array::dtype::DType;
use vortex_array::dtype::Nullability;
use vortex_array::dtype::PType;
use vortex_array::expr::stats::Stat;
use vortex_array::validity::Validity;
use vortex_buffer::BitBuffer;
use vortex_buffer::Buffer;

fn main() {
divan::main();
Expand Down Expand Up @@ -46,3 +55,52 @@ fn cast_u16_to_u32(bencher: Bencher) {
.execute::<Canonical>(&mut LEGACY_SESSION.create_execution_ctx())
});
}

// Slow-path inputs: u32 -> u8 with mixed validity, all values in-range (so the cast succeeds),
// no precomputed min/max stats — forces `cast_values` and the `Mask::Values` arm.
fn slow_path_inputs() -> (Vec<u32>, BitBuffer) {
let mut rng = StdRng::seed_from_u64(42);
let values: Vec<u32> = (0..N).map(|_| rng.random_range(0..=200u32)).collect();
let validity: BitBuffer = (0..N).map(|_| rng.random_bool(0.7)).collect();
(values, validity)
}

#[divan::bench]
fn cast_u32_u8_vortex(bencher: Bencher) {
let (values, validity) = slow_path_inputs();
let arr = PrimitiveArray::new(Buffer::from(values), Validity::from(validity)).into_array();
bencher.with_inputs(|| arr.clone()).bench_refs(|a| {
#[expect(clippy::unwrap_used)]
a.cast(DType::Primitive(PType::U8, Nullability::Nullable))
.unwrap()
.execute::<Canonical>(&mut LEGACY_SESSION.create_execution_ctx())
});
}

#[divan::bench]
fn cast_u32_u8_arrow(bencher: Bencher) {
let (values, validity) = slow_path_inputs();
let nulls = NullBuffer::from(validity.iter().collect::<Vec<_>>());
let arr: Arc<UInt32Array> = Arc::new(UInt32Array::new(values.into(), Some(nulls)));
let opts = CastOptions { safe: false, ..Default::default() };
bencher.with_inputs(|| Arc::clone(&arr)).bench_refs(|a| {
#[expect(clippy::unwrap_used)]
arrow_cast::cast_with_options(a.as_ref(), &ArrowDataType::UInt8, &opts).unwrap()
});
}

// Pure scalar baseline: no validity mask at all, checked cast on every element. Bails on
// the first overflow (which never happens for our in-range inputs).
#[divan::bench]
fn cast_u32_u8_checked_no_validity(bencher: Bencher) {
let (values, _) = slow_path_inputs();
bencher.with_inputs(|| values.clone()).bench_refs(|vs| {
let mut out = Vec::with_capacity(vs.len());
for &v in vs.iter() {
#[expect(clippy::expect_used)]
out.push(u8::try_from(v).expect("in-range"));
}
out
});
}

129 changes: 95 additions & 34 deletions vortex-array/src/arrays/primitive/compute/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@

use num_traits::AsPrimitive;
use num_traits::NumCast;
use vortex_buffer::BitBuffer;
use vortex_buffer::Buffer;
use vortex_buffer::BufferMut;
use vortex_error::VortexResult;
use vortex_error::vortex_bail;
use vortex_error::vortex_err;
use vortex_mask::Mask;

use crate::ArrayRef;
Expand Down Expand Up @@ -102,9 +102,11 @@ impl CastKernel for Primitive {
}
}

/// Cast values from `F` to `T`. For infallible casts this is a pure pass; for fallible casts
/// each valid value goes through a checked `NumCast::from` and the kernel bails if any of them
/// overflow `T`. Invalid positions use the wrapping `as` cast since their values are masked out.
/// Cast values from `F` to `T`. For infallible casts this is a pure pass. For fallible casts
/// where cached stats can't prove fit, the hot loop is unconditional `as_()` + a parallel range
/// check whose results OR-reduce into a single `fail_acc` word — one pass, no `?` in the inner
/// body, fully SIMD-vectorizable. If `fail_acc` is set, a cold scalar pass walks the array to
/// attribute the failure to a specific index for a precise error message.
fn cast_values<F, T>(
array: ArrayView<'_, Primitive>,
new_validity: Validity,
Expand All @@ -116,43 +118,102 @@ where
{
let values = array.as_slice::<F>();

// Fast path: statically infallible, or cached min/max prove every valid value fits in `T`.
// The cached check never triggers a stats computation — if the bounds aren't already known
// we fall through to the per-lane loop below.
if values_always_fit(F::PTYPE, T::PTYPE) || values_fit_in(array, T::PTYPE, ctx, false) {
return Ok(PrimitiveArray::new(cast::<F, T>(values), new_validity).into_array());
}

// TODO(joe): if the values source and target have the same bit-width we can
// mutate in place.

// Fallible: invalid lanes are pre-multiplied to zero so the checked cast always succeeds for
// them; valid lanes go through `NumCast::from` and the whole cast bails on the first overflow.
let mask = array.validity()?.execute_mask(array.len(), ctx)?;
let overflow = || {
vortex_err!(
let mut buffer = BufferMut::<T>::zeroed(values.len());
let out = buffer.as_mut_slice();
let mut fail_acc: u32 = 0;

match &mask {
Mask::AllFalse(_) => {
// No valid lanes — buffer is already zeroed.
}
Mask::AllTrue(_) => {
for (i, &v) in values.iter().enumerate() {
out[i] = v.as_();
fail_acc |= <T as NumCast>::from(v).is_none() as u32;
}
}
Mask::Values(m) => {
fail_acc = fallible_cast_with_validity::<F, T>(values, m.bit_buffer(), out);
}
}

if fail_acc != 0 {
// Cold scalar fallback: identify the failing index for a precise error.
for (idx, (&v, valid)) in values.iter().zip(mask_iter(&mask, values.len())).enumerate() {
if valid && <T as NumCast>::from(v).is_none() {
vortex_bail!(
Compute: "Cannot cast {} to {} — value at index {} exceeds target range",
F::PTYPE, T::PTYPE, idx,
);
}
}
// Should be unreachable, but emit a generic error if the hot/cold paths disagree.
vortex_bail!(
Compute: "Cannot cast {} to {} — value exceeds target range",
F::PTYPE, T::PTYPE,
)
};
let buffer: Buffer<T> = match &mask {
Mask::AllTrue(_) => BufferMut::try_from_trusted_len_iter(
values
.iter()
.map(|&v| <T as NumCast>::from(v).ok_or_else(overflow)),
)?
.freeze(),
Mask::AllFalse(_) => BufferMut::<T>::zeroed(values.len()).freeze(),
Mask::Values(m) => BufferMut::try_from_trusted_len_iter(
values.iter().zip(m.bit_buffer().iter()).map(|(&v, valid)| {
let factor = if valid { F::one() } else { F::zero() };
<T as NumCast>::from(v * factor).ok_or_else(overflow)
}),
)?
.freeze(),
};

Ok(PrimitiveArray::new(buffer, new_validity).into_array())
);
}

Ok(PrimitiveArray::new(buffer.freeze(), new_validity).into_array())
}

/// Unconditional `as_()` cast of every lane in `values` into `out`, with a SIMD-reducible
/// overflow detector that returns a nonzero failure word iff any valid lane would overflow `T`.
/// Walks validity in 64-lane blocks (`from_fn` lane-mask + uniform inner body, fully unrollable)
/// and bails at the block boundary on the first failure — branch is outside the SIMD region.
#[inline]
fn fallible_cast_with_validity<F, T>(
values: &[F],
bit_buffer: &BitBuffer,
out: &mut [T],
) -> u32
where
F: NativePType + AsPrimitive<T>,
T: NativePType,
{
debug_assert_eq!(values.len(), bit_buffer.len());
debug_assert_eq!(values.len(), out.len());
let bit_chunks = bit_buffer.chunks();
let mut fail_acc: u32 = 0;
let mut idx = 0usize;
for word in bit_chunks.iter() {
let valid: [bool; 64] = std::array::from_fn(|i| (word >> i) & 1 != 0);
for i in 0..64 {
let v = values[idx + i];
out[idx + i] = v.as_();
// Mask invalid lanes to F::zero (always fits any T) so they don't pollute fail_acc.
let v_for_check = if valid[i] { v } else { F::zero() };
fail_acc |= <T as NumCast>::from(v_for_check).is_none() as u32;
}
idx += 64;
if fail_acc != 0 {
return fail_acc;
}
}
let rem = bit_chunks.remainder_bits();
for b in 0..bit_chunks.remainder_len() {
let v = values[idx + b];
out[idx + b] = v.as_();
let valid = (rem >> b) & 1 != 0;
let v_for_check = if valid { v } else { F::zero() };
fail_acc |= <T as NumCast>::from(v_for_check).is_none() as u32;
}
fail_acc
}

/// Cold-path iterator over a `Mask` as a sequence of `bool`s. Only used after `fail_acc != 0`
/// to attribute the failure to a specific index.
fn mask_iter<'a>(mask: &'a Mask, len: usize) -> Box<dyn Iterator<Item = bool> + 'a> {
match mask {
Mask::AllTrue(_) => Box::new(std::iter::repeat_n(true, len)),
Mask::AllFalse(_) => Box::new(std::iter::repeat_n(false, len)),
Mask::Values(m) => Box::new(m.bit_buffer().iter()),
}
}

/// Out-of-range values at invalid positions are truncated/wrapped by `as`, which is fine because
Expand Down
Loading