@@ -137,10 +137,44 @@ def ceil_div(a, b):
137137 return rearranged .reshape (padded_rows , padded_cols )
138138
139139
140- def stochastic_round_quantize_nvfp4 (x , per_tensor_scale , pad_16x , seed = 0 ):
140+ def stochastic_round_quantize_nvfp4_block (x , per_tensor_scale , generator ):
141141 F4_E2M1_MAX = 6.0
142142 F8_E4M3_MAX = 448.0
143143
144+ orig_shape = x .shape
145+
146+ block_size = 16
147+
148+ x = x .reshape (orig_shape [0 ], - 1 , block_size )
149+ scaled_block_scales_fp8 = torch .clamp (((torch .amax (torch .abs (x ), dim = - 1 )) / F4_E2M1_MAX ) / per_tensor_scale .to (x .dtype ), max = F8_E4M3_MAX ).to (torch .float8_e4m3fn )
150+ x = x / (per_tensor_scale .to (x .dtype ) * scaled_block_scales_fp8 .to (x .dtype )).unsqueeze (- 1 )
151+
152+ x = x .view (orig_shape ).nan_to_num ()
153+ data_lp = stochastic_float_to_fp4_e2m1 (x , generator = generator )
154+ return data_lp , scaled_block_scales_fp8
155+
156+
157+ def stochastic_round_quantize_nvfp4 (x , per_tensor_scale , pad_16x , seed = 0 ):
158+ def roundup (x : int , multiple : int ) -> int :
159+ """Round up x to the nearest multiple."""
160+ return ((x + multiple - 1 ) // multiple ) * multiple
161+
162+ generator = torch .Generator (device = x .device )
163+ generator .manual_seed (seed )
164+
165+ # Handle padding
166+ if pad_16x :
167+ rows , cols = x .shape
168+ padded_rows = roundup (rows , 16 )
169+ padded_cols = roundup (cols , 16 )
170+ if padded_rows != rows or padded_cols != cols :
171+ x = torch .nn .functional .pad (x , (0 , padded_cols - cols , 0 , padded_rows - rows ))
172+
173+ x , blocked_scaled = stochastic_round_quantize_nvfp4_block (x , per_tensor_scale , generator )
174+ return x , to_blocked (blocked_scaled , flatten = False )
175+
176+
177+ def stochastic_round_quantize_nvfp4_by_block (x , per_tensor_scale , pad_16x , seed = 0 , block_size = 4096 * 4096 ):
144178 def roundup (x : int , multiple : int ) -> int :
145179 """Round up x to the nearest multiple."""
146180 return ((x + multiple - 1 ) // multiple ) * multiple
@@ -158,16 +192,20 @@ def roundup(x: int, multiple: int) -> int:
158192 # what we want to produce. If we pad here, we want the padded output.
159193 orig_shape = x .shape
160194
161- block_size = 16
195+ orig_shape = list ( orig_shape )
162196
163- x = x .reshape (orig_shape [0 ], - 1 , block_size )
164- scaled_block_scales_fp8 = torch .clamp (((torch .amax (torch .abs (x ), dim = - 1 )) / F4_E2M1_MAX ) / per_tensor_scale .to (x .dtype ), max = F8_E4M3_MAX ).to (torch .float8_e4m3fn )
165- x /= (per_tensor_scale .to (x .dtype ) * scaled_block_scales_fp8 .to (x .dtype )).unsqueeze (- 1 )
197+ output_fp4 = torch .empty (orig_shape [:- 1 ] + [orig_shape [- 1 ] // 2 ], dtype = torch .uint8 , device = x .device )
198+ output_block = torch .empty (orig_shape [:- 1 ] + [orig_shape [- 1 ] // 16 ], dtype = torch .float8_e4m3fn , device = x .device )
166199
167200 generator = torch .Generator (device = x .device )
168201 generator .manual_seed (seed )
169202
170- x = x .view (orig_shape ).nan_to_num ()
171- data_lp = stochastic_float_to_fp4_e2m1 (x , generator = generator )
172- blocked_scales = to_blocked (scaled_block_scales_fp8 , flatten = False )
173- return data_lp , blocked_scales
203+ num_slices = max (1 , (x .numel () / block_size ))
204+ slice_size = max (1 , (round (x .shape [0 ] / num_slices )))
205+
206+ for i in range (0 , x .shape [0 ], slice_size ):
207+ fp4 , block = stochastic_round_quantize_nvfp4_block (x [i : i + slice_size ], per_tensor_scale , generator = generator )
208+ output_fp4 [i :i + slice_size ].copy_ (fp4 )
209+ output_block [i :i + slice_size ].copy_ (block )
210+
211+ return output_fp4 , to_blocked (output_block , flatten = False )
0 commit comments