ggml: ARM NEON dequant kernel for turbo4 (vqtbl4q_u8 4-bit PolarQuant)#16
Open
WillowOneVision wants to merge 1 commit into
Conversation
First aarch64 NEON SIMD implementation of TurboQuant turbo4_0 dequantization. Reference implementations existed for Metal, CUDA, Vulkan and scalar C; this adds the ARM NEON path (ARMv8.0+ baseline, vqtbl4q_u8). Strategy: pre-scale the 16-entry CENTROIDS_4BIT * norm into a 64-byte LUT held in 4x uint8x16_t, then use vqtbl4q_u8 for SIMD nibble->fp32 lookup. Auto-enabled at compile time via __ARM_NEON + __aarch64__; disable for debug with -DGGML_TURBO_NEON_DISABLE. Validation: - Bit-exact: 10,000 random blocks x 128 elements = 1,280,000 fp32 values, 0 bit-mismatches vs scalar reference. IEEE 754 deterministic since pre-scaled LUT produces the same (centroid * norm) fp32 product. - Microbench Cortex-A76 (Raspberry Pi 5/16): 2.01x speedup over -O3 scalar, 3.00 -> 6.04 GB/s out, robust 1.89-2.14x across working sets 128 -> 65,536 blocks (8.7 KB -> 4.4 MB, spans L1/L2/DRAM). - End-to-end Pi16 llama-server (Gemma E4B + turbo4 KV): +1.9-3.3% tok/s on text generation (modest because dequant is small fraction of total inference cost; matches Amdahl ceiling).
9 tasks
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
First aarch64 NEON SIMD implementation of the TurboQuant turbo4_0 dequantization kernel. Reference impls already cover Metal (Apple Silicon), CUDA (NVIDIA), Vulkan (cross-GPU) and a portable scalar C path. This adds the ARM NEON path, which is the dominant SIMD on edge devices: Raspberry Pi 4/5, Apple M-series, AWS Graviton, Cortex-X Android.
Algorithm
For each block (
QK_TURBO4 = 128elements):normfp16 → fp32, broadcast to afloat32x4_t.CENTROIDS_4BITtable bynorm→ 16 pre-scaled fp32 values (64 bytes) held in 4×uint8x16_t(=uint8x16x4_t).qs.vand_u8+vshr_n_u8, interleave viavzip_u8.vshl_n_u8.vqtbl1q_u8against{0,0,0,0,1,1,1,1,…}and adding{0,1,2,3,…}stride.vqtbl4q_u8(lut, idx)— SIMD lookup into the 64-byte pre-scaled LUT.float32x4_tper iteration.Total per block: ~36 vector ops vs ~768 scalar ops (load + shift + mask + scalar gather + scalar mul + store × 128 elements).
Bit-exact guarantee
Pre-scaling the LUT (
centroid * normonce per LUT entry) instead of per element (centroid * normper dequant) produces the same fp32 product. IEEE 754 fp32 multiplication is deterministic; the multiplication operands are bit-identical between scalar and NEON paths.vqtbl4q_u8is a pure byte permutation (no arithmetic) so the lookup itself is also bit-preserving.A standalone validator (
neon-turbo4-poc/validate.c) confirms this empirically:Build:
gcc -O3 -march=armv8.2-a+fp16 -ffp-contract=off(FMA contraction disabled to prevent compiler-introduced precision divergence between paths).Performance
Microbench (standalone kernel, Cortex-A76 / Raspberry Pi 16 GB)
Robust 1.89-2.14× speedup across L1/L2/DRAM working sets. NEON path is memory-bound near 6 GB/s output bandwidth (Cortex-A76 LPDDR4X-4267 single-thread ceiling). 2× rather than 5-10× because gcc
-O3auto-vectorizes partial scalar arithmetic but cannot vectorize the 16-entry fp32 LUT gather across nibble indices.End-to-end llama-server (Gemma E4B + turbo4 KV cache)
CV ~2-3% per cell, 3 trials each, temperature=0, cache_prompt=false. Modest end-to-end gain is expected — dequant is a small fraction of total inference cost, so per Amdahl 2× on ~5% fraction ≈ ~2.5% wall-clock.
Verification on the binary
The kernel inlines into
dequantize_row_turbo4_0.objdump -d build/bin/libggml-base.soconfirms expected instructions:Build dispatch
#if defined(__ARM_NEON) && defined(__aarch64__) && !defined(GGML_TURBO_NEON_DISABLE). Sets#define GGML_TURBO_NEON 1.-DGGML_TURBO_NEON_DISABLE(used in the end-to-end measurements above).__ARM_NEON+__aarch64__predefined macros.Limitations / follow-ups (not in this PR)
turbo4_04-bit PolarQuant branch (TURBO4_USE_4BIT=1, the default) is NEON-accelerated.turbo3_0(3-bit, 8 centroids) andturbo2_0(2-bit, 4 centroids) paths still fall through to scalar. Both are analogous, would be follow-up PRs.quantize_row_turbo4_0_ref(encode path) unchanged; called only at model load + KV-cache-write, not on the inference hot path.turbo_cpu_fwht(128-element Walsh-Hadamard butterfly) still scalar. NEON-izing would require a 7-stagevtrn/vrevbutterfly; separate scope.Test plan
#else.