@@ -124,98 +124,6 @@ def _sum_cluster_data(data, tstep):
124124 return np .sign (data ) * tstep
125125
126126
127- @jit (inline = "always" )
128- def _union (a_pos , b_pos , parent , rank ):
129- """Find roots with path compression and union by rank."""
130- ra = a_pos
131- while parent [ra ] != ra :
132- parent [ra ] = parent [parent [ra ]]
133- ra = parent [ra ]
134- rb = b_pos
135- while parent [rb ] != rb :
136- parent [rb ] = parent [parent [rb ]]
137- rb = parent [rb ]
138- if ra != rb :
139- if rank [ra ] < rank [rb ]:
140- parent [ra ] = rb
141- elif rank [ra ] > rank [rb ]:
142- parent [rb ] = ra
143- else :
144- parent [rb ] = ra
145- rank [ra ] += 1
146-
147-
148- @jit ()
149- def _st_fused_ccl (
150- active_idx , n_active , flat_to_active , adj_indptr , adj_indices , n_src , max_step
151- ):
152- """Label connected components among supra-threshold vertices via union-find.
153-
154- Replaces the Python BFS in ``_get_clusters_st`` with a single-pass
155- union-find (disjoint-set) algorithm over spatial and temporal neighbors.
156- Data is organized as ``n_times x n_src``; spatial adjacency is stored in
157- CSR format (``adj_indptr``/``adj_indices``), and temporal neighbors are
158- the same source vertex at up to ``max_step`` earlier time points.
159-
160- Each active vertex starts as its own component. A linear scan unions each
161- vertex with its active spatial and temporal neighbors. Path compression
162- and union-by-rank keep the amortized cost per union nearly O(1), making
163- the full pass O(n * alpha(n)) where alpha is the inverse Ackermann
164- function. The main practical speedup comes from running entirely inside
165- a single Numba-compiled function, eliminating the per-vertex
166- Python/Numba boundary crossings of the BFS approach.
167- """
168- # Union-find / disjoint-set forest:
169- # https://en.wikipedia.org/wiki/Disjoint-set_data_structure
170- # build flat→active mapping
171- for i in range (n_active ):
172- flat_to_active [active_idx [i ]] = i
173-
174- parent = np .arange (n_active )
175- rank = np .zeros (n_active , dtype = np .int32 )
176-
177- for a_pos in range (n_active ):
178- flat_i = active_idx [a_pos ]
179- t_i = flat_i // n_src
180- s_i = flat_i - t_i * n_src
181-
182- # spatial neighbors
183- for j_ptr in range (adj_indptr [s_i ], adj_indptr [s_i + 1 ]):
184- s_j = adj_indices [j_ptr ]
185- flat_j = t_i * n_src + s_j
186- b_pos = flat_to_active [flat_j ]
187- if b_pos >= 0 :
188- _union (a_pos , b_pos , parent , rank )
189-
190- # temporal neighbors (same vertex, previous time steps)
191- for step in range (1 , max_step + 1 ):
192- if t_i >= step :
193- flat_j = (t_i - step ) * n_src + s_i
194- b_pos = flat_to_active [flat_j ]
195- if b_pos >= 0 :
196- _union (a_pos , b_pos , parent , rank )
197-
198- # final path compression + relabel to 0..n_components-1
199- label_map = - np .ones (n_active , dtype = np .intp )
200- next_label = np .intp (0 )
201- components = np .empty (n_active , dtype = np .intp )
202- for i in range (n_active ):
203- a = i
204- while parent [a ] != a :
205- a = parent [a ]
206- parent [i ] = a
207- if label_map [a ] == - 1 :
208- label_map [a ] = next_label
209- next_label += 1
210- components [i ] = label_map [a ]
211-
212- # clean up flat_to_active for next call
213- for i in range (n_active ):
214- flat_to_active [active_idx [i ]] = - 1
215-
216- return components
217-
218-
219127def _get_clusters_spatial (s , neighbors ):
220128 """Form spatial clusters using neighbor lists.
221129
@@ -747,21 +655,11 @@ def _find_clusters_1dir(
747655 return [], np .atleast_1d (np .array ([]))
748656 clusters = []
749657 else :
750- if has_numba :
751- _flat_map = - np .ones (len (x_in ), dtype = np .intp )
752- components = _st_fused_ccl (
753- active_idx ,
754- n_active ,
755- _flat_map ,
756- _indptr ,
757- _indices ,
758- _n_src ,
759- max_step ,
760- )
761- else :
762- components = _get_clusters_st_scipy (
763- x_in , (_indptr , _indices , _n_src ), max_step
764- )
658+ # SciPy connected-components; for a Numba union-find
659+ # alternative see https://github.com/sharifhsn/mne-python/blob/999ea49d9f180cea87dc3d522e530b51fba0dcc5/mne/stats/cluster_level.py#L122-L220
660+ components = _get_clusters_st_scipy (
661+ x_in , (_indptr , _indices , _n_src ), max_step
662+ )
765663 if _sums_only :
766664 if t_power == 1 :
767665 sums = np .bincount (components , weights = x [active_idx ])
0 commit comments