Skip to content

Commit 5ebb0c2

Browse files
FP8 bwd training (Comfy-Org#13121)
1 parent a0a64c6 commit 5ebb0c2

File tree

3 files changed

+59
-16
lines changed

3 files changed

+59
-16
lines changed

comfy/model_management.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ class CPUState(Enum):
5555

5656
# Training Related State
5757
in_training = False
58+
training_fp8_bwd = False
5859

5960

6061
def get_supported_float8_types():

comfy/ops.py

Lines changed: 49 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -777,16 +777,24 @@ def forward(self, *args, **kwargs):
777777

778778

779779
class QuantLinearFunc(torch.autograd.Function):
780-
"""Custom autograd function for quantized linear: quantized forward, compute_dtype backward.
781-
Handles any input rank by flattening to 2D for matmul and restoring shape after.
780+
"""Custom autograd function for quantized linear: quantized forward, optionally FP8 backward.
781+
782+
When training_fp8_bwd is enabled:
783+
- Forward: quantize input per layout (FP8/NVFP4), use quantized matmul
784+
- Backward: all matmuls use FP8 tensor cores via torch.mm dispatch
785+
- Cached input is FP8 (half the memory of bf16)
786+
787+
When training_fp8_bwd is disabled:
788+
- Forward: quantize input per layout, use quantized matmul
789+
- Backward: dequantize weight to compute_dtype, use standard matmul
782790
"""
783791

784792
@staticmethod
785793
def forward(ctx, input_float, weight, bias, layout_type, input_scale, compute_dtype):
786794
input_shape = input_float.shape
787795
inp = input_float.detach().flatten(0, -2) # zero-cost view to 2D
788796

789-
# Quantize input (same as inference path)
797+
# Quantize input for forward (same layout as weight)
790798
if layout_type is not None:
791799
q_input = QuantizedTensor.from_float(inp, layout_type, scale=input_scale)
792800
else:
@@ -797,43 +805,68 @@ def forward(ctx, input_float, weight, bias, layout_type, input_scale, compute_dt
797805

798806
output = torch.nn.functional.linear(q_input, w, b)
799807

800-
# Restore original input shape
808+
# Unflatten output to match original input shape
801809
if len(input_shape) > 2:
802810
output = output.unflatten(0, input_shape[:-1])
803811

804-
ctx.save_for_backward(input_float, weight)
812+
# Save for backward
805813
ctx.input_shape = input_shape
806814
ctx.has_bias = bias is not None
807815
ctx.compute_dtype = compute_dtype
808816
ctx.weight_requires_grad = weight.requires_grad
817+
ctx.fp8_bwd = comfy.model_management.training_fp8_bwd
818+
819+
if ctx.fp8_bwd:
820+
# Cache FP8 quantized input — half the memory of bf16
821+
if isinstance(q_input, QuantizedTensor) and layout_type.startswith('TensorCoreFP8'):
822+
ctx.q_input = q_input # already FP8, reuse
823+
else:
824+
# NVFP4 or other layout — quantize input to FP8 for backward
825+
ctx.q_input = QuantizedTensor.from_float(inp, "TensorCoreFP8E4M3Layout")
826+
ctx.save_for_backward(weight)
827+
else:
828+
ctx.q_input = None
829+
ctx.save_for_backward(input_float, weight)
809830

810831
return output
811832

812833
@staticmethod
813834
@torch.autograd.function.once_differentiable
814835
def backward(ctx, grad_output):
815-
input_float, weight = ctx.saved_tensors
816836
compute_dtype = ctx.compute_dtype
817837
grad_2d = grad_output.flatten(0, -2).to(compute_dtype)
818838

819-
# Dequantize weight to compute dtype for backward matmul
820-
if isinstance(weight, QuantizedTensor):
821-
weight_f = weight.dequantize().to(compute_dtype)
839+
# Value casting — only difference between fp8 and non-fp8 paths
840+
if ctx.fp8_bwd:
841+
weight, = ctx.saved_tensors
842+
# Wrap as FP8 QuantizedTensors → torch.mm dispatches to _scaled_mm
843+
grad_mm = QuantizedTensor.from_float(grad_2d, "TensorCoreFP8E5M2Layout")
844+
if isinstance(weight, QuantizedTensor) and weight._layout_cls.startswith("TensorCoreFP8"):
845+
weight_mm = weight
846+
elif isinstance(weight, QuantizedTensor):
847+
weight_mm = QuantizedTensor.from_float(weight.dequantize().to(compute_dtype), "TensorCoreFP8E4M3Layout")
848+
else:
849+
weight_mm = QuantizedTensor.from_float(weight.to(compute_dtype), "TensorCoreFP8E4M3Layout")
850+
input_mm = ctx.q_input
822851
else:
823-
weight_f = weight.to(compute_dtype)
852+
input_float, weight = ctx.saved_tensors
853+
# Standard tensors → torch.mm does regular matmul
854+
grad_mm = grad_2d
855+
if isinstance(weight, QuantizedTensor):
856+
weight_mm = weight.dequantize().to(compute_dtype)
857+
else:
858+
weight_mm = weight.to(compute_dtype)
859+
input_mm = input_float.flatten(0, -2).to(compute_dtype) if ctx.weight_requires_grad else None
824860

825-
# grad_input = grad_output @ weight
826-
grad_input = torch.mm(grad_2d, weight_f)
861+
# Computation — same for both paths, dispatch handles the rest
862+
grad_input = torch.mm(grad_mm, weight_mm)
827863
if len(ctx.input_shape) > 2:
828864
grad_input = grad_input.unflatten(0, ctx.input_shape[:-1])
829865

830-
# grad_weight (only if weight requires grad, typically frozen for quantized training)
831866
grad_weight = None
832867
if ctx.weight_requires_grad:
833-
input_f = input_float.flatten(0, -2).to(compute_dtype)
834-
grad_weight = torch.mm(grad_2d.t(), input_f)
868+
grad_weight = torch.mm(grad_mm.t(), input_mm)
835869

836-
# grad_bias
837870
grad_bias = None
838871
if ctx.has_bias:
839872
grad_bias = grad_2d.sum(dim=0)

comfy_extras/nodes_train.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1030,6 +1030,11 @@ def define_schema(cls):
10301030
default="bf16",
10311031
tooltip="The dtype to use for lora.",
10321032
),
1033+
io.Boolean.Input(
1034+
"quantized_backward",
1035+
default=False,
1036+
tooltip="When using training_dtype 'none' and training on quantized model, doing backward with quantized matmul when enabled.",
1037+
),
10331038
io.Combo.Input(
10341039
"algorithm",
10351040
options=list(adapter_maps.keys()),
@@ -1097,6 +1102,7 @@ def execute(
10971102
seed,
10981103
training_dtype,
10991104
lora_dtype,
1105+
quantized_backward,
11001106
algorithm,
11011107
gradient_checkpointing,
11021108
checkpoint_depth,
@@ -1117,6 +1123,7 @@ def execute(
11171123
seed = seed[0]
11181124
training_dtype = training_dtype[0]
11191125
lora_dtype = lora_dtype[0]
1126+
quantized_backward = quantized_backward[0]
11201127
algorithm = algorithm[0]
11211128
gradient_checkpointing = gradient_checkpointing[0]
11221129
offloading = offloading[0]
@@ -1125,6 +1132,8 @@ def execute(
11251132
bucket_mode = bucket_mode[0]
11261133
bypass_mode = bypass_mode[0]
11271134

1135+
comfy.model_management.training_fp8_bwd = quantized_backward
1136+
11281137
# Process latents based on mode
11291138
if bucket_mode:
11301139
latents = _process_latents_bucket_mode(latents)

0 commit comments

Comments
 (0)