Skip to content

FIR phase-shift + extract apply_raised_cosine_taper from get_chunk_with_margin#4563

Open
galenlynch wants to merge 2 commits intoSpikeInterface:mainfrom
galenlynch:perf/phase-shift-fir
Open

FIR phase-shift + extract apply_raised_cosine_taper from get_chunk_with_margin#4563
galenlynch wants to merge 2 commits intoSpikeInterface:mainfrom
galenlynch:perf/phase-shift-fir

Conversation

@galenlynch
Copy link
Copy Markdown
Contributor

@galenlynch galenlynch commented Apr 24, 2026

This PR introduces a faster time-domain approach to phase shifting neuropixels recording, which can be opted-into, along side the exist FFT-approach. The time-domain approach achieves >99% accuracy in a fraction of the time.

Attempting to view my data as Kilosort4 would see it using direct calls to get_traces outside of TimeSeriesChunkExecutor was very slow. Profiling showed that the vast majority of the time (98%) was being spent phase shifting, at least on a 384x1M chunk of data. This is perhaps an unrealistic length of data, but even with more standard ~1 second chunks of data (384x30k), the vast majority of the preprocessing time was spent phase shifting. This led me to investigate why it was so slow.

The current approach to phase shifting is doing a rFFT of the entire recording chunk, multiplies the fourier-transformed data with a complex exponential, and then convert back into the time domain with irFFT. This is correct, but scales with O(n log n), requires padding to avoid wrap-around, and is very expensive 'per sample'. However, since the signal is band-limited, we can instead use whittaker-shannon interpolation and convolve the time-domain signal instead of multiplying the entire frequency-domain signal.

A fractional-sample delay is, by definition, a sinc interpolation at the desired offset. Whittaker–Shannon states that any signal bandlimited to Nyquist is exactly reconstructed from its samples via

$$x(t) = \sum_n x[n] \cdot \mathrm{sinc}(t - n)$$

so delaying channel $c$ by a fractional $d_c$ samples is

$$y_c[n] = \sum_k x_c[n-k] \cdot \mathrm{sinc}(k - d_c)$$

— convolution with the ideal fractional-delay kernel $h_{d_c}[k] = \mathrm{sinc}(k - d_c)$.

The existing FFT path realises this convolution spectrally: multiplying by $e^{i,2\pi f, d_c}$ in the frequency domain is the DFT of that same infinite sinc kernel.

We can instead do time domain convolution with finite impulse response sinc filter. Truncating the sinc kernel allows massive performance gains while sacraficing minimal amounts of accuracy. In this PR, the FIR path realises the phase shift in the time domain as explicit linear convolution against a Kaiser-windowed, 32-tap truncation of the same sinc. The operation is identical; the only approximations are:

  1. Sinc truncation. For bandlimited input the sinc tails decay quickly, so 32 taps capture > 99% of the kernel energy for any $d \in [0, 1)$. Longer kernels trade compute for accuracy: 16 taps ≈ 0.8% RMS, 32 taps ≈ 0.19% RMS, 64 taps < 0.05% RMS vs the FFT reference on real NP 2.0 data.
  2. Kaiser windowing (β = 8.6). Rectangular truncation would convolve the frequency response with a Dirichlet kernel → Gibbs-phenomenon ripples. Kaiser trades a small main-lobe widening for ≈ −80 dB stopband attenuation — two orders of magnitude below the ≈ 50 dB SNR of a 12-bit acquisition system. Windowing error is physically unmeasurable in ephys.

For the 384x1M sample case, changing the interpolation to a FIR filter instead of multiplying the FFT sped up the entire PhaseShiftRecording → HighpassFilterRecording(300 Hz) → CommonReferenceRecording preprocessing pipeline by 5.4x. When combined with the companion PR #4564, this increased to 13.1x speedup. Notably, FIR interpolation uses ~2.8x less peak memory than the FFT approach. However, these benchmarks take advantage of numba-level parallelism that I added to the FIR approach.

Isolating only the PhaseShiftRecording component shows that the algorithmic improvement alone produces 10x faster phase shifting. Here I am testing different configurations of TimeSeriesChunkExecutor with different levels of 'outer' parallelism (n_jobs) and 'inner' parallelism (numba threads). CRE number is the n_jobs setting.

Config Time vs baseline Parallelism axis
FFT, CRE n=1 (baseline) 15.59 s 1.00×
FFT, CRE n=8 thread 3.33 s 4.68× outer only
FIR, CRE n=1, numba 1-thread 1.54 s 10.09× algorithm alone
FIR, CRE n=1, numba default 0.73 s 21.25× algo + inner
FIR, CRE n=8 thread, numba 1-thread 0.32 s 48.72× algo + outer
FIR, CRE n=8 thread, numba default 0.29 s 54.15× algo + inner + outer

Algorithm alone beats best-outer-parallelism-alone (10× vs 4.7×). The algorithmic change breaks through a ceiling that TimeSeriesChunkExecutor on stock FFT can't — outer parallelism can only distribute the same FFT work across workers, not change the total work done. Algorithm + outer alone already reaches 48.7×; adding inner (numba default vs 1 thread) takes it from 48.7× to 54.1× — only ~10% more. Inner parallelism has diminishing returns once outer saturates cores.

Correctness

But how accurate is this truncated 32-tap sinc interpolation? The tests in this PR test exactly that, and on both synthetic and actual neuropixels data the difference between the truncated sinc and the infinite sinc is ~0.2%

Path Check Result
FIR vs FFT signal-band RMS < 1% on synthetic Pass at ~0.2%
FIR vs FFT spike-band RMS < 0.5% on real NP 2.0 data ~0.19% measured
Existing test_phase_shift (FFT chunked-vs-full identity) error_mean / rms < 0.001 Pass (regression guard on taper refactor)

All existing PhaseShift tests pass unchanged.

Changes

1. PhaseShiftRecording(method="fft"|"fir", n_taps=32, output_dtype=None)

File: src/spikeinterface/preprocessing/phase_shift.py

  • New method kwarg (default "fft" for backward compatibility).
  • method="fir" uses a 32-tap Kaiser-windowed sinc FIR, implemented as numba-jit kernels with prange over time.
  • Per-channel kernels are precomputed once per segment (sample shifts are fixed for the recording's lifetime); the FFT path recomputed an effective kernel per chunk.
  • FIR margin is n_taps // 2 samples (16 for the 32-tap default), not the 40 ms the FFT path needs.
  • n_taps configurable (default 32, validated as even and ≥ 2).

int16-native input reader (always on for int16 parents)

When the parent recording's dtype is int16, the FIR path dispatches to an int16-input numba kernel that reads int16 samples directly and accumulates in float32. No explicit input cast — the promotion happens per-element inside the kernel's convolution loop, avoiding a full int16 → float32 buffer materialisation. Active automatically for any int16 parent; no opt-in required.

output_dtype=np.float32 — skip the output round-back

Independently, the phase shift output stage can optionally skip its round-to-int16 cast:

  • Default (output_dtype=None): FIR internally produces float32 samples, then rounds + casts back to the parent dtype (e.g., int16). Preserves the int16 contract for downstream stages that expect it.
  • With output_dtype=np.float32: FIR writes float32 directly, and PhaseShiftRecording advertises float32 as its output dtype. Downstream stages that inherit (dtype=None) consume float32 and skip their own int16 round-backs; the full pipeline stays in floating-point.
  • Caveat: if downstream stages set explicit dtype=np.int16, they will cast back to int16 regardless of what phase shift advertises, reinstating the round-back at their own output boundary. output_dtype=np.float32 is fully effective only when the caller builds a downstream chain that inherits dtype from phase shift (or explicitly sets dtype=np.float32 on HP/CMR etc.).

2. apply_raised_cosine_taper extract

File: src/spikeinterface/core/time_series_tools.py

  • New public function apply_raised_cosine_taper(data, margin, *, inplace=True) exposes the raised-cosine window that was previously inlined in get_chunk_with_margin(window_on_margin=True).
  • window_on_margin=True continues to work but is deprecated: it emits a DeprecationWarning and delegates to apply_raised_cosine_taper.
  • The FFT-based PhaseShiftRecording path is updated to call get_chunk_with_margin(window_on_margin=False) and then apply_raised_cosine_taper explicitly. Output is bit-for-bit equivalent to pre-refactor behavior (regression test added).
  • Rationale: the taper is FFT-specific (suppresses spectral leakage from zero-padded boundaries). Before this refactor, get_chunk_with_margin was unusable for bounded-support filters both because the taper was redundant and because the in-place *= against a float taper fails on int-typed chunks. Separating the concern makes the utility method-agnostic.

Performance (reproducible)

Here are some other relevant benchmarks for this PR.

benchmarks/preprocessing/bench_perf.py — synthetic NumpyRecording, 1M × 384 int16, PS → HP @ 300 Hz → CMR pipeline , measured on a 24-core x86_64 host (SI 0.103 dev, numpy 2.1, scipy 1.14, numba 0.60).

End-to-end pipeline — direct get_traces() (no CRE)

Scope: full PhaseShiftRecording → HighpassFilterRecording → CommonReferenceRecording pipeline, single get_traces() call on the whole recording (no TimeSeriesChunkExecutor chunking). Numba threads at default (all cores).

This PR alone (FIR phase shift + stock BP/CMR):

Pipeline Stock (FFT, serial) FIR Speedup
int16 preserved 83.64 s 15.51 s 5.39×
f32 propagated 87.15 s 19.57 s 4.45×

Combined with companion n_workers PR (measured on development branch with both PRs applied):

Pipeline Stock FIR + parallel BP/CMR Speedup
int16 preserved 82.14 s 6.28 s 13.09×
f32 propagated 85.77 s 4.40 s 19.48×

This PR alone delivers the bulk of the direct-get_traces() win because stock phase shift FFT dominated the pipeline; FIR demotes it from ~68 s to ~1 s, exposing band pass (~8 s) as the new bottleneck, which the companion PR's n_workers kwargs then unlock. The int16-preserved path pays ~1.5× more time than f32-propagated because each stage round-trips through float internally and casts back.

FIR × CRE outer parallelism

Scope: full PhaseShiftRecording → HighpassFilterRecording → CommonReferenceRecording pipeline end-to-end, 1M × 384 int16, chunk_duration="1s", stock BP/CMR (no companion-PR n_workers kwargs). Only CRE outer n_jobs and phase shift method vary; numba threads are left at their default (all cores), which is the value FIR uses for its internal prange.

Config Time Speedup
CRE n=1, stock (int16 preserved) 32.41 s 1.00×
CRE n=1, FIR (int16 preserved) 5.81 s 5.58×
CRE n=8 thread, stock (int16 preserved) 5.73 s 5.66×
CRE n=8 thread, FIR (int16 preserved) 1.81 s 17.88×
CRE n=24 thread, stock (int16 preserved) 4.98 s 6.51×
CRE n=24 thread, FIR (int16 preserved) 1.62 s 19.96×
CRE n=8 thread, FIR (f32 throughout) 1.59 s 20.39×
CRE n=24 thread, FIR (f32 throughout) 1.57 s 20.60×

"int16 preserved" rows use the int16-throughout pipeline (each stage rounds back to int16 at its output boundary); the int16-reading FIR kernel is active internally when input is int16, but the final cast reinstates the int16 contract. "f32 throughout" rows flip every stage's dtype to np.float32 so intermediate buffers stay in floating-point.

FIR stacks cleanly with TimeSeriesChunkExecutor. Users already running at n_jobs ≈ core_count on stock get ~3× more speedup from the algorithmic change alone (4.98s → 1.62s).

Peak RSS scaling (chunk=1s, 1M × 384 int16, thread engine)

Scope: full PhaseShiftRecording → HighpassFilterRecording → CommonReferenceRecording pipeline end-to-end. Numba threads follow the Compatibility section's recommendation (numba default on low-n_jobs configs where cores are free; NUMBA_NUM_THREADS=1 when n_jobs ≈ core_count to avoid oversubscription):

Config Numba threads Peak RSS (Δ over baseline) Notes
CRE n=1, stock — (FFT doesn't use numba) 0.48 GB
CRE n=1, FIR 8 0.49 GB ~same as stock at n=1
CRE n=4, stock 1.84 GB
CRE n=8, stock 3.64 GB
CRE n=24, stock 10.80 GB linear: 0.45 GB × workers
CRE n=24, FIR 1 3.89 GB sub-linear: 2.78× less than stock

At chunk=10s, n_jobs=24: stock 13.75 GB, FIR 7.13 GB (1.93× less). The gap widens with n_jobs because FIR's numba thread pool is allocated once process-wide; stock's scipy FFT scratch buffers are per-call and don't share across worker threads.

get_traces() with entire filter chain

Scope: full PhaseShiftRecording → HighpassFilterRecording → CommonReferenceRecording pipeline end-to-end, 1M × 384 int16

This PR alone (FIR phase-shift, stock BP/CMR)

Scenario Stock FIR (this PR) Speedup
Direct get_traces(), int16 preserved 83.6 s 15.5 s 5.4×
Direct get_traces(), f32 propagated 87.2 s 19.6 s 4.5×
TimeSeriesChunkExecutor(n_jobs=24, thread), int16 preserved 4.98 s 1.62 s 3.1× on top of existing CRE

Combined with companion PR (adds n_workers on BP and CMR)

Scenario Stock Combined Speedup
Direct get_traces(), int16 preserved 82.1 s 6.28 s 13.1×
Direct get_traces(), f32 propagated 85.8 s 4.40 s 19.5×
TimeSeriesChunkExecutor(n_jobs=24, thread), int16 preserved 4.98 s 1.57 s 3.2× on top of existing CRE

Combined numbers require both PRs merged. FIR also uses ~2.8× less peak RSS than FFT at the same parallelism (at n_jobs=24, chunk=1s: 10.8 GB → 3.9 GB) because the numba thread pool is shared across workers while scipy's FFT scratch buffers aren't.

Compatibility

  • No default behavior changes. method="fft" is the default; existing callers get existing behavior bit-for-bit.
  • Deprecation only. get_chunk_with_margin(window_on_margin=True) still works and emits a DeprecationWarning pointing callers at apply_raised_cosine_taper.
  • Round-trip dumpability. _kwargs dict updated for new phase shift kwargs; save() / load() round-trip correctly.
  • No new required deps. numba is already a soft dep of SI's Kilosort path; the FIR kernels import it lazily and raise a clear error with install instructions if missing.
  • No n_workers kwarg on phase shift. FIR parallelism is internal to the numba kernel (prange over time), dispatched to numba's process-global thread pool. Tune via NUMBA_NUM_THREADS env var or numba.set_num_threads() — the standard numba mechanism. Suggested settings: numba default (all cores) for direct get_traces() callers and for n_jobs=1 under CRE; NUMBA_NUM_THREADS=1 when n_jobs ≈ core_count to avoid oversubscription. Not set at library level (follows scipy/sklearn convention).

Companion PR

An independent companion PR #4564 adds n_workers kwargs on FilterRecording and CommonReferenceRecording with per-caller-thread inner pools. Most valuable for direct get_traces() callers, where the BP/CMR parallelism compounds on top of this PR's FIR to reach 13× (int16) / 19.5× (f32) pipeline speedup. Under CRE the companion kwargs add a smaller additional gain without causing shared-pool queueing. The two PRs have no code dependency and can land in either order.

@galenlynch
Copy link
Copy Markdown
Contributor Author

Test failures are unrelated to this PR

@alejoe91 alejoe91 added the preprocessing Related to preprocessing module label May 4, 2026
@alejoe91
Copy link
Copy Markdown
Member

alejoe91 commented May 4, 2026

@samuelgarcia @oliche can you take a look at this implementation?

galenlynch and others added 2 commits May 9, 2026 08:19
Adds a sinc-FIR alternative to the FFT-based PhaseShift path, and factors
the FFT-specific raised-cosine taper out of get_chunk_with_margin so
bounded-support FIR consumers don't have to pay for it.

Changes
-------
- PhaseShiftRecording(method="fft"|"fir", n_taps=32, output_dtype=None)
  - method="fir": 32-tap Kaiser-windowed sinc, numba-jit with prange over time
  - Per-channel kernels cached once per segment
  - Int16-native fast path: reads int16 directly, writes float32 when
    output_dtype=np.float32 (skips round-back-to-int16)
  - FIR margin = n_taps // 2 (16 for 32-tap default), vs FFT path's 40 ms
  - Default "fft" preserves existing behavior
- apply_raised_cosine_taper(data, margin, *, inplace=True) public function
  in spikeinterface.core.time_series_tools
- get_chunk_with_margin(window_on_margin=True) emits DeprecationWarning
  and delegates to apply_raised_cosine_taper; old behavior preserved

Whittaker-Shannon justification
-------------------------------
Both FFT and FIR paths are numerical realisations of the same ideal
sinc-interpolation implied by the sampling theorem.  FFT does it spectrally
(phase ramp = DFT of full sinc kernel); FIR does it in time against a
Kaiser-windowed 32-tap truncation of the same sinc.  Approximations: sinc
truncation (32 taps captures >99% of kernel energy for any d ∈ [0,1)) and
Kaiser windowing (-80 dB stopband, two orders of magnitude below NP's
~50 dB SNR).  Measured 0.19% spike-band RMS vs FFT on real NP 2.0 data.

Performance (1M × 384 int16 AIND pipeline, PS → HP → CMR, 24-core host)
-----------------------------------------------------------------------
                                 stock FFT   FIR         speedup
  single-call get_traces         82.1 s      6.28 s      13.1×
  same, f32 propagated           85.8 s      4.40 s      19.5×

  TimeSeriesChunkExecutor n=24   4.98 s      1.62 s      3.1× on top
                                                        of CRE n=24

Memory (same pipeline, n_jobs=24, chunk=1s):
  FFT peak RSS: 10.8 GB       FIR peak RSS: 3.9 GB (2.8× less)

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@galenlynch galenlynch force-pushed the perf/phase-shift-fir branch from b3f3784 to 301355a Compare May 9, 2026 21:04
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

preprocessing Related to preprocessing module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants