Skip to content

Commit e729a5c

Browse files
authored
Separate denoised and noise estimation in Euler CFG++ (Comfy-Org#9008)
This will change their behavior with the sampling CONST type. It also combines euler_cfg_pp and euler_ancestral_cfg_pp into one main function.
1 parent e78d230 commit e729a5c

File tree

1 file changed

+32
-32
lines changed

1 file changed

+32
-32
lines changed

comfy/k_diffusion/sampling.py

Lines changed: 32 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1210,39 +1210,21 @@ def sample_deis(model, x, sigmas, extra_args=None, callback=None, disable=None,
12101210
return x_next
12111211

12121212

1213-
@torch.no_grad()
1214-
def sample_euler_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None):
1215-
extra_args = {} if extra_args is None else extra_args
1216-
1217-
temp = [0]
1218-
def post_cfg_function(args):
1219-
temp[0] = args["uncond_denoised"]
1220-
return args["denoised"]
1221-
1222-
model_options = extra_args.get("model_options", {}).copy()
1223-
extra_args["model_options"] = comfy.model_patcher.set_model_options_post_cfg_function(model_options, post_cfg_function, disable_cfg1_optimization=True)
1224-
1225-
s_in = x.new_ones([x.shape[0]])
1226-
for i in trange(len(sigmas) - 1, disable=disable):
1227-
sigma_hat = sigmas[i]
1228-
denoised = model(x, sigma_hat * s_in, **extra_args)
1229-
d = to_d(x, sigma_hat, temp[0])
1230-
if callback is not None:
1231-
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
1232-
# Euler method
1233-
x = denoised + d * sigmas[i + 1]
1234-
return x
1235-
12361213
@torch.no_grad()
12371214
def sample_euler_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
1238-
"""Ancestral sampling with Euler method steps."""
1215+
"""Ancestral sampling with Euler method steps (CFG++)."""
12391216
extra_args = {} if extra_args is None else extra_args
12401217
seed = extra_args.get("seed", None)
12411218
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
12421219

1243-
temp = [0]
1220+
model_sampling = model.inner_model.model_patcher.get_model_object("model_sampling")
1221+
lambda_fn = partial(sigma_to_half_log_snr, model_sampling=model_sampling)
1222+
1223+
uncond_denoised = None
1224+
12441225
def post_cfg_function(args):
1245-
temp[0] = args["uncond_denoised"]
1226+
nonlocal uncond_denoised
1227+
uncond_denoised = args["uncond_denoised"]
12461228
return args["denoised"]
12471229

12481230
model_options = extra_args.get("model_options", {}).copy()
@@ -1251,15 +1233,33 @@ def post_cfg_function(args):
12511233
s_in = x.new_ones([x.shape[0]])
12521234
for i in trange(len(sigmas) - 1, disable=disable):
12531235
denoised = model(x, sigmas[i] * s_in, **extra_args)
1254-
sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
12551236
if callback is not None:
12561237
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
1257-
d = to_d(x, sigmas[i], temp[0])
1258-
# Euler method
1259-
x = denoised + d * sigma_down
1260-
if sigmas[i + 1] > 0:
1261-
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
1238+
if sigmas[i + 1] == 0:
1239+
# Denoising step
1240+
x = denoised
1241+
else:
1242+
alpha_s = sigmas[i] * lambda_fn(sigmas[i]).exp()
1243+
alpha_t = sigmas[i + 1] * lambda_fn(sigmas[i + 1]).exp()
1244+
d = to_d(x, sigmas[i], alpha_s * uncond_denoised) # to noise
1245+
1246+
# DDIM stochastic sampling
1247+
sigma_down, sigma_up = get_ancestral_step(sigmas[i] / alpha_s, sigmas[i + 1] / alpha_t, eta=eta)
1248+
sigma_down = alpha_t * sigma_down
1249+
1250+
# Euler method
1251+
x = alpha_t * denoised + sigma_down * d
1252+
if eta > 0 and s_noise > 0:
1253+
x = x + alpha_t * noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
12621254
return x
1255+
1256+
1257+
@torch.no_grad()
1258+
def sample_euler_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None):
1259+
"""Euler method steps (CFG++)."""
1260+
return sample_euler_ancestral_cfg_pp(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=0.0, s_noise=0.0, noise_sampler=None)
1261+
1262+
12631263
@torch.no_grad()
12641264
def sample_dpmpp_2s_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
12651265
"""Ancestral sampling with DPM-Solver++(2S) second-order steps."""

0 commit comments

Comments
 (0)