Skip to content

Commit e0cc001

Browse files
committed
fix(embedded): race-free len, full alias docstring, faster brute force
- hnsw.rs: `__len__` now reads the count from the underlying HnswIndex instead of `next_id`. The previous derivation could observe phantom IDs under concurrent fit/__len__ — `fit` bumps next_id under its own mutex, releases it, then acquires the index lock for the actual inserts (with the GIL released around the build). A concurrent `__len__` between those two windows would see the bumped counter before the inserts landed. Locking `inner` makes the count reflect committed inserts only. - hnsw.rs: `__repr__` now reports the same len source plus a `<busy>` marker when the index lock is contended (via `try_lock`). Stops a debug REPL from blocking on a concurrent build. - hnsw.rs: docstring for the `metric` constructor argument now lists every accepted alias (cosine/angular, euclidean/l2, dot/dot_product/ ip/inner_product, manhattan/l1) instead of a partial list. The error message and the parser were already exhaustive; the docstring just needed to catch up. - tests/unit/test_hnsw.py: the brute-force top-k helper uses `np.argpartition` (O(N)) instead of `np.argsort` (O(N log N)) — only the SET of nearest k matters for the recall metric, not the order inside it. - tests/unit/test_hnsw.py: the recall test allocates the 10K × 16 float matrix directly as float32 via `standard_normal(dtype=np.float32)` instead of allocating float64 then `.astype(np.float32)`. Halves the peak memory for that test.
1 parent a0eb616 commit e0cc001

2 files changed

Lines changed: 37 additions & 19 deletions

File tree

coordinode-embedded/src/hnsw.rs

Lines changed: 29 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -64,9 +64,14 @@ impl Hnsw {
6464
///
6565
/// # Arguments
6666
/// * `dim` — embedding dimension (must match the vectors passed to `fit` / `knn_query`).
67-
/// * `metric` — distance metric. One of `cosine` / `angular`, `euclidean` / `l2`,
68-
/// `dot` / `inner_product`, `manhattan` / `l1`. Names mirror ann-benchmarks
69-
/// conventions so existing harnesses pass their `space` argument unchanged.
67+
/// * `metric` — distance metric. Accepted aliases (all case-insensitive):
68+
/// - cosine similarity: `cosine`, `angular`
69+
/// - Euclidean (L2): `euclidean`, `l2`
70+
/// - dot product: `dot`, `dot_product`, `ip`, `inner_product`
71+
/// - Manhattan (L1): `manhattan`, `l1`
72+
///
73+
/// Spellings track ann-benchmarks and VectorDBBench so existing
74+
/// harnesses pass their `space` argument unchanged.
7075
/// * `M` — max connections per element per layer (HNSW spec). Default 16.
7176
/// * `ef_construction` — candidate list size during build. Default 200.
7277
/// * `max_elements` — hint to pre-allocate node storage. Default 1_000_000.
@@ -208,23 +213,31 @@ impl Hnsw {
208213

209214
/// Number of vectors indexed.
210215
fn __len__(&self) -> PyResult<usize> {
211-
// `next_id` is monotonically incremented per insert, so it doubles
212-
// as the count without us reaching into HnswIndex internals.
213-
let next = self
214-
.next_id
216+
// Read the count from the HnswIndex itself, NOT from `next_id`.
217+
// `next_id` is bumped under its own mutex before the inserts happen
218+
// under `inner`; with the GIL released around the build, a concurrent
219+
// `__len__` call would otherwise observe phantom IDs that haven't
220+
// actually landed in the index. Locking `inner` makes the count
221+
// reflect committed inserts only.
222+
let index = self
223+
.inner
215224
.lock()
216-
.map_err(|e| PyRuntimeError::new_err(format!("next_id lock poisoned: {e}")))?;
217-
Ok(*next as usize)
225+
.map_err(|e| PyRuntimeError::new_err(format!("index lock poisoned: {e}")))?;
226+
Ok(index.len())
218227
}
219228

220229
fn __repr__(&self) -> String {
221-
// `__len__` surfaces a poisoned mutex as RuntimeError; `__repr__` can't
222-
// raise (Python expects it to always return a string) so we emit a
223-
// visible marker instead of silently reporting len=0. Hiding a poisoned
224-
// lock would mask real concurrency bugs during debugging.
225-
let len_repr = match self.next_id.lock() {
226-
Ok(g) => g.to_string(),
227-
Err(_) => "<poisoned>".to_owned(),
230+
// `__len__` surfaces a poisoned mutex as RuntimeError; `__repr__`
231+
// can't raise (Python expects it to always return a string), so a
232+
// poison is rendered as a visible marker rather than a silent
233+
// `len=0` that would mask real concurrency bugs during debugging.
234+
// `try_lock` is intentional: even when uncontended `__repr__` runs
235+
// in the debugger and must not block a concurrent build that holds
236+
// the lock — we'd rather show `<busy>` than deadlock the REPL.
237+
let len_repr = match self.inner.try_lock() {
238+
Ok(idx) => idx.len().to_string(),
239+
Err(std::sync::TryLockError::WouldBlock) => "<busy>".to_owned(),
240+
Err(std::sync::TryLockError::Poisoned(_)) => "<poisoned>".to_owned(),
228241
};
229242
format!("Hnsw(dim={}, len={len_repr})", self.dim)
230243
}

tests/unit/test_hnsw.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,11 @@
1414

1515

1616
def _brute_force_topk(X, q, k: int):
17+
# argpartition gives the top-k indices in O(N), vs argsort's O(N log N).
18+
# We only need the SET of nearest k, ordering inside the set doesn't
19+
# matter for the recall metric.
1720
dists = ((X - q) ** 2).sum(axis=1)
18-
return set(np.argsort(dists)[:k].tolist())
21+
return set(np.argpartition(dists, k)[:k].tolist())
1922

2023

2124
def test_metric_parsing_and_dim_validation() -> None:
@@ -59,8 +62,10 @@ def test_recall_at_10_geq_0_95() -> None:
5962
we hold queries out of the training set).
6063
"""
6164
rng = np.random.default_rng(42)
62-
X = rng.standard_normal((10_000, 16)).astype(np.float32)
63-
queries = rng.standard_normal((50, 16)).astype(np.float32)
65+
# `dtype=` on standard_normal skips the float64-then-astype round-trip,
66+
# halving the allocation for this 10K × 16 matrix.
67+
X = rng.standard_normal((10_000, 16), dtype=np.float32)
68+
queries = rng.standard_normal((50, 16), dtype=np.float32)
6469

6570
idx = ce.Hnsw(dim=16, metric="euclidean", M=16, ef_construction=200)
6671
idx.fit(X)

0 commit comments

Comments
 (0)