From 678105fade833ff36664432ce97d3f9f73f725bd Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sun, 5 Feb 2023 01:54:09 -0500 Subject: [PATCH] SD2.x CLIP support for Loras. --- comfy/sd.py | 46 +++++++++++++++++++++++++++++++++++----------- 1 file changed, 35 insertions(+), 11 deletions(-) diff --git a/comfy/sd.py b/comfy/sd.py index bb7023f7..23078e2e 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -62,6 +62,12 @@ LORA_CLIP_MAP = { "self_attn.out_proj": "self_attn_out_proj", } +LORA_CLIP2_MAP = { + "mlp.c_fc": "mlp_fc1", + "mlp.c_proj": "mlp_fc2", + "attn.out_proj": "self_attn_out_proj", +} + LORA_UNET_MAP = { "proj_in": "proj_in", "proj_out": "proj_out", @@ -110,7 +116,7 @@ def model_lora_keys(model, key_map={}): k = "{}.{}.weight".format(tk, c) if k in sdk: lora_key = "lora_unet_down_blocks_{}_attentions_{}_{}".format(counter // 2, counter % 2, LORA_UNET_MAP[c]) - key_map[lora_key] = k + key_map[lora_key] = (k, 0) up_counter += 1 if up_counter >= 4: counter += 1 @@ -118,7 +124,7 @@ def model_lora_keys(model, key_map={}): k = "model.diffusion_model.middle_block.1.{}.weight".format(c) if k in sdk: lora_key = "lora_unet_mid_block_attentions_0_{}".format(LORA_UNET_MAP[c]) - key_map[lora_key] = k + key_map[lora_key] = (k, 0) counter = 3 for b in range(12): tk = "model.diffusion_model.output_blocks.{}.1".format(b) @@ -127,17 +133,30 @@ def model_lora_keys(model, key_map={}): k = "{}.{}.weight".format(tk, c) if k in sdk: lora_key = "lora_unet_up_blocks_{}_attentions_{}_{}".format(counter // 3, counter % 3, LORA_UNET_MAP[c]) - key_map[lora_key] = k + key_map[lora_key] = (k, 0) up_counter += 1 if up_counter >= 4: counter += 1 counter = 0 + text_model_lora_key = "lora_te_text_model_encoder_layers_{}_{}" for b in range(12): for c in LORA_CLIP_MAP: k = "transformer.text_model.encoder.layers.{}.{}.weight".format(b, c) if k in sdk: - lora_key = "lora_te_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) - key_map[lora_key] = k + lora_key = text_model_lora_key.format(b, LORA_CLIP_MAP[c]) + key_map[lora_key] = (k, 0) + for b in range(24): + for c in LORA_CLIP2_MAP: + k = "model.transformer.resblocks.{}.{}.weight".format(b, c) + if k in sdk: + lora_key = text_model_lora_key.format(b, LORA_CLIP2_MAP[c]) + key_map[lora_key] = (k, 0) + k = "model.transformer.resblocks.{}.attn.in_proj_weight".format(b) + if k in sdk: + key_map[text_model_lora_key.format(b, "self_attn_k_proj")] = (k, 0) + key_map[text_model_lora_key.format(b, "self_attn_q_proj")] = (k, 1) + key_map[text_model_lora_key.format(b, "self_attn_v_proj")] = (k, 2) + return key_map class ModelPatcher: @@ -155,7 +174,7 @@ class ModelPatcher: p = {} model_sd = self.model.state_dict() for k in patches: - if k in model_sd: + if k[0] in model_sd: p[k] = patches[k] self.patches += [(strength, p)] return p.keys() @@ -165,20 +184,25 @@ class ModelPatcher: for p in self.patches: for k in p[1]: v = p[1][k] - if k not in model_sd: + key = k[0] + index = k[1] + if key not in model_sd: print("could not patch. key doesn't exist in model:", k) continue - weight = model_sd[k] - if k not in self.backup: - self.backup[k] = weight.clone() + weight = model_sd[key] + if key not in self.backup: + self.backup[key] = weight.clone() alpha = p[0] mat1 = v[0] mat2 = v[1] if v[2] is not None: alpha *= v[2] / mat2.shape[0] - weight += (alpha * torch.mm(mat1.flatten(start_dim=1).float(), mat2.flatten(start_dim=1).float())).reshape(weight.shape).type(weight.dtype).to(weight.device) + calc = (alpha * torch.mm(mat1.flatten(start_dim=1).float(), mat2.flatten(start_dim=1).float())) + if len(weight.shape) > 2: + calc = calc.reshape(weight.shape) + weight[index * mat1.shape[0]:(index + 1) * mat1.shape[0]] += calc.type(weight.dtype).to(weight.device) return self.model def unpatch_model(self): model_sd = self.model.state_dict()