Skip to content

Commit 77a4288

Browse files
committed
fix validity handling
Signed-off-by: Connor Tsui <connor.tsui20@gmail.com>
1 parent 849d6a5 commit 77a4288

1 file changed

Lines changed: 14 additions & 7 deletions

File tree

vortex-tensor/src/encodings/turboquant/compute/cosine_similarity.rs

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@ use vortex_array::IntoArray;
3939
use vortex_array::arrays::FixedSizeListArray;
4040
use vortex_array::arrays::PrimitiveArray;
4141
use vortex_array::match_each_float_ptype;
42-
use vortex_array::validity::Validity;
4342
use vortex_buffer::BufferMut;
4443
use vortex_error::VortexResult;
4544
use vortex_error::vortex_ensure_eq;
@@ -54,6 +53,7 @@ use crate::utils::extension_element_ptype;
5453
/// [`match_each_float_ptype!`].
5554
#[inline]
5655
fn f32_to_t<T: FromPrimitive + Zero>(v: f32) -> T {
56+
// TODO(connor): Is this actually correct? How should we handle f64 overflow?
5757
FromPrimitive::from_f32(v).unwrap_or_else(T::zero)
5858
}
5959

@@ -70,8 +70,8 @@ fn compute_unit_dots(
7070

7171
let lhs_codes_fsl: FixedSizeListArray = lhs.codes().clone().execute(ctx)?;
7272
let rhs_codes_fsl: FixedSizeListArray = rhs.codes().clone().execute(ctx)?;
73-
let lhs_codes = lhs_codes_fsl.elements().to_canonical()?.into_primitive();
74-
let rhs_codes = rhs_codes_fsl.elements().to_canonical()?.into_primitive();
73+
let lhs_codes: PrimitiveArray = lhs_codes_fsl.elements().clone().execute(ctx)?;
74+
let rhs_codes: PrimitiveArray = rhs_codes_fsl.elements().clone().execute(ctx)?;
7575
let ca = lhs_codes.as_slice::<u8>();
7676
let cb = rhs_codes.as_slice::<u8>();
7777

@@ -116,15 +116,19 @@ pub fn cosine_similarity_quantized_column(
116116
);
117117

118118
let element_ptype = extension_element_ptype(lhs.dtype().as_extension())?;
119+
let validity = lhs.norms().validity()?.and(rhs.norms().validity()?)?;
119120
let dots = compute_unit_dots(&lhs, &rhs, ctx)?;
120121

121122
// The unit-norm dot product IS the cosine similarity. Cast from f32 to the native type.
122123
match_each_float_ptype!(element_ptype, |T| {
123124
let mut result = BufferMut::<T>::with_capacity(dots.len());
124125
for &dot in &dots {
125-
result.push(f32_to_t(dot));
126+
// SAFETY: We allocated the correct amount.
127+
unsafe { result.push_unchecked(f32_to_t(dot)) };
126128
}
127-
Ok(PrimitiveArray::new::<T>(result.freeze(), Validity::NonNullable).into_array())
129+
130+
// SAFETY: `result` has the same length as the input arrays, matching `validity`.
131+
Ok(unsafe { PrimitiveArray::new_unchecked(result.freeze(), validity) }.into_array())
128132
})
129133
}
130134

@@ -146,6 +150,7 @@ pub fn dot_product_quantized_column(
146150
);
147151

148152
let element_ptype = extension_element_ptype(lhs.dtype().as_extension())?;
153+
let validity = lhs.norms().validity()?.and(rhs.norms().validity()?)?;
149154
let dots = compute_unit_dots(&lhs, &rhs, ctx)?;
150155
let num_rows = lhs.norms().len();
151156

@@ -160,9 +165,11 @@ pub fn dot_product_quantized_column(
160165
let mut result = BufferMut::<T>::with_capacity(num_rows);
161166
for row in 0..num_rows {
162167
let dot_t: T = f32_to_t(dots[row]);
163-
result.push(na[row] * nb[row] * dot_t);
168+
// SAFETY: We allocated the correct amount.
169+
unsafe { result.push_unchecked(na[row] * nb[row] * dot_t) };
164170
}
165171

166-
Ok(PrimitiveArray::new::<T>(result.freeze(), Validity::NonNullable).into_array())
172+
// SAFETY: `result` has the same length as the input arrays, matching `validity`.
173+
Ok(unsafe { PrimitiveArray::new_unchecked(result.freeze(), validity) }.into_array())
167174
})
168175
}

0 commit comments

Comments
 (0)