From 1a0486bb96fb1ff10f4ea3c0d62eb815e9630585 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 4 Apr 2024 22:08:49 -0400 Subject: [PATCH] Fix model needing to be loaded on GPU to generate the sigmas. --- comfy/model_patcher.py | 6 ++++++ comfy/samplers.py | 28 ++++++++++++++-------------- comfy_extras/nodes_custom_sampler.py | 3 +-- 3 files changed, 21 insertions(+), 16 deletions(-) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 8dda84cf..97fabd4f 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -150,6 +150,12 @@ class ModelPatcher: def add_object_patch(self, name, obj): self.object_patches[name] = obj + def get_model_object(self, name): + if name in self.object_patches: + return self.object_patches[name] + else: + return comfy.utils.get_attr(self.model, name) + def model_patches_to(self, device): to = self.model_options["transformer_options"] if "patches" in to: diff --git a/comfy/samplers.py b/comfy/samplers.py index 475b1aad..08cbab75 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -274,8 +274,8 @@ class KSamplerX0Inpaint: out = out * denoise_mask + self.latent_image * latent_mask return out -def simple_scheduler(model, steps): - s = model.model_sampling +def simple_scheduler(model_sampling, steps): + s = model_sampling sigs = [] ss = len(s.sigmas) / steps for x in range(steps): @@ -283,8 +283,8 @@ def simple_scheduler(model, steps): sigs += [0.0] return torch.FloatTensor(sigs) -def ddim_scheduler(model, steps): - s = model.model_sampling +def ddim_scheduler(model_sampling, steps): + s = model_sampling sigs = [] ss = max(len(s.sigmas) // steps, 1) x = 1 @@ -295,8 +295,8 @@ def ddim_scheduler(model, steps): sigs += [0.0] return torch.FloatTensor(sigs) -def normal_scheduler(model, steps, sgm=False, floor=False): - s = model.model_sampling +def normal_scheduler(model_sampling, steps, sgm=False, floor=False): + s = model_sampling start = s.timestep(s.sigma_max) end = s.timestep(s.sigma_min) @@ -660,19 +660,19 @@ def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model SCHEDULER_NAMES = ["normal", "karras", "exponential", "sgm_uniform", "simple", "ddim_uniform"] SAMPLER_NAMES = KSAMPLER_NAMES + ["ddim", "uni_pc", "uni_pc_bh2"] -def calculate_sigmas_scheduler(model, scheduler_name, steps): +def calculate_sigmas(model_sampling, scheduler_name, steps): if scheduler_name == "karras": - sigmas = k_diffusion_sampling.get_sigmas_karras(n=steps, sigma_min=float(model.model_sampling.sigma_min), sigma_max=float(model.model_sampling.sigma_max)) + sigmas = k_diffusion_sampling.get_sigmas_karras(n=steps, sigma_min=float(model_sampling.sigma_min), sigma_max=float(model_sampling.sigma_max)) elif scheduler_name == "exponential": - sigmas = k_diffusion_sampling.get_sigmas_exponential(n=steps, sigma_min=float(model.model_sampling.sigma_min), sigma_max=float(model.model_sampling.sigma_max)) + sigmas = k_diffusion_sampling.get_sigmas_exponential(n=steps, sigma_min=float(model_sampling.sigma_min), sigma_max=float(model_sampling.sigma_max)) elif scheduler_name == "normal": - sigmas = normal_scheduler(model, steps) + sigmas = normal_scheduler(model_sampling, steps) elif scheduler_name == "simple": - sigmas = simple_scheduler(model, steps) + sigmas = simple_scheduler(model_sampling, steps) elif scheduler_name == "ddim_uniform": - sigmas = ddim_scheduler(model, steps) + sigmas = ddim_scheduler(model_sampling, steps) elif scheduler_name == "sgm_uniform": - sigmas = normal_scheduler(model, steps, sgm=True) + sigmas = normal_scheduler(model_sampling, steps, sgm=True) else: logging.error("error invalid scheduler {}".format(scheduler_name)) return sigmas @@ -714,7 +714,7 @@ class KSampler: steps += 1 discard_penultimate_sigma = True - sigmas = calculate_sigmas_scheduler(self.model.model, self.scheduler, steps) + sigmas = calculate_sigmas(self.model.get_model_object("model_sampling"), self.scheduler, steps) if discard_penultimate_sigma: sigmas = torch.cat([sigmas[:-2], sigmas[-1:]]) diff --git a/comfy_extras/nodes_custom_sampler.py b/comfy_extras/nodes_custom_sampler.py index fbd8cd25..fa113192 100644 --- a/comfy_extras/nodes_custom_sampler.py +++ b/comfy_extras/nodes_custom_sampler.py @@ -28,8 +28,7 @@ class BasicScheduler: return (torch.FloatTensor([]),) total_steps = int(steps/denoise) - comfy.model_management.load_models_gpu([model]) - sigmas = comfy.samplers.calculate_sigmas_scheduler(model.model, scheduler, total_steps).cpu() + sigmas = comfy.samplers.calculate_sigmas(model.get_model_object("model_sampling"), scheduler, total_steps).cpu() sigmas = sigmas[-(steps + 1):] return (sigmas, )