From 7e3fe3ad28fad4dede2893d77093a086344b81b6 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 16 Nov 2023 15:26:28 -0500 Subject: [PATCH] Make deep shrink behave like it should. --- comfy/ldm/modules/diffusionmodules/openaimodel.py | 4 ++++ comfy/model_patcher.py | 3 +++ comfy_extras/nodes_model_downscale.py | 8 ++++++-- 3 files changed, 13 insertions(+), 2 deletions(-) diff --git a/comfy/ldm/modules/diffusionmodules/openaimodel.py b/comfy/ldm/modules/diffusionmodules/openaimodel.py index 504b79ed..10eb68d7 100644 --- a/comfy/ldm/modules/diffusionmodules/openaimodel.py +++ b/comfy/ldm/modules/diffusionmodules/openaimodel.py @@ -633,6 +633,10 @@ class UNetModel(nn.Module): h = p(h, transformer_options) hs.append(h) + if "input_block_patch_after_skip" in transformer_patches: + patch = transformer_patches["input_block_patch_after_skip"] + for p in patch: + h = p(h, transformer_options) transformer_options["block"] = ("middle", 0) h = forward_timestep_embed(self.middle_block, h, emb, context, transformer_options) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 7f5ed45f..a3cffc3b 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -99,6 +99,9 @@ class ModelPatcher: def set_model_input_block_patch(self, patch): self.set_model_patch(patch, "input_block_patch") + def set_model_input_block_patch_after_skip(self, patch): + self.set_model_patch(patch, "input_block_patch_after_skip") + def set_model_output_block_patch(self, patch): self.set_model_patch(patch, "output_block_patch") diff --git a/comfy_extras/nodes_model_downscale.py b/comfy_extras/nodes_model_downscale.py index f1b2d3ff..8850d094 100644 --- a/comfy_extras/nodes_model_downscale.py +++ b/comfy_extras/nodes_model_downscale.py @@ -8,13 +8,14 @@ class PatchModelAddDownscale: "downscale_factor": ("FLOAT", {"default": 2.0, "min": 0.1, "max": 9.0, "step": 0.001}), "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}), "end_percent": ("FLOAT", {"default": 0.35, "min": 0.0, "max": 1.0, "step": 0.001}), + "downscale_after_skip": ("BOOLEAN", {"default": True}), }} RETURN_TYPES = ("MODEL",) FUNCTION = "patch" CATEGORY = "_for_testing" - def patch(self, model, block_number, downscale_factor, start_percent, end_percent): + def patch(self, model, block_number, downscale_factor, start_percent, end_percent, downscale_after_skip): sigma_start = model.model.model_sampling.percent_to_sigma(start_percent).item() sigma_end = model.model.model_sampling.percent_to_sigma(end_percent).item() @@ -31,7 +32,10 @@ class PatchModelAddDownscale: return h, hsp m = model.clone() - m.set_model_input_block_patch(input_block_patch) + if downscale_after_skip: + m.set_model_input_block_patch_after_skip(input_block_patch) + else: + m.set_model_input_block_patch(input_block_patch) m.set_model_output_block_patch(output_block_patch) return (m, )