mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-11 02:15:17 +00:00
Refactor so that zsnr can be set in the sampling_settings.
This commit is contained in:
parent
8a5281006f
commit
2a18e98ccf
@ -2,6 +2,25 @@ import torch
|
|||||||
from comfy.ldm.modules.diffusionmodules.util import make_beta_schedule
|
from comfy.ldm.modules.diffusionmodules.util import make_beta_schedule
|
||||||
import math
|
import math
|
||||||
|
|
||||||
|
def rescale_zero_terminal_snr_sigmas(sigmas):
|
||||||
|
alphas_cumprod = 1 / ((sigmas * sigmas) + 1)
|
||||||
|
alphas_bar_sqrt = alphas_cumprod.sqrt()
|
||||||
|
|
||||||
|
# Store old values.
|
||||||
|
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
|
||||||
|
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
|
||||||
|
|
||||||
|
# Shift so the last timestep is zero.
|
||||||
|
alphas_bar_sqrt -= (alphas_bar_sqrt_T)
|
||||||
|
|
||||||
|
# Scale so the first timestep is back to the old value.
|
||||||
|
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
|
||||||
|
|
||||||
|
# Convert alphas_bar_sqrt to betas
|
||||||
|
alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
|
||||||
|
alphas_bar[-1] = 4.8973451890853435e-08
|
||||||
|
return ((1 - alphas_bar) / alphas_bar) ** 0.5
|
||||||
|
|
||||||
class EPS:
|
class EPS:
|
||||||
def calculate_input(self, sigma, noise):
|
def calculate_input(self, sigma, noise):
|
||||||
sigma = sigma.view(sigma.shape[:1] + (1,) * (noise.ndim - 1))
|
sigma = sigma.view(sigma.shape[:1] + (1,) * (noise.ndim - 1))
|
||||||
@ -48,7 +67,7 @@ class CONST:
|
|||||||
return latent / (1.0 - sigma)
|
return latent / (1.0 - sigma)
|
||||||
|
|
||||||
class ModelSamplingDiscrete(torch.nn.Module):
|
class ModelSamplingDiscrete(torch.nn.Module):
|
||||||
def __init__(self, model_config=None):
|
def __init__(self, model_config=None, zsnr=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
if model_config is not None:
|
if model_config is not None:
|
||||||
@ -61,11 +80,14 @@ class ModelSamplingDiscrete(torch.nn.Module):
|
|||||||
linear_end = sampling_settings.get("linear_end", 0.012)
|
linear_end = sampling_settings.get("linear_end", 0.012)
|
||||||
timesteps = sampling_settings.get("timesteps", 1000)
|
timesteps = sampling_settings.get("timesteps", 1000)
|
||||||
|
|
||||||
self._register_schedule(given_betas=None, beta_schedule=beta_schedule, timesteps=timesteps, linear_start=linear_start, linear_end=linear_end, cosine_s=8e-3)
|
if zsnr is None:
|
||||||
|
zsnr = sampling_settings.get("zsnr", False)
|
||||||
|
|
||||||
|
self._register_schedule(given_betas=None, beta_schedule=beta_schedule, timesteps=timesteps, linear_start=linear_start, linear_end=linear_end, cosine_s=8e-3, zsnr=zsnr)
|
||||||
self.sigma_data = 1.0
|
self.sigma_data = 1.0
|
||||||
|
|
||||||
def _register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000,
|
def _register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000,
|
||||||
linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
|
linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3, zsnr=False):
|
||||||
if given_betas is not None:
|
if given_betas is not None:
|
||||||
betas = given_betas
|
betas = given_betas
|
||||||
else:
|
else:
|
||||||
@ -83,6 +105,9 @@ class ModelSamplingDiscrete(torch.nn.Module):
|
|||||||
# self.register_buffer('alphas_cumprod_prev', torch.tensor(alphas_cumprod_prev, dtype=torch.float32))
|
# self.register_buffer('alphas_cumprod_prev', torch.tensor(alphas_cumprod_prev, dtype=torch.float32))
|
||||||
|
|
||||||
sigmas = ((1 - alphas_cumprod) / alphas_cumprod) ** 0.5
|
sigmas = ((1 - alphas_cumprod) / alphas_cumprod) ** 0.5
|
||||||
|
if zsnr:
|
||||||
|
sigmas = rescale_zero_terminal_snr_sigmas(sigmas)
|
||||||
|
|
||||||
self.set_sigmas(sigmas)
|
self.set_sigmas(sigmas)
|
||||||
|
|
||||||
def set_sigmas(self, sigmas):
|
def set_sigmas(self, sigmas):
|
||||||
|
@ -51,25 +51,6 @@ class ModelSamplingDiscreteDistilled(comfy.model_sampling.ModelSamplingDiscrete)
|
|||||||
return log_sigma.exp().to(timestep.device)
|
return log_sigma.exp().to(timestep.device)
|
||||||
|
|
||||||
|
|
||||||
def rescale_zero_terminal_snr_sigmas(sigmas):
|
|
||||||
alphas_cumprod = 1 / ((sigmas * sigmas) + 1)
|
|
||||||
alphas_bar_sqrt = alphas_cumprod.sqrt()
|
|
||||||
|
|
||||||
# Store old values.
|
|
||||||
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
|
|
||||||
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
|
|
||||||
|
|
||||||
# Shift so the last timestep is zero.
|
|
||||||
alphas_bar_sqrt -= (alphas_bar_sqrt_T)
|
|
||||||
|
|
||||||
# Scale so the first timestep is back to the old value.
|
|
||||||
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
|
|
||||||
|
|
||||||
# Convert alphas_bar_sqrt to betas
|
|
||||||
alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
|
|
||||||
alphas_bar[-1] = 4.8973451890853435e-08
|
|
||||||
return ((1 - alphas_bar) / alphas_bar) ** 0.5
|
|
||||||
|
|
||||||
class ModelSamplingDiscrete:
|
class ModelSamplingDiscrete:
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
@ -100,9 +81,7 @@ class ModelSamplingDiscrete:
|
|||||||
class ModelSamplingAdvanced(sampling_base, sampling_type):
|
class ModelSamplingAdvanced(sampling_base, sampling_type):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
model_sampling = ModelSamplingAdvanced(model.model.model_config)
|
model_sampling = ModelSamplingAdvanced(model.model.model_config, zsnr=zsnr)
|
||||||
if zsnr:
|
|
||||||
model_sampling.set_sigmas(rescale_zero_terminal_snr_sigmas(model_sampling.sigmas))
|
|
||||||
|
|
||||||
m.add_object_patch("model_sampling", model_sampling)
|
m.add_object_patch("model_sampling", model_sampling)
|
||||||
return (m, )
|
return (m, )
|
||||||
|
Loading…
Reference in New Issue
Block a user