diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index 7908d1313..ede506463 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -471,7 +471,7 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False): if skip_reshape: b, _, _, dim_head = q.shape - tensor_layout="HND" + tensor_layout = "HND" else: b, _, dim_head = q.shape dim_head //= heads @@ -479,7 +479,7 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape= lambda t: t.view(b, -1, heads, dim_head), (q, k, v), ) - tensor_layout="NHD" + tensor_layout = "NHD" if mask is not None: # add a batch dimension if there isn't already one @@ -489,7 +489,17 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape= if mask.ndim == 3: mask = mask.unsqueeze(1) - out = sageattn(q, k, v, attn_mask=mask, is_causal=False, tensor_layout=tensor_layout) + try: + out = sageattn(q, k, v, attn_mask=mask, is_causal=False, tensor_layout=tensor_layout) + except Exception as e: + logging.error("Error running sage attention: {}, using pytorch attention instead.".format(e)) + if tensor_layout == "NHD": + q, k, v = map( + lambda t: t.transpose(1, 2), + (q, k, v), + ) + return attention_pytorch(q, k, v, heads, mask=mask, skip_reshape=True, skip_output_reshape=skip_output_reshape) + if tensor_layout == "HND": if not skip_output_reshape: out = ( diff --git a/comfy_extras/nodes_model_merging_model_specific.py b/comfy_extras/nodes_model_merging_model_specific.py index 3e37f70d4..dc3411947 100644 --- a/comfy_extras/nodes_model_merging_model_specific.py +++ b/comfy_extras/nodes_model_merging_model_specific.py @@ -244,6 +244,30 @@ class ModelMergeCosmos14B(comfy_extras.nodes_model_merging.ModelMergeBlocks): return {"required": arg_dict} +class ModelMergeWAN2_1(comfy_extras.nodes_model_merging.ModelMergeBlocks): + CATEGORY = "advanced/model_merging/model_specific" + DESCRIPTION = "1.3B model has 30 blocks, 14B model has 40 blocks. Image to video model has the extra img_emb." + + @classmethod + def INPUT_TYPES(s): + arg_dict = { "model1": ("MODEL",), + "model2": ("MODEL",)} + + argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}) + + arg_dict["patch_embedding."] = argument + arg_dict["time_embedding."] = argument + arg_dict["time_projection."] = argument + arg_dict["text_embedding."] = argument + arg_dict["img_emb."] = argument + + for i in range(40): + arg_dict["blocks.{}.".format(i)] = argument + + arg_dict["head."] = argument + + return {"required": arg_dict} + NODE_CLASS_MAPPINGS = { "ModelMergeSD1": ModelMergeSD1, "ModelMergeSD2": ModelMergeSD1, #SD1 and SD2 have the same blocks @@ -256,4 +280,5 @@ NODE_CLASS_MAPPINGS = { "ModelMergeLTXV": ModelMergeLTXV, "ModelMergeCosmos7B": ModelMergeCosmos7B, "ModelMergeCosmos14B": ModelMergeCosmos14B, + "ModelMergeWAN2_1": ModelMergeWAN2_1, }