Skip to content

ggml: ARM NEON dequant kernel for turbo4 (vqtbl4q_u8 4-bit PolarQuant)#16

Open
WillowOneVision wants to merge 1 commit into
AtomicBot-ai:feature/turboquant-kv-cachefrom
WillowOneVision:cecil/neon-turbo4-dequant
Open

ggml: ARM NEON dequant kernel for turbo4 (vqtbl4q_u8 4-bit PolarQuant)#16
WillowOneVision wants to merge 1 commit into
AtomicBot-ai:feature/turboquant-kv-cachefrom
WillowOneVision:cecil/neon-turbo4-dequant

Conversation

@WillowOneVision
Copy link
Copy Markdown

Summary

First aarch64 NEON SIMD implementation of the TurboQuant turbo4_0 dequantization kernel. Reference impls already cover Metal (Apple Silicon), CUDA (NVIDIA), Vulkan (cross-GPU) and a portable scalar C path. This adds the ARM NEON path, which is the dominant SIMD on edge devices: Raspberry Pi 4/5, Apple M-series, AWS Graviton, Cortex-X Android.

Algorithm

For each block (QK_TURBO4 = 128 elements):

  1. Load norm fp16 → fp32, broadcast to a float32x4_t.
  2. Multiply the 16-entry CENTROIDS_4BIT table by norm → 16 pre-scaled fp32 values (64 bytes) held in 4× uint8x16_t (= uint8x16x4_t).
  3. Per inner iteration (16 iters per block, 8 elements each):
    • Load 4 packed bytes (8 nibbles) from qs.
    • Extract low/high nibbles via vand_u8 + vshr_n_u8, interleave via vzip_u8.
    • Multiply nibbles by 4 (= LUT byte offset) via vshl_n_u8.
    • Build full 16-byte index vector by broadcasting nibbles via vqtbl1q_u8 against {0,0,0,0,1,1,1,1,…} and adding {0,1,2,3,…} stride.
    • Apply vqtbl4q_u8(lut, idx) — SIMD lookup into the 64-byte pre-scaled LUT.
    • Store as two float32x4_t per iteration.

Total per block: ~36 vector ops vs ~768 scalar ops (load + shift + mask + scalar gather + scalar mul + store × 128 elements).

Bit-exact guarantee

Pre-scaling the LUT (centroid * norm once per LUT entry) instead of per element (centroid * norm per dequant) produces the same fp32 product. IEEE 754 fp32 multiplication is deterministic; the multiplication operands are bit-identical between scalar and NEON paths. vqtbl4q_u8 is a pure byte permutation (no arithmetic) so the lookup itself is also bit-preserving.

A standalone validator (neon-turbo4-poc/validate.c) confirms this empirically:

N = 10,000 random blocks × 128 elements = 1,280,000 fp32 values, 0 bit-mismatches vs scalar reference.

Build: gcc -O3 -march=armv8.2-a+fp16 -ffp-contract=off (FMA contraction disabled to prevent compiler-introduced precision divergence between paths).

Performance

Microbench (standalone kernel, Cortex-A76 / Raspberry Pi 16 GB)

Working set Scalar ns/block NEON ns/block Scalar GB/s out NEON GB/s out Speedup
128 blocks (8.7 KB) 171.37 80.17 2.99 6.39 2.14×
512 blocks (35 KB) 164.21 83.23 3.12 6.15 1.97×
2048 blocks (140 KB) 163.76 80.87 3.13 6.33 2.03×
4096 blocks (280 KB) 161.97 80.92 3.16 6.33 2.00×
16384 blocks (1.1 MB) 161.38 82.00 3.17 6.24 1.97×
65536 blocks (4.4 MB) 160.80 84.88 3.18 6.03 1.89×

Robust 1.89-2.14× speedup across L1/L2/DRAM working sets. NEON path is memory-bound near 6 GB/s output bandwidth (Cortex-A76 LPDDR4X-4267 single-thread ceiling). 2× rather than 5-10× because gcc -O3 auto-vectorizes partial scalar arithmetic but cannot vectorize the 16-entry fp32 LUT gather across nibble indices.

End-to-end llama-server (Gemma E4B + turbo4 KV cache)

Prompt Tokens Scalar AVG NEON AVG Δ
long_200 (attention paragraph, 200 max) 126 generated 2.10 tok/s 2.17 tok/s +3.3%
med_150 (Bayesian 5 properties, 150 max) 69 generated 2.14 tok/s 2.18 tok/s +1.9%

CV ~2-3% per cell, 3 trials each, temperature=0, cache_prompt=false. Modest end-to-end gain is expected — dequant is a small fraction of total inference cost, so per Amdahl 2× on ~5% fraction ≈ ~2.5% wall-clock.

Verification on the binary

The kernel inlines into dequantize_row_turbo4_0. objdump -d build/bin/libggml-base.so confirms expected instructions:

4f9f90db    fmul    v27.4s, v6.4s, v31.s[0]       ; LUT × norm pre-scale (×4)
4e0103da    tbl     v26.16b, {v30.16b}, v1.16b    ; vqtbl1q_u8 broadcast
4e1a629a    tbl     v26.16b, {v20.16b-v23.16b}, v26.16b   ; vqtbl4q_u8 64-byte LUT lookup

Build dispatch

  • Default on aarch64 + NEON : NEON path active via #if defined(__ARM_NEON) && defined(__aarch64__) && !defined(GGML_TURBO_NEON_DISABLE). Sets #define GGML_TURBO_NEON 1.
  • Other targets (x86, RISC-V, scalar-only ARM): scalar fallback unchanged, no behavior change.
  • Explicit disable for debug or A/B: -DGGML_TURBO_NEON_DISABLE (used in the end-to-end measurements above).
  • No CMake change — compile-time autodetection via __ARM_NEON + __aarch64__ predefined macros.

Limitations / follow-ups (not in this PR)

  • Only turbo4_0 4-bit PolarQuant branch (TURBO4_USE_4BIT=1, the default) is NEON-accelerated. turbo3_0 (3-bit, 8 centroids) and turbo2_0 (2-bit, 4 centroids) paths still fall through to scalar. Both are analogous, would be follow-up PRs.
  • quantize_row_turbo4_0_ref (encode path) unchanged; called only at model load + KV-cache-write, not on the inference hot path.
  • turbo_cpu_fwht (128-element Walsh-Hadamard butterfly) still scalar. NEON-izing would require a 7-stage vtrn/vrev butterfly; separate scope.

Test plan

  • Bit-exact: 10,000 random blocks, 0 mismatches vs scalar.
  • Microbench: 1.89-2.14× speedup across L1/L2/DRAM working sets.
  • End-to-end: llama-server tok/s improvement empirically observed.
  • Disassembly: vqtbl4q_u8 + fmul confirmed in compiled binary.
  • Zero-cost fallback on non-aarch64: scalar code path unchanged behind #else.

First aarch64 NEON SIMD implementation of TurboQuant turbo4_0 dequantization.

Reference implementations existed for Metal, CUDA, Vulkan and scalar C; this
adds the ARM NEON path (ARMv8.0+ baseline, vqtbl4q_u8). Strategy: pre-scale
the 16-entry CENTROIDS_4BIT * norm into a 64-byte LUT held in 4x uint8x16_t,
then use vqtbl4q_u8 for SIMD nibble->fp32 lookup. Auto-enabled at compile
time via __ARM_NEON + __aarch64__; disable for debug with
-DGGML_TURBO_NEON_DISABLE.

Validation:
- Bit-exact: 10,000 random blocks x 128 elements = 1,280,000 fp32 values,
  0 bit-mismatches vs scalar reference. IEEE 754 deterministic since
  pre-scaled LUT produces the same (centroid * norm) fp32 product.
- Microbench Cortex-A76 (Raspberry Pi 5/16): 2.01x speedup over -O3 scalar,
  3.00 -> 6.04 GB/s out, robust 1.89-2.14x across working sets 128 -> 65,536
  blocks (8.7 KB -> 4.4 MB, spans L1/L2/DRAM).
- End-to-end Pi16 llama-server (Gemma E4B + turbo4 KV): +1.9-3.3% tok/s
  on text generation (modest because dequant is small fraction of total
  inference cost; matches Amdahl ceiling).
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant