From e857dd48b810a36a14dc1a6fa93ec930f4e75ee6 Mon Sep 17 00:00:00 2001 From: chaObserv <154517000+chaObserv@users.noreply.github.com> Date: Wed, 22 Jan 2025 18:29:40 +0800 Subject: [PATCH] Add gradient estimation sampler (#6554) --- comfy/k_diffusion/sampling.py | 23 +++++++++++++++++++++++ comfy/samplers.py | 2 +- 2 files changed, 24 insertions(+), 1 deletion(-) diff --git a/comfy/k_diffusion/sampling.py b/comfy/k_diffusion/sampling.py index 87a522b7..2c0d1832 100644 --- a/comfy/k_diffusion/sampling.py +++ b/comfy/k_diffusion/sampling.py @@ -1336,3 +1336,26 @@ def sample_res_multistep(model, x, sigmas, extra_args=None, callback=None, disab @torch.no_grad() def sample_res_multistep_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1., noise_sampler=None): return res_multistep(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, s_churn=s_churn, s_tmin=s_tmin, s_tmax=s_tmax, s_noise=s_noise, noise_sampler=noise_sampler, cfg_pp=True) + +@torch.no_grad() +def sample_gradient_estimation(model, x, sigmas, extra_args=None, callback=None, disable=None, ge_gamma=2.): + """Gradient-estimation sampler. Paper: https://openreview.net/pdf?id=o2ND9v0CeK""" + extra_args = {} if extra_args is None else extra_args + s_in = x.new_ones([x.shape[0]]) + old_d = None + + for i in trange(len(sigmas) - 1, disable=disable): + denoised = model(x, sigmas[i] * s_in, **extra_args) + d = to_d(x, sigmas[i], denoised) + if callback is not None: + callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) + dt = sigmas[i + 1] - sigmas[i] + if i == 0: + # Euler method + x = x + d * dt + else: + # Gradient estimation + d_bar = ge_gamma * d + (1 - ge_gamma) * old_d + x = x + d_bar * dt + old_d = d + return x diff --git a/comfy/samplers.py b/comfy/samplers.py index d281ecc1..3b66091e 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -686,7 +686,7 @@ class Sampler: KSAMPLER_NAMES = ["euler", "euler_cfg_pp", "euler_ancestral", "euler_ancestral_cfg_pp", "heun", "heunpp2","dpm_2", "dpm_2_ancestral", "lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_2s_ancestral_cfg_pp", "dpmpp_sde", "dpmpp_sde_gpu", "dpmpp_2m", "dpmpp_2m_cfg_pp", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm", - "ipndm", "ipndm_v", "deis", "res_multistep", "res_multistep_cfg_pp"] + "ipndm", "ipndm_v", "deis", "res_multistep", "res_multistep_cfg_pp", "gradient_estimation"] class KSAMPLER(Sampler): def __init__(self, sampler_function, extra_options={}, inpaint_options={}):