diff --git a/comfy_extras/nodes_model_advanced.py b/comfy_extras/nodes_model_advanced.py index 6991c983..20261aad 100644 --- a/comfy_extras/nodes_model_advanced.py +++ b/comfy_extras/nodes_model_advanced.py @@ -17,7 +17,9 @@ class LCM(comfy.model_sampling.EPS): return c_out * x0 + c_skip * model_input -class ModelSamplingDiscreteLCM(torch.nn.Module): +class ModelSamplingDiscreteDistilled(torch.nn.Module): + original_timesteps = 50 + def __init__(self): super().__init__() self.sigma_data = 1.0 @@ -29,13 +31,12 @@ class ModelSamplingDiscreteLCM(torch.nn.Module): alphas = 1.0 - betas alphas_cumprod = torch.cumprod(alphas, dim=0) - original_timesteps = 50 - self.skip_steps = timesteps // original_timesteps + self.skip_steps = timesteps // self.original_timesteps - alphas_cumprod_valid = torch.zeros((original_timesteps), dtype=torch.float32) - for x in range(original_timesteps): - alphas_cumprod_valid[original_timesteps - 1 - x] = alphas_cumprod[timesteps - 1 - x * self.skip_steps] + alphas_cumprod_valid = torch.zeros((self.original_timesteps), dtype=torch.float32) + for x in range(self.original_timesteps): + alphas_cumprod_valid[self.original_timesteps - 1 - x] = alphas_cumprod[timesteps - 1 - x * self.skip_steps] sigmas = ((1 - alphas_cumprod_valid) / alphas_cumprod_valid) ** 0.5 self.set_sigmas(sigmas) @@ -116,7 +117,7 @@ class ModelSamplingDiscrete: sampling_type = comfy.model_sampling.V_PREDICTION elif sampling == "lcm": sampling_type = LCM - sampling_base = ModelSamplingDiscreteLCM + sampling_base = ModelSamplingDiscreteDistilled class ModelSamplingAdvanced(sampling_base, sampling_type): pass