-
-
Notifications
You must be signed in to change notification settings - Fork 1.5k
ENH: Support concatenate_epochs() for EpochsTFR
#13745
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
687063b
ba82699
952219f
4e4ea34
56c5f85
3a8bd89
83d63b2
009a8a7
d748de0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1 @@ | ||
| Add support for :func:`mne.concatenate_epochs` with :class:`~mne.time_frequency.EpochsTFR` instances, by ``aman-coder03``. (:gh:`13745`) | ||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -4700,6 +4700,78 @@ def _concatenate_epochs( | |||||||||||||||||||||
| ) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
|
|
||||||||||||||||||||||
| def _concatenate_epochs_tfr(epochs_list, add_offset=True): | ||||||||||||||||||||||
| """Concatenate a list of EpochsTFR instances.""" | ||||||||||||||||||||||
| for ii, ep in enumerate(epochs_list): | ||||||||||||||||||||||
| if type(ep).__name__ != "EpochsTFR": | ||||||||||||||||||||||
| raise TypeError( | ||||||||||||||||||||||
| f"epochs_list[{ii}] must be an instance of EpochsTFR, got {type(ep)}" | ||||||||||||||||||||||
| ) | ||||||||||||||||||||||
| ref = epochs_list[0] | ||||||||||||||||||||||
| for ii, ep in enumerate(epochs_list[1:], 1): | ||||||||||||||||||||||
| if not np.array_equal(ep.freqs, ref.freqs): | ||||||||||||||||||||||
| raise ValueError(f"epochs_list[{ii}] freqs do not match epochs_list[0]") | ||||||||||||||||||||||
| if not np.array_equal(ep.times, ref.times): | ||||||||||||||||||||||
| raise ValueError(f"epochs_list[{ii}] times do not match epochs_list[0]") | ||||||||||||||||||||||
| _ensure_infos_match(ep.info, ref.info, f"epochs_list[{ii}]") | ||||||||||||||||||||||
| if ep.baseline != ref.baseline: | ||||||||||||||||||||||
| raise ValueError( | ||||||||||||||||||||||
| f"epochs_list[{ii}] baseline {ep.baseline!r} does not match " | ||||||||||||||||||||||
| f"epochs_list[0] baseline {ref.baseline!r}" | ||||||||||||||||||||||
| ) | ||||||||||||||||||||||
| if ep.method != ref.method: | ||||||||||||||||||||||
| raise ValueError( | ||||||||||||||||||||||
| f"epochs_list[{ii}] method {ep.method!r} does not match " | ||||||||||||||||||||||
| f"epochs_list[0] method {ref.method!r}" | ||||||||||||||||||||||
| ) | ||||||||||||||||||||||
| if not np.array_equal(ep.weights, ref.weights): | ||||||||||||||||||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||||||
| raise ValueError(f"epochs_list[{ii}] weights do not match epochs_list[0]") | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| data = np.concatenate([ep.data for ep in epochs_list], axis=0) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| shift = len(ref.times) | ||||||||||||||||||||||
| events_offset = ref.events[-1, 0] + shift | ||||||||||||||||||||||
| all_events = [epochs_list[0].events.copy()] | ||||||||||||||||||||||
| for ep in epochs_list[1:]: | ||||||||||||||||||||||
| evs = ep.events.copy() | ||||||||||||||||||||||
| if add_offset: | ||||||||||||||||||||||
| evs[:, 0] += events_offset | ||||||||||||||||||||||
| events_offset += int(np.max(ep.events[:, 0])) + shift | ||||||||||||||||||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||||||
| all_events.append(evs) | ||||||||||||||||||||||
| events = np.concatenate(all_events, axis=0) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| event_id = deepcopy(ref.event_id) | ||||||||||||||||||||||
| for ep in epochs_list[1:]: | ||||||||||||||||||||||
| event_id.update(ep.event_id) | ||||||||||||||||||||||
|
Comment on lines
+4743
to
+4745
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. not safe. See corresponding code in existing Lines 4620 to 4629 in 692ecce
|
||||||||||||||||||||||
|
|
||||||||||||||||||||||
| selection = np.concatenate([ep.selection for ep in epochs_list]) | ||||||||||||||||||||||
| drop_log = sum([ep.drop_log for ep in epochs_list], ()) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| metadatas = [ep.metadata for ep in epochs_list] | ||||||||||||||||||||||
| n_have = sum(m is not None for m in metadatas) | ||||||||||||||||||||||
| if n_have == 0: | ||||||||||||||||||||||
| metadata = None | ||||||||||||||||||||||
| elif n_have != len(metadatas): | ||||||||||||||||||||||
| raise ValueError( | ||||||||||||||||||||||
| f"{n_have} of {len(metadatas)} EpochsTFR instances have metadata, " | ||||||||||||||||||||||
| "all or none must have metadata" | ||||||||||||||||||||||
| ) | ||||||||||||||||||||||
| else: | ||||||||||||||||||||||
| pd = _check_pandas_installed(strict=False) | ||||||||||||||||||||||
| metadata = pd.concat(metadatas) if pd is not False else sum(metadatas, list()) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| state = ref.__getstate__() | ||||||||||||||||||||||
| state["data"] = data | ||||||||||||||||||||||
| state["events"] = events | ||||||||||||||||||||||
| state["event_id"] = event_id | ||||||||||||||||||||||
| state["selection"] = selection | ||||||||||||||||||||||
| state["drop_log"] = drop_log | ||||||||||||||||||||||
| state["metadata"] = metadata | ||||||||||||||||||||||
| out = type(epochs_list[0]).__new__(type(epochs_list[0])) | ||||||||||||||||||||||
| out.__setstate__(state) | ||||||||||||||||||||||
| return out | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
|
|
||||||||||||||||||||||
| @verbose | ||||||||||||||||||||||
| def concatenate_epochs( | ||||||||||||||||||||||
| epochs_list, add_offset=True, *, on_mismatch="raise", verbose=None | ||||||||||||||||||||||
|
|
@@ -4732,6 +4804,8 @@ def concatenate_epochs( | |||||||||||||||||||||
| ----- | ||||||||||||||||||||||
| .. versionadded:: 0.9.0 | ||||||||||||||||||||||
| """ | ||||||||||||||||||||||
| if epochs_list and type(epochs_list[0]).__name__ == "EpochsTFR": | ||||||||||||||||||||||
| return _concatenate_epochs_tfr(epochs_list, add_offset=add_offset) | ||||||||||||||||||||||
| ( | ||||||||||||||||||||||
| info, | ||||||||||||||||||||||
| data, | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.