Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/changes/dev/newfeature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Speed up :func:`mne.time_frequency.psd_array_welch` and related Welch PSD methods by ~25x for epoched data by batching spectrogram calls instead of per-channel dispatch, by :newcontrib:`Sharif Haason`.
21 changes: 15 additions & 6 deletions mne/time_frequency/psd.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,14 +62,23 @@ def _decomp_aggregate_mask(epoch, func, average, freq_sl):

def _spect_func(epoch, func, freq_sl, average, *, output="power"):
"""Aux function."""
# Decide if we should split this to save memory or not, since doing
# multiple calls will incur some performance overhead. Eventually we might
# want to write (really, go back to) our own spectrogram implementation
# that, if possible, averages after each transform, but this will incur
# a lot of overhead because of the many Python calls required.
# Process in chunks to balance vectorization (scipy.signal.spectrogram
# handles multi-row input efficiently) against memory usage.
kwargs = dict(func=func, average=average, freq_sl=freq_sl)
if epoch.nbytes > 10e6:
spect = np.apply_along_axis(_decomp_aggregate_mask, -1, epoch, **kwargs)
# Process in chunks of rows instead of one-by-one. Each chunk is
# passed to spectrogram as a 2D array, which is much faster than
# calling spectrogram per-row via np.apply_along_axis.
n_rows = epoch.shape[0]
# Target ~10 MB per chunk (same threshold as the original code)
row_bytes = epoch[0].nbytes
chunk_size = max(1, int(10e6 / row_bytes))
parts = []
for start in range(0, n_rows, chunk_size):
parts.append(
_decomp_aggregate_mask(epoch[start : start + chunk_size], **kwargs)
)
spect = np.concatenate(parts, axis=0)
else:
spect = _decomp_aggregate_mask(epoch, **kwargs)
return spect
Expand Down
Loading