Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/spikeinterface/preprocessing/remove_artifacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,12 +274,12 @@ def get_traces(self, start_frame, end_frame, channel_indices):
if pad is None:
traces[trig, :] = 0
else:
if trig - pad[0] > 0 and trig + pad[1] < end_frame - start_frame:
if trig - pad[0] >= 0 and trig + pad[1] < end_frame - start_frame:
traces[trig - pad[0] : trig + pad[1] + 1, :] = 0
elif trig - pad[0] <= 0 and trig + pad[1] >= end_frame - start_frame:
traces[:] = 0
elif trig - pad[0] <= 0:
traces[: trig + pad[1], :] = 0
traces[: trig + pad[1] + 1, :] = 0
elif trig + pad[1] >= end_frame - start_frame:
traces[trig - pad[0] :, :] = 0
elif self.mode in ["linear", "cubic"]:
Expand Down
42 changes: 42 additions & 0 deletions src/spikeinterface/preprocessing/tests/test_remove_artifacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,5 +85,47 @@ def test_remove_artifacts():
)


def test_remove_artifacts_zeros_boundary():
"""
Regression test for off-by-one bug in zeros mode boundary handling.

When get_traces is called with a window wider than the artifact (so the artifact
start aligns with or precedes the window start), the last sample of the artifact
period must also be zeroed. Previously, ``traces[:trig + pad[1]]`` was used
instead of ``traces[:trig + pad[1] + 1]``, leaving the very last artifact sample
un-zeroed.
"""
rec = generate_recording(durations=[10.0])
rec.annotate(is_filtered=True)

trigger = 15000
ms = 10
fs = rec.get_sampling_frequency()
ms_frames = int(ms * fs / 1000)

rec_rmart = remove_artifacts(rec, [trigger], ms_before=ms, ms_after=ms)

# Request a window that starts exactly at trigger - ms_frames (so trig - pad[0] == 0)
# but ends *beyond* trigger + ms_frames. This exercises the boundary branch where
# `trig - pad[0] <= 0` but the artifact end lies strictly inside the chunk.
extra = 10
traces = rec_rmart.get_traces(
start_frame=trigger - ms_frames,
end_frame=trigger + ms_frames + extra,
)

# The artifact window [trigger - ms_frames, trigger + ms_frames] must be all zeros,
# including the last sample at index 2 * ms_frames (previously missed).
zeroed_artifact_traces = traces[: 2 * ms_frames + 1, :]
assert not np.any(zeroed_artifact_traces), (
"Last sample of artifact window was not zeroed (off-by-one boundary bug)"
)

# Samples beyond the artifact window should be non-zero (from the underlying recording)
beyond_artifact = traces[2 * ms_frames + 1 :, :]
assert np.any(beyond_artifact), "Samples beyond the artifact window should not be zero"


if __name__ == "__main__":
test_remove_artifacts()
test_remove_artifacts_zeros_boundary()