Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions reference_hierarchical_nvfp4/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
"""NVFP4 reference: same *recipe* as TE (S_enc, S_dec fp8, bsi, FP4) except
``S_enc = (fp8_max*fp4_max)/amax`` uses each **1x64** window's max instead of
per-tensor global amax. See ``fp8_e4m3_utils_np`` and ``core_nvfp4.cuh``.

- PyTorch: symbols below (requires ``torch`` + ``numpy`` for E4M3).
- CPU: ``hierarchical_nvfp4_ref_numpy`` (``numpy`` only), or run
``python reference_hierarchical_nvfp4/hierarchical_nvfp4_ref_numpy.py``.
"""

from .hierarchical_nvfp4_ref import (
COARSE,
FINE,
HierarchicalNVFP4Colwise,
HierarchicalNVFP4Rowwise,
dequantize_colwise,
dequantize_rowwise,
fp4_e2m1_grid_torch,
quantize_columnwise_1x64_1x16,
quantize_rowwise_1x64_1x16,
reference_matmul_tn,
roundtrip_error,
)

__all__ = [
"COARSE",
"FINE",
"HierarchicalNVFP4Colwise",
"HierarchicalNVFP4Rowwise",
"dequantize_colwise",
"dequantize_rowwise",
"fp4_e2m1_grid_torch",
"quantize_columnwise_1x64_1x16",
"quantize_rowwise_1x64_1x16",
"reference_matmul_tn",
"roundtrip_error",
]
116 changes: 116 additions & 0 deletions reference_hierarchical_nvfp4/compare_64x64_global_vs_1x64.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
# Compare: TE-style NVFP4 (single S_enc from *global* amax) vs
# reference 1x64 (S_enc from max|x| in each 1x64 K-window = per-row for K=64).
# Matrix shape (M, K) = (64, 64), rowwise along K, 1x16 blocks * 4 per row.
#
# Run (no torch): python3 reference_hierarchical_nvfp4/compare_64x64_global_vs_1x64.py

from __future__ import annotations

import os
import sys

import numpy as np

# Load sibling ref modules (avoid package __init__ -> torch)
_REF_DIR = os.path.dirname(os.path.abspath(__file__))
if _REF_DIR not in sys.path:
sys.path.insert(0, _REF_DIR)

from fp8_e4m3_utils_np import (
TINY,
compute_S_dec_f32_before_cast_te,
compute_S_enc_from_amax_1x64_like_te,
e4m3_u8_to_f32,
f32_to_e4m3_u8,
)
from hierarchical_nvfp4_ref_numpy import (
FINE,
FP4_E2M1_GRID,
dequantize_rowwise,
quantize_rowwise_1x64_1x16,
)

M = 64
K = 64


def _round_nearest_fp4(x: np.ndarray) -> np.ndarray:
d = np.abs(x[..., None] - FP4_E2M1_GRID)
return FP4_E2M1_GRID[np.argmin(d, axis=-1).astype(np.int64)]


def te_nvfp4_rowwise_global_senc(
x: np.ndarray, eps: float = 1e-12
) -> tuple[np.ndarray, float, np.ndarray, np.ndarray]:
"""
Same math as ref but S_enc = 2688 / global_amax (single float for all blocks).
Returns: x_recon, global_amax, w_pre_fp4, S_dec_u8 (M, n16) per-block fp8
"""
x = np.asarray(x, np.float32)
m, k = x.shape
g_amax = float(np.max(np.abs(x)))
S_g = compute_S_enc_from_amax_1x64_like_te(g_amax)
n16 = (k + FINE - 1) // FINE
w = np.empty_like(x)
s_dec = np.empty((m, n16), dtype=np.uint8)
for r in range(m):
t16b = 0
while t16b * FINE < k:
lo, hi = t16b * FINE, min((t16b + 1) * FINE, k)
segx = x[r, lo:hi]
bamax = float(np.max(np.abs(segx)))
if bamax < eps:
bamax = float(eps)
raw = compute_S_dec_f32_before_cast_te(bamax, S_g)
u = f32_to_e4m3_u8(np.array([raw], dtype=np.float32).reshape(1))
s_dec[r, t16b] = u.reshape(-1)[0]
s_d = max(float(e4m3_u8_to_f32(s_dec[r : r + 1, t16b : t16b + 1]).reshape(-1)[0]), TINY)
bsi = S_g / s_d
w[r, lo:hi] = segx * bsi
t16b += 1
q = _round_nearest_fp4(w)
t16g = (np.arange(k) // FINE).astype(np.int64)
sde = e4m3_u8_to_f32(s_dec[:, t16g].astype(np.uint8))
sde = np.maximum(sde, TINY)
x_recon = q * (sde / S_g)
return x_recon.astype(np.float32), g_amax, w, s_dec


def main() -> None:
rng = np.random.default_rng(2026)
# "Real" data: not uniform — heavy-tailed + one row scaled up for local/global gap
x = rng.standard_normal((M, K)).astype(np.float32)
x *= 0.35
x[7, :] *= 4.0
x[0:8, 12:20] += 0.4

x_ref1 = quantize_rowwise_1x64_1x16(x)
recon_1x64 = dequantize_rowwise(x_ref1)
recon_global, g_amax, _, _ = te_nvfp4_rowwise_global_senc(x)

d = np.abs(recon_1x64 - x)
d2 = np.abs(recon_global - x)
dg = np.abs(recon_1x64 - recon_global)

print("=== 64x64 数值对比 (rowwise, K=4×16) ===")
print("global_amax =", g_amax)
print("S_enc 现网(全局) = 2688 / global_amax =", compute_S_enc_from_amax_1x64_like_te(g_amax))
print("---")
print("量化再反归一 vs 原张量: max abs err [1x64 S_enc 参考] :", float(np.max(d)))
print("量化再反归一 vs 原张量: max abs err [全局 S_enc] :", float(np.max(d2)))
print("RMS 误差 vs 原张量 [1x64]:", float(np.sqrt(np.mean(d**2))))
print("RMS 误差 vs 原张量 [全局]:", float(np.sqrt(np.mean(d2**2))))
print("---")
print("两种重建之间的 max abs 差 |recon_1x64 - recon_global| :", float(np.max(dg)))
print("RMS( recon_1x64 - recon_global ) :", float(np.sqrt(np.mean(dg**2))))
fn = float(np.linalg.norm(x, "fro"))
if fn > 0:
print("||x||_F =", fn)
print(
"Fro 相对: ||recon_1x64 - recon_global||_F / ||x||_F =",
float(np.linalg.norm(dg, "fro") / fn),
)


if __name__ == "__main__":
main()
71 changes: 71 additions & 0 deletions reference_hierarchical_nvfp4/fp8_e4m3_utils_np.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# FP8 E4M3: decode uint8->float32 and encode f32->nearest uint8 (all 256 codes).
# Matches OCP/FP8 E4M3 for reference roundtrip; close to CUDA static_cast<fp8e4m3>(f32).

from __future__ import annotations

import numpy as np

# Match transformer_engine::detail::TypeExtrema (common.h) / nvfp4.cu kernels
FP8_E4M3_FMAX: float = 448.0
FP4_E2M1_FMAX: float = 6.0
S_ENC_NUMER: float = FP8_E4M3_FMAX * FP4_E2M1_FMAX # 2688
FLT_MAX: float = 3.402823466e38
TINY: float = 1.17549435e-38


def _decode_e4m3_byte(b: int) -> float:
u = b & 0xFF
sign = u >> 7
exp = (u >> 3) & 0x0F
man = u & 0x07
if exp == 0:
if man == 0:
return 0.0
v = (man / 8.0) * (2.0 ** (-6))
else:
v = (1.0 + man / 8.0) * (2.0 ** (exp - 7))
return -v if sign else v


# Precompute 256 float32 values for all E4M3 codes
_E4M3_TABLE: np.ndarray = np.array([_decode_e4m3_byte(i) for i in range(256)], dtype=np.float32)


def f32_to_e4m3_u8(x: np.ndarray) -> np.ndarray:
"""Round each element to nearest fp8e4m3 (by L_inf on 256 codes). x can be any shape."""
x = np.asarray(x, dtype=np.float32)
flat = x.ravel()[:, None]
d = np.abs(flat - _E4M3_TABLE[None, :])
out = np.argmin(d, axis=1).astype(np.uint8)
return out.reshape(x.shape)


def e4m3_u8_to_f32(b: np.ndarray) -> np.ndarray:
return _E4M3_TABLE[np.asarray(b, dtype=np.int32) & 0xFF].astype(np.float32)


def compute_S_enc_from_amax_1x64_like_te(amax: float, eps: float = TINY) -> float:
"""
Same as compute_global_encode_scaling_factor_FP4 in core_nvfp4.cuh, but *amax* is
the 1x64 local max (replaces per-tensor global amax in this ref).
"""
a = float(amax)
if a <= 0.0 or not np.isfinite(a):
return 1.0
safe = max(a, eps)
g = S_ENC_NUMER / safe
return float(min(g, FLT_MAX))


def compute_S_dec_f32_before_cast_te(block_amax: float, S_enc: float) -> float:
"""
Unquantized S_dec = block_amax * (S_enc / 6) before cast to e4m3
(compute_decoding_scaling_factor, quantization_SF in core_nvfp4.cuh).
"""
return float(np.float32(block_amax * (S_enc * (1.0 / FP4_E2M1_FMAX))))


def f32_e4m3_f32(x: float) -> float:
"""Cast pipeline f32 -> fp8e4m3 -> f32, numpy."""
u = f32_to_e4m3_u8(np.array([x], dtype=np.float32))
return float(e4m3_u8_to_f32(u)[0])
Loading