Skip to content

feat: native i0 (modified Bessel function) and kaiser window#3156

Open
Vlor999 wants to merge 1 commit intoml-explore:mainfrom
Vlor999:feat/add-io-function
Open

feat: native i0 (modified Bessel function) and kaiser window#3156
Vlor999 wants to merge 1 commit intoml-explore:mainfrom
Vlor999:feat/add-io-function

Conversation

@Vlor999
Copy link
Contributor

@Vlor999 Vlor999 commented Feb 22, 2026

Description

This PR introduces native hardware-accelerated support for the modified Bessel function of the first kind, order zero (mlx.core.i0), and leverages it to implement the Kaiser window function (mlx.core.kaiser).

Implementing i0 natively allows kaiser to be fully compiled and evaluated on Apple Silicon GPUs without relying on slow CPU fallback loops or external libraries.

Implementation Details

1. Hardware-Accelerated i0 (Cephes Approximation)

Since $I_0(x)$ cannot be computed via simple composition, it is implemented as a core unary primitive using the industry-standard Cephes polynomial approximation (matching NumPy/SciPy):

  • CPU Backend (mlx/backend/cpu/simd/math.h): Implemented vectorized evaluation using Simd<T, N> and FMA instructions for high-throughput CPU processing.
  • GPU Backend (mlx/backend/metal/kernels/i0.h): Implemented a native Metal shader using metal::fma and metal::precise::exp/sqrt to ensure numerical stability matching the CPU.
  • Autograd / VJP (mlx/primitives.cpp): Implemented the analytic derivative. Since $\frac{d}{dx} I_0(x) = I_1(x)$, I implemented a composite $I_1(x)$ function directly in the jvp pass using the same two-domain Cephes approximation logic.

2. Kaiser Window (mlx.core.kaiser)

  • Implemented in mlx/ops.cpp as a composite function using the new i0 primitive.
  • Formula: $w(n) = I_0\left(\beta \sqrt{1 - \left(\frac{2n}{M-1} - 1\right)^2}\right) / I_0(\beta)$
  • Optimization: Uses scalar broadcasting for 1.0f to avoid unnecessary array allocations in the sqrt term. Handles edge cases (M=0, M=1).

3. Python Bindings

  • Exposed i0 and kaiser in python/src/ops.cpp with fully typed nb::sig and LaTeX docstrings.

Test Plan

Added comprehensive tests in python/tests/test_ops.py:

  1. NumPy Parity (test_i0): Checked mx.i0 against numpy.i0 across both polynomial domains (|x| <= 3.75 and |x| > 3.75).
  2. Hardware Parity (test_i0_cpu_gpu_parity): Verified that forcing mx.stream(mx.cpu) and mx.stream(mx.gpu) yields identical results within 1e-5 tolerance.
  3. Kaiser Parity (test_kaiser_general): Verified mx.kaiser(M, beta) against np.kaiser including edge cases (M=1) and symmetry.

All tests pass successfully.

Checklist

Put an x in the boxes that apply.

  • I have read the CONTRIBUTING document
  • I have run pre-commit run --all-files to format my code / installed pre-commit prior to committing changes
  • I have added tests that prove my fix is effective or that my feature works
  • I have updated the necessary documentation (if needed)

@Vlor999 Vlor999 force-pushed the feat/add-io-function branch from a09f9d0 to 11b9b39 Compare February 24, 2026 20:25
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant