Skip to content

Commit 0bfe6ad

Browse files
authored
Merge pull request #76 from vickiw973/feature/nvfp4_dual_gemm
Add nvfp4 dual_gemm example
2 parents 89da649 + 9832f6d commit 0bfe6ad

File tree

6 files changed

+1426
-0
lines changed

6 files changed

+1426
-0
lines changed
Lines changed: 194 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,194 @@
1+
import torch
2+
from task import input_t, output_t
3+
from utils import make_match_reference
4+
5+
# Scaling factor vector size
6+
sf_vec_size = 16
7+
8+
# Helper function for ceiling division
9+
def ceil_div(a, b):
10+
return (a + b - 1) // b
11+
12+
# Helper function to convert scale factor tensor to blocked format
13+
def to_blocked(input_matrix):
14+
rows, cols = input_matrix.shape
15+
16+
# Please ensure rows and cols are multiples of 128 and 4 respectively
17+
n_row_blocks = ceil_div(rows, 128)
18+
n_col_blocks = ceil_div(cols, 4)
19+
20+
padded = input_matrix
21+
blocks = padded.view(n_row_blocks, 128, n_col_blocks, 4).permute(0, 2, 1, 3)
22+
rearranged = blocks.reshape(-1, 4, 32, 4).transpose(1, 2).reshape(-1, 32, 16)
23+
24+
return rearranged.flatten()
25+
26+
27+
def ref_kernel(
28+
data: input_t,
29+
) -> output_t:
30+
"""
31+
PyTorch reference implementation of NVFP4 block-scaled dual GEMM with silu activation,
32+
C = silu(A @ B1) * (A @ B2).
33+
"""
34+
a_ref, b1_ref, b2_ref, sfa_ref_cpu, sfb1_ref_cpu, sfb2_ref_cpu, _, _, _, c_ref = data
35+
36+
# Get dimensions from MxNxL layout
37+
m, n, l = c_ref.shape
38+
39+
# Call torch._scaled_mm to compute the GEMV result
40+
ref1 = torch.empty(
41+
(l, m, n),
42+
dtype=torch.float32,
43+
device="cuda",
44+
).permute(1, 2, 0)
45+
ref2 = torch.empty(
46+
(l, m, n),
47+
dtype=torch.float32,
48+
device="cuda",
49+
).permute(1, 2, 0)
50+
for l_idx in range(l):
51+
# Convert the scale factor tensor to blocked format
52+
scale_a = to_blocked(sfa_ref_cpu[:, :, l_idx])
53+
scale_b1 = to_blocked(sfb1_ref_cpu[:, :, l_idx])
54+
scale_b2 = to_blocked(sfb2_ref_cpu[:, :, l_idx])
55+
# (m, k) @ (n, k).T -> (m, n)
56+
res1 = torch._scaled_mm(
57+
a_ref[:, :, l_idx],
58+
b1_ref[:, :, l_idx].transpose(0, 1),
59+
scale_a.cuda(),
60+
scale_b1.cuda(),
61+
bias=None,
62+
out_dtype=torch.float32,
63+
)
64+
ref1[:, :, l_idx] = res1
65+
66+
res2 = torch._scaled_mm(
67+
a_ref[:, :, l_idx],
68+
b2_ref[:, :, l_idx].transpose(0, 1),
69+
scale_a.cuda(),
70+
scale_b2.cuda(),
71+
bias=None,
72+
out_dtype=torch.float32,
73+
)
74+
ref2[:, :, l_idx] = res2
75+
# Do silu on the first GEMM result and multiply with the second GEMM result
76+
c_ref = (torch.nn.functional.silu(ref1) * ref2).to(torch.float16)
77+
return c_ref
78+
79+
80+
def generate_input(
81+
m: int,
82+
n: int,
83+
k: int,
84+
l: int,
85+
seed: int,
86+
):
87+
"""
88+
Generate input tensors for NVFP4 block-scaled dual GEMM with silu activation,
89+
C = silu(A @ B1) * (A @ B2).
90+
91+
Args:
92+
m: Number of rows in matrix A
93+
n: Number of columns in matrix B1 and B2
94+
k: Number of columns in A and rows of B1 and B2
95+
l: Batch size
96+
seed: Random seed for reproducibility
97+
98+
Returns:
99+
Tuple of (a, b, scale_a, scale_b, c) where:
100+
a: [m, k, l] - Input matrix in torch.float4e2m1fn_x2 data type
101+
b1: [n, k, l] - Input matrix in torch.float4e2m1fn_x2 data type
102+
b2: [n, k, l] - Input matrix in torch.float4e2m1fn_x2 data type
103+
scale_a: [m, k, l] - Input scale factors in torch.float8e4m3fn data type
104+
scale_b1: [n, k, l] - Input scale factors in torch.float8e4m3fn data type
105+
scale_b2: [n, k, l] - Input scale factors in torch.float8e4m3fn data type
106+
scale_a_permuted: [32, 4, rest_m, 4, rest_k, l] - Input scale factors in torch.float8e4m3fn data type
107+
scale_b1_permuted: [32, 4, rest_n, 4, rest_k, l] - Input scale factors in torch.float8e4m3fn data type
108+
scale_b2_permuted: [32, 4, rest_n, 4, rest_k, l] - Input scale factors in torch.float8e4m3fn data type
109+
c: [m, n, l] - Output matrix in torch.float16 data type
110+
"""
111+
torch.manual_seed(seed)
112+
113+
# Generate uint8 tensor, then convert to float4e2m1fn_x2 data type
114+
a_ref = torch.randint(
115+
-6, 6, (l, m, k // 2), dtype=torch.int8, device="cuda"
116+
).permute(1, 2, 0)
117+
b1_ref = torch.randint(
118+
-6, 6, (l, n, k // 2), dtype=torch.int8, device="cuda"
119+
).permute(1, 2, 0)
120+
b2_ref = torch.randint(
121+
-6, 6, (l, n, k // 2), dtype=torch.int8, device="cuda"
122+
).permute(1, 2, 0)
123+
a_ref = a_ref.view(torch.float4_e2m1fn_x2)
124+
b1_ref = b1_ref.view(torch.float4_e2m1fn_x2)
125+
b2_ref = b2_ref.view(torch.float4_e2m1fn_x2)
126+
127+
# Create float16 output tensor
128+
c_ref = torch.randn((l, m, n), dtype=torch.float16, device="cuda").permute(
129+
1, 2, 0
130+
)
131+
132+
# Helper function to prepare the scale factor tensors for both reference
133+
# kernel and customize kernel. The customized data layout can be found in:
134+
# https://docs.nvidia.com/cuda/cublas/index.html?highlight=fp4#d-block-scaling-factors-layout
135+
def create_scale_factor_tensors(l, mn, sf_k):
136+
# Create the reference scale factor tensor (mn, sf_k, l) on CPU.
137+
ref_shape = (l, mn, sf_k)
138+
ref_permute_order = (1, 2, 0)
139+
# Init with uint8 tensor, then convert to float8_e4m3fn
140+
ref_f8_random_int = torch.randint(-3, 3, ref_shape, dtype=torch.int8, device='cuda')
141+
ref_f8_torch_tensor = ref_f8_random_int.to(dtype=torch.float8_e4m3fn)
142+
# permute to match ref_permute_order
143+
ref_f8_torch_tensor_permuted = ref_f8_torch_tensor.permute(*ref_permute_order)
144+
145+
atom_m = (32, 4)
146+
atom_k = 4
147+
mma_shape = (
148+
l, # batch size
149+
ceil_div(mn, atom_m[0] * atom_m[1]),
150+
ceil_div(sf_k, atom_k),
151+
atom_m[0],
152+
atom_m[1],
153+
atom_k,
154+
)
155+
156+
# Reorder scale factor tensor to (32, 4, rest_m, 4, rest_k, l) layout
157+
# Which is needed by the CuTe customized kernel
158+
mma_permute_order = (3, 4, 1, 5, 2, 0)
159+
# Generate a random int8 tensor, then convert to float8_e4m3fn
160+
rand_int_tensor = torch.randint(-3, 3, mma_shape, dtype=torch.int8, device='cuda')
161+
reordered_f8_torch_tensor = rand_int_tensor.to(dtype=torch.float8_e4m3fn)
162+
# Permute according to mma_permute_order
163+
reordered_f8_torch_tensor = reordered_f8_torch_tensor.permute(*mma_permute_order)
164+
165+
# GPU-side vectorized reordering (replaces slow CPU nested loops)
166+
# Create index grids for all dimensions
167+
i_idx = torch.arange(mn, device='cuda')
168+
j_idx = torch.arange(sf_k, device='cuda')
169+
b_idx = torch.arange(l, device='cuda')
170+
171+
# Create meshgrid for all combinations of (i, j, b)
172+
i_grid, j_grid, b_grid = torch.meshgrid(i_idx, j_idx, b_idx, indexing='ij')
173+
174+
# Calculate target indices in vectorized manner
175+
mm = i_grid // (atom_m[0] * atom_m[1])
176+
mm32 = i_grid % atom_m[0]
177+
mm4 = (i_grid % 128) // atom_m[0]
178+
kk = j_grid // atom_k
179+
kk4 = j_grid % atom_k
180+
181+
# Perform the reordering with advanced indexing (all on GPU)
182+
reordered_f8_torch_tensor[mm32, mm4, mm, kk4, kk, b_grid] = ref_f8_torch_tensor_permuted[i_grid, j_grid, b_grid]
183+
184+
return ref_f8_torch_tensor_permuted.cpu(), reordered_f8_torch_tensor
185+
186+
sf_k = ceil_div(k, sf_vec_size)
187+
sfa_ref_cpu, sfa_ref_permuted = create_scale_factor_tensors(l, m, sf_k)
188+
sfb1_ref_cpu, sfb1_ref_permuted = create_scale_factor_tensors(l, n, sf_k)
189+
sfb2_ref_cpu, sfb2_ref_permuted = create_scale_factor_tensors(l, n, sf_k)
190+
191+
return (a_ref, b1_ref, b2_ref, sfa_ref_cpu.to("cuda"), sfb1_ref_cpu.to("cuda"), sfb2_ref_cpu.to("cuda"), sfa_ref_permuted, sfb1_ref_permuted, sfb2_ref_permuted, c_ref)
192+
193+
194+
check_implementation = make_match_reference(ref_kernel, rtol=1e-03, atol=1e-03)

0 commit comments

Comments
 (0)