mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-11 02:15:17 +00:00
17030fd4c0
The keys are just: model.full.model.key.name.lora_up.weight It is supported by all comfyui supported models. Now people can just convert loras to this format instead of having to ask for me to implement them.
301 lines
13 KiB
Python
301 lines
13 KiB
Python
import comfy.utils
|
|
import logging
|
|
|
|
LORA_CLIP_MAP = {
|
|
"mlp.fc1": "mlp_fc1",
|
|
"mlp.fc2": "mlp_fc2",
|
|
"self_attn.k_proj": "self_attn_k_proj",
|
|
"self_attn.q_proj": "self_attn_q_proj",
|
|
"self_attn.v_proj": "self_attn_v_proj",
|
|
"self_attn.out_proj": "self_attn_out_proj",
|
|
}
|
|
|
|
|
|
def load_lora(lora, to_load):
|
|
patch_dict = {}
|
|
loaded_keys = set()
|
|
for x in to_load:
|
|
alpha_name = "{}.alpha".format(x)
|
|
alpha = None
|
|
if alpha_name in lora.keys():
|
|
alpha = lora[alpha_name].item()
|
|
loaded_keys.add(alpha_name)
|
|
|
|
dora_scale_name = "{}.dora_scale".format(x)
|
|
dora_scale = None
|
|
if dora_scale_name in lora.keys():
|
|
dora_scale = lora[dora_scale_name]
|
|
loaded_keys.add(dora_scale_name)
|
|
|
|
regular_lora = "{}.lora_up.weight".format(x)
|
|
diffusers_lora = "{}_lora.up.weight".format(x)
|
|
diffusers2_lora = "{}.lora_B.weight".format(x)
|
|
diffusers3_lora = "{}.lora.up.weight".format(x)
|
|
transformers_lora = "{}.lora_linear_layer.up.weight".format(x)
|
|
A_name = None
|
|
|
|
if regular_lora in lora.keys():
|
|
A_name = regular_lora
|
|
B_name = "{}.lora_down.weight".format(x)
|
|
mid_name = "{}.lora_mid.weight".format(x)
|
|
elif diffusers_lora in lora.keys():
|
|
A_name = diffusers_lora
|
|
B_name = "{}_lora.down.weight".format(x)
|
|
mid_name = None
|
|
elif diffusers2_lora in lora.keys():
|
|
A_name = diffusers2_lora
|
|
B_name = "{}.lora_A.weight".format(x)
|
|
mid_name = None
|
|
elif diffusers3_lora in lora.keys():
|
|
A_name = diffusers3_lora
|
|
B_name = "{}.lora.down.weight".format(x)
|
|
mid_name = None
|
|
elif transformers_lora in lora.keys():
|
|
A_name = transformers_lora
|
|
B_name ="{}.lora_linear_layer.down.weight".format(x)
|
|
mid_name = None
|
|
|
|
if A_name is not None:
|
|
mid = None
|
|
if mid_name is not None and mid_name in lora.keys():
|
|
mid = lora[mid_name]
|
|
loaded_keys.add(mid_name)
|
|
patch_dict[to_load[x]] = ("lora", (lora[A_name], lora[B_name], alpha, mid, dora_scale))
|
|
loaded_keys.add(A_name)
|
|
loaded_keys.add(B_name)
|
|
|
|
|
|
######## loha
|
|
hada_w1_a_name = "{}.hada_w1_a".format(x)
|
|
hada_w1_b_name = "{}.hada_w1_b".format(x)
|
|
hada_w2_a_name = "{}.hada_w2_a".format(x)
|
|
hada_w2_b_name = "{}.hada_w2_b".format(x)
|
|
hada_t1_name = "{}.hada_t1".format(x)
|
|
hada_t2_name = "{}.hada_t2".format(x)
|
|
if hada_w1_a_name in lora.keys():
|
|
hada_t1 = None
|
|
hada_t2 = None
|
|
if hada_t1_name in lora.keys():
|
|
hada_t1 = lora[hada_t1_name]
|
|
hada_t2 = lora[hada_t2_name]
|
|
loaded_keys.add(hada_t1_name)
|
|
loaded_keys.add(hada_t2_name)
|
|
|
|
patch_dict[to_load[x]] = ("loha", (lora[hada_w1_a_name], lora[hada_w1_b_name], alpha, lora[hada_w2_a_name], lora[hada_w2_b_name], hada_t1, hada_t2, dora_scale))
|
|
loaded_keys.add(hada_w1_a_name)
|
|
loaded_keys.add(hada_w1_b_name)
|
|
loaded_keys.add(hada_w2_a_name)
|
|
loaded_keys.add(hada_w2_b_name)
|
|
|
|
|
|
######## lokr
|
|
lokr_w1_name = "{}.lokr_w1".format(x)
|
|
lokr_w2_name = "{}.lokr_w2".format(x)
|
|
lokr_w1_a_name = "{}.lokr_w1_a".format(x)
|
|
lokr_w1_b_name = "{}.lokr_w1_b".format(x)
|
|
lokr_t2_name = "{}.lokr_t2".format(x)
|
|
lokr_w2_a_name = "{}.lokr_w2_a".format(x)
|
|
lokr_w2_b_name = "{}.lokr_w2_b".format(x)
|
|
|
|
lokr_w1 = None
|
|
if lokr_w1_name in lora.keys():
|
|
lokr_w1 = lora[lokr_w1_name]
|
|
loaded_keys.add(lokr_w1_name)
|
|
|
|
lokr_w2 = None
|
|
if lokr_w2_name in lora.keys():
|
|
lokr_w2 = lora[lokr_w2_name]
|
|
loaded_keys.add(lokr_w2_name)
|
|
|
|
lokr_w1_a = None
|
|
if lokr_w1_a_name in lora.keys():
|
|
lokr_w1_a = lora[lokr_w1_a_name]
|
|
loaded_keys.add(lokr_w1_a_name)
|
|
|
|
lokr_w1_b = None
|
|
if lokr_w1_b_name in lora.keys():
|
|
lokr_w1_b = lora[lokr_w1_b_name]
|
|
loaded_keys.add(lokr_w1_b_name)
|
|
|
|
lokr_w2_a = None
|
|
if lokr_w2_a_name in lora.keys():
|
|
lokr_w2_a = lora[lokr_w2_a_name]
|
|
loaded_keys.add(lokr_w2_a_name)
|
|
|
|
lokr_w2_b = None
|
|
if lokr_w2_b_name in lora.keys():
|
|
lokr_w2_b = lora[lokr_w2_b_name]
|
|
loaded_keys.add(lokr_w2_b_name)
|
|
|
|
lokr_t2 = None
|
|
if lokr_t2_name in lora.keys():
|
|
lokr_t2 = lora[lokr_t2_name]
|
|
loaded_keys.add(lokr_t2_name)
|
|
|
|
if (lokr_w1 is not None) or (lokr_w2 is not None) or (lokr_w1_a is not None) or (lokr_w2_a is not None):
|
|
patch_dict[to_load[x]] = ("lokr", (lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2, dora_scale))
|
|
|
|
#glora
|
|
a1_name = "{}.a1.weight".format(x)
|
|
a2_name = "{}.a2.weight".format(x)
|
|
b1_name = "{}.b1.weight".format(x)
|
|
b2_name = "{}.b2.weight".format(x)
|
|
if a1_name in lora:
|
|
patch_dict[to_load[x]] = ("glora", (lora[a1_name], lora[a2_name], lora[b1_name], lora[b2_name], alpha, dora_scale))
|
|
loaded_keys.add(a1_name)
|
|
loaded_keys.add(a2_name)
|
|
loaded_keys.add(b1_name)
|
|
loaded_keys.add(b2_name)
|
|
|
|
w_norm_name = "{}.w_norm".format(x)
|
|
b_norm_name = "{}.b_norm".format(x)
|
|
w_norm = lora.get(w_norm_name, None)
|
|
b_norm = lora.get(b_norm_name, None)
|
|
|
|
if w_norm is not None:
|
|
loaded_keys.add(w_norm_name)
|
|
patch_dict[to_load[x]] = ("diff", (w_norm,))
|
|
if b_norm is not None:
|
|
loaded_keys.add(b_norm_name)
|
|
patch_dict["{}.bias".format(to_load[x][:-len(".weight")])] = ("diff", (b_norm,))
|
|
|
|
diff_name = "{}.diff".format(x)
|
|
diff_weight = lora.get(diff_name, None)
|
|
if diff_weight is not None:
|
|
patch_dict[to_load[x]] = ("diff", (diff_weight,))
|
|
loaded_keys.add(diff_name)
|
|
|
|
diff_bias_name = "{}.diff_b".format(x)
|
|
diff_bias = lora.get(diff_bias_name, None)
|
|
if diff_bias is not None:
|
|
patch_dict["{}.bias".format(to_load[x][:-len(".weight")])] = ("diff", (diff_bias,))
|
|
loaded_keys.add(diff_bias_name)
|
|
|
|
for x in lora.keys():
|
|
if x not in loaded_keys:
|
|
logging.warning("lora key not loaded: {}".format(x))
|
|
|
|
return patch_dict
|
|
|
|
def model_lora_keys_clip(model, key_map={}):
|
|
sdk = model.state_dict().keys()
|
|
|
|
text_model_lora_key = "lora_te_text_model_encoder_layers_{}_{}"
|
|
clip_l_present = False
|
|
for b in range(32): #TODO: clean up
|
|
for c in LORA_CLIP_MAP:
|
|
k = "clip_h.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c)
|
|
if k in sdk:
|
|
lora_key = text_model_lora_key.format(b, LORA_CLIP_MAP[c])
|
|
key_map[lora_key] = k
|
|
lora_key = "lora_te1_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c])
|
|
key_map[lora_key] = k
|
|
lora_key = "text_encoder.text_model.encoder.layers.{}.{}".format(b, c) #diffusers lora
|
|
key_map[lora_key] = k
|
|
|
|
k = "clip_l.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c)
|
|
if k in sdk:
|
|
lora_key = text_model_lora_key.format(b, LORA_CLIP_MAP[c])
|
|
key_map[lora_key] = k
|
|
lora_key = "lora_te1_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) #SDXL base
|
|
key_map[lora_key] = k
|
|
clip_l_present = True
|
|
lora_key = "text_encoder.text_model.encoder.layers.{}.{}".format(b, c) #diffusers lora
|
|
key_map[lora_key] = k
|
|
|
|
k = "clip_g.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c)
|
|
if k in sdk:
|
|
if clip_l_present:
|
|
lora_key = "lora_te2_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) #SDXL base
|
|
key_map[lora_key] = k
|
|
lora_key = "text_encoder_2.text_model.encoder.layers.{}.{}".format(b, c) #diffusers lora
|
|
key_map[lora_key] = k
|
|
else:
|
|
lora_key = "lora_te_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) #TODO: test if this is correct for SDXL-Refiner
|
|
key_map[lora_key] = k
|
|
lora_key = "text_encoder.text_model.encoder.layers.{}.{}".format(b, c) #diffusers lora
|
|
key_map[lora_key] = k
|
|
lora_key = "lora_prior_te_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) #cascade lora: TODO put lora key prefix in the model config
|
|
key_map[lora_key] = k
|
|
|
|
for k in sdk: #OneTrainer SD3 lora
|
|
if k.startswith("t5xxl.transformer.") and k.endswith(".weight"):
|
|
l_key = k[len("t5xxl.transformer."):-len(".weight")]
|
|
lora_key = "lora_te3_{}".format(l_key.replace(".", "_"))
|
|
key_map[lora_key] = k
|
|
|
|
k = "clip_g.transformer.text_projection.weight"
|
|
if k in sdk:
|
|
key_map["lora_prior_te_text_projection"] = k #cascade lora?
|
|
# key_map["text_encoder.text_projection"] = k #TODO: check if other lora have the text_projection too
|
|
key_map["lora_te2_text_projection"] = k #OneTrainer SD3 lora
|
|
|
|
k = "clip_l.transformer.text_projection.weight"
|
|
if k in sdk:
|
|
key_map["lora_te1_text_projection"] = k #OneTrainer SD3 lora, not necessary but omits warning
|
|
|
|
return key_map
|
|
|
|
def model_lora_keys_unet(model, key_map={}):
|
|
sd = model.state_dict()
|
|
sdk = sd.keys()
|
|
|
|
for k in sdk:
|
|
if k.startswith("diffusion_model.") and k.endswith(".weight"):
|
|
key_lora = k[len("diffusion_model."):-len(".weight")].replace(".", "_")
|
|
key_map["lora_unet_{}".format(key_lora)] = k
|
|
key_map["lora_prior_unet_{}".format(key_lora)] = k #cascade lora: TODO put lora key prefix in the model config
|
|
key_map["model.{}".format(k[:-len(".weight")])] = k #generic lora format without any weird key names
|
|
|
|
diffusers_keys = comfy.utils.unet_to_diffusers(model.model_config.unet_config)
|
|
for k in diffusers_keys:
|
|
if k.endswith(".weight"):
|
|
unet_key = "diffusion_model.{}".format(diffusers_keys[k])
|
|
key_lora = k[:-len(".weight")].replace(".", "_")
|
|
key_map["lora_unet_{}".format(key_lora)] = unet_key
|
|
|
|
diffusers_lora_prefix = ["", "unet."]
|
|
for p in diffusers_lora_prefix:
|
|
diffusers_lora_key = "{}{}".format(p, k[:-len(".weight")].replace(".to_", ".processor.to_"))
|
|
if diffusers_lora_key.endswith(".to_out.0"):
|
|
diffusers_lora_key = diffusers_lora_key[:-2]
|
|
key_map[diffusers_lora_key] = unet_key
|
|
|
|
if isinstance(model, comfy.model_base.SD3): #Diffusers lora SD3
|
|
diffusers_keys = comfy.utils.mmdit_to_diffusers(model.model_config.unet_config, output_prefix="diffusion_model.")
|
|
for k in diffusers_keys:
|
|
if k.endswith(".weight"):
|
|
to = diffusers_keys[k]
|
|
key_lora = "transformer.{}".format(k[:-len(".weight")]) #regular diffusers sd3 lora format
|
|
key_map[key_lora] = to
|
|
|
|
key_lora = "base_model.model.{}".format(k[:-len(".weight")]) #format for flash-sd3 lora and others?
|
|
key_map[key_lora] = to
|
|
|
|
key_lora = "lora_transformer_{}".format(k[:-len(".weight")].replace(".", "_")) #OneTrainer lora
|
|
key_map[key_lora] = to
|
|
|
|
if isinstance(model, comfy.model_base.AuraFlow): #Diffusers lora AuraFlow
|
|
diffusers_keys = comfy.utils.auraflow_to_diffusers(model.model_config.unet_config, output_prefix="diffusion_model.")
|
|
for k in diffusers_keys:
|
|
if k.endswith(".weight"):
|
|
to = diffusers_keys[k]
|
|
key_lora = "transformer.{}".format(k[:-len(".weight")]) #simpletrainer and probably regular diffusers lora format
|
|
key_map[key_lora] = to
|
|
|
|
if isinstance(model, comfy.model_base.HunyuanDiT):
|
|
for k in sdk:
|
|
if k.startswith("diffusion_model.") and k.endswith(".weight"):
|
|
key_lora = k[len("diffusion_model."):-len(".weight")]
|
|
key_map["base_model.model.{}".format(key_lora)] = k #official hunyuan lora format
|
|
|
|
if isinstance(model, comfy.model_base.Flux): #Diffusers lora Flux
|
|
diffusers_keys = comfy.utils.flux_to_diffusers(model.model_config.unet_config, output_prefix="diffusion_model.")
|
|
for k in diffusers_keys:
|
|
if k.endswith(".weight"):
|
|
to = diffusers_keys[k]
|
|
key_lora = "transformer.{}".format(k[:-len(".weight")]) #simpletrainer and probably regular diffusers flux lora format
|
|
key_map[key_lora] = to
|
|
|
|
return key_map
|