Skip to content

Commit 6165c38

Browse files
Optimize nvfp4 lora applying. (Comfy-Org#11866)
This changes results a bit but it also speeds up things a lot.
1 parent 712cca3 commit 6165c38

3 files changed

Lines changed: 49 additions & 11 deletions

File tree

comfy/float.py

Lines changed: 47 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -137,10 +137,44 @@ def ceil_div(a, b):
137137
return rearranged.reshape(padded_rows, padded_cols)
138138

139139

140-
def stochastic_round_quantize_nvfp4(x, per_tensor_scale, pad_16x, seed=0):
140+
def stochastic_round_quantize_nvfp4_block(x, per_tensor_scale, generator):
141141
F4_E2M1_MAX = 6.0
142142
F8_E4M3_MAX = 448.0
143143

144+
orig_shape = x.shape
145+
146+
block_size = 16
147+
148+
x = x.reshape(orig_shape[0], -1, block_size)
149+
scaled_block_scales_fp8 = torch.clamp(((torch.amax(torch.abs(x), dim=-1)) / F4_E2M1_MAX) / per_tensor_scale.to(x.dtype), max=F8_E4M3_MAX).to(torch.float8_e4m3fn)
150+
x = x / (per_tensor_scale.to(x.dtype) * scaled_block_scales_fp8.to(x.dtype)).unsqueeze(-1)
151+
152+
x = x.view(orig_shape).nan_to_num()
153+
data_lp = stochastic_float_to_fp4_e2m1(x, generator=generator)
154+
return data_lp, scaled_block_scales_fp8
155+
156+
157+
def stochastic_round_quantize_nvfp4(x, per_tensor_scale, pad_16x, seed=0):
158+
def roundup(x: int, multiple: int) -> int:
159+
"""Round up x to the nearest multiple."""
160+
return ((x + multiple - 1) // multiple) * multiple
161+
162+
generator = torch.Generator(device=x.device)
163+
generator.manual_seed(seed)
164+
165+
# Handle padding
166+
if pad_16x:
167+
rows, cols = x.shape
168+
padded_rows = roundup(rows, 16)
169+
padded_cols = roundup(cols, 16)
170+
if padded_rows != rows or padded_cols != cols:
171+
x = torch.nn.functional.pad(x, (0, padded_cols - cols, 0, padded_rows - rows))
172+
173+
x, blocked_scaled = stochastic_round_quantize_nvfp4_block(x, per_tensor_scale, generator)
174+
return x, to_blocked(blocked_scaled, flatten=False)
175+
176+
177+
def stochastic_round_quantize_nvfp4_by_block(x, per_tensor_scale, pad_16x, seed=0, block_size=4096 * 4096):
144178
def roundup(x: int, multiple: int) -> int:
145179
"""Round up x to the nearest multiple."""
146180
return ((x + multiple - 1) // multiple) * multiple
@@ -158,16 +192,20 @@ def roundup(x: int, multiple: int) -> int:
158192
# what we want to produce. If we pad here, we want the padded output.
159193
orig_shape = x.shape
160194

161-
block_size = 16
195+
orig_shape = list(orig_shape)
162196

163-
x = x.reshape(orig_shape[0], -1, block_size)
164-
scaled_block_scales_fp8 = torch.clamp(((torch.amax(torch.abs(x), dim=-1)) / F4_E2M1_MAX) / per_tensor_scale.to(x.dtype), max=F8_E4M3_MAX).to(torch.float8_e4m3fn)
165-
x /= (per_tensor_scale.to(x.dtype) * scaled_block_scales_fp8.to(x.dtype)).unsqueeze(-1)
197+
output_fp4 = torch.empty(orig_shape[:-1] + [orig_shape[-1] // 2], dtype=torch.uint8, device=x.device)
198+
output_block = torch.empty(orig_shape[:-1] + [orig_shape[-1] // 16], dtype=torch.float8_e4m3fn, device=x.device)
166199

167200
generator = torch.Generator(device=x.device)
168201
generator.manual_seed(seed)
169202

170-
x = x.view(orig_shape).nan_to_num()
171-
data_lp = stochastic_float_to_fp4_e2m1(x, generator=generator)
172-
blocked_scales = to_blocked(scaled_block_scales_fp8, flatten=False)
173-
return data_lp, blocked_scales
203+
num_slices = max(1, (x.numel() / block_size))
204+
slice_size = max(1, (round(x.shape[0] / num_slices)))
205+
206+
for i in range(0, x.shape[0], slice_size):
207+
fp4, block = stochastic_round_quantize_nvfp4_block(x[i: i + slice_size], per_tensor_scale, generator=generator)
208+
output_fp4[i:i + slice_size].copy_(fp4)
209+
output_block[i:i + slice_size].copy_(block)
210+
211+
return output_fp4, to_blocked(output_block, flatten=False)

comfy/quant_ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ def quantize(cls, tensor, scale=None, stochastic_rounding=0, inplace_ops=False):
104104
needs_padding = padded_shape != orig_shape
105105

106106
if stochastic_rounding > 0:
107-
qdata, block_scale = comfy.float.stochastic_round_quantize_nvfp4(tensor, scale, pad_16x=needs_padding, seed=stochastic_rounding)
107+
qdata, block_scale = comfy.float.stochastic_round_quantize_nvfp4_by_block(tensor, scale, pad_16x=needs_padding, seed=stochastic_rounding)
108108
else:
109109
qdata, block_scale = ck.quantize_nvfp4(tensor, scale, pad_16x=needs_padding)
110110

comfy/supported_models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1042,7 +1042,7 @@ class ZImage(Lumina2):
10421042
"shift": 3.0,
10431043
}
10441044

1045-
memory_usage_factor = 2.0
1045+
memory_usage_factor = 2.8
10461046

10471047
supported_inference_dtypes = [torch.bfloat16, torch.float32]
10481048

0 commit comments

Comments
 (0)