mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-25 15:55:18 +00:00
Support controlnet in diffusers format.
This commit is contained in:
parent
09386a3697
commit
78e7958d17
@ -118,3 +118,57 @@ def model_config_from_unet_config(unet_config):
|
|||||||
def model_config_from_unet(state_dict, unet_key_prefix, use_fp16):
|
def model_config_from_unet(state_dict, unet_key_prefix, use_fp16):
|
||||||
unet_config = detect_unet_config(state_dict, unet_key_prefix, use_fp16)
|
unet_config = detect_unet_config(state_dict, unet_key_prefix, use_fp16)
|
||||||
return model_config_from_unet_config(unet_config)
|
return model_config_from_unet_config(unet_config)
|
||||||
|
|
||||||
|
|
||||||
|
def model_config_from_diffusers_unet(state_dict, use_fp16):
|
||||||
|
match = {}
|
||||||
|
match["context_dim"] = state_dict["down_blocks.1.attentions.1.transformer_blocks.0.attn2.to_k.weight"].shape[1]
|
||||||
|
match["model_channels"] = state_dict["conv_in.weight"].shape[0]
|
||||||
|
match["in_channels"] = state_dict["conv_in.weight"].shape[1]
|
||||||
|
match["adm_in_channels"] = None
|
||||||
|
if "class_embedding.linear_1.weight" in state_dict:
|
||||||
|
match["adm_in_channels"] = state_dict["class_embedding.linear_1.weight"].shape[1]
|
||||||
|
elif "add_embedding.linear_1.weight" in state_dict:
|
||||||
|
match["adm_in_channels"] = state_dict["add_embedding.linear_1.weight"].shape[1]
|
||||||
|
|
||||||
|
SDXL = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
||||||
|
'num_classes': 'sequential', 'adm_in_channels': 2816, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 320,
|
||||||
|
'num_res_blocks': 2, 'attention_resolutions': [2, 4], 'transformer_depth': [0, 2, 10], 'channel_mult': [1, 2, 4],
|
||||||
|
'transformer_depth_middle': 10, 'use_linear_in_transformer': True, 'context_dim': 2048}
|
||||||
|
|
||||||
|
SDXL_refiner = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
||||||
|
'num_classes': 'sequential', 'adm_in_channels': 2560, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 384,
|
||||||
|
'num_res_blocks': 2, 'attention_resolutions': [2, 4], 'transformer_depth': [0, 4, 4, 0], 'channel_mult': [1, 2, 4, 4],
|
||||||
|
'transformer_depth_middle': 4, 'use_linear_in_transformer': True, 'context_dim': 1280}
|
||||||
|
|
||||||
|
SD21 = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
||||||
|
'adm_in_channels': None, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': 2,
|
||||||
|
'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4],
|
||||||
|
'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024}
|
||||||
|
|
||||||
|
SD21_uncliph = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
||||||
|
'num_classes': 'sequential', 'adm_in_channels': 2048, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 320,
|
||||||
|
'num_res_blocks': 2, 'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4],
|
||||||
|
'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024}
|
||||||
|
|
||||||
|
SD21_unclipl = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
||||||
|
'num_classes': 'sequential', 'adm_in_channels': 1536, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 320,
|
||||||
|
'num_res_blocks': 2, 'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4],
|
||||||
|
'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024}
|
||||||
|
|
||||||
|
SD15 = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
||||||
|
'adm_in_channels': None, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': 2,
|
||||||
|
'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4],
|
||||||
|
'transformer_depth_middle': 1, 'use_linear_in_transformer': False, 'context_dim': 768}
|
||||||
|
|
||||||
|
supported_models = [SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl]
|
||||||
|
|
||||||
|
for unet_config in supported_models:
|
||||||
|
matches = True
|
||||||
|
for k in match:
|
||||||
|
if match[k] != unet_config[k]:
|
||||||
|
matches = False
|
||||||
|
break
|
||||||
|
if matches:
|
||||||
|
return model_config_from_unet_config(unet_config)
|
||||||
|
return None
|
||||||
|
102
comfy/sd.py
102
comfy/sd.py
@ -765,6 +765,51 @@ class ControlNet:
|
|||||||
|
|
||||||
def load_controlnet(ckpt_path, model=None):
|
def load_controlnet(ckpt_path, model=None):
|
||||||
controlnet_data = utils.load_torch_file(ckpt_path, safe_load=True)
|
controlnet_data = utils.load_torch_file(ckpt_path, safe_load=True)
|
||||||
|
|
||||||
|
controlnet_config = None
|
||||||
|
if "controlnet_cond_embedding.conv_in.weight" in controlnet_data: #diffusers format
|
||||||
|
use_fp16 = model_management.should_use_fp16()
|
||||||
|
controlnet_config = model_detection.model_config_from_diffusers_unet(controlnet_data, use_fp16).unet_config
|
||||||
|
diffusers_keys = utils.unet_to_diffusers(controlnet_config)
|
||||||
|
diffusers_keys["controlnet_mid_block.weight"] = "middle_block_out.0.weight"
|
||||||
|
diffusers_keys["controlnet_mid_block.bias"] = "middle_block_out.0.bias"
|
||||||
|
|
||||||
|
count = 0
|
||||||
|
loop = True
|
||||||
|
while loop:
|
||||||
|
suffix = [".weight", ".bias"]
|
||||||
|
for s in suffix:
|
||||||
|
k_in = "controlnet_down_blocks.{}{}".format(count, s)
|
||||||
|
k_out = "zero_convs.{}.0{}".format(count, s)
|
||||||
|
if k_in not in controlnet_data:
|
||||||
|
loop = False
|
||||||
|
break
|
||||||
|
diffusers_keys[k_in] = k_out
|
||||||
|
count += 1
|
||||||
|
|
||||||
|
count = 0
|
||||||
|
loop = True
|
||||||
|
while loop:
|
||||||
|
suffix = [".weight", ".bias"]
|
||||||
|
for s in suffix:
|
||||||
|
if count == 0:
|
||||||
|
k_in = "controlnet_cond_embedding.conv_in{}".format(s)
|
||||||
|
else:
|
||||||
|
k_in = "controlnet_cond_embedding.blocks.{}{}".format(count - 1, s)
|
||||||
|
k_out = "input_hint_block.{}{}".format(count * 2, s)
|
||||||
|
if k_in not in controlnet_data:
|
||||||
|
k_in = "controlnet_cond_embedding.conv_out{}".format(s)
|
||||||
|
loop = False
|
||||||
|
diffusers_keys[k_in] = k_out
|
||||||
|
count += 1
|
||||||
|
|
||||||
|
new_sd = {}
|
||||||
|
for k in diffusers_keys:
|
||||||
|
if k in controlnet_data:
|
||||||
|
new_sd[diffusers_keys[k]] = controlnet_data.pop(k)
|
||||||
|
|
||||||
|
controlnet_data = new_sd
|
||||||
|
|
||||||
pth_key = 'control_model.zero_convs.0.0.weight'
|
pth_key = 'control_model.zero_convs.0.0.weight'
|
||||||
pth = False
|
pth = False
|
||||||
key = 'zero_convs.0.0.weight'
|
key = 'zero_convs.0.0.weight'
|
||||||
@ -780,8 +825,8 @@ def load_controlnet(ckpt_path, model=None):
|
|||||||
print("error checkpoint does not contain controlnet or t2i adapter data", ckpt_path)
|
print("error checkpoint does not contain controlnet or t2i adapter data", ckpt_path)
|
||||||
return net
|
return net
|
||||||
|
|
||||||
|
if controlnet_config is None:
|
||||||
use_fp16 = model_management.should_use_fp16()
|
use_fp16 = model_management.should_use_fp16()
|
||||||
|
|
||||||
controlnet_config = model_detection.model_config_from_unet(controlnet_data, prefix, use_fp16).unet_config
|
controlnet_config = model_detection.model_config_from_unet(controlnet_data, prefix, use_fp16).unet_config
|
||||||
controlnet_config.pop("out_channels")
|
controlnet_config.pop("out_channels")
|
||||||
controlnet_config["hint_channels"] = 3
|
controlnet_config["hint_channels"] = 3
|
||||||
@ -1140,56 +1185,13 @@ def load_unet(unet_path): #load unet in diffusers format
|
|||||||
parameters = calculate_parameters(sd, "")
|
parameters = calculate_parameters(sd, "")
|
||||||
fp16 = model_management.should_use_fp16(model_params=parameters)
|
fp16 = model_management.should_use_fp16(model_params=parameters)
|
||||||
|
|
||||||
match = {}
|
model_config = model_detection.model_config_from_diffusers_unet(sd, fp16)
|
||||||
match["context_dim"] = sd["down_blocks.1.attentions.1.transformer_blocks.0.attn2.to_k.weight"].shape[1]
|
if model_config is None:
|
||||||
match["model_channels"] = sd["conv_in.weight"].shape[0]
|
print("ERROR UNSUPPORTED UNET", unet_path)
|
||||||
match["in_channels"] = sd["conv_in.weight"].shape[1]
|
return None
|
||||||
match["adm_in_channels"] = None
|
|
||||||
if "class_embedding.linear_1.weight" in sd:
|
|
||||||
match["adm_in_channels"] = sd["class_embedding.linear_1.weight"].shape[1]
|
|
||||||
elif "add_embedding.linear_1.weight" in sd:
|
|
||||||
match["adm_in_channels"] = sd["add_embedding.linear_1.weight"].shape[1]
|
|
||||||
|
|
||||||
SDXL = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
diffusers_keys = utils.unet_to_diffusers(model_config.unet_config)
|
||||||
'num_classes': 'sequential', 'adm_in_channels': 2816, 'use_fp16': fp16, 'in_channels': 4, 'model_channels': 320,
|
|
||||||
'num_res_blocks': 2, 'attention_resolutions': [2, 4], 'transformer_depth': [0, 2, 10], 'channel_mult': [1, 2, 4],
|
|
||||||
'transformer_depth_middle': 10, 'use_linear_in_transformer': True, 'context_dim': 2048}
|
|
||||||
|
|
||||||
SDXL_refiner = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
|
||||||
'num_classes': 'sequential', 'adm_in_channels': 2560, 'use_fp16': fp16, 'in_channels': 4, 'model_channels': 384,
|
|
||||||
'num_res_blocks': 2, 'attention_resolutions': [2, 4], 'transformer_depth': [0, 4, 4, 0], 'channel_mult': [1, 2, 4, 4],
|
|
||||||
'transformer_depth_middle': 4, 'use_linear_in_transformer': True, 'context_dim': 1280}
|
|
||||||
|
|
||||||
SD21 = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
|
||||||
'adm_in_channels': None, 'use_fp16': fp16, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': 2,
|
|
||||||
'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4],
|
|
||||||
'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024}
|
|
||||||
|
|
||||||
SD21_uncliph = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
|
||||||
'num_classes': 'sequential', 'adm_in_channels': 2048, 'use_fp16': True, 'in_channels': 4, 'model_channels': 320,
|
|
||||||
'num_res_blocks': 2, 'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4],
|
|
||||||
'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024}
|
|
||||||
|
|
||||||
SD21_unclipl = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
|
||||||
'num_classes': 'sequential', 'adm_in_channels': 1536, 'use_fp16': True, 'in_channels': 4, 'model_channels': 320,
|
|
||||||
'num_res_blocks': 2, 'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4],
|
|
||||||
'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024}
|
|
||||||
|
|
||||||
SD15 = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
|
||||||
'adm_in_channels': None, 'use_fp16': True, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': 2,
|
|
||||||
'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4],
|
|
||||||
'transformer_depth_middle': 1, 'use_linear_in_transformer': False, 'context_dim': 768}
|
|
||||||
|
|
||||||
supported_models = [SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl]
|
|
||||||
print("match", match)
|
|
||||||
for unet_config in supported_models:
|
|
||||||
matches = True
|
|
||||||
for k in match:
|
|
||||||
if match[k] != unet_config[k]:
|
|
||||||
matches = False
|
|
||||||
break
|
|
||||||
if matches:
|
|
||||||
diffusers_keys = utils.unet_to_diffusers(unet_config)
|
|
||||||
new_sd = {}
|
new_sd = {}
|
||||||
for k in diffusers_keys:
|
for k in diffusers_keys:
|
||||||
if k in sd:
|
if k in sd:
|
||||||
@ -1197,12 +1199,10 @@ def load_unet(unet_path): #load unet in diffusers format
|
|||||||
else:
|
else:
|
||||||
print(diffusers_keys[k], k)
|
print(diffusers_keys[k], k)
|
||||||
offload_device = model_management.unet_offload_device()
|
offload_device = model_management.unet_offload_device()
|
||||||
model_config = model_detection.model_config_from_unet_config(unet_config)
|
|
||||||
model = model_config.get_model(new_sd, "")
|
model = model_config.get_model(new_sd, "")
|
||||||
model = model.to(offload_device)
|
model = model.to(offload_device)
|
||||||
model.load_model_weights(new_sd, "")
|
model.load_model_weights(new_sd, "")
|
||||||
return ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device)
|
return ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device)
|
||||||
print("ERROR UNSUPPORTED UNET", unet_path)
|
|
||||||
|
|
||||||
def save_checkpoint(output_path, model, clip, vae, metadata=None):
|
def save_checkpoint(output_path, model, clip, vae, metadata=None):
|
||||||
try:
|
try:
|
||||||
|
Loading…
Reference in New Issue
Block a user