Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 17 additions & 5 deletions neo/io/nwbio.py
Original file line number Diff line number Diff line change
Expand Up @@ -561,6 +561,9 @@ def _write_signal(self, nwbfile, signal, electrodes):
additional_metadata["conversion"] = conversion
else:
units = signal.units
if hasattr(signal, 'proxy_for') and signal.proxy_for in [AnalogSignal,
IrregularlySampledSignal]:
signal = signal.load()
if isinstance(signal, AnalogSignal):
sampling_rate = signal.sampling_rate.rescale("Hz")
tS = timeseries_class(
Expand Down Expand Up @@ -597,20 +600,26 @@ def _write_signal(self, nwbfile, signal, electrodes):
return tS

def _write_spiketrain(self, nwbfile, spiketrain):
segment = spiketrain.segment
if hasattr(spiketrain, 'proxy_for') and spiketrain.proxy_for is SpikeTrain:
spiketrain = spiketrain.load()
nwbfile.add_unit(spike_times=spiketrain.rescale('s').magnitude,
obs_intervals=[[float(spiketrain.t_start.rescale('s')),
float(spiketrain.t_stop.rescale('s'))]],
_name=spiketrain.name,
# _description=spiketrain.description,
segment=spiketrain.segment.name,
block=spiketrain.segment.block.name)
segment=segment.name,
block=segment.block.name)
# todo: handle annotations (using add_unit_column()?)
# todo: handle Neo Units
# todo: handle spike waveforms, if any (see SpikeEventSeries)
return nwbfile.units

def _write_event(self, nwbfile, event):
hierarchy = {'block': event.segment.block.name, 'segment': event.segment.name}
segment = event.segment
if hasattr(event, 'proxy_for') and event.proxy_for == Event:
event = event.load()
hierarchy = {'block': segment.block.name, 'segment': segment.name}
tS_evt = AnnotationSeries(
name=event.name,
data=event.labels,
Expand All @@ -621,13 +630,16 @@ def _write_event(self, nwbfile, event):
return tS_evt

def _write_epoch(self, nwbfile, epoch):
segment = epoch.segment
if hasattr(epoch, 'proxy_for') and epoch.proxy_for == Epoch:
epoch = epoch.load()
for t_start, duration, label in zip(epoch.rescale('s').magnitude,
epoch.durations.rescale('s').magnitude,
epoch.labels):
nwbfile.add_epoch(t_start, t_start + duration, [label], [],
_name=epoch.name,
segment=epoch.segment.name,
block=epoch.segment.block.name)
segment=segment.name,
block=segment.block.name)
return nwbfile.epochs


Expand Down
59 changes: 59 additions & 0 deletions neo/test/iotest/test_nwbio.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
from neo.core import AnalogSignal, SpikeTrain, Event, Epoch, IrregularlySampledSignal, Segment, \
Block

from neo.rawio.examplerawio import ExampleRawIO
from neo.io.proxyobjects import (AnalogSignalProxy, SpikeTrainProxy, EventProxy, EpochProxy)

try:
import pynwb
from neo.io.nwbio import NWBIO
Expand Down Expand Up @@ -250,6 +253,62 @@ def test_roundtrip_with_annotations(self):

os.remove(test_file_name)

def test_write_proxy_objects(self):
test_file_name = self.local_test_dir / "test_round_trip_with_annotations.nwb"

# generate dummy IO as basis for ProxyObjects
self.proxy_reader = ExampleRawIO(filename='my_filename.fake')
self.proxy_reader.parse_header()

# generate test structure with proxy objects
original_block = Block(name='myblock', session_start_time=datetime.now().astimezone(),
session_description=str(test_file_name),
identifier=str(test_file_name))
seg = Segment(name='mysegment')
original_block.segments.append(seg)

# create proxy objects
proxy_anasig = AnalogSignalProxy(rawio=self.proxy_reader, stream_index=0,
inner_stream_channels=None, block_index=0, seg_index=0,)
proxy_anasig.segment = seg
seg.analogsignals.append(proxy_anasig)

proxy_sptr = SpikeTrainProxy(rawio=self.proxy_reader, spike_channel_index=0, block_index=0,
seg_index=0)
proxy_sptr.segment = seg
seg.spiketrains.append(proxy_sptr)

proxy_event = EventProxy(rawio=self.proxy_reader, event_channel_index=0, block_index=0,
seg_index=0)
proxy_event.segment = seg
seg.events.append(proxy_event)

proxy_epoch = EpochProxy(rawio=self.proxy_reader, event_channel_index=1, block_index=0,
seg_index=0)
proxy_epoch.segment = seg
seg.epochs.append(proxy_epoch)

original_block.create_relationship()

iow = NWBIO(filename=test_file_name, mode='w')

# writing data via proxyobjects
iow.write_all_blocks([original_block])

# checking written data
ior = NWBIO(filename=test_file_name, mode='r')
retrieved_block = ior.read_all_blocks()[0]

for original_segment, retrieved_segment in zip(original_block.segments,
retrieved_block.segments):
assert_array_equal(original_segment.analogsignals[0].load().magnitude,
retrieved_segment.analogsignals[0].magnitude)
assert_array_equal(original_segment.spiketrains[0].load().magnitude,
retrieved_segment.spiketrains[0].magnitude)
assert_array_equal(original_segment.events[0].load().magnitude,
retrieved_segment.events[0].magnitude)
assert_array_equal(original_segment.epochs[0].load().magnitude,
retrieved_segment.epochs[0].magnitude)

if __name__ == "__main__":
if HAVE_PYNWB:
Expand Down