mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-02-28 22:51:45 +00:00
res_multistep: Fix cfgpp and add ancestral samplers (#6731)
This commit is contained in:
parent
3d06e1c555
commit
caeb27c3a5
@ -1267,7 +1267,7 @@ def sample_dpmpp_2m_cfg_pp(model, x, sigmas, extra_args=None, callback=None, dis
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def res_multistep(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1., noise_sampler=None, cfg_pp=False):
|
def res_multistep(model, x, sigmas, extra_args=None, callback=None, disable=None, s_noise=1., noise_sampler=None, eta=1., cfg_pp=False):
|
||||||
extra_args = {} if extra_args is None else extra_args
|
extra_args = {} if extra_args is None else extra_args
|
||||||
seed = extra_args.get("seed", None)
|
seed = extra_args.get("seed", None)
|
||||||
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
|
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
|
||||||
@ -1289,53 +1289,60 @@ def res_multistep(model, x, sigmas, extra_args=None, callback=None, disable=None
|
|||||||
extra_args["model_options"] = comfy.model_patcher.set_model_options_post_cfg_function(model_options, post_cfg_function, disable_cfg1_optimization=True)
|
extra_args["model_options"] = comfy.model_patcher.set_model_options_post_cfg_function(model_options, post_cfg_function, disable_cfg1_optimization=True)
|
||||||
|
|
||||||
for i in trange(len(sigmas) - 1, disable=disable):
|
for i in trange(len(sigmas) - 1, disable=disable):
|
||||||
if s_churn > 0:
|
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||||
gamma = min(s_churn / (len(sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.0
|
sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
|
||||||
sigma_hat = sigmas[i] * (gamma + 1)
|
|
||||||
else:
|
|
||||||
gamma = 0
|
|
||||||
sigma_hat = sigmas[i]
|
|
||||||
|
|
||||||
if gamma > 0:
|
|
||||||
eps = torch.randn_like(x) * s_noise
|
|
||||||
x = x + eps * (sigma_hat**2 - sigmas[i] ** 2) ** 0.5
|
|
||||||
denoised = model(x, sigma_hat * s_in, **extra_args)
|
|
||||||
if callback is not None:
|
if callback is not None:
|
||||||
callback({"x": x, "i": i, "sigma": sigmas[i], "sigma_hat": sigma_hat, "denoised": denoised})
|
callback({"x": x, "i": i, "sigma": sigmas[i], "sigma_hat": sigmas[i], "denoised": denoised})
|
||||||
if sigmas[i + 1] == 0 or old_denoised is None:
|
if sigma_down == 0 or old_denoised is None:
|
||||||
# Euler method
|
# Euler method
|
||||||
if cfg_pp:
|
if cfg_pp:
|
||||||
d = to_d(x, sigma_hat, uncond_denoised)
|
d = to_d(x, sigmas[i], uncond_denoised)
|
||||||
x = denoised + d * sigmas[i + 1]
|
x = denoised + d * sigma_down
|
||||||
else:
|
else:
|
||||||
d = to_d(x, sigma_hat, denoised)
|
d = to_d(x, sigmas[i], denoised)
|
||||||
dt = sigmas[i + 1] - sigma_hat
|
dt = sigma_down - sigmas[i]
|
||||||
x = x + d * dt
|
x = x + d * dt
|
||||||
else:
|
else:
|
||||||
# Second order multistep method in https://arxiv.org/pdf/2308.02157
|
# Second order multistep method in https://arxiv.org/pdf/2308.02157
|
||||||
t, t_next, t_prev = t_fn(sigmas[i]), t_fn(sigmas[i + 1]), t_fn(sigmas[i - 1])
|
t, t_next, t_prev = t_fn(sigmas[i]), t_fn(sigma_down), t_fn(sigmas[i - 1])
|
||||||
h = t_next - t
|
h = t_next - t
|
||||||
c2 = (t_prev - t) / h
|
c2 = (t_prev - t) / h
|
||||||
|
|
||||||
phi1_val, phi2_val = phi1_fn(-h), phi2_fn(-h)
|
phi1_val, phi2_val = phi1_fn(-h), phi2_fn(-h)
|
||||||
b1 = torch.nan_to_num(phi1_val - 1.0 / c2 * phi2_val, nan=0.0)
|
b1 = torch.nan_to_num(phi1_val - phi2_val / c2, nan=0.0)
|
||||||
b2 = torch.nan_to_num(1.0 / c2 * phi2_val, nan=0.0)
|
b2 = torch.nan_to_num(phi2_val / c2, nan=0.0)
|
||||||
|
|
||||||
if cfg_pp:
|
if cfg_pp:
|
||||||
x = x + (denoised - uncond_denoised)
|
x = x + (denoised - uncond_denoised)
|
||||||
|
x = sigma_fn(h) * x + h * (b1 * uncond_denoised + b2 * old_denoised)
|
||||||
|
else:
|
||||||
|
x = sigma_fn(h) * x + h * (b1 * denoised + b2 * old_denoised)
|
||||||
|
|
||||||
x = (sigma_fn(t_next) / sigma_fn(t)) * x + h * (b1 * denoised + b2 * old_denoised)
|
# Noise addition
|
||||||
|
if sigmas[i + 1] > 0:
|
||||||
|
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
|
||||||
|
|
||||||
|
if cfg_pp:
|
||||||
|
old_denoised = uncond_denoised
|
||||||
|
else:
|
||||||
old_denoised = denoised
|
old_denoised = denoised
|
||||||
return x
|
return x
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def sample_res_multistep(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1., noise_sampler=None):
|
def sample_res_multistep(model, x, sigmas, extra_args=None, callback=None, disable=None, s_noise=1., noise_sampler=None):
|
||||||
return res_multistep(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, s_churn=s_churn, s_tmin=s_tmin, s_tmax=s_tmax, s_noise=s_noise, noise_sampler=noise_sampler, cfg_pp=False)
|
return res_multistep(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, s_noise=s_noise, noise_sampler=noise_sampler, eta=0., cfg_pp=False)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def sample_res_multistep_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1., noise_sampler=None):
|
def sample_res_multistep_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, s_noise=1., noise_sampler=None):
|
||||||
return res_multistep(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, s_churn=s_churn, s_tmin=s_tmin, s_tmax=s_tmax, s_noise=s_noise, noise_sampler=noise_sampler, cfg_pp=True)
|
return res_multistep(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, s_noise=s_noise, noise_sampler=noise_sampler, eta=0., cfg_pp=True)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def sample_res_multistep_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
||||||
|
return res_multistep(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, s_noise=s_noise, noise_sampler=noise_sampler, eta=eta, cfg_pp=False)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def sample_res_multistep_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
||||||
|
return res_multistep(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, s_noise=s_noise, noise_sampler=noise_sampler, eta=eta, cfg_pp=True)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def sample_gradient_estimation(model, x, sigmas, extra_args=None, callback=None, disable=None, ge_gamma=2.):
|
def sample_gradient_estimation(model, x, sigmas, extra_args=None, callback=None, disable=None, ge_gamma=2.):
|
||||||
|
@ -686,7 +686,8 @@ class Sampler:
|
|||||||
KSAMPLER_NAMES = ["euler", "euler_cfg_pp", "euler_ancestral", "euler_ancestral_cfg_pp", "heun", "heunpp2","dpm_2", "dpm_2_ancestral",
|
KSAMPLER_NAMES = ["euler", "euler_cfg_pp", "euler_ancestral", "euler_ancestral_cfg_pp", "heun", "heunpp2","dpm_2", "dpm_2_ancestral",
|
||||||
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_2s_ancestral_cfg_pp", "dpmpp_sde", "dpmpp_sde_gpu",
|
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_2s_ancestral_cfg_pp", "dpmpp_sde", "dpmpp_sde_gpu",
|
||||||
"dpmpp_2m", "dpmpp_2m_cfg_pp", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm",
|
"dpmpp_2m", "dpmpp_2m_cfg_pp", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm",
|
||||||
"ipndm", "ipndm_v", "deis", "res_multistep", "res_multistep_cfg_pp", "gradient_estimation"]
|
"ipndm", "ipndm_v", "deis", "res_multistep", "res_multistep_cfg_pp", "res_multistep_ancestral", "res_multistep_ancestral_cfg_pp",
|
||||||
|
"gradient_estimation"]
|
||||||
|
|
||||||
class KSAMPLER(Sampler):
|
class KSAMPLER(Sampler):
|
||||||
def __init__(self, sampler_function, extra_options={}, inpaint_options={}):
|
def __init__(self, sampler_function, extra_options={}, inpaint_options={}):
|
||||||
|
Loading…
Reference in New Issue
Block a user