Skip to content

Commit 16d4089

Browse files
committed
Fix t_starts not propagated to save memory.
1 parent 4539550 commit 16d4089

2 files changed

Lines changed: 8 additions & 9 deletions

File tree

src/spikeinterface/core/baserecording.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -545,11 +545,11 @@ def _save(self, format="binary", verbose: bool = False, **save_kwargs):
545545
if kwargs.get("sharedmem", True):
546546
from .numpyextractors import SharedMemoryRecording
547547

548-
cached = SharedMemoryRecording.from_recording(self, **job_kwargs)
548+
cached = SharedMemoryRecording.from_recording(self, t_starts=t_starts, **job_kwargs)
549549
else:
550550
from spikeinterface.core import NumpyRecording
551551

552-
cached = NumpyRecording.from_recording(self, **job_kwargs)
552+
cached = NumpyRecording.from_recording(self, t_starts=t_starts, **job_kwargs)
553553

554554
elif format == "zarr":
555555
from .zarrextractors import ZarrRecordingExtractor

src/spikeinterface/core/numpyextractors.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ def __init__(self, traces_list, sampling_frequency, t_starts=None, channel_ids=N
8585
}
8686

8787
@staticmethod
88-
def from_recording(source_recording, **job_kwargs):
88+
def from_recording(source_recording, t_starts=None, **job_kwargs):
8989
traces_list, shms = write_memory_recording(source_recording, dtype=None, **job_kwargs)
9090
if shms[0] is not None:
9191
# if the computation was done in parallel then traces_list is shared array
@@ -95,13 +95,14 @@ def from_recording(source_recording, **job_kwargs):
9595
for shm in shms:
9696
shm.close()
9797
shm.unlink()
98-
# TODO later : propagte t_starts ?
98+
9999
recording = NumpyRecording(
100100
traces_list,
101101
source_recording.get_sampling_frequency(),
102-
t_starts=None,
102+
t_starts=t_starts,
103103
channel_ids=source_recording.channel_ids,
104104
)
105+
return recording
105106

106107

107108
class NumpyRecordingSegment(BaseRecordingSegment):
@@ -211,18 +212,16 @@ def __del__(self):
211212
shm.unlink()
212213

213214
@staticmethod
214-
def from_recording(source_recording, **job_kwargs):
215+
def from_recording(source_recording, t_starts=None, **job_kwargs):
215216
traces_list, shms = write_memory_recording(source_recording, buffer_type="sharedmem", **job_kwargs)
216217

217-
# TODO later : propagte t_starts ?
218-
219218
recording = SharedMemoryRecording(
220219
shm_names=[shm.name for shm in shms],
221220
shape_list=[traces.shape for traces in traces_list],
222221
dtype=source_recording.dtype,
223222
sampling_frequency=source_recording.sampling_frequency,
224223
channel_ids=source_recording.channel_ids,
225-
t_starts=None,
224+
t_starts=t_starts,
226225
main_shm_owner=True,
227226
)
228227

0 commit comments

Comments
 (0)