move code to model_lora_keys_unet

This commit is contained in:
rickard 2024-12-23 09:20:22 +01:00
parent 31b6852f19
commit 21f20638bd

View File

@ -64,7 +64,6 @@ def load_lora(lora, to_load, log_missing=True):
diffusers3_lora = "{}.lora.up.weight".format(x) diffusers3_lora = "{}.lora.up.weight".format(x)
mochi_lora = "{}.lora_B".format(x) mochi_lora = "{}.lora_B".format(x)
transformers_lora = "{}.lora_linear_layer.up.weight".format(x) transformers_lora = "{}.lora_linear_layer.up.weight".format(x)
ltx_lora = "transformer.{}.lora_B.weight".format(x)
A_name = None A_name = None
if regular_lora in lora.keys(): if regular_lora in lora.keys():
@ -91,10 +90,6 @@ def load_lora(lora, to_load, log_missing=True):
A_name = transformers_lora A_name = transformers_lora
B_name ="{}.lora_linear_layer.down.weight".format(x) B_name ="{}.lora_linear_layer.down.weight".format(x)
mid_name = None mid_name = None
elif ltx_lora in lora.keys():
A_name = ltx_lora
B_name = "transformer.{}.lora_A.weight".format(x)
mid_name = None
if A_name is not None: if A_name is not None:
mid = None mid = None
@ -404,6 +399,12 @@ def model_lora_keys_unet(model, key_map={}):
key_map["transformer.{}".format(key_lora)] = k key_map["transformer.{}".format(key_lora)] = k
key_map["diffusion_model.{}".format(key_lora)] = k # Old loras key_map["diffusion_model.{}".format(key_lora)] = k # Old loras
if isinstance(model, comfy.model_base.LTXV):
for k in sdk:
if k.startswith("transformer.") and k.endswith(".weight"): #Official Mochi lora format
key_lora = k[len("transformer."):-len(".weight")]
key_map["{}".format(key_lora)] = k
return key_map return key_map