ModelSamplingDiscreteLCM -> ModelSamplingDiscreteDistilled

This commit is contained in:
comfyanonymous 2023-11-27 16:32:03 -05:00
parent 13fdee6abf
commit 488de0b4df

View File

@ -17,7 +17,9 @@ class LCM(comfy.model_sampling.EPS):
return c_out * x0 + c_skip * model_input
class ModelSamplingDiscreteLCM(torch.nn.Module):
class ModelSamplingDiscreteDistilled(torch.nn.Module):
original_timesteps = 50
def __init__(self):
super().__init__()
self.sigma_data = 1.0
@ -29,13 +31,12 @@ class ModelSamplingDiscreteLCM(torch.nn.Module):
alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
original_timesteps = 50
self.skip_steps = timesteps // original_timesteps
self.skip_steps = timesteps // self.original_timesteps
alphas_cumprod_valid = torch.zeros((original_timesteps), dtype=torch.float32)
for x in range(original_timesteps):
alphas_cumprod_valid[original_timesteps - 1 - x] = alphas_cumprod[timesteps - 1 - x * self.skip_steps]
alphas_cumprod_valid = torch.zeros((self.original_timesteps), dtype=torch.float32)
for x in range(self.original_timesteps):
alphas_cumprod_valid[self.original_timesteps - 1 - x] = alphas_cumprod[timesteps - 1 - x * self.skip_steps]
sigmas = ((1 - alphas_cumprod_valid) / alphas_cumprod_valid) ** 0.5
self.set_sigmas(sigmas)
@ -116,7 +117,7 @@ class ModelSamplingDiscrete:
sampling_type = comfy.model_sampling.V_PREDICTION
elif sampling == "lcm":
sampling_type = LCM
sampling_base = ModelSamplingDiscreteLCM
sampling_base = ModelSamplingDiscreteDistilled
class ModelSamplingAdvanced(sampling_base, sampling_type):
pass