diff --git a/comfy/k_diffusion/sampling.py b/comfy/k_diffusion/sampling.py index 2c0d1832..45667998 100644 --- a/comfy/k_diffusion/sampling.py +++ b/comfy/k_diffusion/sampling.py @@ -1267,7 +1267,7 @@ def sample_dpmpp_2m_cfg_pp(model, x, sigmas, extra_args=None, callback=None, dis return x @torch.no_grad() -def res_multistep(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1., noise_sampler=None, cfg_pp=False): +def res_multistep(model, x, sigmas, extra_args=None, callback=None, disable=None, s_noise=1., noise_sampler=None, eta=1., cfg_pp=False): extra_args = {} if extra_args is None else extra_args seed = extra_args.get("seed", None) noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler @@ -1289,53 +1289,60 @@ def res_multistep(model, x, sigmas, extra_args=None, callback=None, disable=None extra_args["model_options"] = comfy.model_patcher.set_model_options_post_cfg_function(model_options, post_cfg_function, disable_cfg1_optimization=True) for i in trange(len(sigmas) - 1, disable=disable): - if s_churn > 0: - gamma = min(s_churn / (len(sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.0 - sigma_hat = sigmas[i] * (gamma + 1) - else: - gamma = 0 - sigma_hat = sigmas[i] - - if gamma > 0: - eps = torch.randn_like(x) * s_noise - x = x + eps * (sigma_hat**2 - sigmas[i] ** 2) ** 0.5 - denoised = model(x, sigma_hat * s_in, **extra_args) + denoised = model(x, sigmas[i] * s_in, **extra_args) + sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta) if callback is not None: - callback({"x": x, "i": i, "sigma": sigmas[i], "sigma_hat": sigma_hat, "denoised": denoised}) - if sigmas[i + 1] == 0 or old_denoised is None: + callback({"x": x, "i": i, "sigma": sigmas[i], "sigma_hat": sigmas[i], "denoised": denoised}) + if sigma_down == 0 or old_denoised is None: # Euler method if cfg_pp: - d = to_d(x, sigma_hat, uncond_denoised) - x = denoised + d * sigmas[i + 1] + d = to_d(x, sigmas[i], uncond_denoised) + x = denoised + d * sigma_down else: - d = to_d(x, sigma_hat, denoised) - dt = sigmas[i + 1] - sigma_hat + d = to_d(x, sigmas[i], denoised) + dt = sigma_down - sigmas[i] x = x + d * dt else: # Second order multistep method in https://arxiv.org/pdf/2308.02157 - t, t_next, t_prev = t_fn(sigmas[i]), t_fn(sigmas[i + 1]), t_fn(sigmas[i - 1]) + t, t_next, t_prev = t_fn(sigmas[i]), t_fn(sigma_down), t_fn(sigmas[i - 1]) h = t_next - t c2 = (t_prev - t) / h phi1_val, phi2_val = phi1_fn(-h), phi2_fn(-h) - b1 = torch.nan_to_num(phi1_val - 1.0 / c2 * phi2_val, nan=0.0) - b2 = torch.nan_to_num(1.0 / c2 * phi2_val, nan=0.0) + b1 = torch.nan_to_num(phi1_val - phi2_val / c2, nan=0.0) + b2 = torch.nan_to_num(phi2_val / c2, nan=0.0) if cfg_pp: x = x + (denoised - uncond_denoised) + x = sigma_fn(h) * x + h * (b1 * uncond_denoised + b2 * old_denoised) + else: + x = sigma_fn(h) * x + h * (b1 * denoised + b2 * old_denoised) - x = (sigma_fn(t_next) / sigma_fn(t)) * x + h * (b1 * denoised + b2 * old_denoised) + # Noise addition + if sigmas[i + 1] > 0: + x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up - old_denoised = denoised + if cfg_pp: + old_denoised = uncond_denoised + else: + old_denoised = denoised return x @torch.no_grad() -def sample_res_multistep(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1., noise_sampler=None): - return res_multistep(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, s_churn=s_churn, s_tmin=s_tmin, s_tmax=s_tmax, s_noise=s_noise, noise_sampler=noise_sampler, cfg_pp=False) +def sample_res_multistep(model, x, sigmas, extra_args=None, callback=None, disable=None, s_noise=1., noise_sampler=None): + return res_multistep(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, s_noise=s_noise, noise_sampler=noise_sampler, eta=0., cfg_pp=False) @torch.no_grad() -def sample_res_multistep_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1., noise_sampler=None): - return res_multistep(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, s_churn=s_churn, s_tmin=s_tmin, s_tmax=s_tmax, s_noise=s_noise, noise_sampler=noise_sampler, cfg_pp=True) +def sample_res_multistep_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, s_noise=1., noise_sampler=None): + return res_multistep(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, s_noise=s_noise, noise_sampler=noise_sampler, eta=0., cfg_pp=True) + +@torch.no_grad() +def sample_res_multistep_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None): + return res_multistep(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, s_noise=s_noise, noise_sampler=noise_sampler, eta=eta, cfg_pp=False) + +@torch.no_grad() +def sample_res_multistep_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None): + return res_multistep(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, s_noise=s_noise, noise_sampler=noise_sampler, eta=eta, cfg_pp=True) @torch.no_grad() def sample_gradient_estimation(model, x, sigmas, extra_args=None, callback=None, disable=None, ge_gamma=2.): diff --git a/comfy/samplers.py b/comfy/samplers.py index 3b66091e..a1b4787e 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -686,7 +686,8 @@ class Sampler: KSAMPLER_NAMES = ["euler", "euler_cfg_pp", "euler_ancestral", "euler_ancestral_cfg_pp", "heun", "heunpp2","dpm_2", "dpm_2_ancestral", "lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_2s_ancestral_cfg_pp", "dpmpp_sde", "dpmpp_sde_gpu", "dpmpp_2m", "dpmpp_2m_cfg_pp", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm", - "ipndm", "ipndm_v", "deis", "res_multistep", "res_multistep_cfg_pp", "gradient_estimation"] + "ipndm", "ipndm_v", "deis", "res_multistep", "res_multistep_cfg_pp", "res_multistep_ancestral", "res_multistep_ancestral_cfg_pp", + "gradient_estimation"] class KSAMPLER(Sampler): def __init__(self, sampler_function, extra_options={}, inpaint_options={}):