From 1765f1c60c862f8fc9f3346384a37c5d6d13d35b Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Sat, 10 Aug 2024 19:26:41 -0600 Subject: [PATCH] FLUX: Added full diffusers mapping for FLUX.1 schnell and dev. Adds full LoRA support from diffusers LoRAs. (#4302) --- comfy/utils.py | 53 +++++++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 48 insertions(+), 5 deletions(-) diff --git a/comfy/utils.py b/comfy/utils.py index e6736dbd..a1e9213f 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -457,8 +457,23 @@ def flux_to_diffusers(mmdit_config, output_prefix=""): key_map["{}add_k_proj.{}".format(k, end)] = (qkv, (0, hidden_size, hidden_size)) key_map["{}add_v_proj.{}".format(k, end)] = (qkv, (0, hidden_size * 2, hidden_size)) - block_map = {"attn.to_out.0.weight": "img_attn.proj.weight", - "attn.to_out.0.bias": "img_attn.proj.bias", + block_map = { + "attn.to_out.0.weight": "img_attn.proj.weight", + "attn.to_out.0.bias": "img_attn.proj.bias", + "norm1.linear.weight": "img_mod.lin.weight", + "norm1.linear.bias": "img_mod.lin.bias", + "norm1_context.linear.weight": "txt_mod.lin.weight", + "norm1_context.linear.bias": "txt_mod.lin.bias", + "attn.to_add_out.weight": "txt_attn.proj.weight", + "attn.to_add_out.bias": "txt_attn.proj.bias", + "ff.net.0.proj.weight": "img_mlp.0.weight", + "ff.net.0.proj.bias": "img_mlp.0.bias", + "ff.net.2.weight": "img_mlp.2.weight", + "ff.net.2.bias": "img_mlp.2.bias", + "ff_context.net.0.proj.weight": "txt_mlp.0.weight", + "ff_context.net.0.proj.bias": "txt_mlp.0.bias", + "ff_context.net.2.weight": "txt_mlp.2.weight", + "ff_context.net.2.bias": "txt_mlp.2.bias", } for k in block_map: @@ -474,15 +489,43 @@ def flux_to_diffusers(mmdit_config, output_prefix=""): key_map["{}to_q.{}".format(k, end)] = (qkv, (0, 0, hidden_size)) key_map["{}to_k.{}".format(k, end)] = (qkv, (0, hidden_size, hidden_size)) key_map["{}to_v.{}".format(k, end)] = (qkv, (0, hidden_size * 2, hidden_size)) - key_map["{}proj_mlp.{}".format(k, end)] = (qkv, (0, hidden_size * 3, hidden_size)) + key_map["{}.proj_mlp.{}".format(prefix_from, end)] = (qkv, (0, hidden_size * 3, hidden_size * 4)) - block_map = {#TODO + block_map = { + "norm.linear.weight": "modulation.lin.weight", + "norm.linear.bias": "modulation.lin.bias", + "proj_out.weight": "linear2.weight", + "proj_out.bias": "linear2.bias", } for k in block_map: key_map["{}.{}".format(prefix_from, k)] = "{}.{}".format(prefix_to, block_map[k]) - MAP_BASIC = { #TODO + MAP_BASIC = { + ("final_layer.linear.bias", "proj_out.bias"), + ("final_layer.linear.weight", "proj_out.weight"), + ("img_in.bias", "x_embedder.bias"), + ("img_in.weight", "x_embedder.weight"), + ("time_in.in_layer.bias", "time_text_embed.timestep_embedder.linear_1.bias"), + ("time_in.in_layer.weight", "time_text_embed.timestep_embedder.linear_1.weight"), + ("time_in.out_layer.bias", "time_text_embed.timestep_embedder.linear_2.bias"), + ("time_in.out_layer.weight", "time_text_embed.timestep_embedder.linear_2.weight"), + ("txt_in.bias", "context_embedder.bias"), + ("txt_in.weight", "context_embedder.weight"), + ("vector_in.in_layer.bias", "time_text_embed.text_embedder.linear_1.bias"), + ("vector_in.in_layer.weight", "time_text_embed.text_embedder.linear_1.weight"), + ("vector_in.out_layer.bias", "time_text_embed.timestep_embedder.linear_2.bias"), + ("vector_in.out_layer.weight", "time_text_embed.text_embedder.linear_2.weight"), + ("guidance_in.in_layer.bias", "time_text_embed.guidance_embedder.linear_1.bias"), + ("guidance_in.in_layer.weight", "time_text_embed.guidance_embedder.linear_1.weight"), + ("guidance_in.out_layer.bias", "time_text_embed.guidance_embedder.linear_1.bias"), + ("guidance_in.out_layer.weight", "time_text_embed.guidance_embedder.linear_2.weight"), + ("final_layer.adaLN_modulation.1.bias", "norm_out.linear.bias", swap_scale_shift), + ("final_layer.adaLN_modulation.1.weight", "norm_out.linear.weight", swap_scale_shift), + + # TODO: the values of these weights are different in Diffusers + ("guidance_in.out_layer.bias", "time_text_embed.guidance_embedder.linear_2.bias"), + ("vector_in.out_layer.bias", "time_text_embed.text_embedder.linear_2.bias"), } for k in MAP_BASIC: