@@ -777,16 +777,24 @@ def forward(self, *args, **kwargs):
777777
778778
779779class QuantLinearFunc (torch .autograd .Function ):
780- """Custom autograd function for quantized linear: quantized forward, compute_dtype backward.
781- Handles any input rank by flattening to 2D for matmul and restoring shape after.
780+ """Custom autograd function for quantized linear: quantized forward, optionally FP8 backward.
781+
782+ When training_fp8_bwd is enabled:
783+ - Forward: quantize input per layout (FP8/NVFP4), use quantized matmul
784+ - Backward: all matmuls use FP8 tensor cores via torch.mm dispatch
785+ - Cached input is FP8 (half the memory of bf16)
786+
787+ When training_fp8_bwd is disabled:
788+ - Forward: quantize input per layout, use quantized matmul
789+ - Backward: dequantize weight to compute_dtype, use standard matmul
782790 """
783791
784792 @staticmethod
785793 def forward (ctx , input_float , weight , bias , layout_type , input_scale , compute_dtype ):
786794 input_shape = input_float .shape
787795 inp = input_float .detach ().flatten (0 , - 2 ) # zero-cost view to 2D
788796
789- # Quantize input (same as inference path )
797+ # Quantize input for forward (same layout as weight )
790798 if layout_type is not None :
791799 q_input = QuantizedTensor .from_float (inp , layout_type , scale = input_scale )
792800 else :
@@ -797,43 +805,68 @@ def forward(ctx, input_float, weight, bias, layout_type, input_scale, compute_dt
797805
798806 output = torch .nn .functional .linear (q_input , w , b )
799807
800- # Restore original input shape
808+ # Unflatten output to match original input shape
801809 if len (input_shape ) > 2 :
802810 output = output .unflatten (0 , input_shape [:- 1 ])
803811
804- ctx . save_for_backward ( input_float , weight )
812+ # Save for backward
805813 ctx .input_shape = input_shape
806814 ctx .has_bias = bias is not None
807815 ctx .compute_dtype = compute_dtype
808816 ctx .weight_requires_grad = weight .requires_grad
817+ ctx .fp8_bwd = comfy .model_management .training_fp8_bwd
818+
819+ if ctx .fp8_bwd :
820+ # Cache FP8 quantized input — half the memory of bf16
821+ if isinstance (q_input , QuantizedTensor ) and layout_type .startswith ('TensorCoreFP8' ):
822+ ctx .q_input = q_input # already FP8, reuse
823+ else :
824+ # NVFP4 or other layout — quantize input to FP8 for backward
825+ ctx .q_input = QuantizedTensor .from_float (inp , "TensorCoreFP8E4M3Layout" )
826+ ctx .save_for_backward (weight )
827+ else :
828+ ctx .q_input = None
829+ ctx .save_for_backward (input_float , weight )
809830
810831 return output
811832
812833 @staticmethod
813834 @torch .autograd .function .once_differentiable
814835 def backward (ctx , grad_output ):
815- input_float , weight = ctx .saved_tensors
816836 compute_dtype = ctx .compute_dtype
817837 grad_2d = grad_output .flatten (0 , - 2 ).to (compute_dtype )
818838
819- # Dequantize weight to compute dtype for backward matmul
820- if isinstance (weight , QuantizedTensor ):
821- weight_f = weight .dequantize ().to (compute_dtype )
839+ # Value casting — only difference between fp8 and non-fp8 paths
840+ if ctx .fp8_bwd :
841+ weight , = ctx .saved_tensors
842+ # Wrap as FP8 QuantizedTensors → torch.mm dispatches to _scaled_mm
843+ grad_mm = QuantizedTensor .from_float (grad_2d , "TensorCoreFP8E5M2Layout" )
844+ if isinstance (weight , QuantizedTensor ) and weight ._layout_cls .startswith ("TensorCoreFP8" ):
845+ weight_mm = weight
846+ elif isinstance (weight , QuantizedTensor ):
847+ weight_mm = QuantizedTensor .from_float (weight .dequantize ().to (compute_dtype ), "TensorCoreFP8E4M3Layout" )
848+ else :
849+ weight_mm = QuantizedTensor .from_float (weight .to (compute_dtype ), "TensorCoreFP8E4M3Layout" )
850+ input_mm = ctx .q_input
822851 else :
823- weight_f = weight .to (compute_dtype )
852+ input_float , weight = ctx .saved_tensors
853+ # Standard tensors → torch.mm does regular matmul
854+ grad_mm = grad_2d
855+ if isinstance (weight , QuantizedTensor ):
856+ weight_mm = weight .dequantize ().to (compute_dtype )
857+ else :
858+ weight_mm = weight .to (compute_dtype )
859+ input_mm = input_float .flatten (0 , - 2 ).to (compute_dtype ) if ctx .weight_requires_grad else None
824860
825- # grad_input = grad_output @ weight
826- grad_input = torch .mm (grad_2d , weight_f )
861+ # Computation — same for both paths, dispatch handles the rest
862+ grad_input = torch .mm (grad_mm , weight_mm )
827863 if len (ctx .input_shape ) > 2 :
828864 grad_input = grad_input .unflatten (0 , ctx .input_shape [:- 1 ])
829865
830- # grad_weight (only if weight requires grad, typically frozen for quantized training)
831866 grad_weight = None
832867 if ctx .weight_requires_grad :
833- input_f = input_float .flatten (0 , - 2 ).to (compute_dtype )
834- grad_weight = torch .mm (grad_2d .t (), input_f )
868+ grad_weight = torch .mm (grad_mm .t (), input_mm )
835869
836- # grad_bias
837870 grad_bias = None
838871 if ctx .has_bias :
839872 grad_bias = grad_2d .sum (dim = 0 )
0 commit comments