@@ -39,7 +39,6 @@ use vortex_array::IntoArray;
3939use vortex_array:: arrays:: FixedSizeListArray ;
4040use vortex_array:: arrays:: PrimitiveArray ;
4141use vortex_array:: match_each_float_ptype;
42- use vortex_array:: validity:: Validity ;
4342use vortex_buffer:: BufferMut ;
4443use vortex_error:: VortexResult ;
4544use vortex_error:: vortex_ensure_eq;
@@ -54,6 +53,7 @@ use crate::utils::extension_element_ptype;
5453/// [`match_each_float_ptype!`].
5554#[ inline]
5655fn f32_to_t < T : FromPrimitive + Zero > ( v : f32 ) -> T {
56+ // TODO(connor): Is this actually correct? How should we handle f64 overflow?
5757 FromPrimitive :: from_f32 ( v) . unwrap_or_else ( T :: zero)
5858}
5959
@@ -70,8 +70,8 @@ fn compute_unit_dots(
7070
7171 let lhs_codes_fsl: FixedSizeListArray = lhs. codes ( ) . clone ( ) . execute ( ctx) ?;
7272 let rhs_codes_fsl: FixedSizeListArray = rhs. codes ( ) . clone ( ) . execute ( ctx) ?;
73- let lhs_codes = lhs_codes_fsl. elements ( ) . to_canonical ( ) ? . into_primitive ( ) ;
74- let rhs_codes = rhs_codes_fsl. elements ( ) . to_canonical ( ) ? . into_primitive ( ) ;
73+ let lhs_codes: PrimitiveArray = lhs_codes_fsl. elements ( ) . clone ( ) . execute ( ctx ) ? ;
74+ let rhs_codes: PrimitiveArray = rhs_codes_fsl. elements ( ) . clone ( ) . execute ( ctx ) ? ;
7575 let ca = lhs_codes. as_slice :: < u8 > ( ) ;
7676 let cb = rhs_codes. as_slice :: < u8 > ( ) ;
7777
@@ -116,15 +116,19 @@ pub fn cosine_similarity_quantized_column(
116116 ) ;
117117
118118 let element_ptype = extension_element_ptype ( lhs. dtype ( ) . as_extension ( ) ) ?;
119+ let validity = lhs. norms ( ) . validity ( ) ?. and ( rhs. norms ( ) . validity ( ) ?) ?;
119120 let dots = compute_unit_dots ( & lhs, & rhs, ctx) ?;
120121
121122 // The unit-norm dot product IS the cosine similarity. Cast from f32 to the native type.
122123 match_each_float_ptype ! ( element_ptype, |T | {
123124 let mut result = BufferMut :: <T >:: with_capacity( dots. len( ) ) ;
124125 for & dot in & dots {
125- result. push( f32_to_t( dot) ) ;
126+ // SAFETY: We allocated the correct amount.
127+ unsafe { result. push_unchecked( f32_to_t( dot) ) } ;
126128 }
127- Ok ( PrimitiveArray :: new:: <T >( result. freeze( ) , Validity :: NonNullable ) . into_array( ) )
129+
130+ // SAFETY: `result` has the same length as the input arrays, matching `validity`.
131+ Ok ( unsafe { PrimitiveArray :: new_unchecked( result. freeze( ) , validity) } . into_array( ) )
128132 } )
129133}
130134
@@ -146,6 +150,7 @@ pub fn dot_product_quantized_column(
146150 ) ;
147151
148152 let element_ptype = extension_element_ptype ( lhs. dtype ( ) . as_extension ( ) ) ?;
153+ let validity = lhs. norms ( ) . validity ( ) ?. and ( rhs. norms ( ) . validity ( ) ?) ?;
149154 let dots = compute_unit_dots ( & lhs, & rhs, ctx) ?;
150155 let num_rows = lhs. norms ( ) . len ( ) ;
151156
@@ -160,9 +165,11 @@ pub fn dot_product_quantized_column(
160165 let mut result = BufferMut :: <T >:: with_capacity( num_rows) ;
161166 for row in 0 ..num_rows {
162167 let dot_t: T = f32_to_t( dots[ row] ) ;
163- result. push( na[ row] * nb[ row] * dot_t) ;
168+ // SAFETY: We allocated the correct amount.
169+ unsafe { result. push_unchecked( na[ row] * nb[ row] * dot_t) } ;
164170 }
165171
166- Ok ( PrimitiveArray :: new:: <T >( result. freeze( ) , Validity :: NonNullable ) . into_array( ) )
172+ // SAFETY: `result` has the same length as the input arrays, matching `validity`.
173+ Ok ( unsafe { PrimitiveArray :: new_unchecked( result. freeze( ) , validity) } . into_array( ) )
167174 } )
168175}
0 commit comments