Skip to content

Commit b53b10e

Browse files
Fix Train LoRA crash when training_dtype is "none" with bfloat16 LoRA weights (Comfy-Org#13145)
When training_dtype is set to "none" and the model's native dtype is float16, GradScaler was unconditionally enabled. However, GradScaler does not support bfloat16 gradients (only float16/float32), causing a NotImplementedError when lora_dtype is "bf16" (the default). Fix by only enabling GradScaler when LoRA parameters are not in bfloat16, since bfloat16 has the same exponent range as float32 and does not need gradient scaling to avoid underflow. Fixes Comfy-Org#13124
1 parent 7d5534d commit b53b10e

File tree

1 file changed

+5
-2
lines changed

1 file changed

+5
-2
lines changed

comfy_extras/nodes_train.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1146,6 +1146,7 @@ def execute(
11461146
# Setup model and dtype
11471147
mp = model.clone()
11481148
use_grad_scaler = False
1149+
lora_dtype = node_helpers.string_to_torch_dtype(lora_dtype)
11491150
if training_dtype != "none":
11501151
dtype = node_helpers.string_to_torch_dtype(training_dtype)
11511152
mp.set_model_compute_dtype(dtype)
@@ -1154,7 +1155,10 @@ def execute(
11541155
model_dtype = mp.model.get_dtype()
11551156
if model_dtype == torch.float16:
11561157
dtype = torch.float16
1157-
use_grad_scaler = True
1158+
# GradScaler only supports float16 gradients, not bfloat16.
1159+
# Only enable it when lora params will also be in float16.
1160+
if lora_dtype != torch.bfloat16:
1161+
use_grad_scaler = True
11581162
# Warn about fp16 accumulation instability during training
11591163
if PerformanceFeature.Fp16Accumulation in args.fast:
11601164
logging.warning(
@@ -1165,7 +1169,6 @@ def execute(
11651169
else:
11661170
# For fp8, bf16, or other dtypes, use bf16 autocast
11671171
dtype = torch.bfloat16
1168-
lora_dtype = node_helpers.string_to_torch_dtype(lora_dtype)
11691172

11701173
# Prepare latents and compute counts
11711174
latents_dtype = dtype if dtype not in (None,) else torch.bfloat16

0 commit comments

Comments
 (0)