Support SD3.5 medium diffusers format weights and loras.

This commit is contained in:
comfyanonymous 2024-10-30 04:24:00 -04:00
parent 65a8659182
commit 09fdb2b269
2 changed files with 19 additions and 5 deletions

View File

@ -540,7 +540,11 @@ def model_config_from_diffusers_unet(state_dict):
def convert_diffusers_mmdit(state_dict, output_prefix=""): def convert_diffusers_mmdit(state_dict, output_prefix=""):
out_sd = {} out_sd = {}
if 'transformer_blocks.0.attn.norm_added_k.weight' in state_dict: #Flux if 'joint_transformer_blocks.0.attn.add_k_proj.weight' in state_dict: #AuraFlow
num_joint = count_blocks(state_dict, 'joint_transformer_blocks.{}.')
num_single = count_blocks(state_dict, 'single_transformer_blocks.{}.')
sd_map = comfy.utils.auraflow_to_diffusers({"n_double_layers": num_joint, "n_layers": num_joint + num_single}, output_prefix=output_prefix)
elif 'single_transformer_blocks.0.attn.norm_q.weight' in state_dict: #Flux
depth = count_blocks(state_dict, 'transformer_blocks.{}.') depth = count_blocks(state_dict, 'transformer_blocks.{}.')
depth_single_blocks = count_blocks(state_dict, 'single_transformer_blocks.{}.') depth_single_blocks = count_blocks(state_dict, 'single_transformer_blocks.{}.')
hidden_size = state_dict["x_embedder.bias"].shape[0] hidden_size = state_dict["x_embedder.bias"].shape[0]
@ -549,10 +553,6 @@ def convert_diffusers_mmdit(state_dict, output_prefix=""):
num_blocks = count_blocks(state_dict, 'transformer_blocks.{}.') num_blocks = count_blocks(state_dict, 'transformer_blocks.{}.')
depth = state_dict["pos_embed.proj.weight"].shape[0] // 64 depth = state_dict["pos_embed.proj.weight"].shape[0] // 64
sd_map = comfy.utils.mmdit_to_diffusers({"depth": depth, "num_blocks": num_blocks}, output_prefix=output_prefix) sd_map = comfy.utils.mmdit_to_diffusers({"depth": depth, "num_blocks": num_blocks}, output_prefix=output_prefix)
elif 'joint_transformer_blocks.0.attn.add_k_proj.weight' in state_dict: #AuraFlow
num_joint = count_blocks(state_dict, 'joint_transformer_blocks.{}.')
num_single = count_blocks(state_dict, 'single_transformer_blocks.{}.')
sd_map = comfy.utils.auraflow_to_diffusers({"n_double_layers": num_joint, "n_layers": num_joint + num_single}, output_prefix=output_prefix)
else: else:
return None return None

View File

@ -316,10 +316,18 @@ MMDIT_MAP_BLOCK = {
("context_block.mlp.fc1.weight", "ff_context.net.0.proj.weight"), ("context_block.mlp.fc1.weight", "ff_context.net.0.proj.weight"),
("context_block.mlp.fc2.bias", "ff_context.net.2.bias"), ("context_block.mlp.fc2.bias", "ff_context.net.2.bias"),
("context_block.mlp.fc2.weight", "ff_context.net.2.weight"), ("context_block.mlp.fc2.weight", "ff_context.net.2.weight"),
("context_block.attn.ln_q.weight", "attn.norm_added_q.weight"),
("context_block.attn.ln_k.weight", "attn.norm_added_k.weight"),
("x_block.adaLN_modulation.1.bias", "norm1.linear.bias"), ("x_block.adaLN_modulation.1.bias", "norm1.linear.bias"),
("x_block.adaLN_modulation.1.weight", "norm1.linear.weight"), ("x_block.adaLN_modulation.1.weight", "norm1.linear.weight"),
("x_block.attn.proj.bias", "attn.to_out.0.bias"), ("x_block.attn.proj.bias", "attn.to_out.0.bias"),
("x_block.attn.proj.weight", "attn.to_out.0.weight"), ("x_block.attn.proj.weight", "attn.to_out.0.weight"),
("x_block.attn.ln_q.weight", "attn.norm_q.weight"),
("x_block.attn.ln_k.weight", "attn.norm_k.weight"),
("x_block.attn2.proj.bias", "attn2.to_out.0.bias"),
("x_block.attn2.proj.weight", "attn2.to_out.0.weight"),
("x_block.attn2.ln_q.weight", "attn2.norm_q.weight"),
("x_block.attn2.ln_k.weight", "attn2.norm_k.weight"),
("x_block.mlp.fc1.bias", "ff.net.0.proj.bias"), ("x_block.mlp.fc1.bias", "ff.net.0.proj.bias"),
("x_block.mlp.fc1.weight", "ff.net.0.proj.weight"), ("x_block.mlp.fc1.weight", "ff.net.0.proj.weight"),
("x_block.mlp.fc2.bias", "ff.net.2.bias"), ("x_block.mlp.fc2.bias", "ff.net.2.bias"),
@ -349,6 +357,12 @@ def mmdit_to_diffusers(mmdit_config, output_prefix=""):
key_map["{}add_k_proj.{}".format(k, end)] = (qkv, (0, offset, offset)) key_map["{}add_k_proj.{}".format(k, end)] = (qkv, (0, offset, offset))
key_map["{}add_v_proj.{}".format(k, end)] = (qkv, (0, offset * 2, offset)) key_map["{}add_v_proj.{}".format(k, end)] = (qkv, (0, offset * 2, offset))
k = "{}.attn2.".format(block_from)
qkv = "{}.x_block.attn2.qkv.{}".format(block_to, end)
key_map["{}to_q.{}".format(k, end)] = (qkv, (0, 0, offset))
key_map["{}to_k.{}".format(k, end)] = (qkv, (0, offset, offset))
key_map["{}to_v.{}".format(k, end)] = (qkv, (0, offset * 2, offset))
for k in MMDIT_MAP_BLOCK: for k in MMDIT_MAP_BLOCK:
key_map["{}.{}".format(block_from, k[1])] = "{}.{}".format(block_to, k[0]) key_map["{}.{}".format(block_from, k[1])] = "{}.{}".format(block_to, k[0])