Refactor so that zsnr can be set in the sampling_settings.

This commit is contained in:
comfyanonymous 2024-11-11 04:55:56 -05:00
parent 8a5281006f
commit 2a18e98ccf
2 changed files with 29 additions and 25 deletions

View File

@ -2,6 +2,25 @@ import torch
from comfy.ldm.modules.diffusionmodules.util import make_beta_schedule from comfy.ldm.modules.diffusionmodules.util import make_beta_schedule
import math import math
def rescale_zero_terminal_snr_sigmas(sigmas):
alphas_cumprod = 1 / ((sigmas * sigmas) + 1)
alphas_bar_sqrt = alphas_cumprod.sqrt()
# Store old values.
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
# Shift so the last timestep is zero.
alphas_bar_sqrt -= (alphas_bar_sqrt_T)
# Scale so the first timestep is back to the old value.
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
# Convert alphas_bar_sqrt to betas
alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
alphas_bar[-1] = 4.8973451890853435e-08
return ((1 - alphas_bar) / alphas_bar) ** 0.5
class EPS: class EPS:
def calculate_input(self, sigma, noise): def calculate_input(self, sigma, noise):
sigma = sigma.view(sigma.shape[:1] + (1,) * (noise.ndim - 1)) sigma = sigma.view(sigma.shape[:1] + (1,) * (noise.ndim - 1))
@ -48,7 +67,7 @@ class CONST:
return latent / (1.0 - sigma) return latent / (1.0 - sigma)
class ModelSamplingDiscrete(torch.nn.Module): class ModelSamplingDiscrete(torch.nn.Module):
def __init__(self, model_config=None): def __init__(self, model_config=None, zsnr=None):
super().__init__() super().__init__()
if model_config is not None: if model_config is not None:
@ -61,11 +80,14 @@ class ModelSamplingDiscrete(torch.nn.Module):
linear_end = sampling_settings.get("linear_end", 0.012) linear_end = sampling_settings.get("linear_end", 0.012)
timesteps = sampling_settings.get("timesteps", 1000) timesteps = sampling_settings.get("timesteps", 1000)
self._register_schedule(given_betas=None, beta_schedule=beta_schedule, timesteps=timesteps, linear_start=linear_start, linear_end=linear_end, cosine_s=8e-3) if zsnr is None:
zsnr = sampling_settings.get("zsnr", False)
self._register_schedule(given_betas=None, beta_schedule=beta_schedule, timesteps=timesteps, linear_start=linear_start, linear_end=linear_end, cosine_s=8e-3, zsnr=zsnr)
self.sigma_data = 1.0 self.sigma_data = 1.0
def _register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000, def _register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000,
linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3, zsnr=False):
if given_betas is not None: if given_betas is not None:
betas = given_betas betas = given_betas
else: else:
@ -83,6 +105,9 @@ class ModelSamplingDiscrete(torch.nn.Module):
# self.register_buffer('alphas_cumprod_prev', torch.tensor(alphas_cumprod_prev, dtype=torch.float32)) # self.register_buffer('alphas_cumprod_prev', torch.tensor(alphas_cumprod_prev, dtype=torch.float32))
sigmas = ((1 - alphas_cumprod) / alphas_cumprod) ** 0.5 sigmas = ((1 - alphas_cumprod) / alphas_cumprod) ** 0.5
if zsnr:
sigmas = rescale_zero_terminal_snr_sigmas(sigmas)
self.set_sigmas(sigmas) self.set_sigmas(sigmas)
def set_sigmas(self, sigmas): def set_sigmas(self, sigmas):

View File

@ -51,25 +51,6 @@ class ModelSamplingDiscreteDistilled(comfy.model_sampling.ModelSamplingDiscrete)
return log_sigma.exp().to(timestep.device) return log_sigma.exp().to(timestep.device)
def rescale_zero_terminal_snr_sigmas(sigmas):
alphas_cumprod = 1 / ((sigmas * sigmas) + 1)
alphas_bar_sqrt = alphas_cumprod.sqrt()
# Store old values.
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
# Shift so the last timestep is zero.
alphas_bar_sqrt -= (alphas_bar_sqrt_T)
# Scale so the first timestep is back to the old value.
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
# Convert alphas_bar_sqrt to betas
alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
alphas_bar[-1] = 4.8973451890853435e-08
return ((1 - alphas_bar) / alphas_bar) ** 0.5
class ModelSamplingDiscrete: class ModelSamplingDiscrete:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
@ -100,9 +81,7 @@ class ModelSamplingDiscrete:
class ModelSamplingAdvanced(sampling_base, sampling_type): class ModelSamplingAdvanced(sampling_base, sampling_type):
pass pass
model_sampling = ModelSamplingAdvanced(model.model.model_config) model_sampling = ModelSamplingAdvanced(model.model.model_config, zsnr=zsnr)
if zsnr:
model_sampling.set_sigmas(rescale_zero_terminal_snr_sigmas(model_sampling.sigmas))
m.add_object_patch("model_sampling", model_sampling) m.add_object_patch("model_sampling", model_sampling)
return (m, ) return (m, )