mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-03-12 22:02:14 +00:00
Implement shift schedule for cascade stage C.
This commit is contained in:
parent
929e266f3e
commit
5b40e7a5ed
@ -136,9 +136,16 @@ class ModelSamplingContinuousEDM(torch.nn.Module):
|
|||||||
class StableCascadeSampling(ModelSamplingDiscrete):
|
class StableCascadeSampling(ModelSamplingDiscrete):
|
||||||
def __init__(self, model_config=None):
|
def __init__(self, model_config=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
if model_config is not None:
|
||||||
|
sampling_settings = model_config.sampling_settings
|
||||||
|
else:
|
||||||
|
sampling_settings = {}
|
||||||
|
|
||||||
self.num_timesteps = 1000
|
self.num_timesteps = 1000
|
||||||
|
self.shift = sampling_settings.get("shift", 1.0)
|
||||||
cosine_s=8e-3
|
cosine_s=8e-3
|
||||||
self.cosine_s = torch.tensor([cosine_s])
|
self.cosine_s = torch.tensor(cosine_s)
|
||||||
sigmas = torch.empty((self.num_timesteps), dtype=torch.float32)
|
sigmas = torch.empty((self.num_timesteps), dtype=torch.float32)
|
||||||
self._init_alpha_cumprod = torch.cos(self.cosine_s / (1 + self.cosine_s) * torch.pi * 0.5) ** 2
|
self._init_alpha_cumprod = torch.cos(self.cosine_s / (1 + self.cosine_s) * torch.pi * 0.5) ** 2
|
||||||
for x in range(self.num_timesteps):
|
for x in range(self.num_timesteps):
|
||||||
@ -148,11 +155,23 @@ class StableCascadeSampling(ModelSamplingDiscrete):
|
|||||||
self.set_sigmas(sigmas)
|
self.set_sigmas(sigmas)
|
||||||
|
|
||||||
def sigma(self, timestep):
|
def sigma(self, timestep):
|
||||||
alpha_cumprod = (torch.cos((timestep + self.cosine_s) / (1 + self.cosine_s) * torch.pi * 0.5) ** 2 / self._init_alpha_cumprod).clamp(0.0001, 0.9999)
|
alpha_cumprod = (torch.cos((timestep + self.cosine_s) / (1 + self.cosine_s) * torch.pi * 0.5) ** 2 / self._init_alpha_cumprod)
|
||||||
|
|
||||||
|
if self.shift != 1.0:
|
||||||
|
var = alpha_cumprod
|
||||||
|
logSNR = (var/(1-var)).log()
|
||||||
|
logSNR += 2 * torch.log(1.0 / torch.tensor(self.shift))
|
||||||
|
alpha_cumprod = logSNR.sigmoid()
|
||||||
|
|
||||||
|
alpha_cumprod = alpha_cumprod.clamp(0.0001, 0.9999)
|
||||||
return ((1 - alpha_cumprod) / alpha_cumprod) ** 0.5
|
return ((1 - alpha_cumprod) / alpha_cumprod) ** 0.5
|
||||||
|
|
||||||
def timestep(self, sigma):
|
def timestep(self, sigma):
|
||||||
return super().timestep(sigma) / 1000.0
|
var = 1 / ((sigma * sigma) + 1)
|
||||||
|
var = var.clamp(0, 1.0)
|
||||||
|
s, min_var = self.cosine_s.to(var.device), self._init_alpha_cumprod.to(var.device)
|
||||||
|
t = (((var * min_var) ** 0.5).acos() / (torch.pi * 0.5)) * (1 + s) - s
|
||||||
|
return t
|
||||||
|
|
||||||
def percent_to_sigma(self, percent):
|
def percent_to_sigma(self, percent):
|
||||||
if percent <= 0.0:
|
if percent <= 0.0:
|
||||||
|
@ -316,6 +316,10 @@ class Stable_Cascade_C(supported_models_base.BASE):
|
|||||||
latent_format = latent_formats.SC_Prior
|
latent_format = latent_formats.SC_Prior
|
||||||
supported_inference_dtypes = [torch.bfloat16, torch.float32]
|
supported_inference_dtypes = [torch.bfloat16, torch.float32]
|
||||||
|
|
||||||
|
sampling_settings = {
|
||||||
|
"shift": 2.0,
|
||||||
|
}
|
||||||
|
|
||||||
def process_unet_state_dict(self, state_dict):
|
def process_unet_state_dict(self, state_dict):
|
||||||
key_list = list(state_dict.keys())
|
key_list = list(state_dict.keys())
|
||||||
for y in ["weight", "bias"]:
|
for y in ["weight", "bias"]:
|
||||||
@ -348,6 +352,10 @@ class Stable_Cascade_B(Stable_Cascade_C):
|
|||||||
latent_format = latent_formats.SC_B
|
latent_format = latent_formats.SC_B
|
||||||
supported_inference_dtypes = [torch.float16, torch.bfloat16, torch.float32]
|
supported_inference_dtypes = [torch.float16, torch.bfloat16, torch.float32]
|
||||||
|
|
||||||
|
sampling_settings = {
|
||||||
|
"shift": 1.0,
|
||||||
|
}
|
||||||
|
|
||||||
def get_model(self, state_dict, prefix="", device=None):
|
def get_model(self, state_dict, prefix="", device=None):
|
||||||
out = model_base.StableCascade_B(self, device=device)
|
out = model_base.StableCascade_B(self, device=device)
|
||||||
return out
|
return out
|
||||||
|
Loading…
Reference in New Issue
Block a user