Support loras in diffusers format.

This commit is contained in:
comfyanonymous 2023-08-05 01:40:24 -04:00
parent 5a90d3cea5
commit c5d7593ccf

View File

@ -70,13 +70,22 @@ def load_lora(lora, to_load):
alpha = lora[alpha_name].item() alpha = lora[alpha_name].item()
loaded_keys.add(alpha_name) loaded_keys.add(alpha_name)
A_name = "{}.lora_up.weight".format(x) regular_lora = "{}.lora_up.weight".format(x)
B_name = "{}.lora_down.weight".format(x) diffusers_lora = "{}_lora.up.weight".format(x)
mid_name = "{}.lora_mid.weight".format(x) A_name = None
if A_name in lora.keys(): if regular_lora in lora.keys():
A_name = regular_lora
B_name = "{}.lora_down.weight".format(x)
mid_name = "{}.lora_mid.weight".format(x)
elif diffusers_lora in lora.keys():
A_name = diffusers_lora
B_name = "{}_lora.down.weight".format(x)
mid_name = None
if A_name is not None:
mid = None mid = None
if mid_name in lora.keys(): if mid_name is not None and mid_name in lora.keys():
mid = lora[mid_name] mid = lora[mid_name]
loaded_keys.add(mid_name) loaded_keys.add(mid_name)
patch_dict[to_load[x]] = (lora[A_name], lora[B_name], alpha, mid) patch_dict[to_load[x]] = (lora[A_name], lora[B_name], alpha, mid)
@ -202,6 +211,11 @@ def model_lora_keys_unet(model, key_map={}):
if k.endswith(".weight"): if k.endswith(".weight"):
key_lora = k[:-len(".weight")].replace(".", "_") key_lora = k[:-len(".weight")].replace(".", "_")
key_map["lora_unet_{}".format(key_lora)] = "diffusion_model.{}".format(diffusers_keys[k]) key_map["lora_unet_{}".format(key_lora)] = "diffusion_model.{}".format(diffusers_keys[k])
diffusers_lora_key = "unet.{}".format(k[:-len(".weight")].replace(".to_", ".processor.to_"))
if diffusers_lora_key.endswith(".to_out.0"):
diffusers_lora_key = diffusers_lora_key[:-2]
key_map[diffusers_lora_key] = "diffusion_model.{}".format(diffusers_keys[k])
return key_map return key_map
def set_attr(obj, attr, value): def set_attr(obj, attr, value):