mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-25 15:55:18 +00:00
heunpp2 sampler.
This commit is contained in:
parent
728613bb3e
commit
57eea0efbb
@ -750,3 +750,61 @@ def sample_lcm(model, x, sigmas, extra_args=None, callback=None, disable=None, n
|
|||||||
if sigmas[i + 1] > 0:
|
if sigmas[i + 1] > 0:
|
||||||
x += sigmas[i + 1] * noise_sampler(sigmas[i], sigmas[i + 1])
|
x += sigmas[i + 1] * noise_sampler(sigmas[i], sigmas[i + 1])
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def sample_heunpp2(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
|
||||||
|
# From MIT licensed: https://github.com/Carzit/sd-webui-samplers-scheduler/
|
||||||
|
extra_args = {} if extra_args is None else extra_args
|
||||||
|
s_in = x.new_ones([x.shape[0]])
|
||||||
|
s_end = sigmas[-1]
|
||||||
|
for i in trange(len(sigmas) - 1, disable=disable):
|
||||||
|
gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
|
||||||
|
eps = torch.randn_like(x) * s_noise
|
||||||
|
sigma_hat = sigmas[i] * (gamma + 1)
|
||||||
|
if gamma > 0:
|
||||||
|
x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
|
||||||
|
denoised = model(x, sigma_hat * s_in, **extra_args)
|
||||||
|
d = to_d(x, sigma_hat, denoised)
|
||||||
|
if callback is not None:
|
||||||
|
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
|
||||||
|
dt = sigmas[i + 1] - sigma_hat
|
||||||
|
if sigmas[i + 1] == s_end:
|
||||||
|
# Euler method
|
||||||
|
x = x + d * dt
|
||||||
|
elif sigmas[i + 2] == s_end:
|
||||||
|
|
||||||
|
# Heun's method
|
||||||
|
x_2 = x + d * dt
|
||||||
|
denoised_2 = model(x_2, sigmas[i + 1] * s_in, **extra_args)
|
||||||
|
d_2 = to_d(x_2, sigmas[i + 1], denoised_2)
|
||||||
|
|
||||||
|
w = 2 * sigmas[0]
|
||||||
|
w2 = sigmas[i+1]/w
|
||||||
|
w1 = 1 - w2
|
||||||
|
|
||||||
|
d_prime = d * w1 + d_2 * w2
|
||||||
|
|
||||||
|
|
||||||
|
x = x + d_prime * dt
|
||||||
|
|
||||||
|
else:
|
||||||
|
# Heun++
|
||||||
|
x_2 = x + d * dt
|
||||||
|
denoised_2 = model(x_2, sigmas[i + 1] * s_in, **extra_args)
|
||||||
|
d_2 = to_d(x_2, sigmas[i + 1], denoised_2)
|
||||||
|
dt_2 = sigmas[i + 2] - sigmas[i + 1]
|
||||||
|
|
||||||
|
x_3 = x_2 + d_2 * dt_2
|
||||||
|
denoised_3 = model(x_3, sigmas[i + 2] * s_in, **extra_args)
|
||||||
|
d_3 = to_d(x_3, sigmas[i + 2], denoised_3)
|
||||||
|
|
||||||
|
w = 3 * sigmas[0]
|
||||||
|
w2 = sigmas[i + 1] / w
|
||||||
|
w3 = sigmas[i + 2] / w
|
||||||
|
w1 = 1 - w2 - w3
|
||||||
|
|
||||||
|
d_prime = w1 * d + w2 * d_2 + w3 * d_3
|
||||||
|
x = x + d_prime * dt
|
||||||
|
return x
|
||||||
|
@ -518,7 +518,7 @@ class UNIPCBH2(Sampler):
|
|||||||
def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False):
|
def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False):
|
||||||
return uni_pc.sample_unipc(model_wrap, noise, latent_image, sigmas, max_denoise=self.max_denoise(model_wrap, sigmas), extra_args=extra_args, noise_mask=denoise_mask, callback=callback, variant='bh2', disable=disable_pbar)
|
return uni_pc.sample_unipc(model_wrap, noise, latent_image, sigmas, max_denoise=self.max_denoise(model_wrap, sigmas), extra_args=extra_args, noise_mask=denoise_mask, callback=callback, variant='bh2', disable=disable_pbar)
|
||||||
|
|
||||||
KSAMPLER_NAMES = ["euler", "euler_ancestral", "heun", "dpm_2", "dpm_2_ancestral",
|
KSAMPLER_NAMES = ["euler", "euler_ancestral", "heun", "heunpp2","dpm_2", "dpm_2_ancestral",
|
||||||
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde", "dpmpp_sde_gpu",
|
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde", "dpmpp_sde_gpu",
|
||||||
"dpmpp_2m", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm"]
|
"dpmpp_2m", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm"]
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user