mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-10 18:05:16 +00:00
Fix unknown sampler error handling in calculate_sigmas function (#6280)
Modernize calculate_sigmas function
This commit is contained in:
parent
1c99734e5a
commit
c0338a46a4
@ -1,12 +1,13 @@
|
||||
from __future__ import annotations
|
||||
from .k_diffusion import sampling as k_diffusion_sampling
|
||||
from .extra_samplers import uni_pc
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Callable, NamedTuple
|
||||
if TYPE_CHECKING:
|
||||
from comfy.model_patcher import ModelPatcher
|
||||
from comfy.model_base import BaseModel
|
||||
from comfy.controlnet import ControlBase
|
||||
import torch
|
||||
from functools import partial
|
||||
import collections
|
||||
from comfy import model_management
|
||||
import math
|
||||
@ -920,31 +921,37 @@ def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model
|
||||
return cfg_guider.sample(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed)
|
||||
|
||||
|
||||
SCHEDULER_NAMES = ["normal", "karras", "exponential", "sgm_uniform", "simple", "ddim_uniform", "beta", "linear_quadratic", "kl_optimal"]
|
||||
SAMPLER_NAMES = KSAMPLER_NAMES + ["ddim", "uni_pc", "uni_pc_bh2"]
|
||||
|
||||
def calculate_sigmas(model_sampling, scheduler_name, steps):
|
||||
if scheduler_name == "karras":
|
||||
sigmas = k_diffusion_sampling.get_sigmas_karras(n=steps, sigma_min=float(model_sampling.sigma_min), sigma_max=float(model_sampling.sigma_max))
|
||||
elif scheduler_name == "exponential":
|
||||
sigmas = k_diffusion_sampling.get_sigmas_exponential(n=steps, sigma_min=float(model_sampling.sigma_min), sigma_max=float(model_sampling.sigma_max))
|
||||
elif scheduler_name == "normal":
|
||||
sigmas = normal_scheduler(model_sampling, steps)
|
||||
elif scheduler_name == "simple":
|
||||
sigmas = simple_scheduler(model_sampling, steps)
|
||||
elif scheduler_name == "ddim_uniform":
|
||||
sigmas = ddim_scheduler(model_sampling, steps)
|
||||
elif scheduler_name == "sgm_uniform":
|
||||
sigmas = normal_scheduler(model_sampling, steps, sgm=True)
|
||||
elif scheduler_name == "beta":
|
||||
sigmas = beta_scheduler(model_sampling, steps)
|
||||
elif scheduler_name == "linear_quadratic":
|
||||
sigmas = linear_quadratic_schedule(model_sampling, steps)
|
||||
elif scheduler_name == "kl_optimal":
|
||||
sigmas = kl_optimal_scheduler(n=steps, sigma_min=float(model_sampling.sigma_min), sigma_max=float(model_sampling.sigma_max))
|
||||
else:
|
||||
logging.error("error invalid scheduler {}".format(scheduler_name))
|
||||
return sigmas
|
||||
class SchedulerHandler(NamedTuple):
|
||||
handler: Callable[..., torch.Tensor]
|
||||
# Boolean indicates whether to call the handler like:
|
||||
# scheduler_function(model_sampling, steps) or
|
||||
# scheduler_function(n, sigma_min: float, sigma_max: float)
|
||||
use_ms: bool = True
|
||||
|
||||
SCHEDULER_HANDLERS = {
|
||||
"normal": SchedulerHandler(normal_scheduler),
|
||||
"karras": SchedulerHandler(k_diffusion_sampling.get_sigmas_karras, use_ms=False),
|
||||
"exponential": SchedulerHandler(k_diffusion_sampling.get_sigmas_exponential, use_ms=False),
|
||||
"sgm_uniform": SchedulerHandler(partial(normal_scheduler, sgm=True)),
|
||||
"simple": SchedulerHandler(simple_scheduler),
|
||||
"ddim_uniform": SchedulerHandler(ddim_scheduler),
|
||||
"beta": SchedulerHandler(beta_scheduler),
|
||||
"linear_quadratic": SchedulerHandler(linear_quadratic_schedule),
|
||||
"kl_optimal": SchedulerHandler(kl_optimal_scheduler, use_ms=False),
|
||||
}
|
||||
SCHEDULER_NAMES = list(SCHEDULER_HANDLERS)
|
||||
|
||||
def calculate_sigmas(model_sampling: object, scheduler_name: str, steps: int) -> torch.Tensor:
|
||||
handler = SCHEDULER_HANDLERS.get(scheduler_name)
|
||||
if handler is None:
|
||||
err = f"error invalid scheduler {scheduler_name}"
|
||||
logging.error(err)
|
||||
raise ValueError(err)
|
||||
if handler.use_ms:
|
||||
return handler.handler(model_sampling, steps)
|
||||
return handler.handler(n=steps, sigma_min=float(model_sampling.sigma_min), sigma_max=float(model_sampling.sigma_max))
|
||||
|
||||
def sampler_object(name):
|
||||
if name == "uni_pc":
|
||||
|
Loading…
Reference in New Issue
Block a user