From ef4f6037cbbbd4150c44862eb398428b70f19263 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 3 Jan 2024 12:16:30 -0500 Subject: [PATCH] Fix model patches not working in custom sampling scheduler nodes. --- comfy/model_patcher.py | 47 ++++++++++++++-------------- comfy_extras/nodes_custom_sampler.py | 8 +++-- 2 files changed, 30 insertions(+), 25 deletions(-) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index b1b5ea6a..a88b737c 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -174,40 +174,41 @@ class ModelPatcher: sd.pop(k) return sd - def patch_model(self, device_to=None): + def patch_model(self, device_to=None, patch_weights=True): for k in self.object_patches: old = getattr(self.model, k) if k not in self.object_patches_backup: self.object_patches_backup[k] = old setattr(self.model, k, self.object_patches[k]) - model_sd = self.model_state_dict() - for key in self.patches: - if key not in model_sd: - print("could not patch. key doesn't exist in model:", key) - continue + if patch_weights: + model_sd = self.model_state_dict() + for key in self.patches: + if key not in model_sd: + print("could not patch. key doesn't exist in model:", key) + continue - weight = model_sd[key] + weight = model_sd[key] - inplace_update = self.weight_inplace_update + inplace_update = self.weight_inplace_update - if key not in self.backup: - self.backup[key] = weight.to(device=self.offload_device, copy=inplace_update) + if key not in self.backup: + self.backup[key] = weight.to(device=self.offload_device, copy=inplace_update) + + if device_to is not None: + temp_weight = comfy.model_management.cast_to_device(weight, device_to, torch.float32, copy=True) + else: + temp_weight = weight.to(torch.float32, copy=True) + out_weight = self.calculate_weight(self.patches[key], temp_weight, key).to(weight.dtype) + if inplace_update: + comfy.utils.copy_to_param(self.model, key, out_weight) + else: + comfy.utils.set_attr(self.model, key, out_weight) + del temp_weight if device_to is not None: - temp_weight = comfy.model_management.cast_to_device(weight, device_to, torch.float32, copy=True) - else: - temp_weight = weight.to(torch.float32, copy=True) - out_weight = self.calculate_weight(self.patches[key], temp_weight, key).to(weight.dtype) - if inplace_update: - comfy.utils.copy_to_param(self.model, key, out_weight) - else: - comfy.utils.set_attr(self.model, key, out_weight) - del temp_weight - - if device_to is not None: - self.model.to(device_to) - self.current_device = device_to + self.model.to(device_to) + self.current_device = device_to return self.model diff --git a/comfy_extras/nodes_custom_sampler.py b/comfy_extras/nodes_custom_sampler.py index d5f9ba00..bb0ed57b 100644 --- a/comfy_extras/nodes_custom_sampler.py +++ b/comfy_extras/nodes_custom_sampler.py @@ -26,7 +26,9 @@ class BasicScheduler: if denoise < 1.0: total_steps = int(steps/denoise) - sigmas = comfy.samplers.calculate_sigmas_scheduler(model.model, scheduler, total_steps).cpu() + inner_model = model.patch_model(patch_weights=False) + sigmas = comfy.samplers.calculate_sigmas_scheduler(inner_model, scheduler, total_steps).cpu() + model.unpatch_model() sigmas = sigmas[-(steps + 1):] return (sigmas, ) @@ -104,7 +106,9 @@ class SDTurboScheduler: def get_sigmas(self, model, steps, denoise): start_step = 10 - int(10 * denoise) timesteps = torch.flip(torch.arange(1, 11) * 100 - 1, (0,))[start_step:start_step + steps] - sigmas = model.model.model_sampling.sigma(timesteps) + inner_model = model.patch_model(patch_weights=False) + sigmas = inner_model.model_sampling.sigma(timesteps) + model.unpatch_model() sigmas = torch.cat([sigmas, sigmas.new_zeros([1])]) return (sigmas, )