From 420beeeb05ef59e887f8731f615f8a9ec6eb0a4c Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 14 Nov 2023 00:39:34 -0500 Subject: [PATCH] Clean up and refactor sampler code. This should make it much easier to write custom nodes with kdiffusion type samplers. --- comfy/samplers.py | 89 +++++++++++++++++----------- comfy_extras/nodes_custom_sampler.py | 6 +- 2 files changed, 56 insertions(+), 39 deletions(-) diff --git a/comfy/samplers.py b/comfy/samplers.py index 5340dd01..65c44791 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -522,42 +522,59 @@ KSAMPLER_NAMES = ["euler", "euler_ancestral", "heun", "dpm_2", "dpm_2_ancestral" "lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde", "dpmpp_sde_gpu", "dpmpp_2m", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm"] +class KSAMPLER(Sampler): + def __init__(self, sampler_function, extra_options={}, inpaint_options={}): + self.sampler_function = sampler_function + self.extra_options = extra_options + self.inpaint_options = inpaint_options + + def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False): + extra_args["denoise_mask"] = denoise_mask + model_k = KSamplerX0Inpaint(model_wrap) + model_k.latent_image = latent_image + if self.inpaint_options.get("random", False): #TODO: Should this be the default? + generator = torch.manual_seed(extra_args.get("seed", 41) + 1) + model_k.noise = torch.randn(noise.shape, generator=generator, device="cpu").to(noise.dtype).to(noise.device) + else: + model_k.noise = noise + + if self.max_denoise(model_wrap, sigmas): + noise = noise * torch.sqrt(1.0 + sigmas[0] ** 2.0) + else: + noise = noise * sigmas[0] + + k_callback = None + total_steps = len(sigmas) - 1 + if callback is not None: + k_callback = lambda x: callback(x["i"], x["denoised"], x["x"], total_steps) + + if latent_image is not None: + noise += latent_image + + samples = self.sampler_function(model_k, noise, sigmas, extra_args=extra_args, callback=k_callback, disable=disable_pbar, **self.extra_options) + return samples + + def ksampler(sampler_name, extra_options={}, inpaint_options={}): - class KSAMPLER(Sampler): - def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False): - extra_args["denoise_mask"] = denoise_mask - model_k = KSamplerX0Inpaint(model_wrap) - model_k.latent_image = latent_image - if inpaint_options.get("random", False): #TODO: Should this be the default? - generator = torch.manual_seed(extra_args.get("seed", 41) + 1) - model_k.noise = torch.randn(noise.shape, generator=generator, device="cpu").to(noise.dtype).to(noise.device) - else: - model_k.noise = noise - - if self.max_denoise(model_wrap, sigmas): - noise = noise * torch.sqrt(1.0 + sigmas[0] ** 2.0) - else: - noise = noise * sigmas[0] - - k_callback = None - total_steps = len(sigmas) - 1 - if callback is not None: - k_callback = lambda x: callback(x["i"], x["denoised"], x["x"], total_steps) - + if sampler_name == "dpm_fast": + def dpm_fast_function(model, noise, sigmas, extra_args, callback, disable): sigma_min = sigmas[-1] if sigma_min == 0: sigma_min = sigmas[-2] + total_steps = len(sigmas) - 1 + return k_diffusion_sampling.sample_dpm_fast(model, noise, sigma_min, sigmas[0], total_steps, extra_args=extra_args, callback=callback, disable=disable) + sampler_function = dpm_fast_function + elif sampler_name == "dpm_adaptive": + def dpm_adaptive_function(model, noise, sigmas, extra_args, callback, disable): + sigma_min = sigmas[-1] + if sigma_min == 0: + sigma_min = sigmas[-2] + return k_diffusion_sampling.sample_dpm_adaptive(model, noise, sigma_min, sigmas[0], extra_args=extra_args, callback=callback, disable=disable) + sampler_function = dpm_adaptive_function + else: + sampler_function = getattr(k_diffusion_sampling, "sample_{}".format(sampler_name)) - if latent_image is not None: - noise += latent_image - if sampler_name == "dpm_fast": - samples = k_diffusion_sampling.sample_dpm_fast(model_k, noise, sigma_min, sigmas[0], total_steps, extra_args=extra_args, callback=k_callback, disable=disable_pbar) - elif sampler_name == "dpm_adaptive": - samples = k_diffusion_sampling.sample_dpm_adaptive(model_k, noise, sigma_min, sigmas[0], extra_args=extra_args, callback=k_callback, disable=disable_pbar) - else: - samples = getattr(k_diffusion_sampling, "sample_{}".format(sampler_name))(model_k, noise, sigmas, extra_args=extra_args, callback=k_callback, disable=disable_pbar, **extra_options) - return samples - return KSAMPLER + return KSAMPLER(sampler_function, extra_options, inpaint_options) def wrap_model(model): model_denoise = CFGNoisePredictor(model) @@ -618,11 +635,11 @@ def calculate_sigmas_scheduler(model, scheduler_name, steps): print("error invalid scheduler", self.scheduler) return sigmas -def sampler_class(name): +def sampler_object(name): if name == "uni_pc": - sampler = UNIPC + sampler = UNIPC() elif name == "uni_pc_bh2": - sampler = UNIPCBH2 + sampler = UNIPCBH2() elif name == "ddim": sampler = ksampler("euler", inpaint_options={"random": True}) else: @@ -687,6 +704,6 @@ class KSampler: else: return torch.zeros_like(noise) - sampler = sampler_class(self.sampler) + sampler = sampler_object(self.sampler) - return sample(self.model, noise, positive, negative, cfg, self.device, sampler(), sigmas, self.model_options, latent_image=latent_image, denoise_mask=denoise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed) + return sample(self.model, noise, positive, negative, cfg, self.device, sampler, sigmas, self.model_options, latent_image=latent_image, denoise_mask=denoise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed) diff --git a/comfy_extras/nodes_custom_sampler.py b/comfy_extras/nodes_custom_sampler.py index f0576946..d3c1d4a2 100644 --- a/comfy_extras/nodes_custom_sampler.py +++ b/comfy_extras/nodes_custom_sampler.py @@ -149,7 +149,7 @@ class KSamplerSelect: FUNCTION = "get_sampler" def get_sampler(self, sampler_name): - sampler = comfy.samplers.sampler_class(sampler_name)() + sampler = comfy.samplers.sampler_object(sampler_name) return (sampler, ) class SamplerDPMPP_2M_SDE: @@ -172,7 +172,7 @@ class SamplerDPMPP_2M_SDE: sampler_name = "dpmpp_2m_sde" else: sampler_name = "dpmpp_2m_sde_gpu" - sampler = comfy.samplers.ksampler(sampler_name, {"eta": eta, "s_noise": s_noise, "solver_type": solver_type})() + sampler = comfy.samplers.ksampler(sampler_name, {"eta": eta, "s_noise": s_noise, "solver_type": solver_type}) return (sampler, ) @@ -196,7 +196,7 @@ class SamplerDPMPP_SDE: sampler_name = "dpmpp_sde" else: sampler_name = "dpmpp_sde_gpu" - sampler = comfy.samplers.ksampler(sampler_name, {"eta": eta, "s_noise": s_noise, "r": r})() + sampler = comfy.samplers.ksampler(sampler_name, {"eta": eta, "s_noise": s_noise, "r": r}) return (sampler, ) class SamplerCustom: