Implement shift schedule for cascade stage C.

This commit is contained in:
comfyanonymous 2024-02-17 11:38:47 -05:00
parent 929e266f3e
commit 5b40e7a5ed
2 changed files with 30 additions and 3 deletions

View File

@ -136,9 +136,16 @@ class ModelSamplingContinuousEDM(torch.nn.Module):
class StableCascadeSampling(ModelSamplingDiscrete):
def __init__(self, model_config=None):
super().__init__()
if model_config is not None:
sampling_settings = model_config.sampling_settings
else:
sampling_settings = {}
self.num_timesteps = 1000
self.shift = sampling_settings.get("shift", 1.0)
cosine_s=8e-3
self.cosine_s = torch.tensor([cosine_s])
self.cosine_s = torch.tensor(cosine_s)
sigmas = torch.empty((self.num_timesteps), dtype=torch.float32)
self._init_alpha_cumprod = torch.cos(self.cosine_s / (1 + self.cosine_s) * torch.pi * 0.5) ** 2
for x in range(self.num_timesteps):
@ -148,11 +155,23 @@ class StableCascadeSampling(ModelSamplingDiscrete):
self.set_sigmas(sigmas)
def sigma(self, timestep):
alpha_cumprod = (torch.cos((timestep + self.cosine_s) / (1 + self.cosine_s) * torch.pi * 0.5) ** 2 / self._init_alpha_cumprod).clamp(0.0001, 0.9999)
alpha_cumprod = (torch.cos((timestep + self.cosine_s) / (1 + self.cosine_s) * torch.pi * 0.5) ** 2 / self._init_alpha_cumprod)
if self.shift != 1.0:
var = alpha_cumprod
logSNR = (var/(1-var)).log()
logSNR += 2 * torch.log(1.0 / torch.tensor(self.shift))
alpha_cumprod = logSNR.sigmoid()
alpha_cumprod = alpha_cumprod.clamp(0.0001, 0.9999)
return ((1 - alpha_cumprod) / alpha_cumprod) ** 0.5
def timestep(self, sigma):
return super().timestep(sigma) / 1000.0
var = 1 / ((sigma * sigma) + 1)
var = var.clamp(0, 1.0)
s, min_var = self.cosine_s.to(var.device), self._init_alpha_cumprod.to(var.device)
t = (((var * min_var) ** 0.5).acos() / (torch.pi * 0.5)) * (1 + s) - s
return t
def percent_to_sigma(self, percent):
if percent <= 0.0:

View File

@ -316,6 +316,10 @@ class Stable_Cascade_C(supported_models_base.BASE):
latent_format = latent_formats.SC_Prior
supported_inference_dtypes = [torch.bfloat16, torch.float32]
sampling_settings = {
"shift": 2.0,
}
def process_unet_state_dict(self, state_dict):
key_list = list(state_dict.keys())
for y in ["weight", "bias"]:
@ -348,6 +352,10 @@ class Stable_Cascade_B(Stable_Cascade_C):
latent_format = latent_formats.SC_B
supported_inference_dtypes = [torch.float16, torch.bfloat16, torch.float32]
sampling_settings = {
"shift": 1.0,
}
def get_model(self, state_dict, prefix="", device=None):
out = model_base.StableCascade_B(self, device=device)
return out