Add missing samplers to KSamplerSelect.

This commit is contained in:
comfyanonymous 2023-09-28 00:17:03 -04:00
parent 1d7dfc07d5
commit d234ca558a
2 changed files with 14 additions and 10 deletions

View File

@ -711,6 +711,17 @@ def calculate_sigmas_scheduler(model, scheduler_name, steps):
print("error invalid scheduler", self.scheduler) print("error invalid scheduler", self.scheduler)
return sigmas return sigmas
def sampler_class(name):
if name == "uni_pc":
sampler = UNIPC
elif name == "uni_pc_bh2":
sampler = UNIPCBH2
elif name == "ddim":
sampler = DDIM
else:
sampler = ksampler(name)
return sampler
class KSampler: class KSampler:
SCHEDULERS = SCHEDULER_NAMES SCHEDULERS = SCHEDULER_NAMES
SAMPLERS = SAMPLER_NAMES SAMPLERS = SAMPLER_NAMES
@ -769,13 +780,6 @@ class KSampler:
else: else:
return torch.zeros_like(noise) return torch.zeros_like(noise)
if self.sampler == "uni_pc": sampler = sampler_class(self.sampler)
sampler = UNIPC
elif self.sampler == "uni_pc_bh2":
sampler = UNIPCBH2
elif self.sampler == "ddim":
sampler = DDIM
else:
sampler = ksampler(self.sampler)
return sample(self.model, noise, positive, negative, cfg, self.device, sampler(), sigmas, self.model_options, latent_image=latent_image, denoise_mask=denoise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed) return sample(self.model, noise, positive, negative, cfg, self.device, sampler(), sigmas, self.model_options, latent_image=latent_image, denoise_mask=denoise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed)

View File

@ -28,7 +28,7 @@ class KSamplerSelect:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
return {"required": return {"required":
{"sampler_name": (comfy.samplers.KSAMPLER_NAMES, ), {"sampler_name": (comfy.samplers.SAMPLER_NAMES, ),
} }
} }
RETURN_TYPES = ("SAMPLER",) RETURN_TYPES = ("SAMPLER",)
@ -37,7 +37,7 @@ class KSamplerSelect:
FUNCTION = "get_sampler" FUNCTION = "get_sampler"
def get_sampler(self, sampler_name): def get_sampler(self, sampler_name):
sampler = comfy.samplers.ksampler(sampler_name)() sampler = comfy.samplers.sampler_class(sampler_name)()
return (sampler, ) return (sampler, )
class SamplerCustom: class SamplerCustom: