Skip to content

Commit 39d2272

Browse files
sharifhsnclaude
andcommitted
PERF: speed up spatio-temporal cluster permutation tests
Precompute sum-of-squares for sign-flip t-tests (s²=1 invariant) and replace Python BFS with SciPy connected_components in _get_clusters_st. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 49a6cd3 commit 39d2272

2 files changed

Lines changed: 104 additions & 49 deletions

File tree

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Speed up :func:`mne.stats.spatio_temporal_cluster_1samp_test` and related permutation cluster functions via precomputed sum-of-squares for sign-flip t-tests and SciPy connected-components clustering (~5x), by :newcontrib:`Sharif Haason`.

mne/stats/cluster_level.py

Lines changed: 103 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ def _masked_sum_power(x, c, t_power):
121121

122122
@jit()
123123
def _sum_cluster_data(data, tstep):
124-
return np.sign(data) * np.logical_not(data == 0) * tstep
124+
return np.sign(data) * tstep
125125

126126

127127
def _get_clusters_spatial(s, neighbors):
@@ -258,33 +258,62 @@ def _get_clusters_st_multistep(keepers, neighbors, max_step=1):
258258

259259

260260
def _get_clusters_st(x_in, neighbors, max_step=1):
261-
"""Choose the most efficient version."""
261+
"""Find spatio-temporal clusters via SciPy connected components.
262+
263+
Builds a sparse adjacency graph over only the supra-threshold vertices
264+
(spatial edges from the neighbor lists, temporal edges between the same
265+
source at adjacent time steps) and labels clusters with
266+
``scipy.sparse.csgraph.connected_components``.
267+
"""
262268
n_src = len(neighbors)
263-
n_times = x_in.size // n_src
264-
cl_goods = np.where(x_in)[0]
265-
if len(cl_goods) > 0:
266-
keepers = [np.array([], dtype=int)] * n_times
267-
row, col = np.unravel_index(cl_goods, (n_times, n_src))
268-
lims = [0]
269-
if isinstance(row, int):
270-
row = [row]
271-
col = [col]
272-
else:
273-
order = np.argsort(row)
274-
row = row[order]
275-
col = col[order]
276-
lims += (np.where(np.diff(row) > 0)[0] + 1).tolist()
277-
lims.append(len(row))
278-
279-
for start, end in zip(lims[:-1], lims[1:]):
280-
keepers[row[start]] = np.sort(col[start:end])
281-
if max_step == 1:
282-
return _get_clusters_st_1step(keepers, neighbors)
283-
else:
284-
return _get_clusters_st_multistep(keepers, neighbors, max_step)
285-
else:
269+
n_total = len(x_in)
270+
active = np.where(x_in)[0]
271+
if len(active) == 0:
286272
return []
287273

274+
# Convert neighbor lists to CSR for vectorized expansion
275+
lengths = np.array([len(a) for a in neighbors])
276+
indptr = np.zeros(n_src + 1, dtype=np.intp)
277+
np.cumsum(lengths, out=indptr[1:])
278+
indices = np.concatenate(neighbors).astype(np.intp)
279+
280+
active_t, active_s = np.divmod(active, n_src)
281+
282+
# Spatial edges: vectorized CSR neighbor expansion
283+
neighbor_counts = indptr[active_s + 1] - indptr[active_s]
284+
src_flat = np.repeat(active, neighbor_counts)
285+
src_t = np.repeat(active_t, neighbor_counts)
286+
starts = indptr[active_s]
287+
offsets = np.arange(int(np.sum(neighbor_counts))) - np.repeat(
288+
np.cumsum(neighbor_counts) - neighbor_counts, neighbor_counts
289+
)
290+
nb_s = indices[np.repeat(starts, neighbor_counts) + offsets]
291+
nb_flat = src_t * n_src + nb_s
292+
mask = x_in[nb_flat]
293+
rows = [src_flat[mask]]
294+
cols = [nb_flat[mask]]
295+
296+
# Temporal edges: same source, adjacent time steps
297+
for step in range(1, max_step + 1):
298+
mask_t = active_t >= step
299+
later = active[mask_t]
300+
earlier = later - step * n_src
301+
both = x_in[earlier]
302+
rows.extend([later[both], earlier[both]])
303+
cols.extend([earlier[both], later[both]])
304+
305+
# Self-loops so isolated active vertices get their own component
306+
rows.append(active)
307+
cols.append(active)
308+
row = np.concatenate(rows)
309+
col = np.concatenate(cols)
310+
adj = sparse.coo_array((np.ones(len(row)), (row, col)), shape=(n_total, n_total))
311+
_, labels = connected_components(adj)
312+
313+
# Build cluster list directly from component labels
314+
cluster_labels = labels[active]
315+
return [active[cluster_labels == id_] for id_ in np.unique(cluster_labels)]
316+
288317

289318
def _get_components(x_in, adjacency, return_list=True):
290319
"""Get connected components from a mask and a adjacency matrix."""
@@ -745,41 +774,63 @@ def _do_1samp_permutations(
745774
# allocate space for output
746775
max_cluster_sums = np.empty(len(orders), dtype=np.double)
747776

777+
# For sign-flips s²=1, so sum(X²) is constant across permutations.
778+
# Precompute once and derive t-statistics via algebra instead of
779+
# calling stat_fun each iteration.
780+
use_fast_ttest = stat_fun is ttest_1samp_no_p
781+
if use_fast_ttest:
782+
sum_sq = np.sum(X**2, axis=0)
783+
sqrt_n_nm1 = np.sqrt(n_samp * (n_samp - 1))
784+
inv_n = 1.0 / n_samp
785+
neg_n = -float(n_samp)
786+
748787
if buffer_size is not None:
749788
# allocate a buffer so we don't need to allocate memory in loop
750789
X_flip_buffer = np.empty((n_samp, buffer_size), dtype=X.dtype)
751790

752791
for seed_idx, order in enumerate(orders):
753792
assert isinstance(order, np.ndarray)
754-
# new surrogate data with specified sign flip
755793
assert order.size == n_samp # should be guaranteed by parent
756-
signs = 2 * order[:, None].astype(int) - 1
757-
if not np.all(np.equal(np.abs(signs), 1)):
758-
raise ValueError("signs from rng must be +/- 1")
759794

760-
if buffer_size is None:
761-
# be careful about non-writable memmap (GH#1507)
762-
if X.flags.writeable:
763-
X *= signs
764-
# Recompute statistic on randomized data
765-
t_obs_surr = stat_fun(X)
766-
# Set X back to previous state (trade memory eff. for CPU use)
767-
X *= signs
768-
else:
769-
t_obs_surr = stat_fun(X * signs)
795+
if use_fast_ttest:
796+
signs = 2.0 * order - 1.0 # (n_samp,) ±1
797+
dot = signs @ X # (n_vars,)
798+
mean_s = dot * inv_n
799+
denom_sq = np.maximum(sum_sq + mean_s * mean_s * neg_n, 0.0)
800+
t_obs_surr = np.where(
801+
denom_sq > 0, mean_s / np.sqrt(denom_sq) * sqrt_n_nm1, 0.0
802+
)
770803
else:
771-
# only sign-flip a small data buffer, so we need less memory
772-
t_obs_surr = np.empty(n_vars, dtype=X.dtype)
804+
# new surrogate data with specified sign flip
805+
signs = 2 * order[:, None].astype(int) - 1
806+
if not np.all(np.equal(np.abs(signs), 1)):
807+
raise ValueError("signs from rng must be +/- 1")
808+
809+
if buffer_size is None:
810+
# be careful about non-writable memmap (GH#1507)
811+
if X.flags.writeable:
812+
X *= signs
813+
# Recompute statistic on randomized data
814+
t_obs_surr = stat_fun(X)
815+
# Set X back to previous state (trade memory eff. for CPU use)
816+
X *= signs
817+
else:
818+
t_obs_surr = stat_fun(X * signs)
819+
else:
820+
# only sign-flip a small data buffer, so we need less memory
821+
t_obs_surr = np.empty(n_vars, dtype=X.dtype)
773822

774-
for pos in range(0, n_vars, buffer_size):
775-
# number of variables for this loop
776-
n_var_loop = min(pos + buffer_size, n_vars) - pos
823+
for pos in range(0, n_vars, buffer_size):
824+
# number of variables for this loop
825+
n_var_loop = min(pos + buffer_size, n_vars) - pos
777826

778-
X_flip_buffer[:, :n_var_loop] = signs * X[:, pos : pos + n_var_loop]
827+
X_flip_buffer[:, :n_var_loop] = (
828+
signs * X[:, pos : pos + n_var_loop]
829+
)
779830

780-
# apply stat_fun and store result
781-
tmp = stat_fun(X_flip_buffer)
782-
t_obs_surr[pos : pos + n_var_loop] = tmp[:n_var_loop]
831+
# apply stat_fun and store result
832+
tmp = stat_fun(X_flip_buffer)
833+
t_obs_surr[pos : pos + n_var_loop] = tmp[:n_var_loop]
783834

784835
# The stat should have the same shape as the samples for no adj.
785836
if adjacency is None:
@@ -953,7 +1004,10 @@ def _permutation_cluster_test(
9531004
logger.info(f"stat_fun(H1): min={np.min(t_obs)} max={np.max(t_obs)}")
9541005

9551006
# test if stat_fun treats variables independently
956-
if buffer_size is not None:
1007+
# (skip for built-in stat functions which are known to be independent)
1008+
if buffer_size is not None and (
1009+
stat_fun is not ttest_1samp_no_p and stat_fun is not f_oneway
1010+
):
9571011
t_obs_buffer = np.zeros_like(t_obs)
9581012
for pos in range(0, n_tests, buffer_size):
9591013
t_obs_buffer[pos : pos + buffer_size] = stat_fun(

0 commit comments

Comments
 (0)