mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-04-20 11:23:29 +00:00
SD2.x CLIP support for Loras.
This commit is contained in:
parent
3f3d77a324
commit
678105fade
46
comfy/sd.py
46
comfy/sd.py
@ -62,6 +62,12 @@ LORA_CLIP_MAP = {
|
|||||||
"self_attn.out_proj": "self_attn_out_proj",
|
"self_attn.out_proj": "self_attn_out_proj",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
LORA_CLIP2_MAP = {
|
||||||
|
"mlp.c_fc": "mlp_fc1",
|
||||||
|
"mlp.c_proj": "mlp_fc2",
|
||||||
|
"attn.out_proj": "self_attn_out_proj",
|
||||||
|
}
|
||||||
|
|
||||||
LORA_UNET_MAP = {
|
LORA_UNET_MAP = {
|
||||||
"proj_in": "proj_in",
|
"proj_in": "proj_in",
|
||||||
"proj_out": "proj_out",
|
"proj_out": "proj_out",
|
||||||
@ -110,7 +116,7 @@ def model_lora_keys(model, key_map={}):
|
|||||||
k = "{}.{}.weight".format(tk, c)
|
k = "{}.{}.weight".format(tk, c)
|
||||||
if k in sdk:
|
if k in sdk:
|
||||||
lora_key = "lora_unet_down_blocks_{}_attentions_{}_{}".format(counter // 2, counter % 2, LORA_UNET_MAP[c])
|
lora_key = "lora_unet_down_blocks_{}_attentions_{}_{}".format(counter // 2, counter % 2, LORA_UNET_MAP[c])
|
||||||
key_map[lora_key] = k
|
key_map[lora_key] = (k, 0)
|
||||||
up_counter += 1
|
up_counter += 1
|
||||||
if up_counter >= 4:
|
if up_counter >= 4:
|
||||||
counter += 1
|
counter += 1
|
||||||
@ -118,7 +124,7 @@ def model_lora_keys(model, key_map={}):
|
|||||||
k = "model.diffusion_model.middle_block.1.{}.weight".format(c)
|
k = "model.diffusion_model.middle_block.1.{}.weight".format(c)
|
||||||
if k in sdk:
|
if k in sdk:
|
||||||
lora_key = "lora_unet_mid_block_attentions_0_{}".format(LORA_UNET_MAP[c])
|
lora_key = "lora_unet_mid_block_attentions_0_{}".format(LORA_UNET_MAP[c])
|
||||||
key_map[lora_key] = k
|
key_map[lora_key] = (k, 0)
|
||||||
counter = 3
|
counter = 3
|
||||||
for b in range(12):
|
for b in range(12):
|
||||||
tk = "model.diffusion_model.output_blocks.{}.1".format(b)
|
tk = "model.diffusion_model.output_blocks.{}.1".format(b)
|
||||||
@ -127,17 +133,30 @@ def model_lora_keys(model, key_map={}):
|
|||||||
k = "{}.{}.weight".format(tk, c)
|
k = "{}.{}.weight".format(tk, c)
|
||||||
if k in sdk:
|
if k in sdk:
|
||||||
lora_key = "lora_unet_up_blocks_{}_attentions_{}_{}".format(counter // 3, counter % 3, LORA_UNET_MAP[c])
|
lora_key = "lora_unet_up_blocks_{}_attentions_{}_{}".format(counter // 3, counter % 3, LORA_UNET_MAP[c])
|
||||||
key_map[lora_key] = k
|
key_map[lora_key] = (k, 0)
|
||||||
up_counter += 1
|
up_counter += 1
|
||||||
if up_counter >= 4:
|
if up_counter >= 4:
|
||||||
counter += 1
|
counter += 1
|
||||||
counter = 0
|
counter = 0
|
||||||
|
text_model_lora_key = "lora_te_text_model_encoder_layers_{}_{}"
|
||||||
for b in range(12):
|
for b in range(12):
|
||||||
for c in LORA_CLIP_MAP:
|
for c in LORA_CLIP_MAP:
|
||||||
k = "transformer.text_model.encoder.layers.{}.{}.weight".format(b, c)
|
k = "transformer.text_model.encoder.layers.{}.{}.weight".format(b, c)
|
||||||
if k in sdk:
|
if k in sdk:
|
||||||
lora_key = "lora_te_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c])
|
lora_key = text_model_lora_key.format(b, LORA_CLIP_MAP[c])
|
||||||
key_map[lora_key] = k
|
key_map[lora_key] = (k, 0)
|
||||||
|
for b in range(24):
|
||||||
|
for c in LORA_CLIP2_MAP:
|
||||||
|
k = "model.transformer.resblocks.{}.{}.weight".format(b, c)
|
||||||
|
if k in sdk:
|
||||||
|
lora_key = text_model_lora_key.format(b, LORA_CLIP2_MAP[c])
|
||||||
|
key_map[lora_key] = (k, 0)
|
||||||
|
k = "model.transformer.resblocks.{}.attn.in_proj_weight".format(b)
|
||||||
|
if k in sdk:
|
||||||
|
key_map[text_model_lora_key.format(b, "self_attn_k_proj")] = (k, 0)
|
||||||
|
key_map[text_model_lora_key.format(b, "self_attn_q_proj")] = (k, 1)
|
||||||
|
key_map[text_model_lora_key.format(b, "self_attn_v_proj")] = (k, 2)
|
||||||
|
|
||||||
return key_map
|
return key_map
|
||||||
|
|
||||||
class ModelPatcher:
|
class ModelPatcher:
|
||||||
@ -155,7 +174,7 @@ class ModelPatcher:
|
|||||||
p = {}
|
p = {}
|
||||||
model_sd = self.model.state_dict()
|
model_sd = self.model.state_dict()
|
||||||
for k in patches:
|
for k in patches:
|
||||||
if k in model_sd:
|
if k[0] in model_sd:
|
||||||
p[k] = patches[k]
|
p[k] = patches[k]
|
||||||
self.patches += [(strength, p)]
|
self.patches += [(strength, p)]
|
||||||
return p.keys()
|
return p.keys()
|
||||||
@ -165,20 +184,25 @@ class ModelPatcher:
|
|||||||
for p in self.patches:
|
for p in self.patches:
|
||||||
for k in p[1]:
|
for k in p[1]:
|
||||||
v = p[1][k]
|
v = p[1][k]
|
||||||
if k not in model_sd:
|
key = k[0]
|
||||||
|
index = k[1]
|
||||||
|
if key not in model_sd:
|
||||||
print("could not patch. key doesn't exist in model:", k)
|
print("could not patch. key doesn't exist in model:", k)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
weight = model_sd[k]
|
weight = model_sd[key]
|
||||||
if k not in self.backup:
|
if key not in self.backup:
|
||||||
self.backup[k] = weight.clone()
|
self.backup[key] = weight.clone()
|
||||||
|
|
||||||
alpha = p[0]
|
alpha = p[0]
|
||||||
mat1 = v[0]
|
mat1 = v[0]
|
||||||
mat2 = v[1]
|
mat2 = v[1]
|
||||||
if v[2] is not None:
|
if v[2] is not None:
|
||||||
alpha *= v[2] / mat2.shape[0]
|
alpha *= v[2] / mat2.shape[0]
|
||||||
weight += (alpha * torch.mm(mat1.flatten(start_dim=1).float(), mat2.flatten(start_dim=1).float())).reshape(weight.shape).type(weight.dtype).to(weight.device)
|
calc = (alpha * torch.mm(mat1.flatten(start_dim=1).float(), mat2.flatten(start_dim=1).float()))
|
||||||
|
if len(weight.shape) > 2:
|
||||||
|
calc = calc.reshape(weight.shape)
|
||||||
|
weight[index * mat1.shape[0]:(index + 1) * mat1.shape[0]] += calc.type(weight.dtype).to(weight.device)
|
||||||
return self.model
|
return self.model
|
||||||
def unpatch_model(self):
|
def unpatch_model(self):
|
||||||
model_sd = self.model.state_dict()
|
model_sd = self.model.state_dict()
|
||||||
|
Loading…
Reference in New Issue
Block a user