mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-11 02:15:17 +00:00
Fix issues with #4302 and support loading diffusers format flux.
This commit is contained in:
parent
1765f1c60c
commit
75b9b55b22
@ -495,7 +495,12 @@ def model_config_from_diffusers_unet(state_dict):
|
|||||||
def convert_diffusers_mmdit(state_dict, output_prefix=""):
|
def convert_diffusers_mmdit(state_dict, output_prefix=""):
|
||||||
out_sd = {}
|
out_sd = {}
|
||||||
|
|
||||||
if 'transformer_blocks.0.attn.add_q_proj.weight' in state_dict: #SD3
|
if 'transformer_blocks.0.attn.norm_added_k.weight' in state_dict: #Flux
|
||||||
|
depth = count_blocks(state_dict, 'transformer_blocks.{}.')
|
||||||
|
depth_single_blocks = count_blocks(state_dict, 'single_transformer_blocks.{}.')
|
||||||
|
hidden_size = state_dict["x_embedder.bias"].shape[0]
|
||||||
|
sd_map = comfy.utils.flux_to_diffusers({"depth": depth, "depth_single_blocks": depth_single_blocks, "hidden_size": hidden_size}, output_prefix=output_prefix)
|
||||||
|
elif 'transformer_blocks.0.attn.add_q_proj.weight' in state_dict: #SD3
|
||||||
num_blocks = count_blocks(state_dict, 'transformer_blocks.{}.')
|
num_blocks = count_blocks(state_dict, 'transformer_blocks.{}.')
|
||||||
depth = state_dict["pos_embed.proj.weight"].shape[0] // 64
|
depth = state_dict["pos_embed.proj.weight"].shape[0] // 64
|
||||||
sd_map = comfy.utils.mmdit_to_diffusers({"depth": depth, "num_blocks": num_blocks}, output_prefix=output_prefix)
|
sd_map = comfy.utils.mmdit_to_diffusers({"depth": depth, "num_blocks": num_blocks}, output_prefix=output_prefix)
|
||||||
@ -521,7 +526,12 @@ def convert_diffusers_mmdit(state_dict, output_prefix=""):
|
|||||||
old_weight = out_sd.get(t[0], None)
|
old_weight = out_sd.get(t[0], None)
|
||||||
if old_weight is None:
|
if old_weight is None:
|
||||||
old_weight = torch.empty_like(weight)
|
old_weight = torch.empty_like(weight)
|
||||||
old_weight = old_weight.repeat([3] + [1] * (len(old_weight.shape) - 1))
|
if old_weight.shape[offset[0]] < offset[1] + offset[2]:
|
||||||
|
exp = list(weight.shape)
|
||||||
|
exp[offset[0]] = offset[1] + offset[2]
|
||||||
|
new = torch.empty(exp, device=weight.device, dtype=weight.dtype)
|
||||||
|
new[:old_weight.shape[0]] = old_weight
|
||||||
|
old_weight = new
|
||||||
|
|
||||||
w = old_weight.narrow(offset[0], offset[1], offset[2])
|
w = old_weight.narrow(offset[0], offset[1], offset[2])
|
||||||
else:
|
else:
|
||||||
|
@ -474,6 +474,10 @@ def flux_to_diffusers(mmdit_config, output_prefix=""):
|
|||||||
"ff_context.net.0.proj.bias": "txt_mlp.0.bias",
|
"ff_context.net.0.proj.bias": "txt_mlp.0.bias",
|
||||||
"ff_context.net.2.weight": "txt_mlp.2.weight",
|
"ff_context.net.2.weight": "txt_mlp.2.weight",
|
||||||
"ff_context.net.2.bias": "txt_mlp.2.bias",
|
"ff_context.net.2.bias": "txt_mlp.2.bias",
|
||||||
|
"attn.norm_q.weight": "img_attn.norm.query_norm.scale",
|
||||||
|
"attn.norm_k.weight": "img_attn.norm.key_norm.scale",
|
||||||
|
"attn.norm_added_q.weight": "txt_attn.norm.query_norm.scale",
|
||||||
|
"attn.norm_added_k.weight": "txt_attn.norm.key_norm.scale",
|
||||||
}
|
}
|
||||||
|
|
||||||
for k in block_map:
|
for k in block_map:
|
||||||
@ -496,6 +500,8 @@ def flux_to_diffusers(mmdit_config, output_prefix=""):
|
|||||||
"norm.linear.bias": "modulation.lin.bias",
|
"norm.linear.bias": "modulation.lin.bias",
|
||||||
"proj_out.weight": "linear2.weight",
|
"proj_out.weight": "linear2.weight",
|
||||||
"proj_out.bias": "linear2.bias",
|
"proj_out.bias": "linear2.bias",
|
||||||
|
"attn.norm_q.weight": "norm.query_norm.scale",
|
||||||
|
"attn.norm_k.weight": "norm.key_norm.scale",
|
||||||
}
|
}
|
||||||
|
|
||||||
for k in block_map:
|
for k in block_map:
|
||||||
@ -514,18 +520,14 @@ def flux_to_diffusers(mmdit_config, output_prefix=""):
|
|||||||
("txt_in.weight", "context_embedder.weight"),
|
("txt_in.weight", "context_embedder.weight"),
|
||||||
("vector_in.in_layer.bias", "time_text_embed.text_embedder.linear_1.bias"),
|
("vector_in.in_layer.bias", "time_text_embed.text_embedder.linear_1.bias"),
|
||||||
("vector_in.in_layer.weight", "time_text_embed.text_embedder.linear_1.weight"),
|
("vector_in.in_layer.weight", "time_text_embed.text_embedder.linear_1.weight"),
|
||||||
("vector_in.out_layer.bias", "time_text_embed.timestep_embedder.linear_2.bias"),
|
("vector_in.out_layer.bias", "time_text_embed.text_embedder.linear_2.bias"),
|
||||||
("vector_in.out_layer.weight", "time_text_embed.text_embedder.linear_2.weight"),
|
("vector_in.out_layer.weight", "time_text_embed.text_embedder.linear_2.weight"),
|
||||||
("guidance_in.in_layer.bias", "time_text_embed.guidance_embedder.linear_1.bias"),
|
("guidance_in.in_layer.bias", "time_text_embed.guidance_embedder.linear_1.bias"),
|
||||||
("guidance_in.in_layer.weight", "time_text_embed.guidance_embedder.linear_1.weight"),
|
("guidance_in.in_layer.weight", "time_text_embed.guidance_embedder.linear_1.weight"),
|
||||||
("guidance_in.out_layer.bias", "time_text_embed.guidance_embedder.linear_1.bias"),
|
("guidance_in.out_layer.bias", "time_text_embed.guidance_embedder.linear_2.bias"),
|
||||||
("guidance_in.out_layer.weight", "time_text_embed.guidance_embedder.linear_2.weight"),
|
("guidance_in.out_layer.weight", "time_text_embed.guidance_embedder.linear_2.weight"),
|
||||||
("final_layer.adaLN_modulation.1.bias", "norm_out.linear.bias", swap_scale_shift),
|
("final_layer.adaLN_modulation.1.bias", "norm_out.linear.bias", swap_scale_shift),
|
||||||
("final_layer.adaLN_modulation.1.weight", "norm_out.linear.weight", swap_scale_shift),
|
("final_layer.adaLN_modulation.1.weight", "norm_out.linear.weight", swap_scale_shift),
|
||||||
|
|
||||||
# TODO: the values of these weights are different in Diffusers
|
|
||||||
("guidance_in.out_layer.bias", "time_text_embed.guidance_embedder.linear_2.bias"),
|
|
||||||
("vector_in.out_layer.bias", "time_text_embed.text_embedder.linear_2.bias"),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for k in MAP_BASIC:
|
for k in MAP_BASIC:
|
||||||
|
Loading…
Reference in New Issue
Block a user