@@ -121,7 +121,7 @@ def _masked_sum_power(x, c, t_power):
121121
122122@jit ()
123123def _sum_cluster_data (data , tstep ):
124- return np .sign (data ) * np . logical_not ( data == 0 ) * tstep
124+ return np .sign (data ) * tstep
125125
126126
127127def _get_clusters_spatial (s , neighbors ):
@@ -258,33 +258,62 @@ def _get_clusters_st_multistep(keepers, neighbors, max_step=1):
258258
259259
260260def _get_clusters_st (x_in , neighbors , max_step = 1 ):
261- """Choose the most efficient version."""
261+ """Find spatio-temporal clusters via SciPy connected components.
262+
263+ Builds a sparse adjacency graph over only the supra-threshold vertices
264+ (spatial edges from the neighbor lists, temporal edges between the same
265+ source at adjacent time steps) and labels clusters with
266+ ``scipy.sparse.csgraph.connected_components``.
267+ """
262268 n_src = len (neighbors )
263- n_times = x_in .size // n_src
264- cl_goods = np .where (x_in )[0 ]
265- if len (cl_goods ) > 0 :
266- keepers = [np .array ([], dtype = int )] * n_times
267- row , col = np .unravel_index (cl_goods , (n_times , n_src ))
268- lims = [0 ]
269- if isinstance (row , int ):
270- row = [row ]
271- col = [col ]
272- else :
273- order = np .argsort (row )
274- row = row [order ]
275- col = col [order ]
276- lims += (np .where (np .diff (row ) > 0 )[0 ] + 1 ).tolist ()
277- lims .append (len (row ))
278-
279- for start , end in zip (lims [:- 1 ], lims [1 :]):
280- keepers [row [start ]] = np .sort (col [start :end ])
281- if max_step == 1 :
282- return _get_clusters_st_1step (keepers , neighbors )
283- else :
284- return _get_clusters_st_multistep (keepers , neighbors , max_step )
285- else :
269+ n_total = len (x_in )
270+ active = np .where (x_in )[0 ]
271+ if len (active ) == 0 :
286272 return []
287273
274+ # Convert neighbor lists to CSR for vectorized expansion
275+ lengths = np .array ([len (a ) for a in neighbors ])
276+ indptr = np .zeros (n_src + 1 , dtype = np .intp )
277+ np .cumsum (lengths , out = indptr [1 :])
278+ indices = np .concatenate (neighbors ).astype (np .intp )
279+
280+ active_t , active_s = np .divmod (active , n_src )
281+
282+ # Spatial edges: vectorized CSR neighbor expansion
283+ neighbor_counts = indptr [active_s + 1 ] - indptr [active_s ]
284+ src_flat = np .repeat (active , neighbor_counts )
285+ src_t = np .repeat (active_t , neighbor_counts )
286+ starts = indptr [active_s ]
287+ offsets = np .arange (int (np .sum (neighbor_counts ))) - np .repeat (
288+ np .cumsum (neighbor_counts ) - neighbor_counts , neighbor_counts
289+ )
290+ nb_s = indices [np .repeat (starts , neighbor_counts ) + offsets ]
291+ nb_flat = src_t * n_src + nb_s
292+ mask = x_in [nb_flat ]
293+ rows = [src_flat [mask ]]
294+ cols = [nb_flat [mask ]]
295+
296+ # Temporal edges: same source, adjacent time steps
297+ for step in range (1 , max_step + 1 ):
298+ mask_t = active_t >= step
299+ later = active [mask_t ]
300+ earlier = later - step * n_src
301+ both = x_in [earlier ]
302+ rows .extend ([later [both ], earlier [both ]])
303+ cols .extend ([earlier [both ], later [both ]])
304+
305+ # Self-loops so isolated active vertices get their own component
306+ rows .append (active )
307+ cols .append (active )
308+ row = np .concatenate (rows )
309+ col = np .concatenate (cols )
310+ adj = sparse .coo_array ((np .ones (len (row )), (row , col )), shape = (n_total , n_total ))
311+ _ , labels = connected_components (adj )
312+
313+ # Build cluster list directly from component labels
314+ cluster_labels = labels [active ]
315+ return [active [cluster_labels == id_ ] for id_ in np .unique (cluster_labels )]
316+
288317
289318def _get_components (x_in , adjacency , return_list = True ):
290319 """Get connected components from a mask and a adjacency matrix."""
@@ -745,41 +774,63 @@ def _do_1samp_permutations(
745774 # allocate space for output
746775 max_cluster_sums = np .empty (len (orders ), dtype = np .double )
747776
777+ # For sign-flips s²=1, so sum(X²) is constant across permutations.
778+ # Precompute once and derive t-statistics via algebra instead of
779+ # calling stat_fun each iteration.
780+ use_fast_ttest = stat_fun is ttest_1samp_no_p
781+ if use_fast_ttest :
782+ sum_sq = np .sum (X ** 2 , axis = 0 )
783+ sqrt_n_nm1 = np .sqrt (n_samp * (n_samp - 1 ))
784+ inv_n = 1.0 / n_samp
785+ neg_n = - float (n_samp )
786+
748787 if buffer_size is not None :
749788 # allocate a buffer so we don't need to allocate memory in loop
750789 X_flip_buffer = np .empty ((n_samp , buffer_size ), dtype = X .dtype )
751790
752791 for seed_idx , order in enumerate (orders ):
753792 assert isinstance (order , np .ndarray )
754- # new surrogate data with specified sign flip
755793 assert order .size == n_samp # should be guaranteed by parent
756- signs = 2 * order [:, None ].astype (int ) - 1
757- if not np .all (np .equal (np .abs (signs ), 1 )):
758- raise ValueError ("signs from rng must be +/- 1" )
759794
760- if buffer_size is None :
761- # be careful about non-writable memmap (GH#1507)
762- if X .flags .writeable :
763- X *= signs
764- # Recompute statistic on randomized data
765- t_obs_surr = stat_fun (X )
766- # Set X back to previous state (trade memory eff. for CPU use)
767- X *= signs
768- else :
769- t_obs_surr = stat_fun (X * signs )
795+ if use_fast_ttest :
796+ signs = 2.0 * order - 1.0 # (n_samp,) ±1
797+ dot = signs @ X # (n_vars,)
798+ mean_s = dot * inv_n
799+ denom_sq = np .maximum (sum_sq + mean_s * mean_s * neg_n , 0.0 )
800+ t_obs_surr = np .where (
801+ denom_sq > 0 , mean_s / np .sqrt (denom_sq ) * sqrt_n_nm1 , 0.0
802+ )
770803 else :
771- # only sign-flip a small data buffer, so we need less memory
772- t_obs_surr = np .empty (n_vars , dtype = X .dtype )
804+ # new surrogate data with specified sign flip
805+ signs = 2 * order [:, None ].astype (int ) - 1
806+ if not np .all (np .equal (np .abs (signs ), 1 )):
807+ raise ValueError ("signs from rng must be +/- 1" )
808+
809+ if buffer_size is None :
810+ # be careful about non-writable memmap (GH#1507)
811+ if X .flags .writeable :
812+ X *= signs
813+ # Recompute statistic on randomized data
814+ t_obs_surr = stat_fun (X )
815+ # Set X back to previous state (trade memory eff. for CPU use)
816+ X *= signs
817+ else :
818+ t_obs_surr = stat_fun (X * signs )
819+ else :
820+ # only sign-flip a small data buffer, so we need less memory
821+ t_obs_surr = np .empty (n_vars , dtype = X .dtype )
773822
774- for pos in range (0 , n_vars , buffer_size ):
775- # number of variables for this loop
776- n_var_loop = min (pos + buffer_size , n_vars ) - pos
823+ for pos in range (0 , n_vars , buffer_size ):
824+ # number of variables for this loop
825+ n_var_loop = min (pos + buffer_size , n_vars ) - pos
777826
778- X_flip_buffer [:, :n_var_loop ] = signs * X [:, pos : pos + n_var_loop ]
827+ X_flip_buffer [:, :n_var_loop ] = (
828+ signs * X [:, pos : pos + n_var_loop ]
829+ )
779830
780- # apply stat_fun and store result
781- tmp = stat_fun (X_flip_buffer )
782- t_obs_surr [pos : pos + n_var_loop ] = tmp [:n_var_loop ]
831+ # apply stat_fun and store result
832+ tmp = stat_fun (X_flip_buffer )
833+ t_obs_surr [pos : pos + n_var_loop ] = tmp [:n_var_loop ]
783834
784835 # The stat should have the same shape as the samples for no adj.
785836 if adjacency is None :
@@ -953,7 +1004,10 @@ def _permutation_cluster_test(
9531004 logger .info (f"stat_fun(H1): min={ np .min (t_obs )} max={ np .max (t_obs )} " )
9541005
9551006 # test if stat_fun treats variables independently
956- if buffer_size is not None :
1007+ # (skip for built-in stat functions which are known to be independent)
1008+ if buffer_size is not None and (
1009+ stat_fun is not ttest_1samp_no_p and stat_fun is not f_oneway
1010+ ):
9571011 t_obs_buffer = np .zeros_like (t_obs )
9581012 for pos in range (0 , n_tests , buffer_size ):
9591013 t_obs_buffer [pos : pos + buffer_size ] = stat_fun (
0 commit comments