mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-11 02:15:17 +00:00
Make deep shrink behave like it should.
This commit is contained in:
parent
9f00a18095
commit
7e3fe3ad28
@ -633,6 +633,10 @@ class UNetModel(nn.Module):
|
|||||||
h = p(h, transformer_options)
|
h = p(h, transformer_options)
|
||||||
|
|
||||||
hs.append(h)
|
hs.append(h)
|
||||||
|
if "input_block_patch_after_skip" in transformer_patches:
|
||||||
|
patch = transformer_patches["input_block_patch_after_skip"]
|
||||||
|
for p in patch:
|
||||||
|
h = p(h, transformer_options)
|
||||||
|
|
||||||
transformer_options["block"] = ("middle", 0)
|
transformer_options["block"] = ("middle", 0)
|
||||||
h = forward_timestep_embed(self.middle_block, h, emb, context, transformer_options)
|
h = forward_timestep_embed(self.middle_block, h, emb, context, transformer_options)
|
||||||
|
@ -99,6 +99,9 @@ class ModelPatcher:
|
|||||||
def set_model_input_block_patch(self, patch):
|
def set_model_input_block_patch(self, patch):
|
||||||
self.set_model_patch(patch, "input_block_patch")
|
self.set_model_patch(patch, "input_block_patch")
|
||||||
|
|
||||||
|
def set_model_input_block_patch_after_skip(self, patch):
|
||||||
|
self.set_model_patch(patch, "input_block_patch_after_skip")
|
||||||
|
|
||||||
def set_model_output_block_patch(self, patch):
|
def set_model_output_block_patch(self, patch):
|
||||||
self.set_model_patch(patch, "output_block_patch")
|
self.set_model_patch(patch, "output_block_patch")
|
||||||
|
|
||||||
|
@ -8,13 +8,14 @@ class PatchModelAddDownscale:
|
|||||||
"downscale_factor": ("FLOAT", {"default": 2.0, "min": 0.1, "max": 9.0, "step": 0.001}),
|
"downscale_factor": ("FLOAT", {"default": 2.0, "min": 0.1, "max": 9.0, "step": 0.001}),
|
||||||
"start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
|
"start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
|
||||||
"end_percent": ("FLOAT", {"default": 0.35, "min": 0.0, "max": 1.0, "step": 0.001}),
|
"end_percent": ("FLOAT", {"default": 0.35, "min": 0.0, "max": 1.0, "step": 0.001}),
|
||||||
|
"downscale_after_skip": ("BOOLEAN", {"default": True}),
|
||||||
}}
|
}}
|
||||||
RETURN_TYPES = ("MODEL",)
|
RETURN_TYPES = ("MODEL",)
|
||||||
FUNCTION = "patch"
|
FUNCTION = "patch"
|
||||||
|
|
||||||
CATEGORY = "_for_testing"
|
CATEGORY = "_for_testing"
|
||||||
|
|
||||||
def patch(self, model, block_number, downscale_factor, start_percent, end_percent):
|
def patch(self, model, block_number, downscale_factor, start_percent, end_percent, downscale_after_skip):
|
||||||
sigma_start = model.model.model_sampling.percent_to_sigma(start_percent).item()
|
sigma_start = model.model.model_sampling.percent_to_sigma(start_percent).item()
|
||||||
sigma_end = model.model.model_sampling.percent_to_sigma(end_percent).item()
|
sigma_end = model.model.model_sampling.percent_to_sigma(end_percent).item()
|
||||||
|
|
||||||
@ -31,7 +32,10 @@ class PatchModelAddDownscale:
|
|||||||
return h, hsp
|
return h, hsp
|
||||||
|
|
||||||
m = model.clone()
|
m = model.clone()
|
||||||
m.set_model_input_block_patch(input_block_patch)
|
if downscale_after_skip:
|
||||||
|
m.set_model_input_block_patch_after_skip(input_block_patch)
|
||||||
|
else:
|
||||||
|
m.set_model_input_block_patch(input_block_patch)
|
||||||
m.set_model_output_block_patch(output_block_patch)
|
m.set_model_output_block_patch(output_block_patch)
|
||||||
return (m, )
|
return (m, )
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user