FIR phase-shift + extract apply_raised_cosine_taper from get_chunk_with_margin#4563
Open
galenlynch wants to merge 2 commits intoSpikeInterface:mainfrom
Open
FIR phase-shift + extract apply_raised_cosine_taper from get_chunk_with_margin#4563galenlynch wants to merge 2 commits intoSpikeInterface:mainfrom
apply_raised_cosine_taper from get_chunk_with_margin#4563galenlynch wants to merge 2 commits intoSpikeInterface:mainfrom
Conversation
Contributor
Author
|
Test failures are unrelated to this PR |
Member
|
@samuelgarcia @oliche can you take a look at this implementation? |
7 tasks
Adds a sinc-FIR alternative to the FFT-based PhaseShift path, and factors
the FFT-specific raised-cosine taper out of get_chunk_with_margin so
bounded-support FIR consumers don't have to pay for it.
Changes
-------
- PhaseShiftRecording(method="fft"|"fir", n_taps=32, output_dtype=None)
- method="fir": 32-tap Kaiser-windowed sinc, numba-jit with prange over time
- Per-channel kernels cached once per segment
- Int16-native fast path: reads int16 directly, writes float32 when
output_dtype=np.float32 (skips round-back-to-int16)
- FIR margin = n_taps // 2 (16 for 32-tap default), vs FFT path's 40 ms
- Default "fft" preserves existing behavior
- apply_raised_cosine_taper(data, margin, *, inplace=True) public function
in spikeinterface.core.time_series_tools
- get_chunk_with_margin(window_on_margin=True) emits DeprecationWarning
and delegates to apply_raised_cosine_taper; old behavior preserved
Whittaker-Shannon justification
-------------------------------
Both FFT and FIR paths are numerical realisations of the same ideal
sinc-interpolation implied by the sampling theorem. FFT does it spectrally
(phase ramp = DFT of full sinc kernel); FIR does it in time against a
Kaiser-windowed 32-tap truncation of the same sinc. Approximations: sinc
truncation (32 taps captures >99% of kernel energy for any d ∈ [0,1)) and
Kaiser windowing (-80 dB stopband, two orders of magnitude below NP's
~50 dB SNR). Measured 0.19% spike-band RMS vs FFT on real NP 2.0 data.
Performance (1M × 384 int16 AIND pipeline, PS → HP → CMR, 24-core host)
-----------------------------------------------------------------------
stock FFT FIR speedup
single-call get_traces 82.1 s 6.28 s 13.1×
same, f32 propagated 85.8 s 4.40 s 19.5×
TimeSeriesChunkExecutor n=24 4.98 s 1.62 s 3.1× on top
of CRE n=24
Memory (same pipeline, n_jobs=24, chunk=1s):
FFT peak RSS: 10.8 GB FIR peak RSS: 3.9 GB (2.8× less)
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
for more information, see https://pre-commit.ci
b3f3784 to
301355a
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
This PR introduces a faster time-domain approach to phase shifting neuropixels recording, which can be opted-into, along side the exist FFT-approach. The time-domain approach achieves >99% accuracy in a fraction of the time.
Attempting to view my data as Kilosort4 would see it using direct calls to
get_tracesoutside ofTimeSeriesChunkExecutorwas very slow. Profiling showed that the vast majority of the time (98%) was being spent phase shifting, at least on a 384x1M chunk of data. This is perhaps an unrealistic length of data, but even with more standard ~1 second chunks of data (384x30k), the vast majority of the preprocessing time was spent phase shifting. This led me to investigate why it was so slow.The current approach to phase shifting is doing a rFFT of the entire recording chunk, multiplies the fourier-transformed data with a complex exponential, and then convert back into the time domain with irFFT. This is correct, but scales with O(n log n), requires padding to avoid wrap-around, and is very expensive 'per sample'. However, since the signal is band-limited, we can instead use whittaker-shannon interpolation and convolve the time-domain signal instead of multiplying the entire frequency-domain signal.
A fractional-sample delay is, by definition, a sinc interpolation at the desired offset. Whittaker–Shannon states that any signal bandlimited to Nyquist is exactly reconstructed from its samples via
so delaying channel$c$ by a fractional $d_c$ samples is
— convolution with the ideal fractional-delay kernel$h_{d_c}[k] = \mathrm{sinc}(k - d_c)$ .
The existing FFT path realises this convolution spectrally: multiplying by$e^{i,2\pi f, d_c}$ in the frequency domain is the DFT of that same infinite sinc kernel.
We can instead do time domain convolution with finite impulse response sinc filter. Truncating the sinc kernel allows massive performance gains while sacraficing minimal amounts of accuracy. In this PR, the FIR path realises the phase shift in the time domain as explicit linear convolution against a Kaiser-windowed, 32-tap truncation of the same sinc. The operation is identical; the only approximations are:
For the 384x1M sample case, changing the interpolation to a FIR filter instead of multiplying the FFT sped up the entire
PhaseShiftRecording → HighpassFilterRecording(300 Hz) → CommonReferenceRecordingpreprocessing pipeline by 5.4x. When combined with the companion PR #4564, this increased to 13.1x speedup. Notably, FIR interpolation uses ~2.8x less peak memory than the FFT approach. However, these benchmarks take advantage of numba-level parallelism that I added to the FIR approach.Isolating only the PhaseShiftRecording component shows that the algorithmic improvement alone produces 10x faster phase shifting. Here I am testing different configurations of
TimeSeriesChunkExecutorwith different levels of 'outer' parallelism (n_jobs) and 'inner' parallelism (numba threads). CRE number is then_jobssetting.Algorithm alone beats best-outer-parallelism-alone (10× vs 4.7×). The algorithmic change breaks through a ceiling that
TimeSeriesChunkExecutoron stock FFT can't — outer parallelism can only distribute the same FFT work across workers, not change the total work done. Algorithm + outer alone already reaches 48.7×; adding inner (numba default vs 1 thread) takes it from 48.7× to 54.1× — only ~10% more. Inner parallelism has diminishing returns once outer saturates cores.Correctness
But how accurate is this truncated 32-tap sinc interpolation? The tests in this PR test exactly that, and on both synthetic and actual neuropixels data the difference between the truncated sinc and the infinite sinc is ~0.2%
< 1%on synthetic< 0.5%on real NP 2.0 datatest_phase_shift(FFT chunked-vs-full identity)error_mean / rms < 0.001All existing PhaseShift tests pass unchanged.
Changes
1.
PhaseShiftRecording(method="fft"|"fir", n_taps=32, output_dtype=None)File:
src/spikeinterface/preprocessing/phase_shift.pymethodkwarg (default"fft"for backward compatibility).method="fir"uses a 32-tap Kaiser-windowed sinc FIR, implemented as numba-jit kernels withprangeover time.n_taps // 2samples (16 for the 32-tap default), not the 40 ms the FFT path needs.n_tapsconfigurable (default 32, validated as even and ≥ 2).int16-native input reader (always on for int16 parents)
When the parent recording's dtype is int16, the FIR path dispatches to an int16-input numba kernel that reads int16 samples directly and accumulates in float32. No explicit input cast — the promotion happens per-element inside the kernel's convolution loop, avoiding a full int16 → float32 buffer materialisation. Active automatically for any int16 parent; no opt-in required.
output_dtype=np.float32— skip the output round-backIndependently, the phase shift output stage can optionally skip its round-to-int16 cast:
output_dtype=None): FIR internally produces float32 samples, then rounds + casts back to the parent dtype (e.g., int16). Preserves the int16 contract for downstream stages that expect it.output_dtype=np.float32: FIR writes float32 directly, andPhaseShiftRecordingadvertises float32 as its output dtype. Downstream stages that inherit (dtype=None) consume float32 and skip their own int16 round-backs; the full pipeline stays in floating-point.dtype=np.int16, they will cast back to int16 regardless of what phase shift advertises, reinstating the round-back at their own output boundary.output_dtype=np.float32is fully effective only when the caller builds a downstream chain that inherits dtype from phase shift (or explicitly setsdtype=np.float32on HP/CMR etc.).2.
apply_raised_cosine_taperextractFile:
src/spikeinterface/core/time_series_tools.pyapply_raised_cosine_taper(data, margin, *, inplace=True)exposes the raised-cosine window that was previously inlined inget_chunk_with_margin(window_on_margin=True).window_on_margin=Truecontinues to work but is deprecated: it emits aDeprecationWarningand delegates toapply_raised_cosine_taper.PhaseShiftRecordingpath is updated to callget_chunk_with_margin(window_on_margin=False)and thenapply_raised_cosine_taperexplicitly. Output is bit-for-bit equivalent to pre-refactor behavior (regression test added).get_chunk_with_marginwas unusable for bounded-support filters both because the taper was redundant and because the in-place*=against a float taper fails on int-typed chunks. Separating the concern makes the utility method-agnostic.Performance (reproducible)
Here are some other relevant benchmarks for this PR.
benchmarks/preprocessing/bench_perf.py— synthetic NumpyRecording, 1M × 384 int16,PS → HP @ 300 Hz → CMRpipeline , measured on a 24-core x86_64 host (SI 0.103 dev, numpy 2.1, scipy 1.14, numba 0.60).End-to-end pipeline — direct
get_traces()(no CRE)Scope: full
PhaseShiftRecording → HighpassFilterRecording → CommonReferenceRecordingpipeline, singleget_traces()call on the whole recording (noTimeSeriesChunkExecutorchunking). Numba threads at default (all cores).This PR alone (FIR phase shift + stock BP/CMR):
Combined with companion
n_workersPR (measured on development branch with both PRs applied):This PR alone delivers the bulk of the direct-
get_traces()win because stock phase shift FFT dominated the pipeline; FIR demotes it from ~68 s to ~1 s, exposing band pass (~8 s) as the new bottleneck, which the companion PR'sn_workerskwargs then unlock. The int16-preserved path pays ~1.5× more time than f32-propagated because each stage round-trips through float internally and casts back.FIR × CRE outer parallelism
Scope: full
PhaseShiftRecording → HighpassFilterRecording → CommonReferenceRecordingpipeline end-to-end, 1M × 384 int16, chunk_duration="1s", stock BP/CMR (no companion-PRn_workerskwargs). Only CRE outer n_jobs and phase shiftmethodvary; numba threads are left at their default (all cores), which is the value FIR uses for its internalprange."int16 preserved" rows use the int16-throughout pipeline (each stage rounds back to int16 at its output boundary); the int16-reading FIR kernel is active internally when input is int16, but the final cast reinstates the int16 contract. "f32 throughout" rows flip every stage's
dtypetonp.float32so intermediate buffers stay in floating-point.FIR stacks cleanly with
TimeSeriesChunkExecutor. Users already running atn_jobs ≈ core_counton stock get ~3× more speedup from the algorithmic change alone (4.98s → 1.62s).Peak RSS scaling (chunk=1s, 1M × 384 int16, thread engine)
Scope: full
PhaseShiftRecording → HighpassFilterRecording → CommonReferenceRecordingpipeline end-to-end. Numba threads follow the Compatibility section's recommendation (numba default on low-n_jobsconfigs where cores are free;NUMBA_NUM_THREADS=1whenn_jobs ≈ core_countto avoid oversubscription):At
chunk=10s, n_jobs=24: stock 13.75 GB, FIR 7.13 GB (1.93× less). The gap widens withn_jobsbecause FIR's numba thread pool is allocated once process-wide; stock's scipy FFT scratch buffers are per-call and don't share across worker threads.get_traces() with entire filter chain
Scope: full
PhaseShiftRecording → HighpassFilterRecording → CommonReferenceRecordingpipeline end-to-end, 1M × 384 int16This PR alone (FIR phase-shift, stock BP/CMR)
get_traces(), int16 preservedget_traces(), f32 propagatedTimeSeriesChunkExecutor(n_jobs=24, thread), int16 preservedCombined with companion PR (adds
n_workerson BP and CMR)get_traces(), int16 preservedget_traces(), f32 propagatedTimeSeriesChunkExecutor(n_jobs=24, thread), int16 preservedCombined numbers require both PRs merged. FIR also uses ~2.8× less peak RSS than FFT at the same parallelism (at
n_jobs=24, chunk=1s: 10.8 GB → 3.9 GB) because the numba thread pool is shared across workers while scipy's FFT scratch buffers aren't.Compatibility
method="fft"is the default; existing callers get existing behavior bit-for-bit.get_chunk_with_margin(window_on_margin=True)still works and emits aDeprecationWarningpointing callers atapply_raised_cosine_taper._kwargsdict updated for new phase shift kwargs;save()/load()round-trip correctly.numbais already a soft dep of SI's Kilosort path; the FIR kernels import it lazily and raise a clear error with install instructions if missing.n_workerskwarg on phase shift. FIR parallelism is internal to the numba kernel (prangeover time), dispatched to numba's process-global thread pool. Tune viaNUMBA_NUM_THREADSenv var ornumba.set_num_threads()— the standard numba mechanism. Suggested settings: numba default (all cores) for directget_traces()callers and forn_jobs=1under CRE;NUMBA_NUM_THREADS=1whenn_jobs ≈ core_countto avoid oversubscription. Not set at library level (follows scipy/sklearn convention).Companion PR
An independent companion PR #4564 adds
n_workerskwargs onFilterRecordingandCommonReferenceRecordingwith per-caller-thread inner pools. Most valuable for directget_traces()callers, where the BP/CMR parallelism compounds on top of this PR's FIR to reach 13× (int16) / 19.5× (f32) pipeline speedup. Under CRE the companion kwargs add a smaller additional gain without causing shared-pool queueing. The two PRs have no code dependency and can land in either order.