Add a way to set output block patches to modify the h and hsp.

This commit is contained in:
comfyanonymous 2023-09-22 20:26:47 -04:00
parent 29ccf9f471
commit afa2399f79
2 changed files with 9 additions and 0 deletions

View File

@ -608,6 +608,7 @@ class UNetModel(nn.Module):
"""
transformer_options["original_shape"] = list(x.shape)
transformer_options["current_index"] = 0
transformer_patches = transformer_options.get("patches", {})
assert (y is not None) == (
self.num_classes is not None
@ -644,6 +645,11 @@ class UNetModel(nn.Module):
if ctrl is not None:
hsp += ctrl
if "output_block_patch" in transformer_patches:
patch = transformer_patches["output_block_patch"]
for p in patch:
h, hsp = p(h, hsp, transformer_options)
h = th.cat([h, hsp], dim=1)
del hsp
if len(hs) > 0:

View File

@ -88,6 +88,9 @@ class ModelPatcher:
def set_model_attn2_output_patch(self, patch):
self.set_model_patch(patch, "attn2_output_patch")
def set_model_output_block_patch(self, patch):
self.set_model_patch(patch, "output_block_patch")
def model_patches_to(self, device):
to = self.model_options["transformer_options"]
if "patches" in to: