From 844dbf97a71b398301e1a6318c6776bc5b1f5b7e Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 7 Nov 2023 03:28:53 -0500 Subject: [PATCH] Add: advanced->model->ModelSamplingDiscrete node. This allows changing the sampling parameters of the model (eps or vpred) or set the model to use zsnr. --- comfy/model_patcher.py | 17 +++++++++ comfy/model_sampling.py | 2 + comfy_extras/nodes_model_advanced.py | 57 ++++++++++++++++++++++++++++ nodes.py | 1 + 4 files changed, 77 insertions(+) create mode 100644 comfy_extras/nodes_model_advanced.py diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 0efdf46e..0f538559 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -11,6 +11,8 @@ class ModelPatcher: self.model = model self.patches = {} self.backup = {} + self.object_patches = {} + self.object_patches_backup = {} self.model_options = {"transformer_options":{}} self.model_size() self.load_device = load_device @@ -91,6 +93,9 @@ class ModelPatcher: def set_model_output_block_patch(self, patch): self.set_model_patch(patch, "output_block_patch") + def add_object_patch(self, name, obj): + self.object_patches[name] = obj + def model_patches_to(self, device): to = self.model_options["transformer_options"] if "patches" in to: @@ -150,6 +155,12 @@ class ModelPatcher: return sd def patch_model(self, device_to=None): + for k in self.object_patches: + old = getattr(self.model, k) + if k not in self.object_patches_backup: + self.object_patches_backup[k] = old + setattr(self.model, k, self.object_patches[k]) + model_sd = self.model_state_dict() for key in self.patches: if key not in model_sd: @@ -290,3 +301,9 @@ class ModelPatcher: if device_to is not None: self.model.to(device_to) self.current_device = device_to + + keys = list(self.object_patches_backup.keys()) + for k in keys: + setattr(self.model, k, self.object_patches_backup[k]) + + self.object_patches_backup = {} diff --git a/comfy/model_sampling.py b/comfy/model_sampling.py index 5e229323..a2935d47 100644 --- a/comfy/model_sampling.py +++ b/comfy/model_sampling.py @@ -48,7 +48,9 @@ class ModelSamplingDiscrete(torch.nn.Module): # self.register_buffer('alphas_cumprod_prev', torch.tensor(alphas_cumprod_prev, dtype=torch.float32)) sigmas = ((1 - alphas_cumprod) / alphas_cumprod) ** 0.5 + self.set_sigmas(sigmas) + def set_sigmas(self, sigmas): self.register_buffer('sigmas', sigmas) self.register_buffer('log_sigmas', sigmas.log()) diff --git a/comfy_extras/nodes_model_advanced.py b/comfy_extras/nodes_model_advanced.py new file mode 100644 index 00000000..c02cfb05 --- /dev/null +++ b/comfy_extras/nodes_model_advanced.py @@ -0,0 +1,57 @@ +import folder_paths +import comfy.sd +import comfy.model_sampling + + +def rescale_zero_terminal_snr_sigmas(sigmas): + alphas_cumprod = 1 / ((sigmas * sigmas) + 1) + alphas_bar_sqrt = alphas_cumprod.sqrt() + + # Store old values. + alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() + alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() + + # Shift so the last timestep is zero. + alphas_bar_sqrt -= (alphas_bar_sqrt_T) + + # Scale so the first timestep is back to the old value. + alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) + + # Convert alphas_bar_sqrt to betas + alphas_bar = alphas_bar_sqrt**2 # Revert sqrt + alphas_bar[-1] = 4.8973451890853435e-08 + return ((1 - alphas_bar) / alphas_bar) ** 0.5 + +class ModelSamplingDiscrete: + @classmethod + def INPUT_TYPES(s): + return {"required": { "model": ("MODEL",), + "sampling": (["eps", "v_prediction"],), + "zsnr": ("BOOLEAN", {"default": False}), + }} + + RETURN_TYPES = ("MODEL",) + FUNCTION = "patch" + + CATEGORY = "advanced/model" + + def patch(self, model, sampling, zsnr): + m = model.clone() + + if sampling == "eps": + sampling_type = comfy.model_sampling.EPS + elif sampling == "v_prediction": + sampling_type = comfy.model_sampling.V_PREDICTION + + class ModelSamplingAdvanced(comfy.model_sampling.ModelSamplingDiscrete, sampling_type): + pass + + model_sampling = ModelSamplingAdvanced() + if zsnr: + model_sampling.set_sigmas(rescale_zero_terminal_snr_sigmas(model_sampling.sigmas)) + m.add_object_patch("model_sampling", model_sampling) + return (m, ) + +NODE_CLASS_MAPPINGS = { + "ModelSamplingDiscrete": ModelSamplingDiscrete, +} diff --git a/nodes.py b/nodes.py index 61ebbb8b..5ed01544 100644 --- a/nodes.py +++ b/nodes.py @@ -1798,6 +1798,7 @@ def init_custom_nodes(): "nodes_freelunch.py", "nodes_custom_sampler.py", "nodes_hypertile.py", + "nodes_model_advanced.py", ] for node_file in extras_files: