mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-25 15:55:18 +00:00
Properly support SDXL diffusers loras for unet.
This commit is contained in:
parent
8d694cc450
commit
acf95191ff
127
comfy/sd.py
127
comfy/sd.py
@ -59,35 +59,6 @@ LORA_CLIP_MAP = {
|
|||||||
"self_attn.out_proj": "self_attn_out_proj",
|
"self_attn.out_proj": "self_attn_out_proj",
|
||||||
}
|
}
|
||||||
|
|
||||||
LORA_UNET_MAP_ATTENTIONS = {
|
|
||||||
"proj_in": "proj_in",
|
|
||||||
"proj_out": "proj_out",
|
|
||||||
}
|
|
||||||
|
|
||||||
transformer_lora_blocks = {
|
|
||||||
"transformer_blocks.{}.attn1.to_q": "transformer_blocks_{}_attn1_to_q",
|
|
||||||
"transformer_blocks.{}.attn1.to_k": "transformer_blocks_{}_attn1_to_k",
|
|
||||||
"transformer_blocks.{}.attn1.to_v": "transformer_blocks_{}_attn1_to_v",
|
|
||||||
"transformer_blocks.{}.attn1.to_out.0": "transformer_blocks_{}_attn1_to_out_0",
|
|
||||||
"transformer_blocks.{}.attn2.to_q": "transformer_blocks_{}_attn2_to_q",
|
|
||||||
"transformer_blocks.{}.attn2.to_k": "transformer_blocks_{}_attn2_to_k",
|
|
||||||
"transformer_blocks.{}.attn2.to_v": "transformer_blocks_{}_attn2_to_v",
|
|
||||||
"transformer_blocks.{}.attn2.to_out.0": "transformer_blocks_{}_attn2_to_out_0",
|
|
||||||
"transformer_blocks.{}.ff.net.0.proj": "transformer_blocks_{}_ff_net_0_proj",
|
|
||||||
"transformer_blocks.{}.ff.net.2": "transformer_blocks_{}_ff_net_2",
|
|
||||||
}
|
|
||||||
|
|
||||||
for i in range(10):
|
|
||||||
for k in transformer_lora_blocks:
|
|
||||||
LORA_UNET_MAP_ATTENTIONS[k.format(i)] = transformer_lora_blocks[k].format(i)
|
|
||||||
|
|
||||||
|
|
||||||
LORA_UNET_MAP_RESNET = {
|
|
||||||
"in_layers.2": "resnets_{}_conv1",
|
|
||||||
"emb_layers.1": "resnets_{}_time_emb_proj",
|
|
||||||
"out_layers.3": "resnets_{}_conv2",
|
|
||||||
"skip_connection": "resnets_{}_conv_shortcut"
|
|
||||||
}
|
|
||||||
|
|
||||||
def load_lora(lora, to_load):
|
def load_lora(lora, to_load):
|
||||||
patch_dict = {}
|
patch_dict = {}
|
||||||
@ -188,39 +159,9 @@ def load_lora(lora, to_load):
|
|||||||
print("lora key not loaded", x)
|
print("lora key not loaded", x)
|
||||||
return patch_dict
|
return patch_dict
|
||||||
|
|
||||||
def model_lora_keys(model, key_map={}):
|
def model_lora_keys_clip(model, key_map={}):
|
||||||
sdk = model.state_dict().keys()
|
sdk = model.state_dict().keys()
|
||||||
|
|
||||||
counter = 0
|
|
||||||
for b in range(12):
|
|
||||||
tk = "diffusion_model.input_blocks.{}.1".format(b)
|
|
||||||
up_counter = 0
|
|
||||||
for c in LORA_UNET_MAP_ATTENTIONS:
|
|
||||||
k = "{}.{}.weight".format(tk, c)
|
|
||||||
if k in sdk:
|
|
||||||
lora_key = "lora_unet_down_blocks_{}_attentions_{}_{}".format(counter // 2, counter % 2, LORA_UNET_MAP_ATTENTIONS[c])
|
|
||||||
key_map[lora_key] = k
|
|
||||||
up_counter += 1
|
|
||||||
if up_counter >= 4:
|
|
||||||
counter += 1
|
|
||||||
for c in LORA_UNET_MAP_ATTENTIONS:
|
|
||||||
k = "diffusion_model.middle_block.1.{}.weight".format(c)
|
|
||||||
if k in sdk:
|
|
||||||
lora_key = "lora_unet_mid_block_attentions_0_{}".format(LORA_UNET_MAP_ATTENTIONS[c])
|
|
||||||
key_map[lora_key] = k
|
|
||||||
counter = 3
|
|
||||||
for b in range(12):
|
|
||||||
tk = "diffusion_model.output_blocks.{}.1".format(b)
|
|
||||||
up_counter = 0
|
|
||||||
for c in LORA_UNET_MAP_ATTENTIONS:
|
|
||||||
k = "{}.{}.weight".format(tk, c)
|
|
||||||
if k in sdk:
|
|
||||||
lora_key = "lora_unet_up_blocks_{}_attentions_{}_{}".format(counter // 3, counter % 3, LORA_UNET_MAP_ATTENTIONS[c])
|
|
||||||
key_map[lora_key] = k
|
|
||||||
up_counter += 1
|
|
||||||
if up_counter >= 4:
|
|
||||||
counter += 1
|
|
||||||
counter = 0
|
|
||||||
text_model_lora_key = "lora_te_text_model_encoder_layers_{}_{}"
|
text_model_lora_key = "lora_te_text_model_encoder_layers_{}_{}"
|
||||||
clip_l_present = False
|
clip_l_present = False
|
||||||
for b in range(32):
|
for b in range(32):
|
||||||
@ -244,69 +185,23 @@ def model_lora_keys(model, key_map={}):
|
|||||||
lora_key = "lora_te_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) #TODO: test if this is correct for SDXL-Refiner
|
lora_key = "lora_te_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) #TODO: test if this is correct for SDXL-Refiner
|
||||||
key_map[lora_key] = k
|
key_map[lora_key] = k
|
||||||
|
|
||||||
|
return key_map
|
||||||
|
|
||||||
#Locon stuff
|
def model_lora_keys_unet(model, key_map={}):
|
||||||
ds_counter = 0
|
sdk = model.state_dict().keys()
|
||||||
counter = 0
|
|
||||||
for b in range(12):
|
|
||||||
tk = "diffusion_model.input_blocks.{}.0".format(b)
|
|
||||||
key_in = False
|
|
||||||
for c in LORA_UNET_MAP_RESNET:
|
|
||||||
k = "{}.{}.weight".format(tk, c)
|
|
||||||
if k in sdk:
|
|
||||||
lora_key = "lora_unet_down_blocks_{}_{}".format(counter // 2, LORA_UNET_MAP_RESNET[c].format(counter % 2))
|
|
||||||
key_map[lora_key] = k
|
|
||||||
key_in = True
|
|
||||||
for bb in range(3):
|
|
||||||
k = "{}.{}.op.weight".format(tk[:-2], bb)
|
|
||||||
if k in sdk:
|
|
||||||
lora_key = "lora_unet_down_blocks_{}_downsamplers_0_conv".format(ds_counter)
|
|
||||||
key_map[lora_key] = k
|
|
||||||
ds_counter += 1
|
|
||||||
if key_in:
|
|
||||||
counter += 1
|
|
||||||
|
|
||||||
counter = 0
|
|
||||||
for b in range(3):
|
|
||||||
tk = "diffusion_model.middle_block.{}".format(b)
|
|
||||||
key_in = False
|
|
||||||
for c in LORA_UNET_MAP_RESNET:
|
|
||||||
k = "{}.{}.weight".format(tk, c)
|
|
||||||
if k in sdk:
|
|
||||||
lora_key = "lora_unet_mid_block_{}".format(LORA_UNET_MAP_RESNET[c].format(counter))
|
|
||||||
key_map[lora_key] = k
|
|
||||||
key_in = True
|
|
||||||
if key_in:
|
|
||||||
counter += 1
|
|
||||||
|
|
||||||
counter = 0
|
|
||||||
us_counter = 0
|
|
||||||
for b in range(12):
|
|
||||||
tk = "diffusion_model.output_blocks.{}.0".format(b)
|
|
||||||
key_in = False
|
|
||||||
for c in LORA_UNET_MAP_RESNET:
|
|
||||||
k = "{}.{}.weight".format(tk, c)
|
|
||||||
if k in sdk:
|
|
||||||
lora_key = "lora_unet_up_blocks_{}_{}".format(counter // 3, LORA_UNET_MAP_RESNET[c].format(counter % 3))
|
|
||||||
key_map[lora_key] = k
|
|
||||||
key_in = True
|
|
||||||
for bb in range(3):
|
|
||||||
k = "{}.{}.conv.weight".format(tk[:-2], bb)
|
|
||||||
if k in sdk:
|
|
||||||
lora_key = "lora_unet_up_blocks_{}_upsamplers_0_conv".format(us_counter)
|
|
||||||
key_map[lora_key] = k
|
|
||||||
us_counter += 1
|
|
||||||
if key_in:
|
|
||||||
counter += 1
|
|
||||||
|
|
||||||
for k in sdk:
|
for k in sdk:
|
||||||
if k.startswith("diffusion_model.") and k.endswith(".weight"):
|
if k.startswith("diffusion_model.") and k.endswith(".weight"):
|
||||||
key_lora = k[len("diffusion_model."):-len(".weight")].replace(".", "_")
|
key_lora = k[len("diffusion_model."):-len(".weight")].replace(".", "_")
|
||||||
key_map["lora_unet_{}".format(key_lora)] = k
|
key_map["lora_unet_{}".format(key_lora)] = k
|
||||||
|
|
||||||
|
diffusers_keys = utils.unet_to_diffusers(model.model_config.unet_config)
|
||||||
|
for k in diffusers_keys:
|
||||||
|
if k.endswith(".weight"):
|
||||||
|
key_lora = k[:-len(".weight")].replace(".", "_")
|
||||||
|
key_map["lora_unet_{}".format(key_lora)] = "diffusion_model.{}".format(diffusers_keys[k])
|
||||||
return key_map
|
return key_map
|
||||||
|
|
||||||
|
|
||||||
class ModelPatcher:
|
class ModelPatcher:
|
||||||
def __init__(self, model, load_device, offload_device, size=0):
|
def __init__(self, model, load_device, offload_device, size=0):
|
||||||
self.size = size
|
self.size = size
|
||||||
@ -506,8 +401,8 @@ class ModelPatcher:
|
|||||||
self.backup = {}
|
self.backup = {}
|
||||||
|
|
||||||
def load_lora_for_models(model, clip, lora, strength_model, strength_clip):
|
def load_lora_for_models(model, clip, lora, strength_model, strength_clip):
|
||||||
key_map = model_lora_keys(model.model)
|
key_map = model_lora_keys_unet(model.model)
|
||||||
key_map = model_lora_keys(clip.cond_stage_model, key_map)
|
key_map = model_lora_keys_clip(clip.cond_stage_model, key_map)
|
||||||
loaded = load_lora(lora, key_map)
|
loaded = load_lora(lora, key_map)
|
||||||
new_modelpatcher = model.clone()
|
new_modelpatcher = model.clone()
|
||||||
k = new_modelpatcher.add_patches(loaded, strength_model)
|
k = new_modelpatcher.add_patches(loaded, strength_model)
|
||||||
|
117
comfy/utils.py
117
comfy/utils.py
@ -70,6 +70,123 @@ def transformers_convert(sd, prefix_from, prefix_to, number):
|
|||||||
sd[k_to] = weights[shape_from*x:shape_from*(x + 1)]
|
sd[k_to] = weights[shape_from*x:shape_from*(x + 1)]
|
||||||
return sd
|
return sd
|
||||||
|
|
||||||
|
UNET_MAP_ATTENTIONS = {
|
||||||
|
"proj_in.weight",
|
||||||
|
"proj_in.bias",
|
||||||
|
"proj_out.weight",
|
||||||
|
"proj_out.bias",
|
||||||
|
"norm.weight",
|
||||||
|
"norm.bias",
|
||||||
|
}
|
||||||
|
|
||||||
|
TRANSFORMER_BLOCKS = {
|
||||||
|
"norm1.weight",
|
||||||
|
"norm1.bias",
|
||||||
|
"norm2.weight",
|
||||||
|
"norm2.bias",
|
||||||
|
"norm3.weight",
|
||||||
|
"norm3.bias",
|
||||||
|
"attn1.to_q.weight",
|
||||||
|
"attn1.to_k.weight",
|
||||||
|
"attn1.to_v.weight",
|
||||||
|
"attn1.to_out.0.weight",
|
||||||
|
"attn1.to_out.0.bias",
|
||||||
|
"attn2.to_q.weight",
|
||||||
|
"attn2.to_k.weight",
|
||||||
|
"attn2.to_v.weight",
|
||||||
|
"attn2.to_out.0.weight",
|
||||||
|
"attn2.to_out.0.bias",
|
||||||
|
"ff.net.0.proj.weight",
|
||||||
|
"ff.net.0.proj.bias",
|
||||||
|
"ff.net.2.weight",
|
||||||
|
"ff.net.2.bias",
|
||||||
|
}
|
||||||
|
|
||||||
|
UNET_MAP_RESNET = {
|
||||||
|
"in_layers.2.weight": "conv1.weight",
|
||||||
|
"in_layers.2.bias": "conv1.bias",
|
||||||
|
"emb_layers.1.weight": "time_emb_proj.weight",
|
||||||
|
"emb_layers.1.bias": "time_emb_proj.bias",
|
||||||
|
"out_layers.3.weight": "conv2.weight",
|
||||||
|
"out_layers.3.bias": "conv2.bias",
|
||||||
|
"skip_connection.weight": "conv_shortcut.weight",
|
||||||
|
"skip_connection.bias": "conv_shortcut.bias",
|
||||||
|
"in_layers.0.weight": "norm1.weight",
|
||||||
|
"in_layers.0.bias": "norm1.bias",
|
||||||
|
"out_layers.0.weight": "norm2.weight",
|
||||||
|
"out_layers.0.bias": "norm2.bias",
|
||||||
|
}
|
||||||
|
|
||||||
|
def unet_to_diffusers(unet_config):
|
||||||
|
num_res_blocks = unet_config["num_res_blocks"]
|
||||||
|
attention_resolutions = unet_config["attention_resolutions"]
|
||||||
|
channel_mult = unet_config["channel_mult"]
|
||||||
|
transformer_depth = unet_config["transformer_depth"]
|
||||||
|
num_blocks = len(channel_mult)
|
||||||
|
if not isinstance(num_res_blocks, list):
|
||||||
|
num_res_blocks = [num_res_blocks] * num_blocks
|
||||||
|
|
||||||
|
transformers_per_layer = []
|
||||||
|
res = 1
|
||||||
|
for i in range(num_blocks):
|
||||||
|
transformers = 0
|
||||||
|
if res in attention_resolutions:
|
||||||
|
transformers = transformer_depth[i]
|
||||||
|
transformers_per_layer.append(transformers)
|
||||||
|
res *= 2
|
||||||
|
|
||||||
|
transformers_mid = unet_config.get("transformer_depth_middle", transformers_per_layer[-1])
|
||||||
|
|
||||||
|
diffusers_unet_map = {}
|
||||||
|
for x in range(num_blocks):
|
||||||
|
n = 1 + (num_res_blocks[x] + 1) * x
|
||||||
|
for i in range(num_res_blocks[x]):
|
||||||
|
for b in UNET_MAP_RESNET:
|
||||||
|
diffusers_unet_map["down_blocks.{}.resnets.{}.{}".format(x, i, UNET_MAP_RESNET[b])] = "input_blocks.{}.0.{}".format(n, b)
|
||||||
|
if transformers_per_layer[x] > 0:
|
||||||
|
for b in UNET_MAP_ATTENTIONS:
|
||||||
|
diffusers_unet_map["down_blocks.{}.attentions.{}.{}".format(x, i, b)] = "input_blocks.{}.1.{}".format(n, b)
|
||||||
|
for t in range(transformers_per_layer[x]):
|
||||||
|
for b in TRANSFORMER_BLOCKS:
|
||||||
|
diffusers_unet_map["down_blocks.{}.attentions.{}.transformer_blocks.{}.{}".format(x, i, t, b)] = "input_blocks.{}.1.transformer_blocks.{}.{}".format(n, t, b)
|
||||||
|
n += 1
|
||||||
|
for k in ["weight", "bias"]:
|
||||||
|
diffusers_unet_map["down_blocks.{}.downsamplers.0.conv.{}".format(x, k)] = "input_blocks.{}.0.op.{}".format(n, k)
|
||||||
|
|
||||||
|
i = 0
|
||||||
|
for b in UNET_MAP_ATTENTIONS:
|
||||||
|
diffusers_unet_map["mid_block.attentions.{}.{}".format(i, b)] = "middle_block.1.{}".format(b)
|
||||||
|
for t in range(transformers_mid):
|
||||||
|
for b in TRANSFORMER_BLOCKS:
|
||||||
|
diffusers_unet_map["mid_block.attentions.{}.transformer_blocks.{}.{}".format(i, t, b)] = "middle_block.1.transformer_blocks.{}.{}".format(t, b)
|
||||||
|
|
||||||
|
for i, n in enumerate([0, 2]):
|
||||||
|
for b in UNET_MAP_RESNET:
|
||||||
|
diffusers_unet_map["mid_block.resnets.{}.{}".format(i, UNET_MAP_RESNET[b])] = "middle_block.{}.{}".format(n, b)
|
||||||
|
|
||||||
|
num_res_blocks = list(reversed(num_res_blocks))
|
||||||
|
transformers_per_layer = list(reversed(transformers_per_layer))
|
||||||
|
for x in range(num_blocks):
|
||||||
|
n = (num_res_blocks[x] + 1) * x
|
||||||
|
l = num_res_blocks[x] + 1
|
||||||
|
for i in range(l):
|
||||||
|
c = 0
|
||||||
|
for b in UNET_MAP_RESNET:
|
||||||
|
diffusers_unet_map["up_blocks.{}.resnets.{}.{}".format(x, i, UNET_MAP_RESNET[b])] = "output_blocks.{}.0.{}".format(n, b)
|
||||||
|
c += 1
|
||||||
|
if transformers_per_layer[x] > 0:
|
||||||
|
c += 1
|
||||||
|
for b in UNET_MAP_ATTENTIONS:
|
||||||
|
diffusers_unet_map["up_blocks.{}.attentions.{}.{}".format(x, i, b)] = "output_blocks.{}.1.{}".format(n, b)
|
||||||
|
for t in range(transformers_per_layer[x]):
|
||||||
|
for b in TRANSFORMER_BLOCKS:
|
||||||
|
diffusers_unet_map["up_blocks.{}.attentions.{}.transformer_blocks.{}.{}".format(x, i, t, b)] = "output_blocks.{}.1.transformer_blocks.{}.{}".format(n, t, b)
|
||||||
|
if i == l - 1:
|
||||||
|
for k in ["weight", "bias"]:
|
||||||
|
diffusers_unet_map["up_blocks.{}.upsamplers.0.conv.{}".format(x, k)] = "output_blocks.{}.{}.conv.{}".format(n, c, k)
|
||||||
|
n += 1
|
||||||
|
return diffusers_unet_map
|
||||||
|
|
||||||
def convert_sd_to(state_dict, dtype):
|
def convert_sd_to(state_dict, dtype):
|
||||||
keys = list(state_dict.keys())
|
keys = list(state_dict.keys())
|
||||||
for k in keys:
|
for k in keys:
|
||||||
|
Loading…
Reference in New Issue
Block a user