mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-04-19 10:53:29 +00:00
Support loading diffusers SD3 model format with UNETLoader node.
This commit is contained in:
parent
b08a9dd04b
commit
0d6a57938e
@ -1,7 +1,9 @@
|
|||||||
import comfy.supported_models
|
import comfy.supported_models
|
||||||
import comfy.supported_models_base
|
import comfy.supported_models_base
|
||||||
|
import comfy.utils
|
||||||
import math
|
import math
|
||||||
import logging
|
import logging
|
||||||
|
import torch
|
||||||
|
|
||||||
def count_blocks(state_dict_keys, prefix_string):
|
def count_blocks(state_dict_keys, prefix_string):
|
||||||
count = 0
|
count = 0
|
||||||
@ -431,3 +433,38 @@ def model_config_from_diffusers_unet(state_dict):
|
|||||||
if unet_config is not None:
|
if unet_config is not None:
|
||||||
return model_config_from_unet_config(unet_config)
|
return model_config_from_unet_config(unet_config)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def convert_diffusers_mmdit(state_dict, output_prefix=""):
|
||||||
|
depth = count_blocks(state_dict, 'transformer_blocks.{}.')
|
||||||
|
if depth > 0:
|
||||||
|
out_sd = {}
|
||||||
|
sd_map = comfy.utils.mmdit_to_diffusers({"depth": depth}, output_prefix=output_prefix)
|
||||||
|
for k in sd_map:
|
||||||
|
weight = state_dict.get(k, None)
|
||||||
|
if weight is not None:
|
||||||
|
t = sd_map[k]
|
||||||
|
|
||||||
|
if not isinstance(t, str):
|
||||||
|
if len(t) > 2:
|
||||||
|
fun = t[2]
|
||||||
|
else:
|
||||||
|
fun = lambda a: a
|
||||||
|
offset = t[1]
|
||||||
|
if offset is not None:
|
||||||
|
old_weight = out_sd.get(t[0], None)
|
||||||
|
if old_weight is None:
|
||||||
|
old_weight = torch.empty_like(weight)
|
||||||
|
old_weight = old_weight.repeat([3] + [1] * (len(old_weight.shape) - 1))
|
||||||
|
|
||||||
|
w = old_weight.narrow(offset[0], offset[1], offset[2])
|
||||||
|
else:
|
||||||
|
old_weight = weight
|
||||||
|
w = weight
|
||||||
|
w[:] = fun(weight)
|
||||||
|
t = t[0]
|
||||||
|
out_sd[t] = old_weight
|
||||||
|
else:
|
||||||
|
out_sd[t] = weight
|
||||||
|
state_dict.pop(k)
|
||||||
|
|
||||||
|
return out_sd
|
||||||
|
@ -568,7 +568,14 @@ def load_unet_state_dict(sd): #load unet in diffusers format
|
|||||||
unet_dtype = model_management.unet_dtype(model_params=parameters)
|
unet_dtype = model_management.unet_dtype(model_params=parameters)
|
||||||
load_device = model_management.get_torch_device()
|
load_device = model_management.get_torch_device()
|
||||||
|
|
||||||
if "input_blocks.0.0.weight" in sd or 'clf.1.weight' in sd: #ldm or stable cascade
|
if 'transformer_blocks.0.attn.add_q_proj.weight' in sd: #MMDIT SD3
|
||||||
|
new_sd = model_detection.convert_diffusers_mmdit(sd, "")
|
||||||
|
if new_sd is None:
|
||||||
|
return None
|
||||||
|
model_config = model_detection.model_config_from_unet(new_sd, "")
|
||||||
|
if model_config is None:
|
||||||
|
return None
|
||||||
|
elif "input_blocks.0.0.weight" in sd or 'clf.1.weight' in sd: #ldm or stable cascade
|
||||||
model_config = model_detection.model_config_from_unet(sd, "")
|
model_config = model_detection.model_config_from_unet(sd, "")
|
||||||
if model_config is None:
|
if model_config is None:
|
||||||
return None
|
return None
|
||||||
|
@ -249,6 +249,11 @@ def unet_to_diffusers(unet_config):
|
|||||||
|
|
||||||
return diffusers_unet_map
|
return diffusers_unet_map
|
||||||
|
|
||||||
|
def swap_scale_shift(weight):
|
||||||
|
shift, scale = weight.chunk(2, dim=0)
|
||||||
|
new_weight = torch.cat([scale, shift], dim=0)
|
||||||
|
return new_weight
|
||||||
|
|
||||||
MMDIT_MAP_BASIC = {
|
MMDIT_MAP_BASIC = {
|
||||||
("context_embedder.bias", "context_embedder.bias"),
|
("context_embedder.bias", "context_embedder.bias"),
|
||||||
("context_embedder.weight", "context_embedder.weight"),
|
("context_embedder.weight", "context_embedder.weight"),
|
||||||
@ -263,8 +268,8 @@ MMDIT_MAP_BASIC = {
|
|||||||
("y_embedder.mlp.2.bias", "time_text_embed.text_embedder.linear_2.bias"),
|
("y_embedder.mlp.2.bias", "time_text_embed.text_embedder.linear_2.bias"),
|
||||||
("y_embedder.mlp.2.weight", "time_text_embed.text_embedder.linear_2.weight"),
|
("y_embedder.mlp.2.weight", "time_text_embed.text_embedder.linear_2.weight"),
|
||||||
("pos_embed", "pos_embed.pos_embed"),
|
("pos_embed", "pos_embed.pos_embed"),
|
||||||
("final_layer.adaLN_modulation.1.bias", "norm_out.linear.bias"),
|
("final_layer.adaLN_modulation.1.bias", "norm_out.linear.bias", swap_scale_shift),
|
||||||
("final_layer.adaLN_modulation.1.weight", "norm_out.linear.weight"),
|
("final_layer.adaLN_modulation.1.weight", "norm_out.linear.weight", swap_scale_shift),
|
||||||
("final_layer.linear.bias", "proj_out.bias"),
|
("final_layer.linear.bias", "proj_out.bias"),
|
||||||
("final_layer.linear.weight", "proj_out.weight"),
|
("final_layer.linear.weight", "proj_out.weight"),
|
||||||
}
|
}
|
||||||
@ -313,7 +318,14 @@ def mmdit_to_diffusers(mmdit_config, output_prefix=""):
|
|||||||
for k in MMDIT_MAP_BLOCK:
|
for k in MMDIT_MAP_BLOCK:
|
||||||
key_map["{}.{}".format(block_from, k[1])] = "{}.{}".format(block_to, k[0])
|
key_map["{}.{}".format(block_from, k[1])] = "{}.{}".format(block_to, k[0])
|
||||||
|
|
||||||
for k in MMDIT_MAP_BASIC:
|
map_basic = MMDIT_MAP_BASIC.copy()
|
||||||
|
map_basic.add(("joint_blocks.{}.context_block.adaLN_modulation.1.bias".format(depth - 1), "transformer_blocks.{}.norm1_context.linear.bias".format(depth - 1), swap_scale_shift))
|
||||||
|
map_basic.add(("joint_blocks.{}.context_block.adaLN_modulation.1.weight".format(depth - 1), "transformer_blocks.{}.norm1_context.linear.weight".format(depth - 1), swap_scale_shift))
|
||||||
|
|
||||||
|
for k in map_basic:
|
||||||
|
if len(k) > 2:
|
||||||
|
key_map[k[1]] = ("{}{}".format(output_prefix, k[0]), None, k[2])
|
||||||
|
else:
|
||||||
key_map[k[1]] = "{}{}".format(output_prefix, k[0])
|
key_map[k[1]] = "{}{}".format(output_prefix, k[0])
|
||||||
|
|
||||||
return key_map
|
return key_map
|
||||||
|
@ -52,9 +52,32 @@ class ModelMergeSDXL(comfy_extras.nodes_model_merging.ModelMergeBlocks):
|
|||||||
|
|
||||||
return {"required": arg_dict}
|
return {"required": arg_dict}
|
||||||
|
|
||||||
|
class ModelMergeSD3(comfy_extras.nodes_model_merging.ModelMergeBlocks):
|
||||||
|
CATEGORY = "advanced/model_merging/model_specific"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
arg_dict = { "model1": ("MODEL",),
|
||||||
|
"model2": ("MODEL",)}
|
||||||
|
|
||||||
|
argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
|
||||||
|
|
||||||
|
arg_dict["pos_embed."] = argument
|
||||||
|
arg_dict["x_embedder."] = argument
|
||||||
|
arg_dict["context_embedder."] = argument
|
||||||
|
arg_dict["y_embedder."] = argument
|
||||||
|
arg_dict["t_embedder."] = argument
|
||||||
|
|
||||||
|
for i in range(38):
|
||||||
|
arg_dict["joint_blocks.{}.".format(i)] = argument
|
||||||
|
|
||||||
|
arg_dict["final_layer."] = argument
|
||||||
|
|
||||||
|
return {"required": arg_dict}
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
NODE_CLASS_MAPPINGS = {
|
||||||
"ModelMergeSD1": ModelMergeSD1,
|
"ModelMergeSD1": ModelMergeSD1,
|
||||||
"ModelMergeSD2": ModelMergeSD1, #SD1 and SD2 have the same blocks
|
"ModelMergeSD2": ModelMergeSD1, #SD1 and SD2 have the same blocks
|
||||||
"ModelMergeSDXL": ModelMergeSDXL,
|
"ModelMergeSDXL": ModelMergeSDXL,
|
||||||
|
"ModelMergeSD3": ModelMergeSD3,
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user