Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .circleci/requirements_testing.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,4 @@ https://github.com/nsdf/nsdf/archive/0.1.tar.gz
coverage
coveralls
pillow
tridesclous>=1.6.1
7 changes: 7 additions & 0 deletions neo/io/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
* :attr:`SpikeGLXIO`
* :attr:`StimfitIO`
* :attr:`TdtIO`
* :attr:`TridesclousIO`
* :attr:`TiffIO`
* :attr:`WinEdrIO`
* :attr:`WinWcpIO`
Expand Down Expand Up @@ -213,6 +214,10 @@

.. autoattribute:: extensions

.. autoclass:: neo.io.TridesclousIO

.. autoattribute:: extensions

.. autoclass:: neo.io.TiffIO

.. autoattribute:: extensions
Expand Down Expand Up @@ -286,6 +291,7 @@
from neo.io.stimfitio import StimfitIO
from neo.io.tdtio import TdtIO
from neo.io.tiffio import TiffIO
from neo.io.tridesclousio import TridesclousIO
from neo.io.winedrio import WinEdrIO
from neo.io.winwcpio import WinWcpIO

Expand Down Expand Up @@ -331,6 +337,7 @@
StimfitIO,
TdtIO,
TiffIO,
TridesclousIO,
WinEdrIO,
WinWcpIO
]
Expand Down
11 changes: 11 additions & 0 deletions neo/io/tridesclousio.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from neo.io.basefromrawio import BaseFromRaw
from neo.rawio.tridesclousrawio import TridesclousRawIO


class TridesclousIO(TridesclousRawIO, BaseFromRaw):
__doc__ = TridesclousRawIO.__doc__
mode = 'dir'

def __init__(self, dirname):
TridesclousRawIO.__init__(self, dirname=dirname)
BaseFromRaw.__init__(self, dirname)
7 changes: 7 additions & 0 deletions neo/rawio/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
* :attr:`Spike2RawIO`
* :attr:`SpikeGLXRawIO`
* :attr:`TdtRawIO`
* :attr:`TridesclousRawIO`
* :attr:`WinEdrRawIO`
* :attr:`WinWcpRawIO`

Expand Down Expand Up @@ -116,6 +117,10 @@

.. autoattribute:: extensions

.. autoclass:: neo.rawio.TridesclousRawIO

.. autoattribute:: extensions

.. autoclass:: neo.rawio.WinEdrRawIO

.. autoattribute:: extensions
Expand Down Expand Up @@ -148,6 +153,7 @@
from neo.rawio.spike2rawio import Spike2RawIO
from neo.rawio.spikeglxrawio import SpikeGLXRawIO
from neo.rawio.tdtrawio import TdtRawIO
from neo.rawio.tridesclousrawio import TridesclousRawIO
from neo.rawio.winedrrawio import WinEdrRawIO
from neo.rawio.winwcprawio import WinWcpRawIO

Expand All @@ -172,6 +178,7 @@
Spike2RawIO,
SpikeGLXRawIO,
TdtRawIO,
TridesclousRawIO,
WinEdrRawIO,
WinWcpRawIO,
]
Expand Down
155 changes: 155 additions & 0 deletions neo/rawio/tridesclousrawio.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
"""


"""

from .baserawio import (BaseRawIO, _signal_channel_dtype, _unit_channel_dtype,
_event_channel_dtype)

import numpy as np

from pathlib import Path


class TridesclousRawIO(BaseRawIO):
"""

"""
extensions = []
rawmode = 'one-dir'

def __init__(self, dirname='', chan_grp=None):
BaseRawIO.__init__(self)
self.dirname = dirname
self.chan_grp = chan_grp

def _source_name(self):
return self.dirname

def _parse_header(self):
try:
import tridesclous as tdc
except ImportError:
print('tridesclous is not installed')

tdc_folder = Path(self.dirname)

tdc_dataio = tdc.DataIO(str(self.dirname))
chan_grp = self.chan_grp
if chan_grp is None:
# if chan_grp is not provided, take the first one if unique
chan_grps = list(tdc_dataio.channel_groups.keys())
assert len(chan_grps) == 1, 'There are several groups in the folder, specify chan_grp=...'
chan_grp = chan_grps[0]

self._sampling_rate = float(tdc_dataio.sample_rate)
catalogue = tdc_dataio.load_catalogue(name='initial', chan_grp=chan_grp)

labels = catalogue['clusters']['cluster_label']
labels = labels[labels >= 0]
self.unit_labels = labels

nb_segment = tdc_dataio.nb_segment

self._all_spikes = []
for seg_index in range(nb_segment):
self._all_spikes.append(tdc_dataio.get_spikes(seg_num=seg_index,
chan_grp=chan_grp, i_start=None, i_stop=None).copy())

self._sampling_rate = tdc_dataio.sample_rate
sr = self._sampling_rate

self._t_starts = [0.] * nb_segment
self._t_stops = [tdc_dataio.segment_shapes[s][0]/sr
for s in range(nb_segment)]

sig_channels = []
sig_channels = np.array(sig_channels, dtype=_signal_channel_dtype)

unit_channels = []
for unit_index, unit_label in enumerate(labels):
unit_name = f'unit{unit_index} #{unit_label}'
unit_id = f'{unit_label}'
wf_units = ''
wf_gain = 0
wf_offset = 0.
wf_left_sweep = 0
wf_sampling_rate = 0
unit_channels.append((unit_name, unit_id, wf_units, wf_gain,
wf_offset, wf_left_sweep, wf_sampling_rate))
unit_channels = np.array(unit_channels, dtype=_unit_channel_dtype)

event_channels = []
event_channels = np.array(event_channels, dtype=_event_channel_dtype)

# fille into header dict
# This is mandatory!!!!!
self.header = {}
self.header['nb_block'] = 1
self.header['nb_segment'] = [nb_segment]
self.header['signal_channels'] = sig_channels
self.header['unit_channels'] = unit_channels
self.header['event_channels'] = event_channels

self._generate_minimal_annotations()

def _segment_t_start(self, block_index, seg_index):
assert block_index == 0
return self._t_starts[seg_index]

def _segment_t_stop(self, block_index, seg_index):
assert block_index == 0
return self._t_stops[seg_index]

def _get_signal_size(self, block_index, seg_index, channel_indexes=None):
return None

def _get_signal_t_start(self, block_index, seg_index, channel_indexes):
return None

def _get_analogsignal_chunk(self, block_index, seg_index, i_start, i_stop, channel_indexes):
return None

def _spike_count(self, block_index, seg_index, unit_index):
assert block_index == 0
spikes = self._all_spikes[seg_index]
unit_label = self.unit_labels[unit_index]
mask = spikes['cluster_label'] == unit_label
nb_spikes = np.sum(mask)
return nb_spikes

def _get_spike_timestamps(self, block_index, seg_index, unit_index, t_start, t_stop):
assert block_index == 0
unit_label = self.unit_labels[unit_index]
spikes = self._all_spikes[seg_index]
mask = spikes['cluster_label'] == unit_label
spike_timestamps = spikes['index'][mask].copy()

if t_start is not None:
start_frame = int(t_start * self._sampling_rate)
spike_timestamps = spike_timestamps[spike_timestamps >= start_frame]
if t_stop is not None:
end_frame = int(t_stop * self._sampling_rate)
spike_timestamps = spike_timestamps[spike_timestamps < end_frame]

return spike_timestamps

def _rescale_spike_timestamp(self, spike_timestamps, dtype):
spike_times = spike_timestamps.astype(dtype)
spike_times /= self._sampling_rate
return spike_times

def _get_spike_raw_waveforms(self, block_index, seg_index, unit_index, t_start, t_stop):
return None

def _event_count(self, block_index, seg_index, event_channel_index):
return None

def _get_event_timestamps(self, block_index, seg_index, event_channel_index, t_start, t_stop):
return None

def _rescale_event_timestamp(self, event_timestamps, dtype):
return None

def _rescale_epoch_duration(self, raw_duration, dtype):
return None
19 changes: 19 additions & 0 deletions neo/test/iotest/test_tridesclousio.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
"""

"""

import unittest

from neo.io import TridesclousIO
from neo.test.iotest.common_io_test import BaseTestIO
from neo.test.rawiotest.test_tridesclousrawio import TestTrisdesclousRawIO


class TestTridesclousIO(BaseTestIO, unittest.TestCase):
files_to_test = TestTrisdesclousRawIO.entities_to_test
files_to_download = TestTrisdesclousRawIO.files_to_download
ioclass = TridesclousIO


if __name__ == "__main__":
unittest.main()
27 changes: 27 additions & 0 deletions neo/test/rawiotest/test_tridesclousrawio.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
"""
Tests of neo.rawio.spikeglxrawio
"""

import unittest

from neo.rawio.tridesclousrawio import TridesclousRawIO
from neo.test.rawiotest.common_rawio_test import BaseTestRawIO


class TestTrisdesclousRawIO(BaseTestRawIO, unittest.TestCase):
rawioclass = TridesclousRawIO
files_to_download = [
'tdc_example0/info.json',
'tdc_example0/probe.prb',
'tdc_example0/channel_group_0/segment_0/arrays.json',
'tdc_example0/channel_group_0/segment_0/spikes.raw',
'tdc_example0/channel_group_0/segment_0/processed_signals.raw',
'tdc_example0/channel_group_0/catalogues/initial/arrays.json',
'tdc_example0/channel_group_0/catalogues/initial/catalogue.pickle',
'tdc_example0/channel_group_0/catalogues/initial/clusters.raw',
]
entities_to_test = ['tdc_example0']


if __name__ == "__main__":
unittest.main()