@@ -1210,39 +1210,21 @@ def sample_deis(model, x, sigmas, extra_args=None, callback=None, disable=None,
12101210 return x_next
12111211
12121212
1213- @torch .no_grad ()
1214- def sample_euler_cfg_pp (model , x , sigmas , extra_args = None , callback = None , disable = None ):
1215- extra_args = {} if extra_args is None else extra_args
1216-
1217- temp = [0 ]
1218- def post_cfg_function (args ):
1219- temp [0 ] = args ["uncond_denoised" ]
1220- return args ["denoised" ]
1221-
1222- model_options = extra_args .get ("model_options" , {}).copy ()
1223- extra_args ["model_options" ] = comfy .model_patcher .set_model_options_post_cfg_function (model_options , post_cfg_function , disable_cfg1_optimization = True )
1224-
1225- s_in = x .new_ones ([x .shape [0 ]])
1226- for i in trange (len (sigmas ) - 1 , disable = disable ):
1227- sigma_hat = sigmas [i ]
1228- denoised = model (x , sigma_hat * s_in , ** extra_args )
1229- d = to_d (x , sigma_hat , temp [0 ])
1230- if callback is not None :
1231- callback ({'x' : x , 'i' : i , 'sigma' : sigmas [i ], 'sigma_hat' : sigma_hat , 'denoised' : denoised })
1232- # Euler method
1233- x = denoised + d * sigmas [i + 1 ]
1234- return x
1235-
12361213@torch .no_grad ()
12371214def sample_euler_ancestral_cfg_pp (model , x , sigmas , extra_args = None , callback = None , disable = None , eta = 1. , s_noise = 1. , noise_sampler = None ):
1238- """Ancestral sampling with Euler method steps."""
1215+ """Ancestral sampling with Euler method steps (CFG++) ."""
12391216 extra_args = {} if extra_args is None else extra_args
12401217 seed = extra_args .get ("seed" , None )
12411218 noise_sampler = default_noise_sampler (x , seed = seed ) if noise_sampler is None else noise_sampler
12421219
1243- temp = [0 ]
1220+ model_sampling = model .inner_model .model_patcher .get_model_object ("model_sampling" )
1221+ lambda_fn = partial (sigma_to_half_log_snr , model_sampling = model_sampling )
1222+
1223+ uncond_denoised = None
1224+
12441225 def post_cfg_function (args ):
1245- temp [0 ] = args ["uncond_denoised" ]
1226+ nonlocal uncond_denoised
1227+ uncond_denoised = args ["uncond_denoised" ]
12461228 return args ["denoised" ]
12471229
12481230 model_options = extra_args .get ("model_options" , {}).copy ()
@@ -1251,15 +1233,33 @@ def post_cfg_function(args):
12511233 s_in = x .new_ones ([x .shape [0 ]])
12521234 for i in trange (len (sigmas ) - 1 , disable = disable ):
12531235 denoised = model (x , sigmas [i ] * s_in , ** extra_args )
1254- sigma_down , sigma_up = get_ancestral_step (sigmas [i ], sigmas [i + 1 ], eta = eta )
12551236 if callback is not None :
12561237 callback ({'x' : x , 'i' : i , 'sigma' : sigmas [i ], 'sigma_hat' : sigmas [i ], 'denoised' : denoised })
1257- d = to_d (x , sigmas [i ], temp [0 ])
1258- # Euler method
1259- x = denoised + d * sigma_down
1260- if sigmas [i + 1 ] > 0 :
1261- x = x + noise_sampler (sigmas [i ], sigmas [i + 1 ]) * s_noise * sigma_up
1238+ if sigmas [i + 1 ] == 0 :
1239+ # Denoising step
1240+ x = denoised
1241+ else :
1242+ alpha_s = sigmas [i ] * lambda_fn (sigmas [i ]).exp ()
1243+ alpha_t = sigmas [i + 1 ] * lambda_fn (sigmas [i + 1 ]).exp ()
1244+ d = to_d (x , sigmas [i ], alpha_s * uncond_denoised ) # to noise
1245+
1246+ # DDIM stochastic sampling
1247+ sigma_down , sigma_up = get_ancestral_step (sigmas [i ] / alpha_s , sigmas [i + 1 ] / alpha_t , eta = eta )
1248+ sigma_down = alpha_t * sigma_down
1249+
1250+ # Euler method
1251+ x = alpha_t * denoised + sigma_down * d
1252+ if eta > 0 and s_noise > 0 :
1253+ x = x + alpha_t * noise_sampler (sigmas [i ], sigmas [i + 1 ]) * s_noise * sigma_up
12621254 return x
1255+
1256+
1257+ @torch .no_grad ()
1258+ def sample_euler_cfg_pp (model , x , sigmas , extra_args = None , callback = None , disable = None ):
1259+ """Euler method steps (CFG++)."""
1260+ return sample_euler_ancestral_cfg_pp (model , x , sigmas , extra_args = extra_args , callback = callback , disable = disable , eta = 0.0 , s_noise = 0.0 , noise_sampler = None )
1261+
1262+
12631263@torch .no_grad ()
12641264def sample_dpmpp_2s_ancestral_cfg_pp (model , x , sigmas , extra_args = None , callback = None , disable = None , eta = 1. , s_noise = 1. , noise_sampler = None ):
12651265 """Ancestral sampling with DPM-Solver++(2S) second-order steps."""
0 commit comments