Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion encodings/alp/src/alp/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -571,6 +571,7 @@ mod tests {
use vortex_array::arrays::PrimitiveArray;
use vortex_array::assert_arrays_eq;
use vortex_array::session::ArraySession;
use vortex_error::VortexExpect;
use vortex_session::VortexSession;

use super::*;
Expand Down Expand Up @@ -775,7 +776,11 @@ mod tests {
for idx in 0..slice_len {
let expected_value = values[slice_start + idx];

let result_valid = result_primitive.validity().is_valid(idx).unwrap();
let result_valid = result_primitive
.validity()
.vortex_expect("result validity should be derivable")
.is_valid(idx)
.unwrap();
assert_eq!(
result_valid,
expected_value.is_some(),
Expand Down
2 changes: 1 addition & 1 deletion encodings/alp/src/alp/compress.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ where
let (exponents, encoded, exceptional_positions, exceptional_values, mut chunk_offsets) =
T::encode(values_slice, exponents);

let encoded_array = PrimitiveArray::new(encoded, values.validity()).into_array();
let encoded_array = PrimitiveArray::new(encoded, values.validity()?).into_array();

let validity = values.validity_mask()?;
// exceptional_positions may contain exceptions at invalid positions (which contain garbage
Expand Down
7 changes: 5 additions & 2 deletions encodings/alp/src/alp/decompress.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ use vortex_array::dtype::DType;
use vortex_array::match_each_unsigned_integer_ptype;
use vortex_array::patches::Patches;
use vortex_buffer::BufferMut;
use vortex_error::VortexExpect;
use vortex_error::VortexResult;

use crate::ALPArray;
Expand Down Expand Up @@ -102,7 +103,9 @@ fn decompress_chunked_core(
patches: &Patches,
dtype: DType,
) -> PrimitiveArray {
let validity = encoded.validity();
let validity = encoded
.validity()
.vortex_expect("ALP validity should be derivable");
let ptype = dtype.as_ptype();
let array_len = encoded.len();
let offset_within_chunk = patches.offset_within_chunk().unwrap_or(0);
Expand Down Expand Up @@ -152,7 +155,7 @@ fn decompress_unchunked_core(
dtype: DType,
ctx: &mut ExecutionCtx,
) -> VortexResult<PrimitiveArray> {
let validity = encoded.validity();
let validity = encoded.validity()?;
let ptype = dtype.as_ptype();

let decoded = match_each_alp_float_ptype!(ptype, |T| {
Expand Down
7 changes: 6 additions & 1 deletion encodings/alp/src/alp_rd/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,12 @@ impl RDEncoder {
}

// Bit-pack down the encoded left-parts array that have been dictionary encoded.
let primitive_left = PrimitiveArray::new(left_parts, array.validity());
let primitive_left = PrimitiveArray::new(
left_parts,
array
.validity()
.vortex_expect("ALP RD validity should be derivable"),
);
// SAFETY: by construction, all values in left_parts can be packed to left_bit_width.
let packed_left = unsafe {
bitpack_encode_unchecked(primitive_left, left_bit_width as _)
Expand Down
12 changes: 7 additions & 5 deletions encodings/bytebool/public-api.lock
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ pub type vortex_bytebool::ByteBool::ArrayData = vortex_bytebool::ByteBoolData

pub type vortex_bytebool::ByteBool::OperationsVTable = vortex_bytebool::ByteBool

pub type vortex_bytebool::ByteBool::ValidityVTable = vortex_array::array::vtable::validity::ValidityVTableFromValidityHelper
pub type vortex_bytebool::ByteBool::ValidityVTable = vortex_bytebool::ByteBool

pub fn vortex_bytebool::ByteBool::array_eq(array: &vortex_bytebool::ByteBoolData, other: &vortex_bytebool::ByteBoolData, precision: vortex_array::hash::Precision) -> bool

Expand Down Expand Up @@ -62,6 +62,10 @@ impl vortex_array::array::vtable::operations::OperationsVTable<vortex_bytebool::

pub fn vortex_bytebool::ByteBool::scalar_at(array: vortex_array::array::view::ArrayView<'_, vortex_bytebool::ByteBool>, index: usize, _ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult<vortex_array::scalar::Scalar>

impl vortex_array::array::vtable::validity::ValidityVTable<vortex_bytebool::ByteBool> for vortex_bytebool::ByteBool

pub fn vortex_bytebool::ByteBool::validity(array: vortex_array::array::view::ArrayView<'_, vortex_bytebool::ByteBool>) -> vortex_error::VortexResult<vortex_array::validity::Validity>

impl vortex_array::arrays::dict::take::TakeExecute for vortex_bytebool::ByteBool

pub fn vortex_bytebool::ByteBool::take(array: vortex_array::array::view::ArrayView<'_, Self>, indices: &vortex_array::array::erased::ArrayRef, ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult<core::option::Option<vortex_array::array::erased::ArrayRef>>
Expand Down Expand Up @@ -96,6 +100,8 @@ pub fn vortex_bytebool::ByteBoolData::new(buffer: vortex_array::buffer::BufferHa

pub fn vortex_bytebool::ByteBoolData::validate(buffer: &vortex_array::buffer::BufferHandle, validity: &vortex_array::validity::Validity, dtype: &vortex_array::dtype::DType, len: usize) -> vortex_error::VortexResult<()>

pub fn vortex_bytebool::ByteBoolData::validity(&self) -> vortex_array::validity::Validity

pub fn vortex_bytebool::ByteBoolData::validity_mask(&self) -> vortex_mask::Mask

impl core::clone::Clone for vortex_bytebool::ByteBoolData
Expand All @@ -114,8 +120,4 @@ impl core::fmt::Debug for vortex_bytebool::ByteBoolData

pub fn vortex_bytebool::ByteBoolData::fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result

impl vortex_array::array::vtable::validity::ValidityHelper for vortex_bytebool::ByteBoolData

pub fn vortex_bytebool::ByteBoolData::validity(&self) -> &vortex_array::validity::Validity

pub type vortex_bytebool::ByteBoolArray = vortex_array::array::typed::Array<vortex_bytebool::ByteBool>
40 changes: 21 additions & 19 deletions encodings/bytebool/src/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,15 @@ use vortex_array::Precision;
use vortex_array::arrays::BoolArray;
use vortex_array::buffer::BufferHandle;
use vortex_array::dtype::DType;
use vortex_array::dtype::Nullability;
use vortex_array::scalar::Scalar;
use vortex_array::serde::ArrayChildren;
use vortex_array::validity::Validity;
use vortex_array::vtable;
use vortex_array::vtable::OperationsVTable;
use vortex_array::vtable::VTable;
use vortex_array::vtable::ValidityHelper;
use vortex_array::vtable::ValidityVTableFromValidityHelper;
use vortex_array::vtable::ValidityVTable;
use vortex_array::vtable::child_to_validity;
use vortex_array::vtable::validity_to_child;
use vortex_buffer::BitBuffer;
use vortex_buffer::ByteBuffer;
Expand All @@ -43,24 +44,25 @@ impl VTable for ByteBool {
type ArrayData = ByteBoolData;

type OperationsVTable = Self;
type ValidityVTable = ValidityVTableFromValidityHelper;
type ValidityVTable = Self;

fn id(&self) -> ArrayId {
Self::ID
}

fn validate(&self, data: &Self::ArrayData, dtype: &DType, len: usize) -> VortexResult<()> {
ByteBoolData::validate(data.buffer(), data.validity(), dtype, len)
let validity = data.validity();
ByteBoolData::validate(data.buffer(), &validity, dtype, len)
}

fn array_hash<H: std::hash::Hasher>(array: &ByteBoolData, state: &mut H, precision: Precision) {
array.buffer.array_hash(state, precision);
array.validity.array_hash(state, precision);
array.validity().array_hash(state, precision);
}

fn array_eq(array: &ByteBoolData, other: &ByteBoolData, precision: Precision) -> bool {
array.buffer.array_eq(&other.buffer, precision)
&& array.validity.array_eq(&other.validity, precision)
&& array.validity().array_eq(&other.validity(), precision)
}

fn nbuffers(_array: ArrayView<'_, Self>) -> usize {
Expand Down Expand Up @@ -132,10 +134,6 @@ impl VTable for ByteBool {
NUM_SLOTS,
slots.len()
);
array.validity = match &slots[VALIDITY_SLOT] {
Some(arr) => Validity::Array(arr.clone()),
None => Validity::from(array.validity.nullability()),
};
array.slots = slots;
Ok(())
}
Expand All @@ -150,7 +148,7 @@ impl VTable for ByteBool {

fn execute(array: Array<Self>, _ctx: &mut ExecutionCtx) -> VortexResult<ExecutionResult> {
let boolean_buffer = BitBuffer::from(array.as_slice());
let validity = array.validity().clone();
let validity = array.validity()?;
Ok(ExecutionResult::done(
BoolArray::new(boolean_buffer, validity).into_array(),
))
Expand All @@ -174,7 +172,7 @@ pub(super) const SLOT_NAMES: [&str; NUM_SLOTS] = ["validity"];
#[derive(Clone, Debug)]
pub struct ByteBoolData {
buffer: BufferHandle,
validity: Validity,
nullability: Nullability,
pub(super) slots: Vec<Option<ArrayRef>>,
}

Expand All @@ -194,15 +192,15 @@ impl ByteBool {
/// Construct a [`ByteBoolArray`] from a `Vec<bool>` and validity.
pub fn from_vec<V: Into<Validity>>(data: Vec<bool>, validity: V) -> ByteBoolArray {
let data = ByteBoolData::from_vec(data, validity);
let dtype = DType::Bool(data.validity.nullability());
let dtype = DType::Bool(data.nullability);
let len = data.len();
unsafe { Array::from_parts_unchecked(ArrayParts::new(ByteBool, dtype, len, data)) }
}

/// Construct a [`ByteBoolArray`] from optional bools.
pub fn from_option_vec(data: Vec<Option<bool>>) -> ByteBoolArray {
let data = ByteBoolData::from(data);
let dtype = DType::Bool(data.validity.nullability());
let dtype = DType::Bool(data.nullability);
let len = data.len();
unsafe { Array::from_parts_unchecked(ArrayParts::new(ByteBool, dtype, len, data)) }
}
Expand Down Expand Up @@ -235,6 +233,10 @@ impl ByteBoolData {
vec![validity_to_child(validity, len)]
}

pub fn validity(&self) -> Validity {
child_to_validity(&self.slots[VALIDITY_SLOT], self.nullability)
}

pub fn new(buffer: BufferHandle, validity: Validity) -> Self {
let length = buffer.len();
if let Some(vlen) = validity.maybe_len()
Expand All @@ -249,7 +251,7 @@ impl ByteBoolData {
let slots = Self::make_slots(&validity, length);
Self {
buffer,
validity,
nullability: validity.nullability(),
slots,
}
}
Expand All @@ -266,7 +268,7 @@ impl ByteBoolData {

/// Returns the validity mask for this array.
pub fn validity_mask(&self) -> Mask {
self.validity.to_mask(self.len())
self.validity().to_mask(self.len())
}

// TODO(ngates): deprecate construction from vec
Expand All @@ -287,9 +289,9 @@ impl ByteBoolData {
}
}

impl ValidityHelper for ByteBoolData {
fn validity(&self) -> &Validity {
&self.validity
impl ValidityVTable<ByteBool> for ByteBool {
fn validity(array: ArrayView<'_, ByteBool>) -> VortexResult<Validity> {
Ok(array.data().validity())
}
}

Expand Down
10 changes: 3 additions & 7 deletions encodings/bytebool/src/compute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,7 @@ impl CastReduce for ByteBool {
// If just changing nullability, we can optimize
if array.dtype().eq_ignore_nullability(dtype) {
let new_validity = array
.validity()
.clone()
.validity()?
.cast_nullability(dtype.nullability(), array.len())?;

return Ok(Some(
Expand All @@ -45,10 +44,7 @@ impl MaskReduce for ByteBool {
Ok(Some(
ByteBool::new(
array.buffer().clone(),
array
.validity()
.clone()
.and(Validity::Array(mask.clone()))?,
array.validity()?.and(Validity::Array(mask.clone()))?,
)
.into_array(),
))
Expand All @@ -65,7 +61,7 @@ impl TakeExecute for ByteBool {
let bools = array.as_slice();

// This handles combining validity from both source array and nullable indices
let validity = array.validity().take(&indices.clone().into_array())?;
let validity = array.validity()?.take(&indices.clone().into_array())?;

let taken_bools = match_each_integer_ptype!(indices.ptype(), |I| {
indices
Expand Down
2 changes: 1 addition & 1 deletion encodings/bytebool/src/slice.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ impl SliceReduce for ByteBool {
Ok(Some(
ByteBool::new(
array.buffer().slice(range.clone()),
array.validity().slice(range)?,
array.validity()?.slice(range)?,
)
.into_array(),
))
Expand Down
7 changes: 6 additions & 1 deletion encodings/datetime-parts/src/canonical.rs
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,12 @@ mod test {
.execute::<PrimitiveArray>(&mut ctx)?;

assert_arrays_eq!(primitive_values, milliseconds);
assert!(primitive_values.validity().mask_eq(&validity, &mut ctx)?);
assert!(
primitive_values
.validity()
.unwrap()
.mask_eq(&validity, &mut ctx)?
);
Ok(())
}
}
14 changes: 11 additions & 3 deletions encodings/datetime-parts/src/compress.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ pub fn split_temporal(array: TemporalArray) -> VortexResult<TemporalParts> {
}

Ok(TemporalParts {
days: PrimitiveArray::new(days, temporal_values.validity()).into_array(),
days: PrimitiveArray::new(days, temporal_values.validity()?).into_array(),
seconds: seconds.into_array(),
subseconds: subseconds.into_array(),
})
Expand Down Expand Up @@ -83,6 +83,7 @@ mod tests {
use vortex_array::extension::datetime::TimeUnit;
use vortex_array::validity::Validity;
use vortex_buffer::buffer;
use vortex_error::VortexExpect;

use crate::TemporalParts;
use crate::split_temporal;
Expand Down Expand Up @@ -114,15 +115,22 @@ mod tests {
assert!(
days.to_primitive()
.validity()
.vortex_expect("days validity should be derivable")
.mask_eq(&validity, &mut ctx)
.unwrap()
);
assert!(matches!(
seconds.to_primitive().validity(),
seconds
.to_primitive()
.validity()
.vortex_expect("seconds validity should be derivable"),
Validity::NonNullable
));
assert!(matches!(
subseconds.to_primitive().validity(),
subseconds
.to_primitive()
.validity()
.vortex_expect("subseconds validity should be derivable"),
Validity::NonNullable
));
}
Expand Down
2 changes: 1 addition & 1 deletion encodings/decimal-byte-parts/src/decimal_byte_parts/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ fn to_canonical_decimal(
.dtype()
.as_decimal_opt()
.vortex_expect("must be a decimal dtype"),
prim.validity(),
prim.validity()?,
)
}
.into_array()
Expand Down
6 changes: 3 additions & 3 deletions encodings/fastlanes/src/bitpacking/array/bitpack_compress.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ pub fn bitpack_encode(
let bitpacked = BitPacked::try_new(
BufferHandle::new_host(packed),
array.ptype(),
array.validity(),
array.validity()?,
patches,
bit_width,
array.len(),
Expand Down Expand Up @@ -103,7 +103,7 @@ pub unsafe fn bitpack_encode_unchecked(
let bitpacked = BitPacked::try_new(
BufferHandle::new_host(packed),
array.ptype(),
array.validity(),
array.validity()?,
None,
bit_width,
array.len(),
Expand Down Expand Up @@ -191,7 +191,7 @@ pub fn gather_patches(
bit_width: u8,
num_exceptions_hint: usize,
) -> VortexResult<Option<Patches>> {
let patch_validity = match parray.validity() {
let patch_validity = match parray.validity()? {
Validity::NonNullable => Validity::NonNullable,
_ => Validity::AllValid,
};
Expand Down
2 changes: 1 addition & 1 deletion encodings/fastlanes/src/bitpacking/compute/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ impl CastReduce for BitPacked {
fn cast(array: ArrayView<'_, Self>, dtype: &DType) -> VortexResult<Option<ArrayRef>> {
if array.dtype().eq_ignore_nullability(dtype) {
let new_validity = array
.validity(array.dtype().nullability())
.validity()?
.cast_nullability(dtype.nullability(), array.len())?;
return Ok(Some(
BitPacked::try_new(
Expand Down
Loading
Loading