Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions neo/io/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
* :attr:`Spike2IO`
* :attr:`SpikeGadgetsIO`
* :attr:`SpikeGLXIO`
* :attr:`SpykingCircusIO`
* :attr:`StimfitIO`
* :attr:`TdtIO`
* :attr:`TiffIO`
Expand Down Expand Up @@ -230,6 +231,10 @@

.. autoattribute:: extensions

. autoclass:: SpykingCircusIO

.. autoattribute:: extensions

.. autoclass:: neo.io.StimfitIO

.. autoattribute:: extensions
Expand Down Expand Up @@ -313,12 +318,14 @@
from neo.io.spike2io import Spike2IO
from neo.io.spikegadgetsio import SpikeGadgetsIO
from neo.io.spikeglxio import SpikeGLXIO
from neo.io.spykingcircusio import SpykingCircusIO
from neo.io.stimfitio import StimfitIO
from neo.io.tdtio import TdtIO
from neo.io.tiffio import TiffIO
from neo.io.winedrio import WinEdrIO
from neo.io.winwcpio import WinWcpIO


iolist = [
AlphaOmegaIO,
AsciiImageIO,
Expand Down Expand Up @@ -363,6 +370,7 @@
Spike2IO,
SpikeGadgetsIO,
SpikeGLXIO,
SpykingCircusIO,
StimfitIO,
TdtIO,
TiffIO,
Expand Down
11 changes: 11 additions & 0 deletions neo/io/spykingcircusio.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from neo.io.basefromrawio import BaseFromRaw
from neo.rawio.spykingcircusrawio import SpykingCircusRawIO


class SpykingCircusIO(SpykingCircusRawIO, BaseFromRaw):
__doc__ = SpykingCircusRawIO.__doc__
mode = 'dir'

def __init__(self, dirname):
SpykingCircusRawIO.__init__(self, dirname=dirname)
BaseFromRaw.__init__(self, dirname)
8 changes: 8 additions & 0 deletions neo/rawio/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
* :attr:`Spike2RawIO`
* :attr:`SpikeGadgetsRawIO`
* :attr:`SpikeGLXRawIO`
* :attr:`SpykingCircusRawIO`
* :attr:`TdtRawIO`
* :attr:`WinEdrRawIO`
* :attr:`WinWcpRawIO`
Expand Down Expand Up @@ -137,6 +138,10 @@

.. autoattribute:: extensions

.. autoclass:: neo.rawio.SpykingCircusRawIO

.. autoattribute:: extensions

.. autoclass:: neo.rawio.TdtRawIO

.. autoattribute:: extensions
Expand Down Expand Up @@ -177,10 +182,12 @@
from neo.rawio.spike2rawio import Spike2RawIO
from neo.rawio.spikegadgetsrawio import SpikeGadgetsRawIO
from neo.rawio.spikeglxrawio import SpikeGLXRawIO
from neo.rawio.spykingcircusrawio import SpykingCircusRawIO
from neo.rawio.tdtrawio import TdtRawIO
from neo.rawio.winedrrawio import WinEdrRawIO
from neo.rawio.winwcprawio import WinWcpRawIO


rawiolist = [
AxographRawIO,
AxonaRawIO,
Expand All @@ -206,6 +213,7 @@
Spike2RawIO,
SpikeGadgetsRawIO,
SpikeGLXRawIO,
SpykingCircusRawIO,
TdtRawIO,
WinEdrRawIO,
WinWcpRawIO,
Expand Down
185 changes: 185 additions & 0 deletions neo/rawio/spykingcircusrawio.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
"""

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add some comments here to dscribe the reader and link to the project.


"""

from .baserawio import (BaseRawIO, _signal_channel_dtype, _spike_channel_dtype,
_event_channel_dtype)

import numpy as np

from pathlib import Path

try:
import h5py
HAVE_HDF5 = True
except ImportError:
HAVE_HDF5 = False


def _load_sample_rate(params_file):
sample_rate = None
with params_file.open('r') as f:
for r in f.readlines():
if 'sampling_rate' in r:
sample_rate = r.split('=')[-1]
if '#' in sample_rate:
sample_rate = sample_rate[:sample_rate.find('#')]
sample_rate = float(sample_rate)
return sample_rate


class SpykingCircusRawIO(BaseRawIO):
"""
RawIO reader to load results that have been obtained via SpyKING CIRCUS
http://spyking-circus.rtfd.org

You simply need to specify the output folder created by SpyKING CIRCUS where
the results have been stored.
"""
extensions = []
rawmode = 'one-dir'

def __init__(self, dirname=''):
BaseRawIO.__init__(self)
self.dirname = dirname

def _source_name(self):
return self.dirname

def _parse_header(self):
spykingcircus_folder = Path(self.dirname)
listfiles = spykingcircus_folder.iterdir()
results = None
sample_rate = None

parent_folder = None
result_folder = None
for f in listfiles:
if f.is_dir():
if any([f_.suffix == '.hdf5' for f_ in f.iterdir()]):
parent_folder = spykingcircus_folder
result_folder = f

if parent_folder is None:
parent_folder = spykingcircus_folder.parent
for f in parent_folder.iterdir():
if f.is_dir():
if any([f_.suffix == '.hdf5' for f_ in f.iterdir()]):
result_folder = spykingcircus_folder

assert isinstance(parent_folder, Path) and \
isinstance(result_folder, Path), "Not a valid spyking circus folder"

# load files
for f in result_folder.iterdir():
if 'result.hdf5' in str(f):
results = f
if 'result-merged.hdf5' in str(f):
results = f
break

# load params
for f in parent_folder.iterdir():
if f.suffix == '.params':
sample_rate = _load_sample_rate(f)
else:
raise Exception('Can not find the .params file')

if sample_rate is not None:
self._sampling_frequency = sample_rate

if results is None:
raise Exception(spykingcircus_folder, " is not a spyking circus folder")
f_results = h5py.File(results, 'r')

self._all_spikes = []
for temp in f_results['spiketimes'].keys():
self._all_spikes.append(np.array(f_results['spiketimes'][temp]).astype('int64'))

self._kwargs = {'folder_path': str(Path(spykingcircus_folder).absolute())}

sig_channels = []
sig_channels = np.array(sig_channels, dtype=_signal_channel_dtype)

unit_channels = []
for unit_index in range(len(self._all_spikes)):
unit_name = f'unit{unit_index} #{unit_index}'
unit_id = f'{unit_index}'
wf_units = ''
wf_gain = 0
wf_offset = 0.
wf_left_sweep = 0
wf_sampling_rate = 0
unit_channels.append((unit_name, unit_id, wf_units, wf_gain,
wf_offset, wf_left_sweep, wf_sampling_rate))
unit_channels = np.array(unit_channels, dtype=_spike_channel_dtype)

event_channels = []
event_channels = np.array(event_channels, dtype=_event_channel_dtype)

self.header = {}
self.header['nb_block'] = 1
self.header['nb_segment'] = [1]
self.header['signal_channels'] = sig_channels
self.header['unit_channels'] = unit_channels
self.header['event_channels'] = event_channels

self._duration = f_results['info']['duration'][0]
self._generate_minimal_annotations()

def _segment_t_start(self, block_index, seg_index):
return 0.

def _segment_t_stop(self, block_index, seg_index):
return self._duration

def _get_signal_size(self, block_index, seg_index, channel_indexes=None):
return None

def _get_signal_t_start(self, block_index, seg_index, channel_indexes):
return None

def _get_analogsignal_chunk(self, block_index, seg_index, i_start, i_stop, channel_indexes):
return None

def _spike_count(self, block_index, seg_index, unit_index):
nb_spikes = len(self._all_spikes[unit_index])
return nb_spikes

def _get_spike_timestamps(self, block_index, seg_index, unit_index, t_start, t_stop):
assert block_index == 0
assert seg_index == 0

spike_timestamps = self._all_spikes[unit_index].copy()

if t_start is not None:
start_frame = int(t_start * self._sampling_rate)
spike_timestamps = spike_timestamps[spike_timestamps >= start_frame]
if t_stop is not None:
end_frame = int(t_stop * self._sampling_rate)
spike_timestamps = spike_timestamps[spike_timestamps < end_frame]

return spike_timestamps

def _rescale_spike_timestamp(self, spike_timestamps, dtype):
# must rescale to second a particular spike_timestamps
# with a fixed dtype so the user can choose the precisino he want.
spike_times = spike_timestamps.astype(dtype)
spike_times /= self._sampling_frequency # because 10kHz
return spike_times

def _get_spike_raw_waveforms(self, block_index, seg_index, unit_index, t_start, t_stop):
return None

def _event_count(self, block_index, seg_index, event_channel_index):
return None

def _get_event_timestamps(self, block_index, seg_index, event_channel_index, t_start, t_stop):
return None

def _rescale_event_timestamp(self, event_timestamps, dtype):
return None

def _rescale_epoch_duration(self, raw_duration, dtype):
return None
22 changes: 22 additions & 0 deletions neo/test/iotest/test_spykingcircus.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
"""
Tests of neo.rawio.spykingcircusio
"""

import unittest

from neo.io.spykingcircusio import SpykingCircusIO
from neo.test.iotest.common_io_test import BaseTestIO


class TestSpykingCircusIO(BaseTestIO, unittest.TestCase):
files_to_download = [
'spykingcircus/spykingcircus_example0/recording.params',
'spykingcircus/spykingcircus_example0/recording/recording.result.hdf5',
'spykingcircus/spykingcircus_example0/recording/recording.result-merged.hdf5',
]
entities_to_test = ['spykingcircus']
ioclass = SpykingCircusIO


if __name__ == "__main__":
unittest.main()
22 changes: 22 additions & 0 deletions neo/test/rawiotest/test_spykingcircusrawio.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
"""
Tests of neo.rawio.spykingcircusrawio
"""

import unittest

from neo.rawio.spykingcircusrawio import SpykingCircusRawIO
from neo.test.rawiotest.common_rawio_test import BaseTestRawIO


class TestSpykingCircusRawIO(BaseTestRawIO, unittest.TestCase):
rawioclass = SpykingCircusRawIO
files_to_download = [
'spykingcircus/spykingcircus_example0/recording.params',
'spykingcircus/spykingcircus_example0/recording/recording.result.hdf5',
'spykingcircus/spykingcircus_example0/recording/recording.result-merged.hdf5',
]
entities_to_test = ['spykingcircus']


if __name__ == "__main__":
unittest.main()