Skip to content

Commit 83901ac

Browse files
committed
gauntlet: cycle 3 fixes for tq-l2-norm
Apply 3 fixes from the cycle-3 phase-3 gauntlet review of the cycle-2 fix-commit (8c02fe56d). - Cover `-0.0` in the negative-stored-norm guard (l2_norm.rs): cycle-2 used `*n < T::zero()` which is `false` for `-0.0` per IEEE 754, so a hand-constructed `-0.0` stored norm slipped through and the kernel returned `-0.0` while canonical computed `sqrt(sum_sq) == +0.0`. Switch to `n.is_sign_negative()`, which covers both strictly-negative values and `-0.0`. The comment now documents the IEEE 754 subtlety and the cache-warm `O(rows)` cost of the scan. - Add `l2_norm_over_tq_decode_with_negative_stored_norm_falls_back` (kernels.rs), parameterized over `-5.0` and `-0.0`, that hand-builds a TurboQuant array with a sign-negative stored norm and asserts the result is non-negative and finite (proving the kernel fell back to the canonical path, which always returns `|stored_norm|`). - Add `l2_norm_over_tq_decode_rejects_codes_validity_narrower_than_struct` (kernels.rs) that hand-builds a TurboQuant array whose `codes` child has row validity narrower than the outer struct's, mirroring the existing canonical-path `decode_rejects_child_masks_that_disagree_with_struct_validity` test. Asserts `L2Norm(TQDecode(_))` errors via `parse_storage_norms_only`'s validation, pinning the fast/slow-path validation parity that the cycle-1 fix-commit added. Signed-off-by: Connor Tsui <connor.tsui20@gmail.com>
1 parent d5644f6 commit 83901ac

2 files changed

Lines changed: 120 additions & 8 deletions

File tree

vortex-turboquant/src/scalar_fns/compute/l2_norm.rs

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
//! `L2Norm` execute-parent kernel that intercepts `L2Norm(TQDecode(tq))` and returns the
55
//! stored per-row norms directly instead of decoding and recomputing.
66
7-
use num_traits::Zero;
87
use vortex_array::ArrayRef;
98
use vortex_array::ExecutionCtx;
109
use vortex_array::IntoArray;
@@ -64,14 +63,25 @@ fn l2_norm_tq_decode_execute_parent(
6463
let parsed = parse_storage_norms_only(tq_array, ctx)?;
6564

6665
// Fall back to the canonical `L2Norm` path on the (adversarial) case where any stored
67-
// norm is strictly negative. Encode always produces non-negative norms (via `L2Norm`,
68-
// which returns `sqrt(sum_sq)`), but a hand-constructed TurboQuant storage could carry
69-
// arbitrary values in the `norms` child. Returning the stored bits verbatim would then
70-
// violate `L2Norm`'s always-non-negative output invariant. The canonical path runs the
71-
// in-flight decode rescaling and reapplies the stored norm, so its `L2Norm` output is
72-
// `|stored_norm|` for every row by construction.
66+
// norm has its sign bit set. Encode always produces non-negative norms (via `L2Norm`,
67+
// which returns `sqrt(sum_sq)` and never yields `-0.0`), but a hand-constructed
68+
// TurboQuant storage could carry arbitrary values in the `norms` child. Returning the
69+
// stored bits verbatim would then violate `L2Norm`'s always-non-negative output
70+
// invariant. The canonical path runs the in-flight decode rescaling and reapplies the
71+
// stored norm, so its `L2Norm` output is `|stored_norm|` for every row by construction.
72+
//
73+
// Using `is_sign_negative` rather than `< T::zero()` is load-bearing: `-0.0 < 0.0` is
74+
// `false` per IEEE 754, so a literal comparison would miss a stored `-0.0` while the
75+
// canonical path would still collapse it to `+0.0` via `sqrt(sum_sq)`.
76+
//
77+
// The scan is `O(rows)` over a buffer the just-completed `parse_storage_norms_only`
78+
// materialized, so it does not move the kernel out of its constant-time-per-row regime.
7379
let has_negative_norm = match_each_float_ptype!(parsed.norms.ptype(), |T| {
74-
parsed.norms.as_slice::<T>().iter().any(|n| *n < T::zero())
80+
parsed
81+
.norms
82+
.as_slice::<T>()
83+
.iter()
84+
.any(|n| n.is_sign_negative())
7585
});
7686
if has_negative_norm {
7787
return Ok(None);

vortex-turboquant/src/tests/kernels.rs

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -269,3 +269,105 @@ fn l2_norm_over_tq_decode_matches_canonical(#[case] dim: u32) -> VortexResult<()
269269
}
270270
Ok(())
271271
}
272+
273+
/// Adversarial: a hand-constructed TurboQuant storage with a `-5.0` or `-0.0` stored norm
274+
/// makes the fast path fall back to the canonical `L2Norm(execute(TQDecode))` path so that
275+
/// the result preserves `L2Norm`'s always-non-negative output invariant. The kernel scans
276+
/// the parsed `norms` once and triggers fallback via `is_sign_negative`, which covers both
277+
/// strictly-negative values and `-0.0` (where the literal `< 0` comparison would fail per
278+
/// IEEE 754).
279+
#[rstest]
280+
#[case::strict_negative(-5.0_f32)]
281+
#[case::negative_zero(-0.0_f32)]
282+
fn l2_norm_over_tq_decode_with_negative_stored_norm_falls_back(
283+
#[case] stored: f32,
284+
) -> VortexResult<()> {
285+
let session = test_session();
286+
let mut ctx = session.create_execution_ctx();
287+
let metadata = TurboQuantMetadata {
288+
element_ptype: PType::F32,
289+
dimensions: DIM,
290+
bit_width: 1,
291+
seed: 42,
292+
num_rounds: 3,
293+
};
294+
295+
let norms =
296+
PrimitiveArray::new::<f32>(Buffer::copy_from([stored]), Validity::NonNullable).into_array();
297+
let codes = PrimitiveArray::new::<u8>(vec![0u8; DIM as usize], Validity::NonNullable);
298+
let codes = FixedSizeListArray::try_new(codes.into_array(), DIM, Validity::NonNullable, 1)?
299+
.into_array();
300+
let storage = StructArray::try_new(
301+
FieldNames::from(["norms", "codes"]),
302+
vec![norms, codes],
303+
1,
304+
Validity::NonNullable,
305+
)?;
306+
let tq = ExtensionArray::try_new_from_vtable(TurboQuant, metadata, storage.into_array())?
307+
.into_array();
308+
309+
let decoded = TQDecode::try_new_array(tq)?.into_array();
310+
let result: PrimitiveArray = L2Norm::try_new_array(decoded, 1)?
311+
.into_array()
312+
.execute(&mut ctx)?;
313+
314+
// Whatever path runs, the result is an `L2Norm` output and must be non-negative; in
315+
// particular the kernel must NOT return the stored sign-negative value verbatim. The
316+
// exact magnitude depends on which centroid the all-zero codes decode to; we only
317+
// assert the sign and finiteness, which is what `L2Norm`'s contract pins.
318+
assert_eq!(result.as_slice::<f32>().len(), 1);
319+
let value = result.as_slice::<f32>()[0];
320+
assert!(
321+
value.is_finite() && !value.is_sign_negative(),
322+
"L2Norm result must be non-negative and finite (got {value})"
323+
);
324+
Ok(())
325+
}
326+
327+
/// Adversarial: a hand-constructed TurboQuant storage whose `codes` child has row validity
328+
/// narrower than the outer struct's must fail the fast path the same way it fails the
329+
/// canonical decode path (see `malformed::decode_rejects_child_masks_that_disagree_with_struct_validity`).
330+
/// `parse_storage_norms_only` executes the `codes` FSL wrapper specifically to enforce this
331+
/// invariant.
332+
#[test]
333+
fn l2_norm_over_tq_decode_rejects_codes_validity_narrower_than_struct() -> VortexResult<()> {
334+
let session = test_session();
335+
let mut ctx = session.create_execution_ctx();
336+
let metadata = TurboQuantMetadata {
337+
element_ptype: PType::F32,
338+
dimensions: DIM,
339+
bit_width: 1,
340+
seed: 42,
341+
num_rounds: 3,
342+
};
343+
344+
let norms =
345+
PrimitiveArray::new::<f32>(Buffer::copy_from([1.0f32, 1.0, 1.0]), Validity::NonNullable)
346+
.into_array();
347+
let codes = PrimitiveArray::new::<u8>(vec![0u8; 3 * DIM as usize], Validity::NonNullable);
348+
let codes = FixedSizeListArray::try_new(
349+
codes.into_array(),
350+
DIM,
351+
Validity::from_iter([true, false, true]),
352+
3,
353+
)?
354+
.into_array();
355+
let storage = StructArray::try_new(
356+
FieldNames::from(["norms", "codes"]),
357+
vec![norms, codes],
358+
3,
359+
Validity::NonNullable,
360+
)?;
361+
let tq = ExtensionArray::try_new_from_vtable(TurboQuant, metadata, storage.into_array())?
362+
.into_array();
363+
364+
let decoded = TQDecode::try_new_array(tq)?.into_array();
365+
let result: VortexResult<PrimitiveArray> = L2Norm::try_new_array(decoded, 3)?
366+
.into_array()
367+
.execute(&mut ctx);
368+
assert!(
369+
result.is_err(),
370+
"kernel must reject codes-validity narrower than struct-validity"
371+
);
372+
Ok(())
373+
}

0 commit comments

Comments
 (0)