mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-07-02 21:37:15 +08:00
Multistep DPM++ SDE samplers for RF (#8541)
Include alpha in sampling and minor refactoring
This commit is contained in:
parent
e1c6dc720e
commit
8e81c507d2
@ -781,6 +781,7 @@ def sample_dpmpp_2m(model, x, sigmas, extra_args=None, callback=None, disable=No
|
|||||||
old_denoised = denoised
|
old_denoised = denoised
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint'):
|
def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint'):
|
||||||
"""DPM-Solver++(2M) SDE."""
|
"""DPM-Solver++(2M) SDE."""
|
||||||
@ -796,9 +797,12 @@ def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disabl
|
|||||||
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler
|
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler
|
||||||
s_in = x.new_ones([x.shape[0]])
|
s_in = x.new_ones([x.shape[0]])
|
||||||
|
|
||||||
|
model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling')
|
||||||
|
lambda_fn = partial(sigma_to_half_log_snr, model_sampling=model_sampling)
|
||||||
|
sigmas = offset_first_sigma_for_snr(sigmas, model_sampling)
|
||||||
|
|
||||||
old_denoised = None
|
old_denoised = None
|
||||||
h_last = None
|
h, h_last = None, None
|
||||||
h = None
|
|
||||||
|
|
||||||
for i in trange(len(sigmas) - 1, disable=disable):
|
for i in trange(len(sigmas) - 1, disable=disable):
|
||||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||||
@ -809,26 +813,29 @@ def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disabl
|
|||||||
x = denoised
|
x = denoised
|
||||||
else:
|
else:
|
||||||
# DPM-Solver++(2M) SDE
|
# DPM-Solver++(2M) SDE
|
||||||
t, s = -sigmas[i].log(), -sigmas[i + 1].log()
|
lambda_s, lambda_t = lambda_fn(sigmas[i]), lambda_fn(sigmas[i + 1])
|
||||||
h = s - t
|
h = lambda_t - lambda_s
|
||||||
eta_h = eta * h
|
h_eta = h * (eta + 1)
|
||||||
|
|
||||||
x = sigmas[i + 1] / sigmas[i] * (-eta_h).exp() * x + (-h - eta_h).expm1().neg() * denoised
|
alpha_t = sigmas[i + 1] * lambda_t.exp()
|
||||||
|
|
||||||
|
x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x + alpha_t * (-h_eta).expm1().neg() * denoised
|
||||||
|
|
||||||
if old_denoised is not None:
|
if old_denoised is not None:
|
||||||
r = h_last / h
|
r = h_last / h
|
||||||
if solver_type == 'heun':
|
if solver_type == 'heun':
|
||||||
x = x + ((-h - eta_h).expm1().neg() / (-h - eta_h) + 1) * (1 / r) * (denoised - old_denoised)
|
x = x + alpha_t * ((-h_eta).expm1().neg() / (-h_eta) + 1) * (1 / r) * (denoised - old_denoised)
|
||||||
elif solver_type == 'midpoint':
|
elif solver_type == 'midpoint':
|
||||||
x = x + 0.5 * (-h - eta_h).expm1().neg() * (1 / r) * (denoised - old_denoised)
|
x = x + 0.5 * alpha_t * (-h_eta).expm1().neg() * (1 / r) * (denoised - old_denoised)
|
||||||
|
|
||||||
if eta:
|
if eta > 0 and s_noise > 0:
|
||||||
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * eta_h).expm1().neg().sqrt() * s_noise
|
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * h * eta).expm1().neg().sqrt() * s_noise
|
||||||
|
|
||||||
old_denoised = denoised
|
old_denoised = denoised
|
||||||
h_last = h
|
h_last = h
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
||||||
"""DPM-Solver++(3M) SDE."""
|
"""DPM-Solver++(3M) SDE."""
|
||||||
@ -842,6 +849,10 @@ def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disabl
|
|||||||
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler
|
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler
|
||||||
s_in = x.new_ones([x.shape[0]])
|
s_in = x.new_ones([x.shape[0]])
|
||||||
|
|
||||||
|
model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling')
|
||||||
|
lambda_fn = partial(sigma_to_half_log_snr, model_sampling=model_sampling)
|
||||||
|
sigmas = offset_first_sigma_for_snr(sigmas, model_sampling)
|
||||||
|
|
||||||
denoised_1, denoised_2 = None, None
|
denoised_1, denoised_2 = None, None
|
||||||
h, h_1, h_2 = None, None, None
|
h, h_1, h_2 = None, None, None
|
||||||
|
|
||||||
@ -853,13 +864,16 @@ def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disabl
|
|||||||
# Denoising step
|
# Denoising step
|
||||||
x = denoised
|
x = denoised
|
||||||
else:
|
else:
|
||||||
t, s = -sigmas[i].log(), -sigmas[i + 1].log()
|
lambda_s, lambda_t = lambda_fn(sigmas[i]), lambda_fn(sigmas[i + 1])
|
||||||
h = s - t
|
h = lambda_t - lambda_s
|
||||||
h_eta = h * (eta + 1)
|
h_eta = h * (eta + 1)
|
||||||
|
|
||||||
x = torch.exp(-h_eta) * x + (-h_eta).expm1().neg() * denoised
|
alpha_t = sigmas[i + 1] * lambda_t.exp()
|
||||||
|
|
||||||
|
x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x + alpha_t * (-h_eta).expm1().neg() * denoised
|
||||||
|
|
||||||
if h_2 is not None:
|
if h_2 is not None:
|
||||||
|
# DPM-Solver++(3M) SDE
|
||||||
r0 = h_1 / h
|
r0 = h_1 / h
|
||||||
r1 = h_2 / h
|
r1 = h_2 / h
|
||||||
d1_0 = (denoised - denoised_1) / r0
|
d1_0 = (denoised - denoised_1) / r0
|
||||||
@ -868,20 +882,22 @@ def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disabl
|
|||||||
d2 = (d1_0 - d1_1) / (r0 + r1)
|
d2 = (d1_0 - d1_1) / (r0 + r1)
|
||||||
phi_2 = h_eta.neg().expm1() / h_eta + 1
|
phi_2 = h_eta.neg().expm1() / h_eta + 1
|
||||||
phi_3 = phi_2 / h_eta - 0.5
|
phi_3 = phi_2 / h_eta - 0.5
|
||||||
x = x + phi_2 * d1 - phi_3 * d2
|
x = x + (alpha_t * phi_2) * d1 - (alpha_t * phi_3) * d2
|
||||||
elif h_1 is not None:
|
elif h_1 is not None:
|
||||||
|
# DPM-Solver++(2M) SDE
|
||||||
r = h_1 / h
|
r = h_1 / h
|
||||||
d = (denoised - denoised_1) / r
|
d = (denoised - denoised_1) / r
|
||||||
phi_2 = h_eta.neg().expm1() / h_eta + 1
|
phi_2 = h_eta.neg().expm1() / h_eta + 1
|
||||||
x = x + phi_2 * d
|
x = x + (alpha_t * phi_2) * d
|
||||||
|
|
||||||
if eta:
|
if eta > 0 and s_noise > 0:
|
||||||
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * h * eta).expm1().neg().sqrt() * s_noise
|
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * h * eta).expm1().neg().sqrt() * s_noise
|
||||||
|
|
||||||
denoised_1, denoised_2 = denoised, denoised_1
|
denoised_1, denoised_2 = denoised, denoised_1
|
||||||
h_1, h_2 = h, h_1
|
h_1, h_2 = h, h_1
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def sample_dpmpp_3m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
def sample_dpmpp_3m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
||||||
if len(sigmas) <= 1:
|
if len(sigmas) <= 1:
|
||||||
@ -891,6 +907,7 @@ def sample_dpmpp_3m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, di
|
|||||||
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler
|
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler
|
||||||
return sample_dpmpp_3m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler)
|
return sample_dpmpp_3m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler)
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def sample_dpmpp_2m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint'):
|
def sample_dpmpp_2m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint'):
|
||||||
if len(sigmas) <= 1:
|
if len(sigmas) <= 1:
|
||||||
@ -900,6 +917,7 @@ def sample_dpmpp_2m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, di
|
|||||||
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler
|
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler
|
||||||
return sample_dpmpp_2m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, solver_type=solver_type)
|
return sample_dpmpp_2m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, solver_type=solver_type)
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def sample_dpmpp_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=1 / 2):
|
def sample_dpmpp_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=1 / 2):
|
||||||
if len(sigmas) <= 1:
|
if len(sigmas) <= 1:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user