Skip to content

Commit eed58ad

Browse files
sharifhsnclaude
andcommitted
PERF: Remove Numba union-find kernel in favor of SciPy fallback
Remove _union and _st_fused_ccl Numba functions. End-to-end benchmarks show only ~10% speedup over the SciPy connected-components path, which doesn't justify ~100 lines of Numba-specific code. The SciPy path provides ~5x speedup over the original BFS without requiring Numba. A permalink to the removed Numba kernel is preserved in a code comment for future reference. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 8b66a74 commit eed58ad

1 file changed

Lines changed: 5 additions & 107 deletions

File tree

mne/stats/cluster_level.py

Lines changed: 5 additions & 107 deletions
Original file line numberDiff line numberDiff line change
@@ -124,98 +124,6 @@ def _sum_cluster_data(data, tstep):
124124
return np.sign(data) * tstep
125125

126126

127-
@jit(inline="always")
128-
def _union(a_pos, b_pos, parent, rank):
129-
"""Find roots with path compression and union by rank."""
130-
ra = a_pos
131-
while parent[ra] != ra:
132-
parent[ra] = parent[parent[ra]]
133-
ra = parent[ra]
134-
rb = b_pos
135-
while parent[rb] != rb:
136-
parent[rb] = parent[parent[rb]]
137-
rb = parent[rb]
138-
if ra != rb:
139-
if rank[ra] < rank[rb]:
140-
parent[ra] = rb
141-
elif rank[ra] > rank[rb]:
142-
parent[rb] = ra
143-
else:
144-
parent[rb] = ra
145-
rank[ra] += 1
146-
147-
148-
@jit()
149-
def _st_fused_ccl(
150-
active_idx, n_active, flat_to_active, adj_indptr, adj_indices, n_src, max_step
151-
):
152-
"""Label connected components among supra-threshold vertices via union-find.
153-
154-
Replaces the Python BFS in ``_get_clusters_st`` with a single-pass
155-
union-find (disjoint-set) algorithm over spatial and temporal neighbors.
156-
Data is organized as ``n_times x n_src``; spatial adjacency is stored in
157-
CSR format (``adj_indptr``/``adj_indices``), and temporal neighbors are
158-
the same source vertex at up to ``max_step`` earlier time points.
159-
160-
Each active vertex starts as its own component. A linear scan unions each
161-
vertex with its active spatial and temporal neighbors. Path compression
162-
and union-by-rank keep the amortized cost per union nearly O(1), making
163-
the full pass O(n * alpha(n)) where alpha is the inverse Ackermann
164-
function. The main practical speedup comes from running entirely inside
165-
a single Numba-compiled function, eliminating the per-vertex
166-
Python/Numba boundary crossings of the BFS approach.
167-
"""
168-
# Union-find / disjoint-set forest:
169-
# https://en.wikipedia.org/wiki/Disjoint-set_data_structure
170-
# build flat→active mapping
171-
for i in range(n_active):
172-
flat_to_active[active_idx[i]] = i
173-
174-
parent = np.arange(n_active)
175-
rank = np.zeros(n_active, dtype=np.int32)
176-
177-
for a_pos in range(n_active):
178-
flat_i = active_idx[a_pos]
179-
t_i = flat_i // n_src
180-
s_i = flat_i - t_i * n_src
181-
182-
# spatial neighbors
183-
for j_ptr in range(adj_indptr[s_i], adj_indptr[s_i + 1]):
184-
s_j = adj_indices[j_ptr]
185-
flat_j = t_i * n_src + s_j
186-
b_pos = flat_to_active[flat_j]
187-
if b_pos >= 0:
188-
_union(a_pos, b_pos, parent, rank)
189-
190-
# temporal neighbors (same vertex, previous time steps)
191-
for step in range(1, max_step + 1):
192-
if t_i >= step:
193-
flat_j = (t_i - step) * n_src + s_i
194-
b_pos = flat_to_active[flat_j]
195-
if b_pos >= 0:
196-
_union(a_pos, b_pos, parent, rank)
197-
198-
# final path compression + relabel to 0..n_components-1
199-
label_map = -np.ones(n_active, dtype=np.intp)
200-
next_label = np.intp(0)
201-
components = np.empty(n_active, dtype=np.intp)
202-
for i in range(n_active):
203-
a = i
204-
while parent[a] != a:
205-
a = parent[a]
206-
parent[i] = a
207-
if label_map[a] == -1:
208-
label_map[a] = next_label
209-
next_label += 1
210-
components[i] = label_map[a]
211-
212-
# clean up flat_to_active for next call
213-
for i in range(n_active):
214-
flat_to_active[active_idx[i]] = -1
215-
216-
return components
217-
218-
219127
def _get_clusters_spatial(s, neighbors):
220128
"""Form spatial clusters using neighbor lists.
221129
@@ -747,21 +655,11 @@ def _find_clusters_1dir(
747655
return [], np.atleast_1d(np.array([]))
748656
clusters = []
749657
else:
750-
if has_numba:
751-
_flat_map = -np.ones(len(x_in), dtype=np.intp)
752-
components = _st_fused_ccl(
753-
active_idx,
754-
n_active,
755-
_flat_map,
756-
_indptr,
757-
_indices,
758-
_n_src,
759-
max_step,
760-
)
761-
else:
762-
components = _get_clusters_st_scipy(
763-
x_in, (_indptr, _indices, _n_src), max_step
764-
)
658+
# SciPy connected-components; for a Numba union-find
659+
# alternative see https://github.com/sharifhsn/mne-python/blob/999ea49d9f180cea87dc3d522e530b51fba0dcc5/mne/stats/cluster_level.py#L122-L220
660+
components = _get_clusters_st_scipy(
661+
x_in, (_indptr, _indices, _n_src), max_step
662+
)
765663
if _sums_only:
766664
if t_power == 1:
767665
sums = np.bincount(components, weights=x[active_idx])

0 commit comments

Comments
 (0)