mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-25 15:55:18 +00:00
Add some more transformer hooks and move tomesd to comfy_extras.
Tomesd now uses q instead of x to decide which tokens to merge because it seems to give better results.
This commit is contained in:
parent
fa28d7334b
commit
05676942b7
@ -12,8 +12,6 @@ from .sub_quadratic_attention import efficient_dot_product_attention
|
|||||||
from comfy import model_management
|
from comfy import model_management
|
||||||
import comfy.ops
|
import comfy.ops
|
||||||
|
|
||||||
from . import tomesd
|
|
||||||
|
|
||||||
if model_management.xformers_enabled():
|
if model_management.xformers_enabled():
|
||||||
import xformers
|
import xformers
|
||||||
import xformers.ops
|
import xformers.ops
|
||||||
@ -519,23 +517,39 @@ class BasicTransformerBlock(nn.Module):
|
|||||||
self.norm2 = nn.LayerNorm(dim, dtype=dtype)
|
self.norm2 = nn.LayerNorm(dim, dtype=dtype)
|
||||||
self.norm3 = nn.LayerNorm(dim, dtype=dtype)
|
self.norm3 = nn.LayerNorm(dim, dtype=dtype)
|
||||||
self.checkpoint = checkpoint
|
self.checkpoint = checkpoint
|
||||||
|
self.n_heads = n_heads
|
||||||
|
self.d_head = d_head
|
||||||
|
|
||||||
def forward(self, x, context=None, transformer_options={}):
|
def forward(self, x, context=None, transformer_options={}):
|
||||||
return checkpoint(self._forward, (x, context, transformer_options), self.parameters(), self.checkpoint)
|
return checkpoint(self._forward, (x, context, transformer_options), self.parameters(), self.checkpoint)
|
||||||
|
|
||||||
def _forward(self, x, context=None, transformer_options={}):
|
def _forward(self, x, context=None, transformer_options={}):
|
||||||
extra_options = {}
|
extra_options = {}
|
||||||
|
block = None
|
||||||
|
block_index = 0
|
||||||
if "current_index" in transformer_options:
|
if "current_index" in transformer_options:
|
||||||
extra_options["transformer_index"] = transformer_options["current_index"]
|
extra_options["transformer_index"] = transformer_options["current_index"]
|
||||||
if "block_index" in transformer_options:
|
if "block_index" in transformer_options:
|
||||||
extra_options["block_index"] = transformer_options["block_index"]
|
block_index = transformer_options["block_index"]
|
||||||
|
extra_options["block_index"] = block_index
|
||||||
if "original_shape" in transformer_options:
|
if "original_shape" in transformer_options:
|
||||||
extra_options["original_shape"] = transformer_options["original_shape"]
|
extra_options["original_shape"] = transformer_options["original_shape"]
|
||||||
|
if "block" in transformer_options:
|
||||||
|
block = transformer_options["block"]
|
||||||
|
extra_options["block"] = block
|
||||||
if "patches" in transformer_options:
|
if "patches" in transformer_options:
|
||||||
transformer_patches = transformer_options["patches"]
|
transformer_patches = transformer_options["patches"]
|
||||||
else:
|
else:
|
||||||
transformer_patches = {}
|
transformer_patches = {}
|
||||||
|
|
||||||
|
extra_options["n_heads"] = self.n_heads
|
||||||
|
extra_options["dim_head"] = self.d_head
|
||||||
|
|
||||||
|
if "patches_replace" in transformer_options:
|
||||||
|
transformer_patches_replace = transformer_options["patches_replace"]
|
||||||
|
else:
|
||||||
|
transformer_patches_replace = {}
|
||||||
|
|
||||||
n = self.norm1(x)
|
n = self.norm1(x)
|
||||||
if self.disable_self_attn:
|
if self.disable_self_attn:
|
||||||
context_attn1 = context
|
context_attn1 = context
|
||||||
@ -551,12 +565,29 @@ class BasicTransformerBlock(nn.Module):
|
|||||||
for p in patch:
|
for p in patch:
|
||||||
n, context_attn1, value_attn1 = p(n, context_attn1, value_attn1, extra_options)
|
n, context_attn1, value_attn1 = p(n, context_attn1, value_attn1, extra_options)
|
||||||
|
|
||||||
if "tomesd" in transformer_options:
|
transformer_block = (block[0], block[1], block_index)
|
||||||
m, u = tomesd.get_functions(x, transformer_options["tomesd"]["ratio"], transformer_options["original_shape"])
|
attn1_replace_patch = transformer_patches_replace.get("attn1", {})
|
||||||
n = u(self.attn1(m(n), context=context_attn1, value=value_attn1))
|
block_attn1 = transformer_block
|
||||||
|
if block_attn1 not in attn1_replace_patch:
|
||||||
|
block_attn1 = block
|
||||||
|
|
||||||
|
if block_attn1 in attn1_replace_patch:
|
||||||
|
if context_attn1 is None:
|
||||||
|
context_attn1 = n
|
||||||
|
value_attn1 = n
|
||||||
|
n = self.attn1.to_q(n)
|
||||||
|
context_attn1 = self.attn1.to_k(context_attn1)
|
||||||
|
value_attn1 = self.attn1.to_v(value_attn1)
|
||||||
|
n = attn1_replace_patch[block_attn1](n, context_attn1, value_attn1, extra_options)
|
||||||
|
n = self.attn1.to_out(n)
|
||||||
else:
|
else:
|
||||||
n = self.attn1(n, context=context_attn1, value=value_attn1)
|
n = self.attn1(n, context=context_attn1, value=value_attn1)
|
||||||
|
|
||||||
|
if "attn1_output_patch" in transformer_patches:
|
||||||
|
patch = transformer_patches["attn1_output_patch"]
|
||||||
|
for p in patch:
|
||||||
|
n = p(n, extra_options)
|
||||||
|
|
||||||
x += n
|
x += n
|
||||||
if "middle_patch" in transformer_patches:
|
if "middle_patch" in transformer_patches:
|
||||||
patch = transformer_patches["middle_patch"]
|
patch = transformer_patches["middle_patch"]
|
||||||
@ -573,7 +604,21 @@ class BasicTransformerBlock(nn.Module):
|
|||||||
for p in patch:
|
for p in patch:
|
||||||
n, context_attn2, value_attn2 = p(n, context_attn2, value_attn2, extra_options)
|
n, context_attn2, value_attn2 = p(n, context_attn2, value_attn2, extra_options)
|
||||||
|
|
||||||
n = self.attn2(n, context=context_attn2, value=value_attn2)
|
attn2_replace_patch = transformer_patches_replace.get("attn2", {})
|
||||||
|
block_attn2 = transformer_block
|
||||||
|
if block_attn2 not in attn2_replace_patch:
|
||||||
|
block_attn2 = block
|
||||||
|
|
||||||
|
if block_attn2 in attn2_replace_patch:
|
||||||
|
if value_attn2 is None:
|
||||||
|
value_attn2 = context_attn2
|
||||||
|
n = self.attn2.to_q(n)
|
||||||
|
context_attn2 = self.attn2.to_k(context_attn2)
|
||||||
|
value_attn2 = self.attn2.to_v(value_attn2)
|
||||||
|
n = attn2_replace_patch[block_attn2](n, context_attn2, value_attn2, extra_options)
|
||||||
|
n = self.attn2.to_out(n)
|
||||||
|
else:
|
||||||
|
n = self.attn2(n, context=context_attn2, value=value_attn2)
|
||||||
|
|
||||||
if "attn2_output_patch" in transformer_patches:
|
if "attn2_output_patch" in transformer_patches:
|
||||||
patch = transformer_patches["attn2_output_patch"]
|
patch = transformer_patches["attn2_output_patch"]
|
||||||
|
@ -830,17 +830,20 @@ class UNetModel(nn.Module):
|
|||||||
|
|
||||||
h = x.type(self.dtype)
|
h = x.type(self.dtype)
|
||||||
for id, module in enumerate(self.input_blocks):
|
for id, module in enumerate(self.input_blocks):
|
||||||
|
transformer_options["block"] = ("input", id)
|
||||||
h = forward_timestep_embed(module, h, emb, context, transformer_options)
|
h = forward_timestep_embed(module, h, emb, context, transformer_options)
|
||||||
if control is not None and 'input' in control and len(control['input']) > 0:
|
if control is not None and 'input' in control and len(control['input']) > 0:
|
||||||
ctrl = control['input'].pop()
|
ctrl = control['input'].pop()
|
||||||
if ctrl is not None:
|
if ctrl is not None:
|
||||||
h += ctrl
|
h += ctrl
|
||||||
hs.append(h)
|
hs.append(h)
|
||||||
|
transformer_options["block"] = ("middle", 0)
|
||||||
h = forward_timestep_embed(self.middle_block, h, emb, context, transformer_options)
|
h = forward_timestep_embed(self.middle_block, h, emb, context, transformer_options)
|
||||||
if control is not None and 'middle' in control and len(control['middle']) > 0:
|
if control is not None and 'middle' in control and len(control['middle']) > 0:
|
||||||
h += control['middle'].pop()
|
h += control['middle'].pop()
|
||||||
|
|
||||||
for module in self.output_blocks:
|
for id, module in enumerate(self.output_blocks):
|
||||||
|
transformer_options["block"] = ("output", id)
|
||||||
hsp = hs.pop()
|
hsp = hs.pop()
|
||||||
if control is not None and 'output' in control and len(control['output']) > 0:
|
if control is not None and 'output' in control and len(control['output']) > 0:
|
||||||
ctrl = control['output'].pop()
|
ctrl = control['output'].pop()
|
||||||
|
27
comfy/sd.py
27
comfy/sd.py
@ -315,9 +315,6 @@ class ModelPatcher:
|
|||||||
n.model_keys = self.model_keys
|
n.model_keys = self.model_keys
|
||||||
return n
|
return n
|
||||||
|
|
||||||
def set_model_tomesd(self, ratio):
|
|
||||||
self.model_options["transformer_options"]["tomesd"] = {"ratio": ratio}
|
|
||||||
|
|
||||||
def set_model_sampler_cfg_function(self, sampler_cfg_function):
|
def set_model_sampler_cfg_function(self, sampler_cfg_function):
|
||||||
if len(inspect.signature(sampler_cfg_function).parameters) == 3:
|
if len(inspect.signature(sampler_cfg_function).parameters) == 3:
|
||||||
self.model_options["sampler_cfg_function"] = lambda args: sampler_cfg_function(args["cond"], args["uncond"], args["cond_scale"]) #Old way
|
self.model_options["sampler_cfg_function"] = lambda args: sampler_cfg_function(args["cond"], args["uncond"], args["cond_scale"]) #Old way
|
||||||
@ -330,12 +327,29 @@ class ModelPatcher:
|
|||||||
to["patches"] = {}
|
to["patches"] = {}
|
||||||
to["patches"][name] = to["patches"].get(name, []) + [patch]
|
to["patches"][name] = to["patches"].get(name, []) + [patch]
|
||||||
|
|
||||||
|
def set_model_patch_replace(self, patch, name, block_name, number):
|
||||||
|
to = self.model_options["transformer_options"]
|
||||||
|
if "patches_replace" not in to:
|
||||||
|
to["patches_replace"] = {}
|
||||||
|
if name not in to["patches_replace"]:
|
||||||
|
to["patches_replace"][name] = {}
|
||||||
|
to["patches_replace"][name][(block_name, number)] = patch
|
||||||
|
|
||||||
def set_model_attn1_patch(self, patch):
|
def set_model_attn1_patch(self, patch):
|
||||||
self.set_model_patch(patch, "attn1_patch")
|
self.set_model_patch(patch, "attn1_patch")
|
||||||
|
|
||||||
def set_model_attn2_patch(self, patch):
|
def set_model_attn2_patch(self, patch):
|
||||||
self.set_model_patch(patch, "attn2_patch")
|
self.set_model_patch(patch, "attn2_patch")
|
||||||
|
|
||||||
|
def set_model_attn1_replace(self, patch, block_name, number):
|
||||||
|
self.set_model_patch_replace(patch, "attn1", block_name, number)
|
||||||
|
|
||||||
|
def set_model_attn2_replace(self, patch, block_name, number):
|
||||||
|
self.set_model_patch_replace(patch, "attn2", block_name, number)
|
||||||
|
|
||||||
|
def set_model_attn1_output_patch(self, patch):
|
||||||
|
self.set_model_patch(patch, "attn1_output_patch")
|
||||||
|
|
||||||
def set_model_attn2_output_patch(self, patch):
|
def set_model_attn2_output_patch(self, patch):
|
||||||
self.set_model_patch(patch, "attn2_output_patch")
|
self.set_model_patch(patch, "attn2_output_patch")
|
||||||
|
|
||||||
@ -348,6 +362,13 @@ class ModelPatcher:
|
|||||||
for i in range(len(patch_list)):
|
for i in range(len(patch_list)):
|
||||||
if hasattr(patch_list[i], "to"):
|
if hasattr(patch_list[i], "to"):
|
||||||
patch_list[i] = patch_list[i].to(device)
|
patch_list[i] = patch_list[i].to(device)
|
||||||
|
if "patches_replace" in to:
|
||||||
|
patches = to["patches_replace"]
|
||||||
|
for name in patches:
|
||||||
|
patch_list = patches[name]
|
||||||
|
for k in patch_list:
|
||||||
|
if hasattr(patch_list[k], "to"):
|
||||||
|
patch_list[k] = patch_list[k].to(device)
|
||||||
|
|
||||||
def model_dtype(self):
|
def model_dtype(self):
|
||||||
return self.model.get_dtype()
|
return self.model.get_dtype()
|
||||||
|
@ -142,3 +142,36 @@ def get_functions(x, ratio, original_shape):
|
|||||||
|
|
||||||
nothing = lambda y: y
|
nothing = lambda y: y
|
||||||
return nothing, nothing
|
return nothing, nothing
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class TomePatchModel:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": { "model": ("MODEL",),
|
||||||
|
"ratio": ("FLOAT", {"default": 0.3, "min": 0.0, "max": 1.0, "step": 0.01}),
|
||||||
|
}}
|
||||||
|
RETURN_TYPES = ("MODEL",)
|
||||||
|
FUNCTION = "patch"
|
||||||
|
|
||||||
|
CATEGORY = "_for_testing"
|
||||||
|
|
||||||
|
def patch(self, model, ratio):
|
||||||
|
self.u = None
|
||||||
|
def tomesd_m(q, k, v, extra_options):
|
||||||
|
#NOTE: In the reference code get_functions takes x (input of the transformer block) as the argument instead of q
|
||||||
|
#however from my basic testing it seems that using q instead gives better results
|
||||||
|
m, self.u = get_functions(q, ratio, extra_options["original_shape"])
|
||||||
|
return m(q), k, v
|
||||||
|
def tomesd_u(n, extra_options):
|
||||||
|
return self.u(n)
|
||||||
|
|
||||||
|
m = model.clone()
|
||||||
|
m.set_model_attn1_patch(tomesd_m)
|
||||||
|
m.set_model_attn1_output_patch(tomesd_u)
|
||||||
|
return (m, )
|
||||||
|
|
||||||
|
|
||||||
|
NODE_CLASS_MAPPINGS = {
|
||||||
|
"TomePatchModel": TomePatchModel,
|
||||||
|
}
|
18
nodes.py
18
nodes.py
@ -437,22 +437,6 @@ class LoraLoader:
|
|||||||
model_lora, clip_lora = comfy.sd.load_lora_for_models(model, clip, lora_path, strength_model, strength_clip)
|
model_lora, clip_lora = comfy.sd.load_lora_for_models(model, clip, lora_path, strength_model, strength_clip)
|
||||||
return (model_lora, clip_lora)
|
return (model_lora, clip_lora)
|
||||||
|
|
||||||
class TomePatchModel:
|
|
||||||
@classmethod
|
|
||||||
def INPUT_TYPES(s):
|
|
||||||
return {"required": { "model": ("MODEL",),
|
|
||||||
"ratio": ("FLOAT", {"default": 0.3, "min": 0.0, "max": 1.0, "step": 0.01}),
|
|
||||||
}}
|
|
||||||
RETURN_TYPES = ("MODEL",)
|
|
||||||
FUNCTION = "patch"
|
|
||||||
|
|
||||||
CATEGORY = "_for_testing"
|
|
||||||
|
|
||||||
def patch(self, model, ratio):
|
|
||||||
m = model.clone()
|
|
||||||
m.set_model_tomesd(ratio)
|
|
||||||
return (m, )
|
|
||||||
|
|
||||||
class VAELoader:
|
class VAELoader:
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
@ -1341,7 +1325,6 @@ NODE_CLASS_MAPPINGS = {
|
|||||||
"CLIPVisionLoader": CLIPVisionLoader,
|
"CLIPVisionLoader": CLIPVisionLoader,
|
||||||
"VAEDecodeTiled": VAEDecodeTiled,
|
"VAEDecodeTiled": VAEDecodeTiled,
|
||||||
"VAEEncodeTiled": VAEEncodeTiled,
|
"VAEEncodeTiled": VAEEncodeTiled,
|
||||||
"TomePatchModel": TomePatchModel,
|
|
||||||
"unCLIPCheckpointLoader": unCLIPCheckpointLoader,
|
"unCLIPCheckpointLoader": unCLIPCheckpointLoader,
|
||||||
"GLIGENLoader": GLIGENLoader,
|
"GLIGENLoader": GLIGENLoader,
|
||||||
"GLIGENTextBoxApply": GLIGENTextBoxApply,
|
"GLIGENTextBoxApply": GLIGENTextBoxApply,
|
||||||
@ -1466,4 +1449,5 @@ def init_custom_nodes():
|
|||||||
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_mask.py"))
|
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_mask.py"))
|
||||||
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_rebatch.py"))
|
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_rebatch.py"))
|
||||||
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_model_merging.py"))
|
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_model_merging.py"))
|
||||||
|
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_tomesd.py"))
|
||||||
load_custom_nodes()
|
load_custom_nodes()
|
||||||
|
Loading…
Reference in New Issue
Block a user