Skip to content

Commit c15909b

Browse files
authored
CFG++ for gradient estimation sampler (Comfy-Org#7809)
1 parent 772b4c5 commit c15909b

2 files changed

Lines changed: 30 additions & 6 deletions

File tree

comfy/k_diffusion/sampling.py

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1345,28 +1345,52 @@ def sample_res_multistep_ancestral_cfg_pp(model, x, sigmas, extra_args=None, cal
13451345
return res_multistep(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, s_noise=s_noise, noise_sampler=noise_sampler, eta=eta, cfg_pp=True)
13461346

13471347
@torch.no_grad()
1348-
def sample_gradient_estimation(model, x, sigmas, extra_args=None, callback=None, disable=None, ge_gamma=2.):
1348+
def sample_gradient_estimation(model, x, sigmas, extra_args=None, callback=None, disable=None, ge_gamma=2., cfg_pp=False):
13491349
"""Gradient-estimation sampler. Paper: https://openreview.net/pdf?id=o2ND9v0CeK"""
13501350
extra_args = {} if extra_args is None else extra_args
13511351
s_in = x.new_ones([x.shape[0]])
13521352
old_d = None
13531353

1354+
uncond_denoised = None
1355+
def post_cfg_function(args):
1356+
nonlocal uncond_denoised
1357+
uncond_denoised = args["uncond_denoised"]
1358+
return args["denoised"]
1359+
1360+
if cfg_pp:
1361+
model_options = extra_args.get("model_options", {}).copy()
1362+
extra_args["model_options"] = comfy.model_patcher.set_model_options_post_cfg_function(model_options, post_cfg_function, disable_cfg1_optimization=True)
1363+
13541364
for i in trange(len(sigmas) - 1, disable=disable):
13551365
denoised = model(x, sigmas[i] * s_in, **extra_args)
1356-
d = to_d(x, sigmas[i], denoised)
1366+
if cfg_pp:
1367+
d = to_d(x, sigmas[i], uncond_denoised)
1368+
else:
1369+
d = to_d(x, sigmas[i], denoised)
13571370
if callback is not None:
13581371
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
13591372
dt = sigmas[i + 1] - sigmas[i]
13601373
if i == 0:
13611374
# Euler method
1362-
x = x + d * dt
1375+
if cfg_pp:
1376+
x = denoised + d * sigmas[i + 1]
1377+
else:
1378+
x = x + d * dt
13631379
else:
13641380
# Gradient estimation
1365-
d_bar = ge_gamma * d + (1 - ge_gamma) * old_d
1366-
x = x + d_bar * dt
1381+
if cfg_pp:
1382+
d_bar = (ge_gamma - 1) * (d - old_d)
1383+
x = denoised + d * sigmas[i + 1] + d_bar * dt
1384+
else:
1385+
d_bar = ge_gamma * d + (1 - ge_gamma) * old_d
1386+
x = x + d_bar * dt
13671387
old_d = d
13681388
return x
13691389

1390+
@torch.no_grad()
1391+
def sample_gradient_estimation_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, ge_gamma=2.):
1392+
return sample_gradient_estimation(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, ge_gamma=ge_gamma, cfg_pp=True)
1393+
13701394
@torch.no_grad()
13711395
def sample_er_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, s_noise=1., noise_sampler=None, noise_scaler=None, max_stage=3):
13721396
"""

comfy/samplers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -710,7 +710,7 @@ def max_denoise(self, model_wrap, sigmas):
710710
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_2s_ancestral_cfg_pp", "dpmpp_sde", "dpmpp_sde_gpu",
711711
"dpmpp_2m", "dpmpp_2m_cfg_pp", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm",
712712
"ipndm", "ipndm_v", "deis", "res_multistep", "res_multistep_cfg_pp", "res_multistep_ancestral", "res_multistep_ancestral_cfg_pp",
713-
"gradient_estimation", "er_sde", "seeds_2", "seeds_3"]
713+
"gradient_estimation", "gradient_estimation_cfg_pp", "er_sde", "seeds_2", "seeds_3"]
714714

715715
class KSAMPLER(Sampler):
716716
def __init__(self, sampler_function, extra_options={}, inpaint_options={}):

0 commit comments

Comments
 (0)