mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-04-20 03:13:30 +00:00
Refactor LCM to support more model types.
This commit is contained in:
parent
9cad2f06ff
commit
014c8bf2f2
@ -17,41 +17,19 @@ class LCM(comfy.model_sampling.EPS):
|
|||||||
|
|
||||||
return c_out * x0 + c_skip * model_input
|
return c_out * x0 + c_skip * model_input
|
||||||
|
|
||||||
class ModelSamplingDiscreteDistilled(torch.nn.Module):
|
class ModelSamplingDiscreteDistilled(comfy.model_sampling.ModelSamplingDiscrete):
|
||||||
original_timesteps = 50
|
original_timesteps = 50
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self, model_config=None):
|
||||||
super().__init__()
|
super().__init__(model_config)
|
||||||
self.sigma_data = 1.0
|
|
||||||
timesteps = 1000
|
|
||||||
beta_start = 0.00085
|
|
||||||
beta_end = 0.012
|
|
||||||
|
|
||||||
betas = torch.linspace(beta_start**0.5, beta_end**0.5, timesteps, dtype=torch.float32) ** 2
|
self.skip_steps = self.num_timesteps // self.original_timesteps
|
||||||
alphas = 1.0 - betas
|
|
||||||
alphas_cumprod = torch.cumprod(alphas, dim=0)
|
|
||||||
|
|
||||||
self.skip_steps = timesteps // self.original_timesteps
|
sigmas_valid = torch.zeros((self.original_timesteps), dtype=torch.float32)
|
||||||
|
|
||||||
|
|
||||||
alphas_cumprod_valid = torch.zeros((self.original_timesteps), dtype=torch.float32)
|
|
||||||
for x in range(self.original_timesteps):
|
for x in range(self.original_timesteps):
|
||||||
alphas_cumprod_valid[self.original_timesteps - 1 - x] = alphas_cumprod[timesteps - 1 - x * self.skip_steps]
|
sigmas_valid[self.original_timesteps - 1 - x] = self.sigmas[self.num_timesteps - 1 - x * self.skip_steps]
|
||||||
|
|
||||||
sigmas = ((1 - alphas_cumprod_valid) / alphas_cumprod_valid) ** 0.5
|
self.set_sigmas(sigmas_valid)
|
||||||
self.set_sigmas(sigmas)
|
|
||||||
|
|
||||||
def set_sigmas(self, sigmas):
|
|
||||||
self.register_buffer('sigmas', sigmas)
|
|
||||||
self.register_buffer('log_sigmas', sigmas.log())
|
|
||||||
|
|
||||||
@property
|
|
||||||
def sigma_min(self):
|
|
||||||
return self.sigmas[0]
|
|
||||||
|
|
||||||
@property
|
|
||||||
def sigma_max(self):
|
|
||||||
return self.sigmas[-1]
|
|
||||||
|
|
||||||
def timestep(self, sigma):
|
def timestep(self, sigma):
|
||||||
log_sigma = sigma.log()
|
log_sigma = sigma.log()
|
||||||
@ -66,14 +44,6 @@ class ModelSamplingDiscreteDistilled(torch.nn.Module):
|
|||||||
log_sigma = (1 - w) * self.log_sigmas[low_idx] + w * self.log_sigmas[high_idx]
|
log_sigma = (1 - w) * self.log_sigmas[low_idx] + w * self.log_sigmas[high_idx]
|
||||||
return log_sigma.exp().to(timestep.device)
|
return log_sigma.exp().to(timestep.device)
|
||||||
|
|
||||||
def percent_to_sigma(self, percent):
|
|
||||||
if percent <= 0.0:
|
|
||||||
return 999999999.9
|
|
||||||
if percent >= 1.0:
|
|
||||||
return 0.0
|
|
||||||
percent = 1.0 - percent
|
|
||||||
return self.sigma(torch.tensor(percent * 999.0)).item()
|
|
||||||
|
|
||||||
|
|
||||||
def rescale_zero_terminal_snr_sigmas(sigmas):
|
def rescale_zero_terminal_snr_sigmas(sigmas):
|
||||||
alphas_cumprod = 1 / ((sigmas * sigmas) + 1)
|
alphas_cumprod = 1 / ((sigmas * sigmas) + 1)
|
||||||
@ -154,7 +124,7 @@ class ModelSamplingContinuousEDM:
|
|||||||
class ModelSamplingAdvanced(comfy.model_sampling.ModelSamplingContinuousEDM, sampling_type):
|
class ModelSamplingAdvanced(comfy.model_sampling.ModelSamplingContinuousEDM, sampling_type):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
model_sampling = ModelSamplingAdvanced()
|
model_sampling = ModelSamplingAdvanced(model.model.model_config)
|
||||||
model_sampling.set_sigma_range(sigma_min, sigma_max)
|
model_sampling.set_sigma_range(sigma_min, sigma_max)
|
||||||
m.add_object_patch("model_sampling", model_sampling)
|
m.add_object_patch("model_sampling", model_sampling)
|
||||||
return (m, )
|
return (m, )
|
||||||
|
Loading…
Reference in New Issue
Block a user