@@ -1345,28 +1345,52 @@ def sample_res_multistep_ancestral_cfg_pp(model, x, sigmas, extra_args=None, cal
13451345 return res_multistep (model , x , sigmas , extra_args = extra_args , callback = callback , disable = disable , s_noise = s_noise , noise_sampler = noise_sampler , eta = eta , cfg_pp = True )
13461346
13471347@torch .no_grad ()
1348- def sample_gradient_estimation (model , x , sigmas , extra_args = None , callback = None , disable = None , ge_gamma = 2. ):
1348+ def sample_gradient_estimation (model , x , sigmas , extra_args = None , callback = None , disable = None , ge_gamma = 2. , cfg_pp = False ):
13491349 """Gradient-estimation sampler. Paper: https://openreview.net/pdf?id=o2ND9v0CeK"""
13501350 extra_args = {} if extra_args is None else extra_args
13511351 s_in = x .new_ones ([x .shape [0 ]])
13521352 old_d = None
13531353
1354+ uncond_denoised = None
1355+ def post_cfg_function (args ):
1356+ nonlocal uncond_denoised
1357+ uncond_denoised = args ["uncond_denoised" ]
1358+ return args ["denoised" ]
1359+
1360+ if cfg_pp :
1361+ model_options = extra_args .get ("model_options" , {}).copy ()
1362+ extra_args ["model_options" ] = comfy .model_patcher .set_model_options_post_cfg_function (model_options , post_cfg_function , disable_cfg1_optimization = True )
1363+
13541364 for i in trange (len (sigmas ) - 1 , disable = disable ):
13551365 denoised = model (x , sigmas [i ] * s_in , ** extra_args )
1356- d = to_d (x , sigmas [i ], denoised )
1366+ if cfg_pp :
1367+ d = to_d (x , sigmas [i ], uncond_denoised )
1368+ else :
1369+ d = to_d (x , sigmas [i ], denoised )
13571370 if callback is not None :
13581371 callback ({'x' : x , 'i' : i , 'sigma' : sigmas [i ], 'sigma_hat' : sigmas [i ], 'denoised' : denoised })
13591372 dt = sigmas [i + 1 ] - sigmas [i ]
13601373 if i == 0 :
13611374 # Euler method
1362- x = x + d * dt
1375+ if cfg_pp :
1376+ x = denoised + d * sigmas [i + 1 ]
1377+ else :
1378+ x = x + d * dt
13631379 else :
13641380 # Gradient estimation
1365- d_bar = ge_gamma * d + (1 - ge_gamma ) * old_d
1366- x = x + d_bar * dt
1381+ if cfg_pp :
1382+ d_bar = (ge_gamma - 1 ) * (d - old_d )
1383+ x = denoised + d * sigmas [i + 1 ] + d_bar * dt
1384+ else :
1385+ d_bar = ge_gamma * d + (1 - ge_gamma ) * old_d
1386+ x = x + d_bar * dt
13671387 old_d = d
13681388 return x
13691389
1390+ @torch .no_grad ()
1391+ def sample_gradient_estimation_cfg_pp (model , x , sigmas , extra_args = None , callback = None , disable = None , ge_gamma = 2. ):
1392+ return sample_gradient_estimation (model , x , sigmas , extra_args = extra_args , callback = callback , disable = disable , ge_gamma = ge_gamma , cfg_pp = True )
1393+
13701394@torch .no_grad ()
13711395def sample_er_sde (model , x , sigmas , extra_args = None , callback = None , disable = None , s_noise = 1. , noise_sampler = None , noise_scaler = None , max_stage = 3 ):
13721396 """
0 commit comments