Add locon support.

This commit is contained in:
comfyanonymous 2023-03-09 21:41:24 -05:00
parent d0b195c7d0
commit cd64111c83

View File

@ -99,7 +99,7 @@ LORA_CLIP_MAP = {
"self_attn.out_proj": "self_attn_out_proj", "self_attn.out_proj": "self_attn_out_proj",
} }
LORA_UNET_MAP = { LORA_UNET_MAP_ATTENTIONS = {
"proj_in": "proj_in", "proj_in": "proj_in",
"proj_out": "proj_out", "proj_out": "proj_out",
"transformer_blocks.0.attn1.to_q": "transformer_blocks_0_attn1_to_q", "transformer_blocks.0.attn1.to_q": "transformer_blocks_0_attn1_to_q",
@ -114,6 +114,12 @@ LORA_UNET_MAP = {
"transformer_blocks.0.ff.net.2": "transformer_blocks_0_ff_net_2", "transformer_blocks.0.ff.net.2": "transformer_blocks_0_ff_net_2",
} }
LORA_UNET_MAP_RESNET = {
"in_layers.2": "resnets_{}_conv1",
"emb_layers.1": "resnets_{}_time_emb_proj",
"out_layers.3": "resnets_{}_conv2",
"skip_connection": "resnets_{}_conv_shortcut"
}
def load_lora(path, to_load): def load_lora(path, to_load):
lora = load_torch_file(path) lora = load_torch_file(path)
@ -143,27 +149,27 @@ def model_lora_keys(model, key_map={}):
for b in range(12): for b in range(12):
tk = "model.diffusion_model.input_blocks.{}.1".format(b) tk = "model.diffusion_model.input_blocks.{}.1".format(b)
up_counter = 0 up_counter = 0
for c in LORA_UNET_MAP: for c in LORA_UNET_MAP_ATTENTIONS:
k = "{}.{}.weight".format(tk, c) k = "{}.{}.weight".format(tk, c)
if k in sdk: if k in sdk:
lora_key = "lora_unet_down_blocks_{}_attentions_{}_{}".format(counter // 2, counter % 2, LORA_UNET_MAP[c]) lora_key = "lora_unet_down_blocks_{}_attentions_{}_{}".format(counter // 2, counter % 2, LORA_UNET_MAP_ATTENTIONS[c])
key_map[lora_key] = k key_map[lora_key] = k
up_counter += 1 up_counter += 1
if up_counter >= 4: if up_counter >= 4:
counter += 1 counter += 1
for c in LORA_UNET_MAP: for c in LORA_UNET_MAP_ATTENTIONS:
k = "model.diffusion_model.middle_block.1.{}.weight".format(c) k = "model.diffusion_model.middle_block.1.{}.weight".format(c)
if k in sdk: if k in sdk:
lora_key = "lora_unet_mid_block_attentions_0_{}".format(LORA_UNET_MAP[c]) lora_key = "lora_unet_mid_block_attentions_0_{}".format(LORA_UNET_MAP_ATTENTIONS[c])
key_map[lora_key] = k key_map[lora_key] = k
counter = 3 counter = 3
for b in range(12): for b in range(12):
tk = "model.diffusion_model.output_blocks.{}.1".format(b) tk = "model.diffusion_model.output_blocks.{}.1".format(b)
up_counter = 0 up_counter = 0
for c in LORA_UNET_MAP: for c in LORA_UNET_MAP_ATTENTIONS:
k = "{}.{}.weight".format(tk, c) k = "{}.{}.weight".format(tk, c)
if k in sdk: if k in sdk:
lora_key = "lora_unet_up_blocks_{}_attentions_{}_{}".format(counter // 3, counter % 3, LORA_UNET_MAP[c]) lora_key = "lora_unet_up_blocks_{}_attentions_{}_{}".format(counter // 3, counter % 3, LORA_UNET_MAP_ATTENTIONS[c])
key_map[lora_key] = k key_map[lora_key] = k
up_counter += 1 up_counter += 1
if up_counter >= 4: if up_counter >= 4:
@ -177,6 +183,61 @@ def model_lora_keys(model, key_map={}):
lora_key = text_model_lora_key.format(b, LORA_CLIP_MAP[c]) lora_key = text_model_lora_key.format(b, LORA_CLIP_MAP[c])
key_map[lora_key] = k key_map[lora_key] = k
#Locon stuff
ds_counter = 0
counter = 0
for b in range(12):
tk = "model.diffusion_model.input_blocks.{}.0".format(b)
key_in = False
for c in LORA_UNET_MAP_RESNET:
k = "{}.{}.weight".format(tk, c)
if k in sdk:
lora_key = "lora_unet_down_blocks_{}_{}".format(counter // 2, LORA_UNET_MAP_RESNET[c].format(counter % 2))
key_map[lora_key] = k
key_in = True
for bb in range(3):
k = "{}.{}.op.weight".format(tk[:-2], bb)
if k in sdk:
lora_key = "lora_unet_down_blocks_{}_downsamplers_0_conv".format(ds_counter)
key_map[lora_key] = k
ds_counter += 1
if key_in:
counter += 1
counter = 0
for b in range(3):
tk = "model.diffusion_model.middle_block.{}".format(b)
key_in = False
for c in LORA_UNET_MAP_RESNET:
k = "{}.{}.weight".format(tk, c)
if k in sdk:
lora_key = "lora_unet_mid_block_{}".format(LORA_UNET_MAP_RESNET[c].format(counter))
key_map[lora_key] = k
key_in = True
if key_in:
counter += 1
counter = 0
us_counter = 0
for b in range(12):
tk = "model.diffusion_model.output_blocks.{}.0".format(b)
key_in = False
for c in LORA_UNET_MAP_RESNET:
k = "{}.{}.weight".format(tk, c)
if k in sdk:
lora_key = "lora_unet_up_blocks_{}_{}".format(counter // 3, LORA_UNET_MAP_RESNET[c].format(counter % 3))
key_map[lora_key] = k
key_in = True
for bb in range(3):
k = "{}.{}.conv.weight".format(tk[:-2], bb)
if k in sdk:
lora_key = "lora_unet_up_blocks_{}_upsamplers_0_conv".format(us_counter)
key_map[lora_key] = k
us_counter += 1
if key_in:
counter += 1
return key_map return key_map
class ModelPatcher: class ModelPatcher: