-
Notifications
You must be signed in to change notification settings - Fork 237
fix bug spatial filter #4175 #4286
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -207,6 +207,7 @@ def get_traces(self, start_frame, end_frame, channel_indices): | |
| traces = traces.copy() | ||
|
|
||
| # apply AGC and keep the gains | ||
| traces = traces.astype(np.float32) | ||
| if self.window is not None: | ||
| traces, agc_gains = agc(traces, window=self.window) | ||
| else: | ||
|
|
@@ -255,7 +256,7 @@ def get_traces(self, start_frame, end_frame, channel_indices): | |
| # ----------------------------------------------------------------------------------------------- | ||
|
|
||
|
|
||
| def agc(traces, window, epsilon=1e-8): | ||
| def agc(traces, window, epsilon=None): | ||
| """ | ||
| Automatic gain control | ||
| w_agc, gain = agc(w, window_length=.5, si=.002, epsilon=1e-8) | ||
|
|
@@ -268,13 +269,15 @@ def agc(traces, window, epsilon=1e-8): | |
| """ | ||
| import scipy.signal | ||
|
|
||
| gain = scipy.signal.fftconvolve(np.abs(traces), window[:, None], mode="same", axes=0) | ||
| # default value for epsilon is relative to the rms, loosely matching the IBL 1e-8 for an input in Volts | ||
| if epsilon is None: | ||
| epsilon = np.std(traces - np.mean(traces)) * 0.003 | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. maybe the |
||
|
|
||
| gain += (np.sum(gain, axis=0) * epsilon / traces.shape[0])[np.newaxis, :] | ||
| gain = scipy.signal.fftconvolve(np.abs(traces), window[:, None], mode="same", axes=0) | ||
|
|
||
| dead_channels = np.sum(gain, axis=0) == 0 | ||
|
|
||
| traces[:, ~dead_channels] = traces[:, ~dead_channels] / gain[:, ~dead_channels] | ||
| traces[:, ~dead_channels] = traces[:, ~dead_channels] / np.maximum(epsilon, gain[:, ~dead_channels]) | ||
|
|
||
| return traces, gain | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -3,7 +3,7 @@ | |||||
| import numpy as np | ||||||
| from copy import deepcopy | ||||||
|
|
||||||
| import spikeinterface as si | ||||||
| import spikeinterface.full as si | ||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why switch from core to full. This is super heavy and probably not necessary right? This is a bit of a poorly explained part of our importing. import spikeinterface as siis equivalent to import spikeinterface.core as siwhich is a faster import. |
||||||
| import spikeinterface.preprocessing as spre | ||||||
| import spikeinterface.extractors as se | ||||||
| from spikeinterface.core import generate_recording | ||||||
|
|
@@ -24,7 +24,7 @@ | |||||
|
|
||||||
|
|
||||||
| @pytest.mark.skipif( | ||||||
| importlib.util.find_spec("neurodsp") is None or importlib.util.find_spec("spikeglx") is None or ON_GITHUB, | ||||||
| importlib.util.find_spec("ibldsp") is None or importlib.util.find_spec("spikeglx") is None or ON_GITHUB, | ||||||
| reason="Only local. Requires ibl-neuropixel install", | ||||||
| ) | ||||||
| @pytest.mark.parametrize("lagc", [False, 1, 300]) | ||||||
|
|
@@ -51,32 +51,28 @@ def test_highpass_spatial_filter_real_data(lagc): | |||||
| use DEBUG = true to visualise. | ||||||
|
|
||||||
| """ | ||||||
| import spikeglx | ||||||
| import neurodsp.voltage as voltage | ||||||
| import ibldsp.voltage | ||||||
| import neuropixel | ||||||
|
|
||||||
| options = dict(lagc=lagc, ntr_pad=25, ntr_tap=50, butter_kwargs=None) | ||||||
| print(options) | ||||||
|
|
||||||
| ibl_data, si_recording = get_ibl_si_data() | ||||||
|
|
||||||
| si_filtered, _ = run_si_highpass_filter(si_recording, **options) | ||||||
| local_path = si.download_dataset(remote_path="spikeglx/Noise4Sam_g0") | ||||||
| si_recording = se.read_spikeglx(local_path, stream_id="imec0.ap") | ||||||
| si_recording = spre.astype(si_recording, "float") | ||||||
| recording_ps = si.phase_shift(si_recording) | ||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. following @zm711 comment about not importing full! |
||||||
| recording_hp = si.highpass_filter(recording_ps, freq_min=300, filter_order=3) | ||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
| recording_hps = si.highpass_spatial_filter(recording_hp) | ||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
| raw = si_recording.get_traces().astype(np.float32).T * neuropixel.S2V_AP | ||||||
| si_filtered = recording_hps.get_traces().astype(np.float32).T * neuropixel.S2V_AP | ||||||
|
|
||||||
| ibl_filtered = run_ibl_highpass_filter(ibl_data.copy(), **options) | ||||||
| destripe = ibldsp.voltage.destripe(raw, fs=30_000, neuropixel_version=1) | ||||||
|
|
||||||
| if DEBUG: | ||||||
| fig, axs = plt.subplots(ncols=4) | ||||||
| axs[0].imshow(si_recording.get_traces(return_in_uV=True)) | ||||||
| axs[0].set_title("SI Raw") | ||||||
| axs[1].imshow(ibl_data.T) | ||||||
| axs[1].set_title("IBL Raw") | ||||||
| axs[2].imshow(si_filtered) | ||||||
| axs[2].set_title("SI Filtered ") | ||||||
| axs[3].imshow(ibl_filtered) | ||||||
| axs[3].set_title("IBL Filtered") | ||||||
| from viewephys.gui import viewephys | ||||||
|
|
||||||
| eqc = {} | ||||||
| eqc["si_filtered"] = viewephys(si_filtered, fs=30_000, title="si_filtered") | ||||||
| eqc["ibl_filtered"] = viewephys(destripe, fs=30_000, title="ibl_filtered") | ||||||
|
|
||||||
| assert np.allclose( | ||||||
| si_filtered, ibl_filtered * 1e6, atol=1e-01, rtol=0 | ||||||
| ) # the differences are entired due to scaling on data load. | ||||||
| np.testing.assert_allclose(si_filtered[12:120, 300:800], destripe[12:120, 300:800], atol=1e-05, rtol=0) | ||||||
|
|
||||||
|
|
||||||
| @pytest.mark.parametrize("ntr_pad", [None, 0, 31]) | ||||||
|
|
@@ -140,24 +136,6 @@ def test_dtype_stability(dtype): | |||||
| # ---------------------------------------------------------------------------------------------------------------------- | ||||||
|
|
||||||
|
|
||||||
| def get_ibl_si_data(): | ||||||
| """ | ||||||
| Set fixture to session to ensure origional data is not changed. | ||||||
| """ | ||||||
| import spikeglx | ||||||
|
|
||||||
| local_path = si.download_dataset(remote_path="spikeglx/Noise4Sam_g0") | ||||||
| ibl_recording = spikeglx.Reader( | ||||||
| local_path / "Noise4Sam_g0_imec0" / "Noise4Sam_g0_t0.imec0.ap.bin", ignore_warnings=True | ||||||
| ) | ||||||
| ibl_data = ibl_recording.read(slice(None), slice(None), sync=False)[:, :-1].T # cut sync channel | ||||||
|
|
||||||
| si_recording = se.read_spikeglx(local_path, stream_id="imec0.ap") | ||||||
| si_recording = spre.astype(si_recording, dtype="float32") | ||||||
|
|
||||||
| return ibl_data, si_recording | ||||||
|
|
||||||
|
|
||||||
| def process_args_for_si(si_recording, lagc): | ||||||
| """""" | ||||||
| if isinstance(lagc, bool) and not lagc: | ||||||
|
|
@@ -215,9 +193,10 @@ def run_si_highpass_filter(si_recording, ntr_pad, ntr_tap, lagc, butter_kwargs, | |||||
|
|
||||||
|
|
||||||
| def run_ibl_highpass_filter(ibl_data, ntr_pad, ntr_tap, lagc, butter_kwargs): | ||||||
| butter_kwargs, ntr_pad, lagc = process_args_for_ibl(butter_kwargs, ntr_pad, lagc) | ||||||
| import ibldsp.voltage | ||||||
|
|
||||||
| ibl_filtered = voltage.kfilt(ibl_data, None, ntr_pad, ntr_tap, lagc, butter_kwargs).T | ||||||
| butter_kwargs, ntr_pad, lagc = process_args_for_ibl(butter_kwargs, ntr_pad, lagc) | ||||||
| ibl_filtered = ibldsp.voltage.kfilt(ibl_data, None, ntr_pad, ntr_tap, lagc, butter_kwargs).T | ||||||
|
|
||||||
| return ibl_filtered | ||||||
|
|
||||||
|
|
||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we need this? :)
One can either pass the root folder or a direct path the the cbin file with the
cbin_file_patharg