Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
168 changes: 165 additions & 3 deletions vortex-tensor/src/encodings/turboquant/centroids.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,41 @@ impl HalfIntExponent {
}
}

/// Compute optimal centroids via the Max-Lloyd (Lloyd-Max) algorithm.
/// How far to spread the initial centroids, as a multiple of the coordinate standard deviation
/// `sigma = 1 / sqrt(dimension)`.
///
/// Seeding centroids across the full support `[-1, 1]` strands most of them in the near-zero-mass
/// tails, where the zero-denominator guard in [`mean_between_centroids`] freezes them for every
/// iteration; scaling the seed by `sigma` keeps every cell on live probability mass.
#[derive(Clone, Copy, Debug)]
enum InitSpread {
/// A constant multiple of `sigma`, independent of bit width. Only the sweep test constructs
/// this; production uses [`InitSpread::SqrtRate`].
#[cfg_attr(not(test), allow(dead_code))]
Fixed(f64),
/// `coeff * sqrt(bit_width)` multiples of `sigma`. A codebook with more levels needs a wider
/// seed to keep its outermost cells on live probability mass, so the spread grows with the bit
/// width — mirroring how a quantizer's optimal loading factor grows with rate.
SqrtRate(f64),
}

impl InitSpread {
/// The seed half-width, in multiples of `sigma`, for the given bit width.
fn sigmas(self, bit_width: u8) -> f64 {
match self {
InitSpread::Fixed(sigmas) => sigmas,
InitSpread::SqrtRate(coeff) => coeff * f64::from(bit_width).sqrt(),
}
}
}

/// Default centroid initialization. The seed half-width grows as `sqrt(bit_width)` standard
/// deviations, tracking the bit-width-dependent optimum and beating every fixed multiple in
/// `sweep_centroid_init` (including vLLM's `3.5 sigma`).
const DEFAULT_INIT_SPREAD: InitSpread = InitSpread::SqrtRate(1.0);

/// Compute optimal centroids via the Max-Lloyd (Lloyd-Max) algorithm with the
/// [default initialization](DEFAULT_INIT_SPREAD).
///
/// Operates on the marginal distribution of a single coordinate of a randomly rotated unit vector
/// in d dimensions.
Expand All @@ -93,15 +127,26 @@ impl HalfIntExponent {
/// `f(x) = C_d * (1 - x^2)^((d-3)/2)` on `[-1, 1]`
/// where `C_d` is the normalizing constant.
fn max_lloyd_centroids(dimension: u32, bit_width: u8) -> Buffer<f32> {
max_lloyd_centroids_with(dimension, bit_width, DEFAULT_INIT_SPREAD)
}

/// Compute Max-Lloyd centroids for an explicit [`InitSpread`]. Production code calls
/// [`max_lloyd_centroids`]; the sweep test explores alternatives through this entry point.
fn max_lloyd_centroids_with(dimension: u32, bit_width: u8, init: InitSpread) -> Buffer<f32> {
debug_assert!((1..=MAX_BIT_WIDTH).contains(&bit_width));
let num_centroids = 1usize << bit_width;

// For the marginal distribution on [-1, 1], we use the exponent (d-3)/2.
let exponent = HalfIntExponent::from_numerator(dimension as i32 - 3);

// Initialize centroids uniformly on [-1, 1].
// The coordinate marginal concentrates around 0 with this standard deviation.
let sigma = 1.0 / f64::from(dimension).sqrt();
let init_half = (init.sigmas(bit_width) * sigma).min(1.0);

// Initialize centroids uniformly on [-init_half, init_half], where the mass lives, so no cell
// starts in a zero-mass region and freezes.
let mut centroids: Vec<f64> = (0..num_centroids)
.map(|idx| -1.0 + (2.0 * (idx as f64) + 1.0) / (num_centroids as f64))
.map(|idx| -init_half + (2.0 * (idx as f64) + 1.0) * init_half / (num_centroids as f64))
.collect();

let mut boundaries: Vec<f64> = vec![0.0; num_centroids + 1];
Expand Down Expand Up @@ -222,6 +267,8 @@ pub fn find_nearest_centroid(value: f32, boundaries: &[f32]) -> u8 {

#[cfg(test)]
mod tests {
use std::f64::consts::PI;

use rstest::rstest;
use vortex_error::VortexResult;

Expand Down Expand Up @@ -329,4 +376,119 @@ mod tests {
assert!(compute_or_get_centroids(1, 2).is_err());
assert!(compute_or_get_centroids(127, 2).is_err());
}

/// Fine-grained reference measurement of a codebook's quality on the coordinate marginal,
/// computed independently of the solver's own (coarser) integration grid.
struct QuantizerQuality {
/// Implied normalized reconstruction error `E[||x - x_hat||^2 / ||x||^2]` under an ideal
/// orthogonal rotation: `dimension * E[(X - q(X))^2]`.
normalized_mse: f64,
/// `normalized_mse` divided by the Theorem 1 high-rate bound `sqrt(3) * pi / 2 / 4^b`.
ratio_to_bound: f64,
/// Number of centroids whose decision cell carries less than 1e-6 of the total mass, i.e.
/// codes that are wasted because the solver froze them in a near-zero-mass region.
wasted: usize,
}

/// Measure how well `centroids` quantize the coordinate marginal for `dimension`.
#[expect(
clippy::cast_possible_truncation,
reason = "integration samples are cast f64 -> f32 only to drive find_nearest_centroid"
)]
fn measure_quantizer(dimension: u32, bit_width: u8, centroids: &[f32]) -> QuantizerQuality {
const POINTS: usize = 100_000;
let exponent = HalfIntExponent::from_numerator(dimension as i32 - 3);
let boundaries = compute_centroid_boundaries(centroids);
let count = centroids.len();
let mut mass = vec![0.0f64; count];
let mut distortion = vec![0.0f64; count];
let mut total = 0.0f64;
let dx = 2.0 / POINTS as f64;
for step in 0..=POINTS {
let x = -1.0 + step as f64 * dx;
let trapezoid = if step == 0 || step == POINTS {
0.5
} else {
1.0
};
let weight = trapezoid * pdf_unnormalized(x, exponent);
let idx = usize::from(find_nearest_centroid(x as f32, &boundaries));
let delta = x - f64::from(centroids[idx]);
mass[idx] += weight;
distortion[idx] += weight * delta * delta;
total += weight;
}
let per_coord_mse = distortion.iter().sum::<f64>() / total;
let normalized_mse = f64::from(dimension) * per_coord_mse;
let bound = 3.0f64.sqrt() * PI / 2.0 / 4.0f64.powi(i32::from(bit_width));
let wasted = mass.iter().filter(|&&m| m / total < 1e-6).count();
QuantizerQuality {
normalized_mse,
ratio_to_bound: normalized_mse / bound,
wasted,
}
}

/// Every code in the production codebook must land on live probability mass. This is the
/// invariant the legacy `[-1, 1]` initialization violated for `dimension >= 256`, where most
/// cells froze in the zero-mass tails and wasted their codes.
#[rstest]
#[case(128)]
#[case(256)]
#[case(1024)]
#[case(2048)]
fn production_centroids_have_no_wasted_cells(#[case] dimension: u32) -> VortexResult<()> {
for bit_width in 1..=MAX_BIT_WIDTH {
let centroids = compute_or_get_centroids(dimension, bit_width)?;
let quality = measure_quantizer(dimension, bit_width, &centroids);
assert_eq!(
quality.wasted, 0,
"dim={dimension} bits={bit_width}: {} codes landed on zero-mass cells",
quality.wasted
);
}
Ok(())
}

/// Exploratory sweep over centroid-init and outer-edge configurations. Not a pass/fail gate;
/// run with `cargo test -p vortex-tensor centroids::tests::sweep -- --ignored --nocapture` to
/// compare distortion and wasted-code counts when revisiting the default configuration.
#[test]
#[ignore = "exploratory sweep; run with --ignored --nocapture"]
fn sweep_centroid_init() {
// `1e9` saturates the seed spread past 1.0, reproducing the legacy `[-1, 1]` choice.
let configs: &[(&str, InitSpread)] = &[
("legacy [-1,1]", InitSpread::Fixed(1e9)),
("fixed 2.5s", InitSpread::Fixed(2.5)),
("fixed 3.0s", InitSpread::Fixed(3.0)),
("fixed 3.5s (vLLM)", InitSpread::Fixed(3.5)),
("sqrt 1.00*sqrt(b) [default]", DEFAULT_INIT_SPREAD),
("sqrt 1.05*sqrt(b)", InitSpread::SqrtRate(1.05)),
("sqrt 1.10*sqrt(b)", InitSpread::SqrtRate(1.10)),
(
"sqrt 1.18*sqrt(b) [sqrt(2lnN)]",
InitSpread::SqrtRate(1.1774),
),
];
let dims = [128u32, 1024, 2048];
let bits_list = [4u8, 5, 6, 7, 8];

for &(name, init) in configs {
println!("\n=== {name} ===");
println!(
"{:>6} {:>5} {:>12} {:>9} {:>7}",
"dim", "bits", "norm_mse", "x_bound", "wasted"
);
for &dimension in &dims {
for &bit_width in &bits_list {
let centroids = max_lloyd_centroids_with(dimension, bit_width, init);
let q = measure_quantizer(dimension, bit_width, &centroids);
println!(
"{dimension:>6} {bit_width:>5} {:>12.3e} {:>9.2} {:>7}",
q.normalized_mse, q.ratio_to_bound, q.wasted
);
}
}
}
}
}
15 changes: 13 additions & 2 deletions vortex-turboquant/src/centroids.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,16 +106,27 @@ impl HalfIntExponent {
/// The probability distribution function is:
/// `f(x) = C_d * (1 - x^2)^((d-3)/2)` on `[-1, 1]`
/// where `C_d` is the normalizing constant.
///
/// Centroids are seeded uniformly on `±sqrt(bit_width) * sigma` (`sigma = 1/sqrt(dimension)`)
/// rather than across the full support `[-1, 1]`, which would strand most of them in the
/// near-zero-mass tails where the zero-denominator guard in [`mean_between_centroids`] freezes them.
/// This must stay identical to `vortex-tensor`'s canonical centroid code (which carries the
/// supporting sweep); the cross-crate parity test enforces it.
fn max_lloyd_centroids(dimension: u32, bit_width: u8) -> Buffer<f32> {
debug_assert!((1..=MAX_BIT_WIDTH).contains(&bit_width));
let num_centroids = 1usize << bit_width;

// For the marginal distribution on [-1, 1], we use the exponent (d-3)/2.
let exponent = HalfIntExponent::from_numerator(dimension as i32 - 3);

// Initialize centroids uniformly on [-1, 1].
// The coordinate marginal concentrates around 0 with this standard deviation.
let sigma = 1.0 / f64::from(dimension).sqrt();
let init_half = (f64::from(bit_width).sqrt() * sigma).min(1.0);

// Initialize centroids uniformly on [-init_half, init_half], where the mass lives, so no cell
// starts in a zero-mass region and freezes.
let mut centroids: Vec<f64> = (0..num_centroids)
.map(|idx| -1.0 + (2.0 * (idx as f64) + 1.0) / (num_centroids as f64))
.map(|idx| -init_half + (2.0 * (idx as f64) + 1.0) * init_half / (num_centroids as f64))
.collect();

let mut boundaries: Vec<f64> = vec![0.0; num_centroids + 1];
Expand Down
Loading