From acf95191ff742edbe3dfe8ffe9e0c372d7773f2f Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 4 Jul 2023 21:10:12 -0400 Subject: [PATCH] Properly support SDXL diffusers loras for unet. --- comfy/sd.py | 127 +++++-------------------------------------------- comfy/utils.py | 117 +++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 128 insertions(+), 116 deletions(-) diff --git a/comfy/sd.py b/comfy/sd.py index 5001d497..360f2962 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -59,35 +59,6 @@ LORA_CLIP_MAP = { "self_attn.out_proj": "self_attn_out_proj", } -LORA_UNET_MAP_ATTENTIONS = { - "proj_in": "proj_in", - "proj_out": "proj_out", -} - -transformer_lora_blocks = { - "transformer_blocks.{}.attn1.to_q": "transformer_blocks_{}_attn1_to_q", - "transformer_blocks.{}.attn1.to_k": "transformer_blocks_{}_attn1_to_k", - "transformer_blocks.{}.attn1.to_v": "transformer_blocks_{}_attn1_to_v", - "transformer_blocks.{}.attn1.to_out.0": "transformer_blocks_{}_attn1_to_out_0", - "transformer_blocks.{}.attn2.to_q": "transformer_blocks_{}_attn2_to_q", - "transformer_blocks.{}.attn2.to_k": "transformer_blocks_{}_attn2_to_k", - "transformer_blocks.{}.attn2.to_v": "transformer_blocks_{}_attn2_to_v", - "transformer_blocks.{}.attn2.to_out.0": "transformer_blocks_{}_attn2_to_out_0", - "transformer_blocks.{}.ff.net.0.proj": "transformer_blocks_{}_ff_net_0_proj", - "transformer_blocks.{}.ff.net.2": "transformer_blocks_{}_ff_net_2", -} - -for i in range(10): - for k in transformer_lora_blocks: - LORA_UNET_MAP_ATTENTIONS[k.format(i)] = transformer_lora_blocks[k].format(i) - - -LORA_UNET_MAP_RESNET = { - "in_layers.2": "resnets_{}_conv1", - "emb_layers.1": "resnets_{}_time_emb_proj", - "out_layers.3": "resnets_{}_conv2", - "skip_connection": "resnets_{}_conv_shortcut" -} def load_lora(lora, to_load): patch_dict = {} @@ -188,39 +159,9 @@ def load_lora(lora, to_load): print("lora key not loaded", x) return patch_dict -def model_lora_keys(model, key_map={}): +def model_lora_keys_clip(model, key_map={}): sdk = model.state_dict().keys() - counter = 0 - for b in range(12): - tk = "diffusion_model.input_blocks.{}.1".format(b) - up_counter = 0 - for c in LORA_UNET_MAP_ATTENTIONS: - k = "{}.{}.weight".format(tk, c) - if k in sdk: - lora_key = "lora_unet_down_blocks_{}_attentions_{}_{}".format(counter // 2, counter % 2, LORA_UNET_MAP_ATTENTIONS[c]) - key_map[lora_key] = k - up_counter += 1 - if up_counter >= 4: - counter += 1 - for c in LORA_UNET_MAP_ATTENTIONS: - k = "diffusion_model.middle_block.1.{}.weight".format(c) - if k in sdk: - lora_key = "lora_unet_mid_block_attentions_0_{}".format(LORA_UNET_MAP_ATTENTIONS[c]) - key_map[lora_key] = k - counter = 3 - for b in range(12): - tk = "diffusion_model.output_blocks.{}.1".format(b) - up_counter = 0 - for c in LORA_UNET_MAP_ATTENTIONS: - k = "{}.{}.weight".format(tk, c) - if k in sdk: - lora_key = "lora_unet_up_blocks_{}_attentions_{}_{}".format(counter // 3, counter % 3, LORA_UNET_MAP_ATTENTIONS[c]) - key_map[lora_key] = k - up_counter += 1 - if up_counter >= 4: - counter += 1 - counter = 0 text_model_lora_key = "lora_te_text_model_encoder_layers_{}_{}" clip_l_present = False for b in range(32): @@ -244,69 +185,23 @@ def model_lora_keys(model, key_map={}): lora_key = "lora_te_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) #TODO: test if this is correct for SDXL-Refiner key_map[lora_key] = k + return key_map - #Locon stuff - ds_counter = 0 - counter = 0 - for b in range(12): - tk = "diffusion_model.input_blocks.{}.0".format(b) - key_in = False - for c in LORA_UNET_MAP_RESNET: - k = "{}.{}.weight".format(tk, c) - if k in sdk: - lora_key = "lora_unet_down_blocks_{}_{}".format(counter // 2, LORA_UNET_MAP_RESNET[c].format(counter % 2)) - key_map[lora_key] = k - key_in = True - for bb in range(3): - k = "{}.{}.op.weight".format(tk[:-2], bb) - if k in sdk: - lora_key = "lora_unet_down_blocks_{}_downsamplers_0_conv".format(ds_counter) - key_map[lora_key] = k - ds_counter += 1 - if key_in: - counter += 1 - - counter = 0 - for b in range(3): - tk = "diffusion_model.middle_block.{}".format(b) - key_in = False - for c in LORA_UNET_MAP_RESNET: - k = "{}.{}.weight".format(tk, c) - if k in sdk: - lora_key = "lora_unet_mid_block_{}".format(LORA_UNET_MAP_RESNET[c].format(counter)) - key_map[lora_key] = k - key_in = True - if key_in: - counter += 1 - - counter = 0 - us_counter = 0 - for b in range(12): - tk = "diffusion_model.output_blocks.{}.0".format(b) - key_in = False - for c in LORA_UNET_MAP_RESNET: - k = "{}.{}.weight".format(tk, c) - if k in sdk: - lora_key = "lora_unet_up_blocks_{}_{}".format(counter // 3, LORA_UNET_MAP_RESNET[c].format(counter % 3)) - key_map[lora_key] = k - key_in = True - for bb in range(3): - k = "{}.{}.conv.weight".format(tk[:-2], bb) - if k in sdk: - lora_key = "lora_unet_up_blocks_{}_upsamplers_0_conv".format(us_counter) - key_map[lora_key] = k - us_counter += 1 - if key_in: - counter += 1 +def model_lora_keys_unet(model, key_map={}): + sdk = model.state_dict().keys() for k in sdk: if k.startswith("diffusion_model.") and k.endswith(".weight"): key_lora = k[len("diffusion_model."):-len(".weight")].replace(".", "_") key_map["lora_unet_{}".format(key_lora)] = k + diffusers_keys = utils.unet_to_diffusers(model.model_config.unet_config) + for k in diffusers_keys: + if k.endswith(".weight"): + key_lora = k[:-len(".weight")].replace(".", "_") + key_map["lora_unet_{}".format(key_lora)] = "diffusion_model.{}".format(diffusers_keys[k]) return key_map - class ModelPatcher: def __init__(self, model, load_device, offload_device, size=0): self.size = size @@ -506,8 +401,8 @@ class ModelPatcher: self.backup = {} def load_lora_for_models(model, clip, lora, strength_model, strength_clip): - key_map = model_lora_keys(model.model) - key_map = model_lora_keys(clip.cond_stage_model, key_map) + key_map = model_lora_keys_unet(model.model) + key_map = model_lora_keys_clip(clip.cond_stage_model, key_map) loaded = load_lora(lora, key_map) new_modelpatcher = model.clone() k = new_modelpatcher.add_patches(loaded, strength_model) diff --git a/comfy/utils.py b/comfy/utils.py index b6434905..25ccd944 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -70,6 +70,123 @@ def transformers_convert(sd, prefix_from, prefix_to, number): sd[k_to] = weights[shape_from*x:shape_from*(x + 1)] return sd +UNET_MAP_ATTENTIONS = { + "proj_in.weight", + "proj_in.bias", + "proj_out.weight", + "proj_out.bias", + "norm.weight", + "norm.bias", +} + +TRANSFORMER_BLOCKS = { + "norm1.weight", + "norm1.bias", + "norm2.weight", + "norm2.bias", + "norm3.weight", + "norm3.bias", + "attn1.to_q.weight", + "attn1.to_k.weight", + "attn1.to_v.weight", + "attn1.to_out.0.weight", + "attn1.to_out.0.bias", + "attn2.to_q.weight", + "attn2.to_k.weight", + "attn2.to_v.weight", + "attn2.to_out.0.weight", + "attn2.to_out.0.bias", + "ff.net.0.proj.weight", + "ff.net.0.proj.bias", + "ff.net.2.weight", + "ff.net.2.bias", +} + +UNET_MAP_RESNET = { + "in_layers.2.weight": "conv1.weight", + "in_layers.2.bias": "conv1.bias", + "emb_layers.1.weight": "time_emb_proj.weight", + "emb_layers.1.bias": "time_emb_proj.bias", + "out_layers.3.weight": "conv2.weight", + "out_layers.3.bias": "conv2.bias", + "skip_connection.weight": "conv_shortcut.weight", + "skip_connection.bias": "conv_shortcut.bias", + "in_layers.0.weight": "norm1.weight", + "in_layers.0.bias": "norm1.bias", + "out_layers.0.weight": "norm2.weight", + "out_layers.0.bias": "norm2.bias", +} + +def unet_to_diffusers(unet_config): + num_res_blocks = unet_config["num_res_blocks"] + attention_resolutions = unet_config["attention_resolutions"] + channel_mult = unet_config["channel_mult"] + transformer_depth = unet_config["transformer_depth"] + num_blocks = len(channel_mult) + if not isinstance(num_res_blocks, list): + num_res_blocks = [num_res_blocks] * num_blocks + + transformers_per_layer = [] + res = 1 + for i in range(num_blocks): + transformers = 0 + if res in attention_resolutions: + transformers = transformer_depth[i] + transformers_per_layer.append(transformers) + res *= 2 + + transformers_mid = unet_config.get("transformer_depth_middle", transformers_per_layer[-1]) + + diffusers_unet_map = {} + for x in range(num_blocks): + n = 1 + (num_res_blocks[x] + 1) * x + for i in range(num_res_blocks[x]): + for b in UNET_MAP_RESNET: + diffusers_unet_map["down_blocks.{}.resnets.{}.{}".format(x, i, UNET_MAP_RESNET[b])] = "input_blocks.{}.0.{}".format(n, b) + if transformers_per_layer[x] > 0: + for b in UNET_MAP_ATTENTIONS: + diffusers_unet_map["down_blocks.{}.attentions.{}.{}".format(x, i, b)] = "input_blocks.{}.1.{}".format(n, b) + for t in range(transformers_per_layer[x]): + for b in TRANSFORMER_BLOCKS: + diffusers_unet_map["down_blocks.{}.attentions.{}.transformer_blocks.{}.{}".format(x, i, t, b)] = "input_blocks.{}.1.transformer_blocks.{}.{}".format(n, t, b) + n += 1 + for k in ["weight", "bias"]: + diffusers_unet_map["down_blocks.{}.downsamplers.0.conv.{}".format(x, k)] = "input_blocks.{}.0.op.{}".format(n, k) + + i = 0 + for b in UNET_MAP_ATTENTIONS: + diffusers_unet_map["mid_block.attentions.{}.{}".format(i, b)] = "middle_block.1.{}".format(b) + for t in range(transformers_mid): + for b in TRANSFORMER_BLOCKS: + diffusers_unet_map["mid_block.attentions.{}.transformer_blocks.{}.{}".format(i, t, b)] = "middle_block.1.transformer_blocks.{}.{}".format(t, b) + + for i, n in enumerate([0, 2]): + for b in UNET_MAP_RESNET: + diffusers_unet_map["mid_block.resnets.{}.{}".format(i, UNET_MAP_RESNET[b])] = "middle_block.{}.{}".format(n, b) + + num_res_blocks = list(reversed(num_res_blocks)) + transformers_per_layer = list(reversed(transformers_per_layer)) + for x in range(num_blocks): + n = (num_res_blocks[x] + 1) * x + l = num_res_blocks[x] + 1 + for i in range(l): + c = 0 + for b in UNET_MAP_RESNET: + diffusers_unet_map["up_blocks.{}.resnets.{}.{}".format(x, i, UNET_MAP_RESNET[b])] = "output_blocks.{}.0.{}".format(n, b) + c += 1 + if transformers_per_layer[x] > 0: + c += 1 + for b in UNET_MAP_ATTENTIONS: + diffusers_unet_map["up_blocks.{}.attentions.{}.{}".format(x, i, b)] = "output_blocks.{}.1.{}".format(n, b) + for t in range(transformers_per_layer[x]): + for b in TRANSFORMER_BLOCKS: + diffusers_unet_map["up_blocks.{}.attentions.{}.transformer_blocks.{}.{}".format(x, i, t, b)] = "output_blocks.{}.1.transformer_blocks.{}.{}".format(n, t, b) + if i == l - 1: + for k in ["weight", "bias"]: + diffusers_unet_map["up_blocks.{}.upsamplers.0.conv.{}".format(x, k)] = "output_blocks.{}.{}.conv.{}".format(n, c, k) + n += 1 + return diffusers_unet_map + def convert_sd_to(state_dict, dtype): keys = list(state_dict.keys()) for k in keys: