From 75b9b55b221fc95f7137a91e2349e45693e342b8 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 10 Aug 2024 21:28:24 -0400 Subject: [PATCH] Fix issues with #4302 and support loading diffusers format flux. --- comfy/model_detection.py | 14 ++++++++++++-- comfy/utils.py | 14 ++++++++------ 2 files changed, 20 insertions(+), 8 deletions(-) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 15e6b735..c05975cc 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -495,7 +495,12 @@ def model_config_from_diffusers_unet(state_dict): def convert_diffusers_mmdit(state_dict, output_prefix=""): out_sd = {} - if 'transformer_blocks.0.attn.add_q_proj.weight' in state_dict: #SD3 + if 'transformer_blocks.0.attn.norm_added_k.weight' in state_dict: #Flux + depth = count_blocks(state_dict, 'transformer_blocks.{}.') + depth_single_blocks = count_blocks(state_dict, 'single_transformer_blocks.{}.') + hidden_size = state_dict["x_embedder.bias"].shape[0] + sd_map = comfy.utils.flux_to_diffusers({"depth": depth, "depth_single_blocks": depth_single_blocks, "hidden_size": hidden_size}, output_prefix=output_prefix) + elif 'transformer_blocks.0.attn.add_q_proj.weight' in state_dict: #SD3 num_blocks = count_blocks(state_dict, 'transformer_blocks.{}.') depth = state_dict["pos_embed.proj.weight"].shape[0] // 64 sd_map = comfy.utils.mmdit_to_diffusers({"depth": depth, "num_blocks": num_blocks}, output_prefix=output_prefix) @@ -521,7 +526,12 @@ def convert_diffusers_mmdit(state_dict, output_prefix=""): old_weight = out_sd.get(t[0], None) if old_weight is None: old_weight = torch.empty_like(weight) - old_weight = old_weight.repeat([3] + [1] * (len(old_weight.shape) - 1)) + if old_weight.shape[offset[0]] < offset[1] + offset[2]: + exp = list(weight.shape) + exp[offset[0]] = offset[1] + offset[2] + new = torch.empty(exp, device=weight.device, dtype=weight.dtype) + new[:old_weight.shape[0]] = old_weight + old_weight = new w = old_weight.narrow(offset[0], offset[1], offset[2]) else: diff --git a/comfy/utils.py b/comfy/utils.py index a1e9213f..d0d410d9 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -474,6 +474,10 @@ def flux_to_diffusers(mmdit_config, output_prefix=""): "ff_context.net.0.proj.bias": "txt_mlp.0.bias", "ff_context.net.2.weight": "txt_mlp.2.weight", "ff_context.net.2.bias": "txt_mlp.2.bias", + "attn.norm_q.weight": "img_attn.norm.query_norm.scale", + "attn.norm_k.weight": "img_attn.norm.key_norm.scale", + "attn.norm_added_q.weight": "txt_attn.norm.query_norm.scale", + "attn.norm_added_k.weight": "txt_attn.norm.key_norm.scale", } for k in block_map: @@ -496,6 +500,8 @@ def flux_to_diffusers(mmdit_config, output_prefix=""): "norm.linear.bias": "modulation.lin.bias", "proj_out.weight": "linear2.weight", "proj_out.bias": "linear2.bias", + "attn.norm_q.weight": "norm.query_norm.scale", + "attn.norm_k.weight": "norm.key_norm.scale", } for k in block_map: @@ -514,18 +520,14 @@ def flux_to_diffusers(mmdit_config, output_prefix=""): ("txt_in.weight", "context_embedder.weight"), ("vector_in.in_layer.bias", "time_text_embed.text_embedder.linear_1.bias"), ("vector_in.in_layer.weight", "time_text_embed.text_embedder.linear_1.weight"), - ("vector_in.out_layer.bias", "time_text_embed.timestep_embedder.linear_2.bias"), + ("vector_in.out_layer.bias", "time_text_embed.text_embedder.linear_2.bias"), ("vector_in.out_layer.weight", "time_text_embed.text_embedder.linear_2.weight"), ("guidance_in.in_layer.bias", "time_text_embed.guidance_embedder.linear_1.bias"), ("guidance_in.in_layer.weight", "time_text_embed.guidance_embedder.linear_1.weight"), - ("guidance_in.out_layer.bias", "time_text_embed.guidance_embedder.linear_1.bias"), + ("guidance_in.out_layer.bias", "time_text_embed.guidance_embedder.linear_2.bias"), ("guidance_in.out_layer.weight", "time_text_embed.guidance_embedder.linear_2.weight"), ("final_layer.adaLN_modulation.1.bias", "norm_out.linear.bias", swap_scale_shift), ("final_layer.adaLN_modulation.1.weight", "norm_out.linear.weight", swap_scale_shift), - - # TODO: the values of these weights are different in Diffusers - ("guidance_in.out_layer.bias", "time_text_embed.guidance_embedder.linear_2.bias"), - ("vector_in.out_layer.bias", "time_text_embed.text_embedder.linear_2.bias"), } for k in MAP_BASIC: