Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 70 additions & 20 deletions doc/modules/curation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ Curation format
SpikeInterface internally supports a JSON-based manual curation format.
When manual curation is necessary, modifying a dataset in place is a bad practice.
Instead, to ensure the reproducibility of the spike sorting pipelines, we have introduced a simple and JSON-based manual curation format.
This format defines at the moment : merges + deletions + manual tags.
This format defines at the moment : manual labelling, removal, merging, splitting and removal of spikes from a unit.
The simple file can be kept along side the output of a sorter and applied on the result to have a "clean" result.

This format has two part:
Expand All @@ -216,21 +216,26 @@ This format has two part:
* "format_version" : format specification
* "unit_ids" : the list of unit_ds
* "label_definitions" : list of label categories and possible labels per category.
Every category can be *exclusive=True* onely one label or *exclusive=False* several labels possible
If a unit can only have one label, the category can be set to be *exclusive=True*. If several labels can be used at once, the category can be set to be *exclusive=False*.

* **manual output** curation with the folowing keys:

* "manual_labels"
* "merge_unit_groups"
* "removed_units"
* "merges"
* "removed"
* "splits"
* "discard_spikes"

Here is the description of the format with a simple example (the first part of the
format is the definition; the second part of the format is manual action):
The first three ("manual_labels", "merges" and "removed") act at the unit level. They label, merge or remove whole units. While the final two
("splits" and "discard_spikes") act at the spike level: we need to define which spikes from a unit are being split into a new unit, or which
spikes from a unit are to be discarded. Note that all spike indices are with respect to the original analyzer.

Here is a simple example of the format:

.. code-block:: json

{
"format_version": "1",
"format_version": "3",
"unit_ids": [
"u1",
"u2",
Expand Down Expand Up @@ -266,25 +271,31 @@ format is the definition; the second part of the format is manual action):
"manual_labels": [
{
"unit_id": "u1",
"quality": [
"good"
]
"labels": {
"quality": [
"good"
]
}
},
{
"unit_id": "u2",
"quality": [
"noise"
],
"putative_type": [
"excitatory",
"pyramidal"
]
"labels": {
"quality": [
"noise"
],
"putative_type": [
"excitatory",
"pyramidal"
]
}
},
{
"unit_id": "u3",
"putative_type": [
"inhibitory"
]
"labels": {
"putative_type": [
"inhibitory"
]
}
}
],
"merge_unit_groups": [
Expand All @@ -301,9 +312,48 @@ format is the definition; the second part of the format is manual action):
"removed_units": [
"u31",
"u42"
],
"splits": [
{
"unit_id": "u1",
"mode": "indices",
"indices": [
[
10,
20,
30
]
],
"new_unit_ids": [
"u1-1",
"u1-2"
]
}
],
"discard_spikes": [
{
"unit_id": "u10",
"indices": [
56,
57,
59,
60
]
},
{
"unit_id": "u14",
"indices": [
123,
321
]
}
]
}

Note that you cannot split and merge a unit at the same time.

We do not expect users to create their own curation json files. Instead, our internal curation algorithms will output
results which can be easily transformed into the format. We also hope that external packages can use our format.

The curation format can be loaded into a dictionary and directly applied to
a ``BaseSorting`` or ``SortingAnalyzer`` object using the :py:func:`~spikeinterface.curation.apply_curation` function.
Expand Down
114 changes: 78 additions & 36 deletions src/spikeinterface/curation/curation_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
from pathlib import Path
import json
import numpy as np
from itertools import chain

from spikeinterface.core import BaseSorting, SortingAnalyzer, apply_merges_to_sorting, apply_splits_to_sorting
from spikeinterface.curation.curation_model import CurationModel
from spikeinterface.core.sorting_tools import generate_unit_ids_for_split


def validate_curation_dict(curation_dict: dict):
Expand Down Expand Up @@ -153,9 +153,9 @@ def apply_curation(
Steps are done in this order:

1. Apply labels using curation_dict["manual_labels"]
2. Apply removal using curation_dict["removed"]
3. Apply merges using curation_dict["merges"]
4. Apply splits using curation_dict["splits"]
2. Remove whole units using curation_dict["removed"]
3. Apply splits using curation_dict["splits"] and remove spikes from units using curation_dict["discard_spikes"]
4. Apply merges using curation_dict["merges"]

A new Sorting or SortingAnalyzer (in memory) is returned.
The user (an adult) has the responsability to save it somewhere (or not).
Expand Down Expand Up @@ -218,7 +218,80 @@ def apply_curation(
else:
curated_sorting_or_analyzer = sorting_or_analyzer

# 3. Merge units
# 3. Split and discard spikes from units
# Do this at the same time, otherwise have to do a lot of spike index shuffling.
# Strategy: put the discarded spikes in a new unit when splitting, then remove them at the end.
if len(curation_model.splits) > 0 or len(curation_model.discard_spikes) > 0:
if len(curation_model.splits) > 0:
split_spikes_unit_ids = [split.unit_id for split in curation_model.splits]
if len(curation_model.discard_spikes) > 0:
discard_spikes_unit_ids = [discard_spike.unit_id for discard_spike in curation_model.discard_spikes]

split_units = {}

sorting = (
curated_sorting_or_analyzer if isinstance(sorting_or_analyzer, BaseSorting) else sorting_or_analyzer.sorting
)

for unit_id in curation_model.unit_ids:

if unit_id in split_spikes_unit_ids:

split_spikes_arg = np.where(np.array(split_spikes_unit_ids) == unit_id)[0][0]
split = curation_model.splits[split_spikes_arg]
split_units[unit_id] = split.get_full_spike_indices(sorting)

# If the unit is not split, but does contain spikes to discard, make an initial "split"
# unit containing the full spike train.
elif unit_id in discard_spikes_unit_ids and unit_id not in split_spikes_unit_ids:

split_units[unit_id] = [sorting.get_unit_spike_train(unit_id)]

# Now find all spikes which are in discard_spikes, and remove them from the units-to-split.
# Put the discarded spikes in their own split-unit, at the end of the list of split units.
if unit_id in discard_spikes_unit_ids:

discard_spikes_arg = np.where(np.array(discard_spikes_unit_ids) == unit_id)[0][0]
discard_spikes = np.array(curation_model.discard_spikes[discard_spikes_arg].indices)

split_units_with_discard = []
for split_spike_train in split_units[split.unit_id]:
split_spike_train_cleaned = np.setdiff1d(split_spike_train, discard_spikes, assume_unique=True)
split_units_with_discard.append(split_spike_train_cleaned)
split_units_with_discard.append(discard_spikes)
split_units[unit_id] = split_units_with_discard

split_new_unit_ids = [s.new_unit_ids for s in curation_model.splits if s.new_unit_ids is not None]
unit_ids_to_discard = []

# We need to know which units to remove, so need control of the new unit ids here
if len(split_new_unit_ids) == 0:
split_new_unit_ids = None
new_unit_ids = generate_unit_ids_for_split(
sorting.unit_ids, split_units, new_unit_ids=None, new_id_strategy=new_id_strategy
)
for old_unit_id, new_unit_id_list in zip(split_units.keys(), new_unit_ids):
if old_unit_id in discard_spikes_unit_ids:
unit_ids_to_discard.append(new_unit_id_list[-1])

if isinstance(sorting_or_analyzer, BaseSorting):
curated_sorting_or_analyzer = apply_splits_to_sorting(
curated_sorting_or_analyzer,
split_units,
new_unit_ids=split_new_unit_ids,
)
else:
curated_sorting_or_analyzer = curated_sorting_or_analyzer.split_units(
split_units,
new_id_strategy=new_id_strategy,
new_unit_ids=split_new_unit_ids,
format="memory",
verbose=verbose,
)
if len(unit_ids_to_discard) > 0:
curated_sorting_or_analyzer = sorting_or_analyzer.remove_units(unit_ids_to_discard)

# 4. Merge units
if len(curation_model.merges) > 0:
merge_unit_groups = [m.unit_ids for m in curation_model.merges]
merge_new_unit_ids = [m.new_unit_id for m in curation_model.merges if m.new_unit_id is not None]
Expand Down Expand Up @@ -246,37 +319,6 @@ def apply_curation(
**job_kwargs,
)

# 4. Split units
if len(curation_model.splits) > 0:
split_units = {}
for split in curation_model.splits:
sorting = (
curated_sorting_or_analyzer
if isinstance(sorting_or_analyzer, BaseSorting)
else sorting_or_analyzer.sorting
)
split_units[split.unit_id] = split.get_full_spike_indices(sorting)
split_new_unit_ids = [s.new_unit_ids for s in curation_model.splits if s.new_unit_ids is not None]
if len(split_new_unit_ids) == 0:
split_new_unit_ids = None
if isinstance(sorting_or_analyzer, BaseSorting):
curated_sorting_or_analyzer, _ = apply_splits_to_sorting(
curated_sorting_or_analyzer,
split_units,
new_unit_ids=split_new_unit_ids,
new_id_strategy=new_id_strategy,
return_extra=True,
)
else:
curated_sorting_or_analyzer, _ = curated_sorting_or_analyzer.split_units(
split_units,
new_id_strategy=new_id_strategy,
return_new_unit_ids=True,
new_unit_ids=split_new_unit_ids,
format="memory",
verbose=verbose,
)

return curated_sorting_or_analyzer


Expand Down
Loading
Loading