From d234ca558a7777b607a4f81aeb9e8703ef020977 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 28 Sep 2023 00:17:03 -0400 Subject: [PATCH] Add missing samplers to KSamplerSelect. --- comfy/samplers.py | 20 ++++++++++++-------- comfy_extras/nodes_custom_sampler.py | 4 ++-- 2 files changed, 14 insertions(+), 10 deletions(-) diff --git a/comfy/samplers.py b/comfy/samplers.py index 7668d791..a7c240f4 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -711,6 +711,17 @@ def calculate_sigmas_scheduler(model, scheduler_name, steps): print("error invalid scheduler", self.scheduler) return sigmas +def sampler_class(name): + if name == "uni_pc": + sampler = UNIPC + elif name == "uni_pc_bh2": + sampler = UNIPCBH2 + elif name == "ddim": + sampler = DDIM + else: + sampler = ksampler(name) + return sampler + class KSampler: SCHEDULERS = SCHEDULER_NAMES SAMPLERS = SAMPLER_NAMES @@ -769,13 +780,6 @@ class KSampler: else: return torch.zeros_like(noise) - if self.sampler == "uni_pc": - sampler = UNIPC - elif self.sampler == "uni_pc_bh2": - sampler = UNIPCBH2 - elif self.sampler == "ddim": - sampler = DDIM - else: - sampler = ksampler(self.sampler) + sampler = sampler_class(self.sampler) return sample(self.model, noise, positive, negative, cfg, self.device, sampler(), sigmas, self.model_options, latent_image=latent_image, denoise_mask=denoise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed) diff --git a/comfy_extras/nodes_custom_sampler.py b/comfy_extras/nodes_custom_sampler.py index 842a9de4..1c587dbd 100644 --- a/comfy_extras/nodes_custom_sampler.py +++ b/comfy_extras/nodes_custom_sampler.py @@ -28,7 +28,7 @@ class KSamplerSelect: @classmethod def INPUT_TYPES(s): return {"required": - {"sampler_name": (comfy.samplers.KSAMPLER_NAMES, ), + {"sampler_name": (comfy.samplers.SAMPLER_NAMES, ), } } RETURN_TYPES = ("SAMPLER",) @@ -37,7 +37,7 @@ class KSamplerSelect: FUNCTION = "get_sampler" def get_sampler(self, sampler_name): - sampler = comfy.samplers.ksampler(sampler_name)() + sampler = comfy.samplers.sampler_class(sampler_name)() return (sampler, ) class SamplerCustom: