|
| 1 | +import torch |
| 2 | +from task import input_t, output_t |
| 3 | +from utils import make_match_reference |
| 4 | + |
| 5 | +# Scaling factor vector size |
| 6 | +sf_vec_size = 16 |
| 7 | + |
| 8 | +# Helper function for ceiling division |
| 9 | +def ceil_div(a, b): |
| 10 | + return (a + b - 1) // b |
| 11 | + |
| 12 | +# Helper function to convert scale factor tensor to blocked format |
| 13 | +def to_blocked(input_matrix): |
| 14 | + rows, cols = input_matrix.shape |
| 15 | + |
| 16 | + # Please ensure rows and cols are multiples of 128 and 4 respectively |
| 17 | + n_row_blocks = ceil_div(rows, 128) |
| 18 | + n_col_blocks = ceil_div(cols, 4) |
| 19 | + |
| 20 | + padded = input_matrix |
| 21 | + blocks = padded.view(n_row_blocks, 128, n_col_blocks, 4).permute(0, 2, 1, 3) |
| 22 | + rearranged = blocks.reshape(-1, 4, 32, 4).transpose(1, 2).reshape(-1, 32, 16) |
| 23 | + |
| 24 | + return rearranged.flatten() |
| 25 | + |
| 26 | + |
| 27 | +def ref_kernel( |
| 28 | + data: input_t, |
| 29 | +) -> output_t: |
| 30 | + """ |
| 31 | + PyTorch reference implementation of NVFP4 block-scaled dual GEMM with silu activation, |
| 32 | + C = silu(A @ B1) * (A @ B2). |
| 33 | + """ |
| 34 | + a_ref, b1_ref, b2_ref, sfa_ref_cpu, sfb1_ref_cpu, sfb2_ref_cpu, _, _, _, c_ref = data |
| 35 | + |
| 36 | + # Get dimensions from MxNxL layout |
| 37 | + m, n, l = c_ref.shape |
| 38 | + |
| 39 | + # Call torch._scaled_mm to compute the GEMV result |
| 40 | + ref1 = torch.empty( |
| 41 | + (l, m, n), |
| 42 | + dtype=torch.float32, |
| 43 | + device="cuda", |
| 44 | + ).permute(1, 2, 0) |
| 45 | + ref2 = torch.empty( |
| 46 | + (l, m, n), |
| 47 | + dtype=torch.float32, |
| 48 | + device="cuda", |
| 49 | + ).permute(1, 2, 0) |
| 50 | + for l_idx in range(l): |
| 51 | + # Convert the scale factor tensor to blocked format |
| 52 | + scale_a = to_blocked(sfa_ref_cpu[:, :, l_idx]) |
| 53 | + scale_b1 = to_blocked(sfb1_ref_cpu[:, :, l_idx]) |
| 54 | + scale_b2 = to_blocked(sfb2_ref_cpu[:, :, l_idx]) |
| 55 | + # (m, k) @ (n, k).T -> (m, n) |
| 56 | + res1 = torch._scaled_mm( |
| 57 | + a_ref[:, :, l_idx], |
| 58 | + b1_ref[:, :, l_idx].transpose(0, 1), |
| 59 | + scale_a.cuda(), |
| 60 | + scale_b1.cuda(), |
| 61 | + bias=None, |
| 62 | + out_dtype=torch.float32, |
| 63 | + ) |
| 64 | + ref1[:, :, l_idx] = res1 |
| 65 | + |
| 66 | + res2 = torch._scaled_mm( |
| 67 | + a_ref[:, :, l_idx], |
| 68 | + b2_ref[:, :, l_idx].transpose(0, 1), |
| 69 | + scale_a.cuda(), |
| 70 | + scale_b2.cuda(), |
| 71 | + bias=None, |
| 72 | + out_dtype=torch.float32, |
| 73 | + ) |
| 74 | + ref2[:, :, l_idx] = res2 |
| 75 | + # Do silu on the first GEMM result and multiply with the second GEMM result |
| 76 | + c_ref = (torch.nn.functional.silu(ref1) * ref2).to(torch.float16) |
| 77 | + return c_ref |
| 78 | + |
| 79 | + |
| 80 | +def generate_input( |
| 81 | + m: int, |
| 82 | + n: int, |
| 83 | + k: int, |
| 84 | + l: int, |
| 85 | + seed: int, |
| 86 | +): |
| 87 | + """ |
| 88 | + Generate input tensors for NVFP4 block-scaled dual GEMM with silu activation, |
| 89 | + C = silu(A @ B1) * (A @ B2). |
| 90 | + |
| 91 | + Args: |
| 92 | + m: Number of rows in matrix A |
| 93 | + n: Number of columns in matrix B1 and B2 |
| 94 | + k: Number of columns in A and rows of B1 and B2 |
| 95 | + l: Batch size |
| 96 | + seed: Random seed for reproducibility |
| 97 | + |
| 98 | + Returns: |
| 99 | + Tuple of (a, b, scale_a, scale_b, c) where: |
| 100 | + a: [m, k, l] - Input matrix in torch.float4e2m1fn_x2 data type |
| 101 | + b1: [n, k, l] - Input matrix in torch.float4e2m1fn_x2 data type |
| 102 | + b2: [n, k, l] - Input matrix in torch.float4e2m1fn_x2 data type |
| 103 | + scale_a: [m, k, l] - Input scale factors in torch.float8e4m3fn data type |
| 104 | + scale_b1: [n, k, l] - Input scale factors in torch.float8e4m3fn data type |
| 105 | + scale_b2: [n, k, l] - Input scale factors in torch.float8e4m3fn data type |
| 106 | + scale_a_permuted: [32, 4, rest_m, 4, rest_k, l] - Input scale factors in torch.float8e4m3fn data type |
| 107 | + scale_b1_permuted: [32, 4, rest_n, 4, rest_k, l] - Input scale factors in torch.float8e4m3fn data type |
| 108 | + scale_b2_permuted: [32, 4, rest_n, 4, rest_k, l] - Input scale factors in torch.float8e4m3fn data type |
| 109 | + c: [m, n, l] - Output matrix in torch.float16 data type |
| 110 | + """ |
| 111 | + torch.manual_seed(seed) |
| 112 | + |
| 113 | + # Generate uint8 tensor, then convert to float4e2m1fn_x2 data type |
| 114 | + a_ref = torch.randint( |
| 115 | + -6, 6, (l, m, k // 2), dtype=torch.int8, device="cuda" |
| 116 | + ).permute(1, 2, 0) |
| 117 | + b1_ref = torch.randint( |
| 118 | + -6, 6, (l, n, k // 2), dtype=torch.int8, device="cuda" |
| 119 | + ).permute(1, 2, 0) |
| 120 | + b2_ref = torch.randint( |
| 121 | + -6, 6, (l, n, k // 2), dtype=torch.int8, device="cuda" |
| 122 | + ).permute(1, 2, 0) |
| 123 | + a_ref = a_ref.view(torch.float4_e2m1fn_x2) |
| 124 | + b1_ref = b1_ref.view(torch.float4_e2m1fn_x2) |
| 125 | + b2_ref = b2_ref.view(torch.float4_e2m1fn_x2) |
| 126 | + |
| 127 | + # Create float16 output tensor |
| 128 | + c_ref = torch.randn((l, m, n), dtype=torch.float16, device="cuda").permute( |
| 129 | + 1, 2, 0 |
| 130 | + ) |
| 131 | + |
| 132 | + # Helper function to prepare the scale factor tensors for both reference |
| 133 | + # kernel and customize kernel. The customized data layout can be found in: |
| 134 | + # https://docs.nvidia.com/cuda/cublas/index.html?highlight=fp4#d-block-scaling-factors-layout |
| 135 | + def create_scale_factor_tensors(l, mn, sf_k): |
| 136 | + # Create the reference scale factor tensor (mn, sf_k, l) on CPU. |
| 137 | + ref_shape = (l, mn, sf_k) |
| 138 | + ref_permute_order = (1, 2, 0) |
| 139 | + # Init with uint8 tensor, then convert to float8_e4m3fn |
| 140 | + ref_f8_random_int = torch.randint(-3, 3, ref_shape, dtype=torch.int8, device='cuda') |
| 141 | + ref_f8_torch_tensor = ref_f8_random_int.to(dtype=torch.float8_e4m3fn) |
| 142 | + # permute to match ref_permute_order |
| 143 | + ref_f8_torch_tensor_permuted = ref_f8_torch_tensor.permute(*ref_permute_order) |
| 144 | + |
| 145 | + atom_m = (32, 4) |
| 146 | + atom_k = 4 |
| 147 | + mma_shape = ( |
| 148 | + l, # batch size |
| 149 | + ceil_div(mn, atom_m[0] * atom_m[1]), |
| 150 | + ceil_div(sf_k, atom_k), |
| 151 | + atom_m[0], |
| 152 | + atom_m[1], |
| 153 | + atom_k, |
| 154 | + ) |
| 155 | + |
| 156 | + # Reorder scale factor tensor to (32, 4, rest_m, 4, rest_k, l) layout |
| 157 | + # Which is needed by the CuTe customized kernel |
| 158 | + mma_permute_order = (3, 4, 1, 5, 2, 0) |
| 159 | + # Generate a random int8 tensor, then convert to float8_e4m3fn |
| 160 | + rand_int_tensor = torch.randint(-3, 3, mma_shape, dtype=torch.int8, device='cuda') |
| 161 | + reordered_f8_torch_tensor = rand_int_tensor.to(dtype=torch.float8_e4m3fn) |
| 162 | + # Permute according to mma_permute_order |
| 163 | + reordered_f8_torch_tensor = reordered_f8_torch_tensor.permute(*mma_permute_order) |
| 164 | + |
| 165 | + # GPU-side vectorized reordering (replaces slow CPU nested loops) |
| 166 | + # Create index grids for all dimensions |
| 167 | + i_idx = torch.arange(mn, device='cuda') |
| 168 | + j_idx = torch.arange(sf_k, device='cuda') |
| 169 | + b_idx = torch.arange(l, device='cuda') |
| 170 | + |
| 171 | + # Create meshgrid for all combinations of (i, j, b) |
| 172 | + i_grid, j_grid, b_grid = torch.meshgrid(i_idx, j_idx, b_idx, indexing='ij') |
| 173 | + |
| 174 | + # Calculate target indices in vectorized manner |
| 175 | + mm = i_grid // (atom_m[0] * atom_m[1]) |
| 176 | + mm32 = i_grid % atom_m[0] |
| 177 | + mm4 = (i_grid % 128) // atom_m[0] |
| 178 | + kk = j_grid // atom_k |
| 179 | + kk4 = j_grid % atom_k |
| 180 | + |
| 181 | + # Perform the reordering with advanced indexing (all on GPU) |
| 182 | + reordered_f8_torch_tensor[mm32, mm4, mm, kk4, kk, b_grid] = ref_f8_torch_tensor_permuted[i_grid, j_grid, b_grid] |
| 183 | + |
| 184 | + return ref_f8_torch_tensor_permuted.cpu(), reordered_f8_torch_tensor |
| 185 | + |
| 186 | + sf_k = ceil_div(k, sf_vec_size) |
| 187 | + sfa_ref_cpu, sfa_ref_permuted = create_scale_factor_tensors(l, m, sf_k) |
| 188 | + sfb1_ref_cpu, sfb1_ref_permuted = create_scale_factor_tensors(l, n, sf_k) |
| 189 | + sfb2_ref_cpu, sfb2_ref_permuted = create_scale_factor_tensors(l, n, sf_k) |
| 190 | + |
| 191 | + return (a_ref, b1_ref, b2_ref, sfa_ref_cpu.to("cuda"), sfb1_ref_cpu.to("cuda"), sfb2_ref_cpu.to("cuda"), sfa_ref_permuted, sfb1_ref_permuted, sfb2_ref_permuted, c_ref) |
| 192 | + |
| 193 | + |
| 194 | +check_implementation = make_match_reference(ref_kernel, rtol=1e-03, atol=1e-03) |
0 commit comments