Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/spikeinterface/extractors/cbin_ibl.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ def __init__(
):
from neo.rawio.spikeglxrawio import read_meta_file

if Path(folder_path).is_file():
folder_path = Path(folder_path).parent
Comment on lines +52 to +53
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need this? :)

One can either pass the root folder or a direct path the the cbin file with the cbin_file_path arg

try:
import mtscomp
except ImportError:
Expand Down
4 changes: 2 additions & 2 deletions src/spikeinterface/preprocessing/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ class FilterRecording(BasePreprocessor):
def __init__(
self,
recording,
band=[300.0, 6000.0],
band=(300.0, 6000.0),
btype="bandpass",
filter_order=5,
ftype="butter",
Expand Down Expand Up @@ -370,7 +370,7 @@ def __init__(self, recording, freq=3000, q=30, margin_ms="auto", dtype=None, **f
def causal_filter(
recording,
direction="forward",
band=[300.0, 6000.0],
band=(300.0, 6000.0),
btype="bandpass",
filter_order=5,
ftype="butter",
Expand Down
11 changes: 7 additions & 4 deletions src/spikeinterface/preprocessing/highpass_spatial_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,7 @@ def get_traces(self, start_frame, end_frame, channel_indices):
traces = traces.copy()

# apply AGC and keep the gains
traces = traces.astype(np.float32)
if self.window is not None:
traces, agc_gains = agc(traces, window=self.window)
else:
Expand Down Expand Up @@ -255,7 +256,7 @@ def get_traces(self, start_frame, end_frame, channel_indices):
# -----------------------------------------------------------------------------------------------


def agc(traces, window, epsilon=1e-8):
def agc(traces, window, epsilon=None):
"""
Automatic gain control
w_agc, gain = agc(w, window_length=.5, si=.002, epsilon=1e-8)
Expand All @@ -268,13 +269,15 @@ def agc(traces, window, epsilon=1e-8):
"""
import scipy.signal

gain = scipy.signal.fftconvolve(np.abs(traces), window[:, None], mode="same", axes=0)
# default value for epsilon is relative to the rms, loosely matching the IBL 1e-8 for an input in Volts
if epsilon is None:
epsilon = np.std(traces - np.mean(traces)) * 0.003
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe the 0.003 should also be an argument? epsilon_factor?


gain += (np.sum(gain, axis=0) * epsilon / traces.shape[0])[np.newaxis, :]
gain = scipy.signal.fftconvolve(np.abs(traces), window[:, None], mode="same", axes=0)

dead_channels = np.sum(gain, axis=0) == 0

traces[:, ~dead_channels] = traces[:, ~dead_channels] / gain[:, ~dead_channels]
traces[:, ~dead_channels] = traces[:, ~dead_channels] / np.maximum(epsilon, gain[:, ~dead_channels])

return traces, gain

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import numpy as np
from copy import deepcopy

import spikeinterface as si
import spikeinterface.full as si
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why switch from core to full. This is super heavy and probably not necessary right? This is a bit of a poorly explained part of our importing.

import spikeinterface as si

is equivalent to

import spikeinterface.core as si

which is a faster import.

import spikeinterface.preprocessing as spre
import spikeinterface.extractors as se
from spikeinterface.core import generate_recording
Expand All @@ -24,7 +24,7 @@


@pytest.mark.skipif(
importlib.util.find_spec("neurodsp") is None or importlib.util.find_spec("spikeglx") is None or ON_GITHUB,
importlib.util.find_spec("ibldsp") is None or importlib.util.find_spec("spikeglx") is None or ON_GITHUB,
reason="Only local. Requires ibl-neuropixel install",
)
@pytest.mark.parametrize("lagc", [False, 1, 300])
Expand All @@ -51,32 +51,28 @@ def test_highpass_spatial_filter_real_data(lagc):
use DEBUG = true to visualise.

"""
import spikeglx
import neurodsp.voltage as voltage
import ibldsp.voltage
import neuropixel

options = dict(lagc=lagc, ntr_pad=25, ntr_tap=50, butter_kwargs=None)
print(options)

ibl_data, si_recording = get_ibl_si_data()

si_filtered, _ = run_si_highpass_filter(si_recording, **options)
local_path = si.download_dataset(remote_path="spikeglx/Noise4Sam_g0")
si_recording = se.read_spikeglx(local_path, stream_id="imec0.ap")
si_recording = spre.astype(si_recording, "float")
recording_ps = si.phase_shift(si_recording)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
recording_ps = si.phase_shift(si_recording)
recording_ps = spre.phase_shift(si_recording)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

following @zm711 comment about not importing full!

recording_hp = si.highpass_filter(recording_ps, freq_min=300, filter_order=3)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
recording_hp = si.highpass_filter(recording_ps, freq_min=300, filter_order=3)
recording_hp = spre.highpass_filter(recording_ps, freq_min=300, filter_order=3)

recording_hps = si.highpass_spatial_filter(recording_hp)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
recording_hps = si.highpass_spatial_filter(recording_hp)
recording_hps = spre.highpass_spatial_filter(recording_hp)

raw = si_recording.get_traces().astype(np.float32).T * neuropixel.S2V_AP
si_filtered = recording_hps.get_traces().astype(np.float32).T * neuropixel.S2V_AP

ibl_filtered = run_ibl_highpass_filter(ibl_data.copy(), **options)
destripe = ibldsp.voltage.destripe(raw, fs=30_000, neuropixel_version=1)

if DEBUG:
fig, axs = plt.subplots(ncols=4)
axs[0].imshow(si_recording.get_traces(return_in_uV=True))
axs[0].set_title("SI Raw")
axs[1].imshow(ibl_data.T)
axs[1].set_title("IBL Raw")
axs[2].imshow(si_filtered)
axs[2].set_title("SI Filtered ")
axs[3].imshow(ibl_filtered)
axs[3].set_title("IBL Filtered")
from viewephys.gui import viewephys

eqc = {}
eqc["si_filtered"] = viewephys(si_filtered, fs=30_000, title="si_filtered")
eqc["ibl_filtered"] = viewephys(destripe, fs=30_000, title="ibl_filtered")

assert np.allclose(
si_filtered, ibl_filtered * 1e6, atol=1e-01, rtol=0
) # the differences are entired due to scaling on data load.
np.testing.assert_allclose(si_filtered[12:120, 300:800], destripe[12:120, 300:800], atol=1e-05, rtol=0)


@pytest.mark.parametrize("ntr_pad", [None, 0, 31])
Expand Down Expand Up @@ -140,24 +136,6 @@ def test_dtype_stability(dtype):
# ----------------------------------------------------------------------------------------------------------------------


def get_ibl_si_data():
"""
Set fixture to session to ensure origional data is not changed.
"""
import spikeglx

local_path = si.download_dataset(remote_path="spikeglx/Noise4Sam_g0")
ibl_recording = spikeglx.Reader(
local_path / "Noise4Sam_g0_imec0" / "Noise4Sam_g0_t0.imec0.ap.bin", ignore_warnings=True
)
ibl_data = ibl_recording.read(slice(None), slice(None), sync=False)[:, :-1].T # cut sync channel

si_recording = se.read_spikeglx(local_path, stream_id="imec0.ap")
si_recording = spre.astype(si_recording, dtype="float32")

return ibl_data, si_recording


def process_args_for_si(si_recording, lagc):
""""""
if isinstance(lagc, bool) and not lagc:
Expand Down Expand Up @@ -215,9 +193,10 @@ def run_si_highpass_filter(si_recording, ntr_pad, ntr_tap, lagc, butter_kwargs,


def run_ibl_highpass_filter(ibl_data, ntr_pad, ntr_tap, lagc, butter_kwargs):
butter_kwargs, ntr_pad, lagc = process_args_for_ibl(butter_kwargs, ntr_pad, lagc)
import ibldsp.voltage

ibl_filtered = voltage.kfilt(ibl_data, None, ntr_pad, ntr_tap, lagc, butter_kwargs).T
butter_kwargs, ntr_pad, lagc = process_args_for_ibl(butter_kwargs, ntr_pad, lagc)
ibl_filtered = ibldsp.voltage.kfilt(ibl_data, None, ntr_pad, ntr_tap, lagc, butter_kwargs).T

return ibl_filtered

Expand Down