Commit b53b10e
authored
Fix Train LoRA crash when training_dtype is "none" with bfloat16 LoRA weights (Comfy-Org#13145)
When training_dtype is set to "none" and the model's native dtype is
float16, GradScaler was unconditionally enabled. However, GradScaler
does not support bfloat16 gradients (only float16/float32), causing a
NotImplementedError when lora_dtype is "bf16" (the default).
Fix by only enabling GradScaler when LoRA parameters are not in
bfloat16, since bfloat16 has the same exponent range as float32 and
does not need gradient scaling to avoid underflow.
Fixes Comfy-Org#131241 parent 7d5534d commit b53b10e
1 file changed
+5
-2
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
1146 | 1146 | | |
1147 | 1147 | | |
1148 | 1148 | | |
| 1149 | + | |
1149 | 1150 | | |
1150 | 1151 | | |
1151 | 1152 | | |
| |||
1154 | 1155 | | |
1155 | 1156 | | |
1156 | 1157 | | |
1157 | | - | |
| 1158 | + | |
| 1159 | + | |
| 1160 | + | |
| 1161 | + | |
1158 | 1162 | | |
1159 | 1163 | | |
1160 | 1164 | | |
| |||
1165 | 1169 | | |
1166 | 1170 | | |
1167 | 1171 | | |
1168 | | - | |
1169 | 1172 | | |
1170 | 1173 | | |
1171 | 1174 | | |
| |||
0 commit comments