mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-25 15:55:18 +00:00
Add locon support.
This commit is contained in:
parent
d0b195c7d0
commit
cd64111c83
75
comfy/sd.py
75
comfy/sd.py
@ -99,7 +99,7 @@ LORA_CLIP_MAP = {
|
|||||||
"self_attn.out_proj": "self_attn_out_proj",
|
"self_attn.out_proj": "self_attn_out_proj",
|
||||||
}
|
}
|
||||||
|
|
||||||
LORA_UNET_MAP = {
|
LORA_UNET_MAP_ATTENTIONS = {
|
||||||
"proj_in": "proj_in",
|
"proj_in": "proj_in",
|
||||||
"proj_out": "proj_out",
|
"proj_out": "proj_out",
|
||||||
"transformer_blocks.0.attn1.to_q": "transformer_blocks_0_attn1_to_q",
|
"transformer_blocks.0.attn1.to_q": "transformer_blocks_0_attn1_to_q",
|
||||||
@ -114,6 +114,12 @@ LORA_UNET_MAP = {
|
|||||||
"transformer_blocks.0.ff.net.2": "transformer_blocks_0_ff_net_2",
|
"transformer_blocks.0.ff.net.2": "transformer_blocks_0_ff_net_2",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
LORA_UNET_MAP_RESNET = {
|
||||||
|
"in_layers.2": "resnets_{}_conv1",
|
||||||
|
"emb_layers.1": "resnets_{}_time_emb_proj",
|
||||||
|
"out_layers.3": "resnets_{}_conv2",
|
||||||
|
"skip_connection": "resnets_{}_conv_shortcut"
|
||||||
|
}
|
||||||
|
|
||||||
def load_lora(path, to_load):
|
def load_lora(path, to_load):
|
||||||
lora = load_torch_file(path)
|
lora = load_torch_file(path)
|
||||||
@ -143,27 +149,27 @@ def model_lora_keys(model, key_map={}):
|
|||||||
for b in range(12):
|
for b in range(12):
|
||||||
tk = "model.diffusion_model.input_blocks.{}.1".format(b)
|
tk = "model.diffusion_model.input_blocks.{}.1".format(b)
|
||||||
up_counter = 0
|
up_counter = 0
|
||||||
for c in LORA_UNET_MAP:
|
for c in LORA_UNET_MAP_ATTENTIONS:
|
||||||
k = "{}.{}.weight".format(tk, c)
|
k = "{}.{}.weight".format(tk, c)
|
||||||
if k in sdk:
|
if k in sdk:
|
||||||
lora_key = "lora_unet_down_blocks_{}_attentions_{}_{}".format(counter // 2, counter % 2, LORA_UNET_MAP[c])
|
lora_key = "lora_unet_down_blocks_{}_attentions_{}_{}".format(counter // 2, counter % 2, LORA_UNET_MAP_ATTENTIONS[c])
|
||||||
key_map[lora_key] = k
|
key_map[lora_key] = k
|
||||||
up_counter += 1
|
up_counter += 1
|
||||||
if up_counter >= 4:
|
if up_counter >= 4:
|
||||||
counter += 1
|
counter += 1
|
||||||
for c in LORA_UNET_MAP:
|
for c in LORA_UNET_MAP_ATTENTIONS:
|
||||||
k = "model.diffusion_model.middle_block.1.{}.weight".format(c)
|
k = "model.diffusion_model.middle_block.1.{}.weight".format(c)
|
||||||
if k in sdk:
|
if k in sdk:
|
||||||
lora_key = "lora_unet_mid_block_attentions_0_{}".format(LORA_UNET_MAP[c])
|
lora_key = "lora_unet_mid_block_attentions_0_{}".format(LORA_UNET_MAP_ATTENTIONS[c])
|
||||||
key_map[lora_key] = k
|
key_map[lora_key] = k
|
||||||
counter = 3
|
counter = 3
|
||||||
for b in range(12):
|
for b in range(12):
|
||||||
tk = "model.diffusion_model.output_blocks.{}.1".format(b)
|
tk = "model.diffusion_model.output_blocks.{}.1".format(b)
|
||||||
up_counter = 0
|
up_counter = 0
|
||||||
for c in LORA_UNET_MAP:
|
for c in LORA_UNET_MAP_ATTENTIONS:
|
||||||
k = "{}.{}.weight".format(tk, c)
|
k = "{}.{}.weight".format(tk, c)
|
||||||
if k in sdk:
|
if k in sdk:
|
||||||
lora_key = "lora_unet_up_blocks_{}_attentions_{}_{}".format(counter // 3, counter % 3, LORA_UNET_MAP[c])
|
lora_key = "lora_unet_up_blocks_{}_attentions_{}_{}".format(counter // 3, counter % 3, LORA_UNET_MAP_ATTENTIONS[c])
|
||||||
key_map[lora_key] = k
|
key_map[lora_key] = k
|
||||||
up_counter += 1
|
up_counter += 1
|
||||||
if up_counter >= 4:
|
if up_counter >= 4:
|
||||||
@ -177,6 +183,61 @@ def model_lora_keys(model, key_map={}):
|
|||||||
lora_key = text_model_lora_key.format(b, LORA_CLIP_MAP[c])
|
lora_key = text_model_lora_key.format(b, LORA_CLIP_MAP[c])
|
||||||
key_map[lora_key] = k
|
key_map[lora_key] = k
|
||||||
|
|
||||||
|
|
||||||
|
#Locon stuff
|
||||||
|
ds_counter = 0
|
||||||
|
counter = 0
|
||||||
|
for b in range(12):
|
||||||
|
tk = "model.diffusion_model.input_blocks.{}.0".format(b)
|
||||||
|
key_in = False
|
||||||
|
for c in LORA_UNET_MAP_RESNET:
|
||||||
|
k = "{}.{}.weight".format(tk, c)
|
||||||
|
if k in sdk:
|
||||||
|
lora_key = "lora_unet_down_blocks_{}_{}".format(counter // 2, LORA_UNET_MAP_RESNET[c].format(counter % 2))
|
||||||
|
key_map[lora_key] = k
|
||||||
|
key_in = True
|
||||||
|
for bb in range(3):
|
||||||
|
k = "{}.{}.op.weight".format(tk[:-2], bb)
|
||||||
|
if k in sdk:
|
||||||
|
lora_key = "lora_unet_down_blocks_{}_downsamplers_0_conv".format(ds_counter)
|
||||||
|
key_map[lora_key] = k
|
||||||
|
ds_counter += 1
|
||||||
|
if key_in:
|
||||||
|
counter += 1
|
||||||
|
|
||||||
|
counter = 0
|
||||||
|
for b in range(3):
|
||||||
|
tk = "model.diffusion_model.middle_block.{}".format(b)
|
||||||
|
key_in = False
|
||||||
|
for c in LORA_UNET_MAP_RESNET:
|
||||||
|
k = "{}.{}.weight".format(tk, c)
|
||||||
|
if k in sdk:
|
||||||
|
lora_key = "lora_unet_mid_block_{}".format(LORA_UNET_MAP_RESNET[c].format(counter))
|
||||||
|
key_map[lora_key] = k
|
||||||
|
key_in = True
|
||||||
|
if key_in:
|
||||||
|
counter += 1
|
||||||
|
|
||||||
|
counter = 0
|
||||||
|
us_counter = 0
|
||||||
|
for b in range(12):
|
||||||
|
tk = "model.diffusion_model.output_blocks.{}.0".format(b)
|
||||||
|
key_in = False
|
||||||
|
for c in LORA_UNET_MAP_RESNET:
|
||||||
|
k = "{}.{}.weight".format(tk, c)
|
||||||
|
if k in sdk:
|
||||||
|
lora_key = "lora_unet_up_blocks_{}_{}".format(counter // 3, LORA_UNET_MAP_RESNET[c].format(counter % 3))
|
||||||
|
key_map[lora_key] = k
|
||||||
|
key_in = True
|
||||||
|
for bb in range(3):
|
||||||
|
k = "{}.{}.conv.weight".format(tk[:-2], bb)
|
||||||
|
if k in sdk:
|
||||||
|
lora_key = "lora_unet_up_blocks_{}_upsamplers_0_conv".format(us_counter)
|
||||||
|
key_map[lora_key] = k
|
||||||
|
us_counter += 1
|
||||||
|
if key_in:
|
||||||
|
counter += 1
|
||||||
|
|
||||||
return key_map
|
return key_map
|
||||||
|
|
||||||
class ModelPatcher:
|
class ModelPatcher:
|
||||||
|
Loading…
Reference in New Issue
Block a user