Add sgm_uniform scheduler that acts like the default one in sgm.

This commit is contained in:
comfyanonymous 2023-08-14 00:13:35 -04:00
parent 58c7da3665
commit e244b2df83

View File

@ -347,6 +347,17 @@ def ddim_scheduler(model, steps):
sigs += [0.0] sigs += [0.0]
return torch.FloatTensor(sigs) return torch.FloatTensor(sigs)
def sgm_scheduler(model, steps):
sigs = []
timesteps = torch.linspace(model.inner_model.inner_model.num_timesteps - 1, 0, steps + 1)[:-1].type(torch.int)
for x in range(len(timesteps)):
ts = timesteps[x]
if ts > 999:
ts = 999
sigs.append(model.t_to_sigma(torch.tensor(ts)))
sigs += [0.0]
return torch.FloatTensor(sigs)
def blank_inpaint_image_like(latent_image): def blank_inpaint_image_like(latent_image):
blank_image = torch.ones_like(latent_image) blank_image = torch.ones_like(latent_image)
# these are the values for "zero" in pixel space translated to latent space # these are the values for "zero" in pixel space translated to latent space
@ -525,7 +536,7 @@ def encode_adm(model, conds, batch_size, width, height, device, prompt_type):
class KSampler: class KSampler:
SCHEDULERS = ["normal", "karras", "exponential", "simple", "ddim_uniform"] SCHEDULERS = ["normal", "karras", "exponential", "sgm_uniform", "simple", "ddim_uniform"]
SAMPLERS = ["euler", "euler_ancestral", "heun", "dpm_2", "dpm_2_ancestral", SAMPLERS = ["euler", "euler_ancestral", "heun", "dpm_2", "dpm_2_ancestral",
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde", "dpmpp_sde_gpu", "lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde", "dpmpp_sde_gpu",
"dpmpp_2m", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddim", "uni_pc", "uni_pc_bh2"] "dpmpp_2m", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddim", "uni_pc", "uni_pc_bh2"]
@ -570,6 +581,8 @@ class KSampler:
sigmas = simple_scheduler(self.model_wrap, steps) sigmas = simple_scheduler(self.model_wrap, steps)
elif self.scheduler == "ddim_uniform": elif self.scheduler == "ddim_uniform":
sigmas = ddim_scheduler(self.model_wrap, steps) sigmas = ddim_scheduler(self.model_wrap, steps)
elif self.scheduler == "sgm_uniform":
sigmas = sgm_scheduler(self.model_wrap, steps)
else: else:
print("error invalid scheduler", self.scheduler) print("error invalid scheduler", self.scheduler)