Multistep DPM++ SDE samplers for RF (#8541)

Include alpha in sampling and minor refactoring
This commit is contained in:
chaObserv 2025-06-17 02:47:10 +08:00 committed by GitHub
parent e1c6dc720e
commit 8e81c507d2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -781,6 +781,7 @@ def sample_dpmpp_2m(model, x, sigmas, extra_args=None, callback=None, disable=No
old_denoised = denoised old_denoised = denoised
return x return x
@torch.no_grad() @torch.no_grad()
def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint'): def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint'):
"""DPM-Solver++(2M) SDE.""" """DPM-Solver++(2M) SDE."""
@ -796,9 +797,12 @@ def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disabl
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler
s_in = x.new_ones([x.shape[0]]) s_in = x.new_ones([x.shape[0]])
model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling')
lambda_fn = partial(sigma_to_half_log_snr, model_sampling=model_sampling)
sigmas = offset_first_sigma_for_snr(sigmas, model_sampling)
old_denoised = None old_denoised = None
h_last = None h, h_last = None, None
h = None
for i in trange(len(sigmas) - 1, disable=disable): for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args) denoised = model(x, sigmas[i] * s_in, **extra_args)
@ -809,26 +813,29 @@ def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disabl
x = denoised x = denoised
else: else:
# DPM-Solver++(2M) SDE # DPM-Solver++(2M) SDE
t, s = -sigmas[i].log(), -sigmas[i + 1].log() lambda_s, lambda_t = lambda_fn(sigmas[i]), lambda_fn(sigmas[i + 1])
h = s - t h = lambda_t - lambda_s
eta_h = eta * h h_eta = h * (eta + 1)
x = sigmas[i + 1] / sigmas[i] * (-eta_h).exp() * x + (-h - eta_h).expm1().neg() * denoised alpha_t = sigmas[i + 1] * lambda_t.exp()
x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x + alpha_t * (-h_eta).expm1().neg() * denoised
if old_denoised is not None: if old_denoised is not None:
r = h_last / h r = h_last / h
if solver_type == 'heun': if solver_type == 'heun':
x = x + ((-h - eta_h).expm1().neg() / (-h - eta_h) + 1) * (1 / r) * (denoised - old_denoised) x = x + alpha_t * ((-h_eta).expm1().neg() / (-h_eta) + 1) * (1 / r) * (denoised - old_denoised)
elif solver_type == 'midpoint': elif solver_type == 'midpoint':
x = x + 0.5 * (-h - eta_h).expm1().neg() * (1 / r) * (denoised - old_denoised) x = x + 0.5 * alpha_t * (-h_eta).expm1().neg() * (1 / r) * (denoised - old_denoised)
if eta: if eta > 0 and s_noise > 0:
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * eta_h).expm1().neg().sqrt() * s_noise x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * h * eta).expm1().neg().sqrt() * s_noise
old_denoised = denoised old_denoised = denoised
h_last = h h_last = h
return x return x
@torch.no_grad() @torch.no_grad()
def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None): def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
"""DPM-Solver++(3M) SDE.""" """DPM-Solver++(3M) SDE."""
@ -842,6 +849,10 @@ def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disabl
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler
s_in = x.new_ones([x.shape[0]]) s_in = x.new_ones([x.shape[0]])
model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling')
lambda_fn = partial(sigma_to_half_log_snr, model_sampling=model_sampling)
sigmas = offset_first_sigma_for_snr(sigmas, model_sampling)
denoised_1, denoised_2 = None, None denoised_1, denoised_2 = None, None
h, h_1, h_2 = None, None, None h, h_1, h_2 = None, None, None
@ -853,13 +864,16 @@ def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disabl
# Denoising step # Denoising step
x = denoised x = denoised
else: else:
t, s = -sigmas[i].log(), -sigmas[i + 1].log() lambda_s, lambda_t = lambda_fn(sigmas[i]), lambda_fn(sigmas[i + 1])
h = s - t h = lambda_t - lambda_s
h_eta = h * (eta + 1) h_eta = h * (eta + 1)
x = torch.exp(-h_eta) * x + (-h_eta).expm1().neg() * denoised alpha_t = sigmas[i + 1] * lambda_t.exp()
x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x + alpha_t * (-h_eta).expm1().neg() * denoised
if h_2 is not None: if h_2 is not None:
# DPM-Solver++(3M) SDE
r0 = h_1 / h r0 = h_1 / h
r1 = h_2 / h r1 = h_2 / h
d1_0 = (denoised - denoised_1) / r0 d1_0 = (denoised - denoised_1) / r0
@ -868,20 +882,22 @@ def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disabl
d2 = (d1_0 - d1_1) / (r0 + r1) d2 = (d1_0 - d1_1) / (r0 + r1)
phi_2 = h_eta.neg().expm1() / h_eta + 1 phi_2 = h_eta.neg().expm1() / h_eta + 1
phi_3 = phi_2 / h_eta - 0.5 phi_3 = phi_2 / h_eta - 0.5
x = x + phi_2 * d1 - phi_3 * d2 x = x + (alpha_t * phi_2) * d1 - (alpha_t * phi_3) * d2
elif h_1 is not None: elif h_1 is not None:
# DPM-Solver++(2M) SDE
r = h_1 / h r = h_1 / h
d = (denoised - denoised_1) / r d = (denoised - denoised_1) / r
phi_2 = h_eta.neg().expm1() / h_eta + 1 phi_2 = h_eta.neg().expm1() / h_eta + 1
x = x + phi_2 * d x = x + (alpha_t * phi_2) * d
if eta: if eta > 0 and s_noise > 0:
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * h * eta).expm1().neg().sqrt() * s_noise x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * h * eta).expm1().neg().sqrt() * s_noise
denoised_1, denoised_2 = denoised, denoised_1 denoised_1, denoised_2 = denoised, denoised_1
h_1, h_2 = h, h_1 h_1, h_2 = h, h_1
return x return x
@torch.no_grad() @torch.no_grad()
def sample_dpmpp_3m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None): def sample_dpmpp_3m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
if len(sigmas) <= 1: if len(sigmas) <= 1:
@ -891,6 +907,7 @@ def sample_dpmpp_3m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, di
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler
return sample_dpmpp_3m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler) return sample_dpmpp_3m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler)
@torch.no_grad() @torch.no_grad()
def sample_dpmpp_2m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint'): def sample_dpmpp_2m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint'):
if len(sigmas) <= 1: if len(sigmas) <= 1:
@ -900,6 +917,7 @@ def sample_dpmpp_2m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, di
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler
return sample_dpmpp_2m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, solver_type=solver_type) return sample_dpmpp_2m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, solver_type=solver_type)
@torch.no_grad() @torch.no_grad()
def sample_dpmpp_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=1 / 2): def sample_dpmpp_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=1 / 2):
if len(sigmas) <= 1: if len(sigmas) <= 1: