Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions app/user_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import glob
import shutil
import logging
import tempfile
from aiohttp import web
from urllib import parse
from comfy.cli_args import args
Expand Down Expand Up @@ -377,8 +378,15 @@ async def post_userdata(request):
try:
body = await request.read()

with open(path, "wb") as f:
f.write(body)
dir_name = os.path.dirname(path)
fd, tmp_path = tempfile.mkstemp(dir=dir_name)
try:
with os.fdopen(fd, "wb") as f:
f.write(body)
os.replace(tmp_path, path)
except:
os.unlink(tmp_path)
raise
except OSError as e:
logging.warning(f"Error saving file '{path}': {e}")
return web.Response(
Expand Down
2 changes: 2 additions & 0 deletions comfy/ldm/hunyuan3dv2_1/hunyuandit.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,7 @@ def forward(self, x, y):
k.reshape(b, s2, self.num_heads * self.head_dim),
v,
heads=self.num_heads,
low_precision_attention=False,
)

out = self.out_proj(x)
Expand Down Expand Up @@ -412,6 +413,7 @@ def forward(self, x):
key.reshape(B, N, self.num_heads * self.head_dim),
value,
heads=self.num_heads,
low_precision_attention=False,
)

x = self.out_proj(x)
Expand Down
101 changes: 98 additions & 3 deletions comfy/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -776,6 +776,71 @@ def forward(self, *args, **kwargs):
)


class QuantLinearFunc(torch.autograd.Function):
"""Custom autograd function for quantized linear: quantized forward, compute_dtype backward.
Handles any input rank by flattening to 2D for matmul and restoring shape after.
"""

@staticmethod
def forward(ctx, input_float, weight, bias, layout_type, input_scale, compute_dtype):
input_shape = input_float.shape
inp = input_float.detach().flatten(0, -2) # zero-cost view to 2D

# Quantize input (same as inference path)
if layout_type is not None:
q_input = QuantizedTensor.from_float(inp, layout_type, scale=input_scale)
else:
q_input = inp

w = weight.detach() if weight.requires_grad else weight
b = bias.detach() if bias is not None and bias.requires_grad else bias

output = torch.nn.functional.linear(q_input, w, b)

# Restore original input shape
if len(input_shape) > 2:
output = output.unflatten(0, input_shape[:-1])

ctx.save_for_backward(input_float, weight)
ctx.input_shape = input_shape
ctx.has_bias = bias is not None
ctx.compute_dtype = compute_dtype
ctx.weight_requires_grad = weight.requires_grad

return output

@staticmethod
@torch.autograd.function.once_differentiable
def backward(ctx, grad_output):
input_float, weight = ctx.saved_tensors
compute_dtype = ctx.compute_dtype
grad_2d = grad_output.flatten(0, -2).to(compute_dtype)

# Dequantize weight to compute dtype for backward matmul
if isinstance(weight, QuantizedTensor):
weight_f = weight.dequantize().to(compute_dtype)
else:
weight_f = weight.to(compute_dtype)

# grad_input = grad_output @ weight
grad_input = torch.mm(grad_2d, weight_f)
if len(ctx.input_shape) > 2:
grad_input = grad_input.unflatten(0, ctx.input_shape[:-1])

# grad_weight (only if weight requires grad, typically frozen for quantized training)
grad_weight = None
if ctx.weight_requires_grad:
input_f = input_float.flatten(0, -2).to(compute_dtype)
grad_weight = torch.mm(grad_2d.t(), input_f)

# grad_bias
grad_bias = None
if ctx.has_bias:
grad_bias = grad_2d.sum(dim=0)

return grad_input, grad_weight, grad_bias, None, None, None


def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_precision_mm=False, disabled=[]):
class MixedPrecisionOps(manual_cast):
_quant_config = quant_config
Expand Down Expand Up @@ -970,10 +1035,37 @@ def forward(self, input, *args, **kwargs):
#If cast needs to apply lora, it should be done in the compute dtype
compute_dtype = input.dtype

if (getattr(self, 'layout_type', None) is not None and
_use_quantized = (
getattr(self, 'layout_type', None) is not None and
not isinstance(input, QuantizedTensor) and not self._full_precision_mm and
not getattr(self, 'comfy_force_cast_weights', False) and
len(self.weight_function) == 0 and len(self.bias_function) == 0):
len(self.weight_function) == 0 and len(self.bias_function) == 0
)

# Training path: quantized forward with compute_dtype backward via autograd function
if (input.requires_grad and _use_quantized):

weight, bias, offload_stream = cast_bias_weight(
self,
input,
offloadable=True,
compute_dtype=compute_dtype,
want_requant=True
)

scale = getattr(self, 'input_scale', None)
if scale is not None:
scale = comfy.model_management.cast_to_device(scale, input.device, None)

output = QuantLinearFunc.apply(
input, weight, bias, self.layout_type, scale, compute_dtype
)

uncast_bias_weight(self, weight, bias, offload_stream)
return output

# Inference path (unchanged)
if _use_quantized:

# Reshape 3D tensors to 2D for quantization (needed for NVFP4 and others)
input_reshaped = input.reshape(-1, input_shape[2]) if input.ndim == 3 else input
Expand Down Expand Up @@ -1021,7 +1113,10 @@ def _apply(self, fn, recurse=True): # This is to get torch.compile + moving wei
for key, param in self._parameters.items():
if param is None:
continue
self.register_parameter(key, torch.nn.Parameter(fn(param), requires_grad=False))
p = fn(param)
if p.is_inference():
p = p.clone()
self.register_parameter(key, torch.nn.Parameter(p, requires_grad=False))
for key, buf in self._buffers.items():
if buf is not None:
self._buffers[key] = fn(buf)
Expand Down
4 changes: 4 additions & 0 deletions comfy/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -897,6 +897,10 @@ def set_attr(obj, attr, value):
return prev

def set_attr_param(obj, attr, value):
# Clone inference tensors (created under torch.inference_mode) since
# their version counter is frozen and nn.Parameter() cannot wrap them.
if (not torch.is_inference_mode_enabled()) and value.is_inference():
value = value.clone()
return set_attr(obj, attr, torch.nn.Parameter(value, requires_grad=False))

def set_attr_buffer(obj, attr, value):
Expand Down
68 changes: 48 additions & 20 deletions comfy_extras/nodes_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import comfy.sd
import comfy.utils
import comfy.model_management
from comfy.cli_args import args, PerformanceFeature
import comfy_extras.nodes_custom_sampler
import folder_paths
import node_helpers
Expand Down Expand Up @@ -138,6 +139,7 @@ def __init__(
training_dtype=torch.bfloat16,
real_dataset=None,
bucket_latents=None,
use_grad_scaler=False,
):
self.loss_fn = loss_fn
self.optimizer = optimizer
Expand All @@ -152,6 +154,8 @@ def __init__(
self.bucket_latents: list[torch.Tensor] | None = (
bucket_latents # list of (Bi, C, Hi, Wi)
)
# GradScaler for fp16 training
self.grad_scaler = torch.amp.GradScaler() if use_grad_scaler else None
# Precompute bucket offsets and weights for sampling
if bucket_latents is not None:
self._init_bucket_data(bucket_latents)
Expand Down Expand Up @@ -204,10 +208,13 @@ def fwd_bwd(
batch_sigmas.requires_grad_(True),
**batch_extra_args,
)
loss = self.loss_fn(x0_pred, x0)
loss = self.loss_fn(x0_pred.float(), x0.float())
if bwd:
bwd_loss = loss / self.grad_acc
bwd_loss.backward()
if self.grad_scaler is not None:
self.grad_scaler.scale(bwd_loss).backward()
else:
bwd_loss.backward()
return loss

def _generate_batch_sigmas(self, model_wrap, batch_size, device):
Expand Down Expand Up @@ -307,7 +314,10 @@ def _train_step_multires_mode(self, model_wrap, cond, extra_args, noisegen, late
)
total_loss += loss
total_loss = total_loss / self.grad_acc / len(indicies)
total_loss.backward()
if self.grad_scaler is not None:
self.grad_scaler.scale(total_loss).backward()
else:
total_loss.backward()
if self.loss_callback:
self.loss_callback(total_loss.item())
pbar.set_postfix({"loss": f"{total_loss.item():.4f}"})
Expand Down Expand Up @@ -348,12 +358,18 @@ def sample(
self._train_step_multires_mode(model_wrap, cond, extra_args, noisegen, latent_image, dataset_size, pbar)

if (i + 1) % self.grad_acc == 0:
if self.grad_scaler is not None:
self.grad_scaler.unscale_(self.optimizer)
for param_groups in self.optimizer.param_groups:
for param in param_groups["params"]:
if param.grad is None:
continue
param.grad.data = param.grad.data.to(param.data.dtype)
self.optimizer.step()
if self.grad_scaler is not None:
self.grad_scaler.step(self.optimizer)
self.grad_scaler.update()
else:
self.optimizer.step()
self.optimizer.zero_grad()
ui_pbar.update(1)
torch.cuda.empty_cache()
Expand Down Expand Up @@ -1004,9 +1020,9 @@ def define_schema(cls):
),
io.Combo.Input(
"training_dtype",
options=["bf16", "fp32"],
options=["bf16", "fp32", "none"],
default="bf16",
tooltip="The dtype to use for training.",
tooltip="The dtype to use for training. 'none' preserves the model's native compute dtype instead of overriding it. For fp16 models, GradScaler is automatically enabled.",
),
io.Combo.Input(
"lora_dtype",
Expand Down Expand Up @@ -1035,7 +1051,7 @@ def define_schema(cls):
io.Boolean.Input(
"offloading",
default=False,
tooltip="Offload the Model to RAM. Requires Bypass Mode.",
tooltip="Offload model weights to CPU during training to save GPU memory.",
),
io.Combo.Input(
"existing_lora",
Expand Down Expand Up @@ -1120,22 +1136,32 @@ def execute(

# Setup model and dtype
mp = model.clone()
dtype = node_helpers.string_to_torch_dtype(training_dtype)
use_grad_scaler = False
if training_dtype != "none":
dtype = node_helpers.string_to_torch_dtype(training_dtype)
mp.set_model_compute_dtype(dtype)
else:
# Detect model's native dtype for autocast
model_dtype = mp.model.get_dtype()
if model_dtype == torch.float16:
dtype = torch.float16
use_grad_scaler = True
# Warn about fp16 accumulation instability during training
if PerformanceFeature.Fp16Accumulation in args.fast:
logging.warning(
"WARNING: FP16 model detected with fp16_accumulation enabled. "
"This combination can be numerically unstable during training and may cause NaN values. "
"Suggested fixes: 1) Set training_dtype to 'bf16', or 2) Disable fp16_accumulation (remove from --fast flags)."
)
else:
# For fp8, bf16, or other dtypes, use bf16 autocast
dtype = torch.bfloat16
lora_dtype = node_helpers.string_to_torch_dtype(lora_dtype)
mp.set_model_compute_dtype(dtype)

if mp.is_dynamic():
if not bypass_mode:
logging.info("Training MP is Dynamic - forcing bypass mode. Start comfy with --highvram to force weight diff mode")
bypass_mode = True
offloading = True
elif offloading:
if not bypass_mode:
logging.info("Training Offload selected - forcing bypass mode. Set bypass = True to remove this message")

# Prepare latents and compute counts
latents_dtype = dtype if dtype not in (None,) else torch.bfloat16
latents, num_images, multi_res = _prepare_latents_and_count(
latents, dtype, bucket_mode
latents, latents_dtype, bucket_mode
)

# Validate and expand conditioning
Expand Down Expand Up @@ -1201,6 +1227,7 @@ def loss_callback(loss):
seed=seed,
training_dtype=dtype,
bucket_latents=latents,
use_grad_scaler=use_grad_scaler,
)
else:
train_sampler = TrainSampler(
Expand All @@ -1213,6 +1240,7 @@ def loss_callback(loss):
seed=seed,
training_dtype=dtype,
real_dataset=latents if multi_res else None,
use_grad_scaler=use_grad_scaler,
)

# Setup guider
Expand Down Expand Up @@ -1337,7 +1365,7 @@ def define_schema(cls):
io.Int.Input(
"steps",
optional=True,
tooltip="Optional: The number of steps to LoRA has been trained for, used to name the saved file.",
tooltip="Optional: The number of steps the LoRA has been trained for, used to name the saved file.",
),
],
outputs=[],
Expand Down
Loading