Scheduler code refactor.

This commit is contained in:
comfyanonymous 2023-09-26 16:25:34 -04:00
parent 446caf711c
commit 1d6dd83184

View File

@ -549,7 +549,7 @@ class Sampler:
pass
def max_denoise(self, model_wrap, sigmas):
return math.isclose(float(model_wrap.sigma_max), float(sigmas[0]))
return math.isclose(float(model_wrap.sigma_max), float(sigmas[0]), rel_tol=1e-05)
class DDIM(Sampler):
def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False):
@ -631,6 +631,13 @@ def ksampler(sampler_name):
return samples
return KSAMPLER
def wrap_model(model):
model_denoise = CFGNoisePredictor(model)
if model.model_type == model_base.ModelType.V_PREDICTION:
model_wrap = CompVisVDenoiser(model_denoise, quantize=True)
else:
model_wrap = k_diffusion_external.CompVisDenoiser(model_denoise, quantize=True)
return model_wrap
def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model_options={}, latent_image=None, denoise_mask=None, callback=None, disable_pbar=False, seed=None):
positive = positive[:]
@ -639,11 +646,7 @@ def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model
resolve_areas_and_cond_masks(positive, noise.shape[2], noise.shape[3], device)
resolve_areas_and_cond_masks(negative, noise.shape[2], noise.shape[3], device)
model_denoise = CFGNoisePredictor(model)
if model.model_type == model_base.ModelType.V_PREDICTION:
model_wrap = CompVisVDenoiser(model_denoise, quantize=True)
else:
model_wrap = k_diffusion_external.CompVisDenoiser(model_denoise, quantize=True)
model_wrap = wrap_model(model)
calculate_start_end_timesteps(model_wrap, negative)
calculate_start_end_timesteps(model_wrap, positive)
@ -687,19 +690,33 @@ def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model
samples = sampler.sample(model_wrap, sigmas, extra_args, callback, noise, latent_image, denoise_mask, disable_pbar)
return model.process_latent_out(samples.to(torch.float32))
SCHEDULER_NAMES = ["normal", "karras", "exponential", "sgm_uniform", "simple", "ddim_uniform"]
SAMPLER_NAMES = KSAMPLER_NAMES + ["ddim", "uni_pc", "uni_pc_bh2"]
def calculate_sigmas_scheduler(model, scheduler_name, steps):
model_wrap = wrap_model(model)
if scheduler_name == "karras":
sigmas = k_diffusion_sampling.get_sigmas_karras(n=steps, sigma_min=float(model_wrap.sigma_min), sigma_max=float(model_wrap.sigma_max))
elif scheduler_name == "exponential":
sigmas = k_diffusion_sampling.get_sigmas_exponential(n=steps, sigma_min=float(model_wrap.sigma_min), sigma_max=float(model_wrap.sigma_max))
elif scheduler_name == "normal":
sigmas = model_wrap.get_sigmas(steps)
elif scheduler_name == "simple":
sigmas = simple_scheduler(model_wrap, steps)
elif scheduler_name == "ddim_uniform":
sigmas = ddim_scheduler(model_wrap, steps)
elif scheduler_name == "sgm_uniform":
sigmas = sgm_scheduler(model_wrap, steps)
else:
print("error invalid scheduler", self.scheduler)
return sigmas
class KSampler:
SCHEDULERS = ["normal", "karras", "exponential", "sgm_uniform", "simple", "ddim_uniform"]
SAMPLERS = KSAMPLER_NAMES + ["ddim", "uni_pc", "uni_pc_bh2"]
SCHEDULERS = SCHEDULER_NAMES
SAMPLERS = SAMPLER_NAMES
def __init__(self, model, steps, device, sampler=None, scheduler=None, denoise=None, model_options={}):
self.model = model
self.model_denoise = CFGNoisePredictor(self.model)
if self.model.model_type == model_base.ModelType.V_PREDICTION:
self.model_wrap = CompVisVDenoiser(self.model_denoise, quantize=True)
else:
self.model_wrap = k_diffusion_external.CompVisDenoiser(self.model_denoise, quantize=True)
self.model_k = KSamplerX0Inpaint(self.model_wrap)
self.device = device
if scheduler not in self.SCHEDULERS:
scheduler = self.SCHEDULERS[0]
@ -707,8 +724,6 @@ class KSampler:
sampler = self.SAMPLERS[0]
self.scheduler = scheduler
self.sampler = sampler
self.sigma_min=float(self.model_wrap.sigma_min)
self.sigma_max=float(self.model_wrap.sigma_max)
self.set_steps(steps, denoise)
self.denoise = denoise
self.model_options = model_options
@ -721,20 +736,7 @@ class KSampler:
steps += 1
discard_penultimate_sigma = True
if self.scheduler == "karras":
sigmas = k_diffusion_sampling.get_sigmas_karras(n=steps, sigma_min=self.sigma_min, sigma_max=self.sigma_max)
elif self.scheduler == "exponential":
sigmas = k_diffusion_sampling.get_sigmas_exponential(n=steps, sigma_min=self.sigma_min, sigma_max=self.sigma_max)
elif self.scheduler == "normal":
sigmas = self.model_wrap.get_sigmas(steps)
elif self.scheduler == "simple":
sigmas = simple_scheduler(self.model_wrap, steps)
elif self.scheduler == "ddim_uniform":
sigmas = ddim_scheduler(self.model_wrap, steps)
elif self.scheduler == "sgm_uniform":
sigmas = sgm_scheduler(self.model_wrap, steps)
else:
print("error invalid scheduler", self.scheduler)
sigmas = calculate_sigmas_scheduler(self.model, self.scheduler, steps)
if discard_penultimate_sigma:
sigmas = torch.cat([sigmas[:-2], sigmas[-1:]])
@ -752,10 +754,8 @@ class KSampler:
def sample(self, noise, positive, negative, cfg, latent_image=None, start_step=None, last_step=None, force_full_denoise=False, denoise_mask=None, sigmas=None, callback=None, disable_pbar=False, seed=None):
if sigmas is None:
sigmas = self.sigmas
sigma_min = self.sigma_min
if last_step is not None and last_step < (len(sigmas) - 1):
sigma_min = sigmas[last_step]
sigmas = sigmas[:last_step + 1]
if force_full_denoise:
sigmas[-1] = 0