From 8e81c507d2124b6c04993e3dbe3df9f40573f814 Mon Sep 17 00:00:00 2001 From: chaObserv <154517000+chaObserv@users.noreply.github.com> Date: Tue, 17 Jun 2025 02:47:10 +0800 Subject: [PATCH] Multistep DPM++ SDE samplers for RF (#8541) Include alpha in sampling and minor refactoring --- comfy/k_diffusion/sampling.py | 50 ++++++++++++++++++++++++----------- 1 file changed, 34 insertions(+), 16 deletions(-) diff --git a/comfy/k_diffusion/sampling.py b/comfy/k_diffusion/sampling.py index a8fd9849..8030048f 100644 --- a/comfy/k_diffusion/sampling.py +++ b/comfy/k_diffusion/sampling.py @@ -781,6 +781,7 @@ def sample_dpmpp_2m(model, x, sigmas, extra_args=None, callback=None, disable=No old_denoised = denoised return x + @torch.no_grad() def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint'): """DPM-Solver++(2M) SDE.""" @@ -796,9 +797,12 @@ def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disabl noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler s_in = x.new_ones([x.shape[0]]) + model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling') + lambda_fn = partial(sigma_to_half_log_snr, model_sampling=model_sampling) + sigmas = offset_first_sigma_for_snr(sigmas, model_sampling) + old_denoised = None - h_last = None - h = None + h, h_last = None, None for i in trange(len(sigmas) - 1, disable=disable): denoised = model(x, sigmas[i] * s_in, **extra_args) @@ -809,26 +813,29 @@ def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disabl x = denoised else: # DPM-Solver++(2M) SDE - t, s = -sigmas[i].log(), -sigmas[i + 1].log() - h = s - t - eta_h = eta * h + lambda_s, lambda_t = lambda_fn(sigmas[i]), lambda_fn(sigmas[i + 1]) + h = lambda_t - lambda_s + h_eta = h * (eta + 1) - x = sigmas[i + 1] / sigmas[i] * (-eta_h).exp() * x + (-h - eta_h).expm1().neg() * denoised + alpha_t = sigmas[i + 1] * lambda_t.exp() + + x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x + alpha_t * (-h_eta).expm1().neg() * denoised if old_denoised is not None: r = h_last / h if solver_type == 'heun': - x = x + ((-h - eta_h).expm1().neg() / (-h - eta_h) + 1) * (1 / r) * (denoised - old_denoised) + x = x + alpha_t * ((-h_eta).expm1().neg() / (-h_eta) + 1) * (1 / r) * (denoised - old_denoised) elif solver_type == 'midpoint': - x = x + 0.5 * (-h - eta_h).expm1().neg() * (1 / r) * (denoised - old_denoised) + x = x + 0.5 * alpha_t * (-h_eta).expm1().neg() * (1 / r) * (denoised - old_denoised) - if eta: - x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * eta_h).expm1().neg().sqrt() * s_noise + if eta > 0 and s_noise > 0: + x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * h * eta).expm1().neg().sqrt() * s_noise old_denoised = denoised h_last = h return x + @torch.no_grad() def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None): """DPM-Solver++(3M) SDE.""" @@ -842,6 +849,10 @@ def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disabl noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler s_in = x.new_ones([x.shape[0]]) + model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling') + lambda_fn = partial(sigma_to_half_log_snr, model_sampling=model_sampling) + sigmas = offset_first_sigma_for_snr(sigmas, model_sampling) + denoised_1, denoised_2 = None, None h, h_1, h_2 = None, None, None @@ -853,13 +864,16 @@ def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disabl # Denoising step x = denoised else: - t, s = -sigmas[i].log(), -sigmas[i + 1].log() - h = s - t + lambda_s, lambda_t = lambda_fn(sigmas[i]), lambda_fn(sigmas[i + 1]) + h = lambda_t - lambda_s h_eta = h * (eta + 1) - x = torch.exp(-h_eta) * x + (-h_eta).expm1().neg() * denoised + alpha_t = sigmas[i + 1] * lambda_t.exp() + + x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x + alpha_t * (-h_eta).expm1().neg() * denoised if h_2 is not None: + # DPM-Solver++(3M) SDE r0 = h_1 / h r1 = h_2 / h d1_0 = (denoised - denoised_1) / r0 @@ -868,20 +882,22 @@ def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disabl d2 = (d1_0 - d1_1) / (r0 + r1) phi_2 = h_eta.neg().expm1() / h_eta + 1 phi_3 = phi_2 / h_eta - 0.5 - x = x + phi_2 * d1 - phi_3 * d2 + x = x + (alpha_t * phi_2) * d1 - (alpha_t * phi_3) * d2 elif h_1 is not None: + # DPM-Solver++(2M) SDE r = h_1 / h d = (denoised - denoised_1) / r phi_2 = h_eta.neg().expm1() / h_eta + 1 - x = x + phi_2 * d + x = x + (alpha_t * phi_2) * d - if eta: + if eta > 0 and s_noise > 0: x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * h * eta).expm1().neg().sqrt() * s_noise denoised_1, denoised_2 = denoised, denoised_1 h_1, h_2 = h, h_1 return x + @torch.no_grad() def sample_dpmpp_3m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None): if len(sigmas) <= 1: @@ -891,6 +907,7 @@ def sample_dpmpp_3m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, di noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler return sample_dpmpp_3m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler) + @torch.no_grad() def sample_dpmpp_2m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint'): if len(sigmas) <= 1: @@ -900,6 +917,7 @@ def sample_dpmpp_2m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, di noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler return sample_dpmpp_2m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, solver_type=solver_type) + @torch.no_grad() def sample_dpmpp_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=1 / 2): if len(sigmas) <= 1: