From 53e8d8193c6b669eefc0a9bfa6e525b97837d876 Mon Sep 17 00:00:00 2001 From: chaObserv <154517000+chaObserv@users.noreply.github.com> Date: Sun, 15 Jun 2025 04:58:16 +0800 Subject: [PATCH] Generalize SEEDS samplers (#8529) Restore VP algorithm for RF and refactor noise_coeffs and half-logSNR calculations --- comfy/k_diffusion/sampling.py | 169 ++++++++++++++++++++++------------ 1 file changed, 109 insertions(+), 60 deletions(-) diff --git a/comfy/k_diffusion/sampling.py b/comfy/k_diffusion/sampling.py index fbdf6f55..a8fd9849 100644 --- a/comfy/k_diffusion/sampling.py +++ b/comfy/k_diffusion/sampling.py @@ -1,4 +1,5 @@ import math +from functools import partial from scipy import integrate import torch @@ -142,6 +143,33 @@ class BrownianTreeNoiseSampler: return self.tree(t0, t1) / (t1 - t0).abs().sqrt() +def sigma_to_half_log_snr(sigma, model_sampling): + """Convert sigma to half-logSNR log(alpha_t / sigma_t).""" + if isinstance(model_sampling, comfy.model_sampling.CONST): + # log((1 - t) / t) = log((1 - sigma) / sigma) + return sigma.logit().neg() + return sigma.log().neg() + + +def half_log_snr_to_sigma(half_log_snr, model_sampling): + """Convert half-logSNR log(alpha_t / sigma_t) to sigma.""" + if isinstance(model_sampling, comfy.model_sampling.CONST): + # 1 / (1 + exp(half_log_snr)) + return half_log_snr.neg().sigmoid() + return half_log_snr.neg().exp() + + +def offset_first_sigma_for_snr(sigmas, model_sampling, percent_offset=1e-4): + """Adjust the first sigma to avoid invalid logSNR.""" + if len(sigmas) <= 1: + return sigmas + if isinstance(model_sampling, comfy.model_sampling.CONST): + if sigmas[0] >= 1: + sigmas = sigmas.clone() + sigmas[0] = model_sampling.percent_to_sigma(percent_offset) + return sigmas + + @torch.no_grad() def sample_euler(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.): """Implements Algorithm 2 (Euler steps) from Karras et al. (2022).""" @@ -1449,12 +1477,12 @@ def sample_er_sde(model, x, sigmas, extra_args=None, callback=None, disable=None old_denoised = denoised return x + @torch.no_grad() def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=0.5): - ''' - SEEDS-2 - Stochastic Explicit Exponential Derivative-free Solvers (VE Data Prediction) stage 2 - Arxiv: https://arxiv.org/abs/2305.14267 - ''' + """SEEDS-2 - Stochastic Explicit Exponential Derivative-free Solvers (VP Data Prediction) stage 2. + arXiv: https://arxiv.org/abs/2305.14267 + """ extra_args = {} if extra_args is None else extra_args seed = extra_args.get("seed", None) noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler @@ -1462,6 +1490,11 @@ def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=Non inject_noise = eta > 0 and s_noise > 0 + model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling') + sigma_fn = partial(half_log_snr_to_sigma, model_sampling=model_sampling) + lambda_fn = partial(sigma_to_half_log_snr, model_sampling=model_sampling) + sigmas = offset_first_sigma_for_snr(sigmas, model_sampling) + for i in trange(len(sigmas) - 1, disable=disable): denoised = model(x, sigmas[i] * s_in, **extra_args) if callback is not None: @@ -1469,80 +1502,96 @@ def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=Non if sigmas[i + 1] == 0: x = denoised else: - t, t_next = -sigmas[i].log(), -sigmas[i + 1].log() - h = t_next - t + lambda_s, lambda_t = lambda_fn(sigmas[i]), lambda_fn(sigmas[i + 1]) + h = lambda_t - lambda_s h_eta = h * (eta + 1) - s = t + r * h + lambda_s_1 = lambda_s + r * h fac = 1 / (2 * r) - sigma_s = s.neg().exp() + sigma_s_1 = sigma_fn(lambda_s_1) + + # alpha_t = sigma_t * exp(log(alpha_t / sigma_t)) = sigma_t * exp(lambda_t) + alpha_s_1 = sigma_s_1 * lambda_s_1.exp() + alpha_t = sigmas[i + 1] * lambda_t.exp() coeff_1, coeff_2 = (-r * h_eta).expm1(), (-h_eta).expm1() if inject_noise: + # 0 < r < 1 noise_coeff_1 = (-2 * r * h * eta).expm1().neg().sqrt() - noise_coeff_2 = ((-2 * r * h * eta).expm1() - (-2 * h * eta).expm1()).sqrt() - noise_1, noise_2 = noise_sampler(sigmas[i], sigma_s), noise_sampler(sigma_s, sigmas[i + 1]) + noise_coeff_2 = (-r * h * eta).exp() * (-2 * (1 - r) * h * eta).expm1().neg().sqrt() + noise_1, noise_2 = noise_sampler(sigmas[i], sigma_s_1), noise_sampler(sigma_s_1, sigmas[i + 1]) # Step 1 - x_2 = (coeff_1 + 1) * x - coeff_1 * denoised - if inject_noise: - x_2 = x_2 + sigma_s * (noise_coeff_1 * noise_1) * s_noise - denoised_2 = model(x_2, sigma_s * s_in, **extra_args) - - # Step 2 - denoised_d = (1 - fac) * denoised + fac * denoised_2 - x = (coeff_2 + 1) * x - coeff_2 * denoised_d - if inject_noise: - x = x + sigmas[i + 1] * (noise_coeff_2 * noise_1 + noise_coeff_1 * noise_2) * s_noise - return x - -@torch.no_grad() -def sample_seeds_3(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r_1=1./3, r_2=2./3): - ''' - SEEDS-3 - Stochastic Explicit Exponential Derivative-free Solvers (VE Data Prediction) stage 3 - Arxiv: https://arxiv.org/abs/2305.14267 - ''' - extra_args = {} if extra_args is None else extra_args - seed = extra_args.get("seed", None) - noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler - s_in = x.new_ones([x.shape[0]]) - - inject_noise = eta > 0 and s_noise > 0 - - for i in trange(len(sigmas) - 1, disable=disable): - denoised = model(x, sigmas[i] * s_in, **extra_args) - if callback is not None: - callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) - if sigmas[i + 1] == 0: - x = denoised - else: - t, t_next = -sigmas[i].log(), -sigmas[i + 1].log() - h = t_next - t - h_eta = h * (eta + 1) - s_1 = t + r_1 * h - s_2 = t + r_2 * h - sigma_s_1, sigma_s_2 = s_1.neg().exp(), s_2.neg().exp() - - coeff_1, coeff_2, coeff_3 = (-r_1 * h_eta).expm1(), (-r_2 * h_eta).expm1(), (-h_eta).expm1() - if inject_noise: - noise_coeff_1 = (-2 * r_1 * h * eta).expm1().neg().sqrt() - noise_coeff_2 = ((-2 * r_1 * h * eta).expm1() - (-2 * r_2 * h * eta).expm1()).sqrt() - noise_coeff_3 = ((-2 * r_2 * h * eta).expm1() - (-2 * h * eta).expm1()).sqrt() - noise_1, noise_2, noise_3 = noise_sampler(sigmas[i], sigma_s_1), noise_sampler(sigma_s_1, sigma_s_2), noise_sampler(sigma_s_2, sigmas[i + 1]) - - # Step 1 - x_2 = (coeff_1 + 1) * x - coeff_1 * denoised + x_2 = sigma_s_1 / sigmas[i] * (-r * h * eta).exp() * x - alpha_s_1 * coeff_1 * denoised if inject_noise: x_2 = x_2 + sigma_s_1 * (noise_coeff_1 * noise_1) * s_noise denoised_2 = model(x_2, sigma_s_1 * s_in, **extra_args) # Step 2 - x_3 = (coeff_2 + 1) * x - coeff_2 * denoised + (r_2 / r_1) * (coeff_2 / (r_2 * h_eta) + 1) * (denoised_2 - denoised) + denoised_d = (1 - fac) * denoised + fac * denoised_2 + x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * coeff_2 * denoised_d + if inject_noise: + x = x + sigmas[i + 1] * (noise_coeff_2 * noise_1 + noise_coeff_1 * noise_2) * s_noise + return x + + +@torch.no_grad() +def sample_seeds_3(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r_1=1./3, r_2=2./3): + """SEEDS-3 - Stochastic Explicit Exponential Derivative-free Solvers (VP Data Prediction) stage 3. + arXiv: https://arxiv.org/abs/2305.14267 + """ + extra_args = {} if extra_args is None else extra_args + seed = extra_args.get("seed", None) + noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler + s_in = x.new_ones([x.shape[0]]) + + inject_noise = eta > 0 and s_noise > 0 + + model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling') + sigma_fn = partial(half_log_snr_to_sigma, model_sampling=model_sampling) + lambda_fn = partial(sigma_to_half_log_snr, model_sampling=model_sampling) + sigmas = offset_first_sigma_for_snr(sigmas, model_sampling) + + for i in trange(len(sigmas) - 1, disable=disable): + denoised = model(x, sigmas[i] * s_in, **extra_args) + if callback is not None: + callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) + if sigmas[i + 1] == 0: + x = denoised + else: + lambda_s, lambda_t = lambda_fn(sigmas[i]), lambda_fn(sigmas[i + 1]) + h = lambda_t - lambda_s + h_eta = h * (eta + 1) + lambda_s_1 = lambda_s + r_1 * h + lambda_s_2 = lambda_s + r_2 * h + sigma_s_1, sigma_s_2 = sigma_fn(lambda_s_1), sigma_fn(lambda_s_2) + + # alpha_t = sigma_t * exp(log(alpha_t / sigma_t)) = sigma_t * exp(lambda_t) + alpha_s_1 = sigma_s_1 * lambda_s_1.exp() + alpha_s_2 = sigma_s_2 * lambda_s_2.exp() + alpha_t = sigmas[i + 1] * lambda_t.exp() + + coeff_1, coeff_2, coeff_3 = (-r_1 * h_eta).expm1(), (-r_2 * h_eta).expm1(), (-h_eta).expm1() + if inject_noise: + # 0 < r_1 < r_2 < 1 + noise_coeff_1 = (-2 * r_1 * h * eta).expm1().neg().sqrt() + noise_coeff_2 = (-r_1 * h * eta).exp() * (-2 * (r_2 - r_1) * h * eta).expm1().neg().sqrt() + noise_coeff_3 = (-r_2 * h * eta).exp() * (-2 * (1 - r_2) * h * eta).expm1().neg().sqrt() + noise_1, noise_2, noise_3 = noise_sampler(sigmas[i], sigma_s_1), noise_sampler(sigma_s_1, sigma_s_2), noise_sampler(sigma_s_2, sigmas[i + 1]) + + # Step 1 + x_2 = sigma_s_1 / sigmas[i] * (-r_1 * h * eta).exp() * x - alpha_s_1 * coeff_1 * denoised + if inject_noise: + x_2 = x_2 + sigma_s_1 * (noise_coeff_1 * noise_1) * s_noise + denoised_2 = model(x_2, sigma_s_1 * s_in, **extra_args) + + # Step 2 + x_3 = sigma_s_2 / sigmas[i] * (-r_2 * h * eta).exp() * x - alpha_s_2 * coeff_2 * denoised + (r_2 / r_1) * alpha_s_2 * (coeff_2 / (r_2 * h_eta) + 1) * (denoised_2 - denoised) if inject_noise: x_3 = x_3 + sigma_s_2 * (noise_coeff_2 * noise_1 + noise_coeff_1 * noise_2) * s_noise denoised_3 = model(x_3, sigma_s_2 * s_in, **extra_args) # Step 3 - x = (coeff_3 + 1) * x - coeff_3 * denoised + (1. / r_2) * (coeff_3 / h_eta + 1) * (denoised_3 - denoised) + x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * coeff_3 * denoised + (1. / r_2) * alpha_t * (coeff_3 / h_eta + 1) * (denoised_3 - denoised) if inject_noise: x = x + sigmas[i + 1] * (noise_coeff_3 * noise_1 + noise_coeff_2 * noise_2 + noise_coeff_1 * noise_3) * s_noise return x