diff --git a/comfy/model_sampling.py b/comfy/model_sampling.py index b1fbf3e2..f42f3015 100644 --- a/comfy/model_sampling.py +++ b/comfy/model_sampling.py @@ -136,9 +136,16 @@ class ModelSamplingContinuousEDM(torch.nn.Module): class StableCascadeSampling(ModelSamplingDiscrete): def __init__(self, model_config=None): super().__init__() + + if model_config is not None: + sampling_settings = model_config.sampling_settings + else: + sampling_settings = {} + self.num_timesteps = 1000 + self.shift = sampling_settings.get("shift", 1.0) cosine_s=8e-3 - self.cosine_s = torch.tensor([cosine_s]) + self.cosine_s = torch.tensor(cosine_s) sigmas = torch.empty((self.num_timesteps), dtype=torch.float32) self._init_alpha_cumprod = torch.cos(self.cosine_s / (1 + self.cosine_s) * torch.pi * 0.5) ** 2 for x in range(self.num_timesteps): @@ -148,11 +155,23 @@ class StableCascadeSampling(ModelSamplingDiscrete): self.set_sigmas(sigmas) def sigma(self, timestep): - alpha_cumprod = (torch.cos((timestep + self.cosine_s) / (1 + self.cosine_s) * torch.pi * 0.5) ** 2 / self._init_alpha_cumprod).clamp(0.0001, 0.9999) + alpha_cumprod = (torch.cos((timestep + self.cosine_s) / (1 + self.cosine_s) * torch.pi * 0.5) ** 2 / self._init_alpha_cumprod) + + if self.shift != 1.0: + var = alpha_cumprod + logSNR = (var/(1-var)).log() + logSNR += 2 * torch.log(1.0 / torch.tensor(self.shift)) + alpha_cumprod = logSNR.sigmoid() + + alpha_cumprod = alpha_cumprod.clamp(0.0001, 0.9999) return ((1 - alpha_cumprod) / alpha_cumprod) ** 0.5 def timestep(self, sigma): - return super().timestep(sigma) / 1000.0 + var = 1 / ((sigma * sigma) + 1) + var = var.clamp(0, 1.0) + s, min_var = self.cosine_s.to(var.device), self._init_alpha_cumprod.to(var.device) + t = (((var * min_var) ** 0.5).acos() / (torch.pi * 0.5)) * (1 + s) - s + return t def percent_to_sigma(self, percent): if percent <= 0.0: diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 3a317edc..1a673646 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -316,6 +316,10 @@ class Stable_Cascade_C(supported_models_base.BASE): latent_format = latent_formats.SC_Prior supported_inference_dtypes = [torch.bfloat16, torch.float32] + sampling_settings = { + "shift": 2.0, + } + def process_unet_state_dict(self, state_dict): key_list = list(state_dict.keys()) for y in ["weight", "bias"]: @@ -348,6 +352,10 @@ class Stable_Cascade_B(Stable_Cascade_C): latent_format = latent_formats.SC_B supported_inference_dtypes = [torch.float16, torch.bfloat16, torch.float32] + sampling_settings = { + "shift": 1.0, + } + def get_model(self, state_dict, prefix="", device=None): out = model_base.StableCascade_B(self, device=device) return out