diff --git a/comfy/model_detection.py b/comfy/model_detection.py index cf764e0b..691d4c6c 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -118,3 +118,57 @@ def model_config_from_unet_config(unet_config): def model_config_from_unet(state_dict, unet_key_prefix, use_fp16): unet_config = detect_unet_config(state_dict, unet_key_prefix, use_fp16) return model_config_from_unet_config(unet_config) + + +def model_config_from_diffusers_unet(state_dict, use_fp16): + match = {} + match["context_dim"] = state_dict["down_blocks.1.attentions.1.transformer_blocks.0.attn2.to_k.weight"].shape[1] + match["model_channels"] = state_dict["conv_in.weight"].shape[0] + match["in_channels"] = state_dict["conv_in.weight"].shape[1] + match["adm_in_channels"] = None + if "class_embedding.linear_1.weight" in state_dict: + match["adm_in_channels"] = state_dict["class_embedding.linear_1.weight"].shape[1] + elif "add_embedding.linear_1.weight" in state_dict: + match["adm_in_channels"] = state_dict["add_embedding.linear_1.weight"].shape[1] + + SDXL = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, + 'num_classes': 'sequential', 'adm_in_channels': 2816, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 320, + 'num_res_blocks': 2, 'attention_resolutions': [2, 4], 'transformer_depth': [0, 2, 10], 'channel_mult': [1, 2, 4], + 'transformer_depth_middle': 10, 'use_linear_in_transformer': True, 'context_dim': 2048} + + SDXL_refiner = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, + 'num_classes': 'sequential', 'adm_in_channels': 2560, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 384, + 'num_res_blocks': 2, 'attention_resolutions': [2, 4], 'transformer_depth': [0, 4, 4, 0], 'channel_mult': [1, 2, 4, 4], + 'transformer_depth_middle': 4, 'use_linear_in_transformer': True, 'context_dim': 1280} + + SD21 = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, + 'adm_in_channels': None, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': 2, + 'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4], + 'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024} + + SD21_uncliph = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, + 'num_classes': 'sequential', 'adm_in_channels': 2048, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 320, + 'num_res_blocks': 2, 'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4], + 'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024} + + SD21_unclipl = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, + 'num_classes': 'sequential', 'adm_in_channels': 1536, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 320, + 'num_res_blocks': 2, 'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4], + 'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024} + + SD15 = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, + 'adm_in_channels': None, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': 2, + 'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4], + 'transformer_depth_middle': 1, 'use_linear_in_transformer': False, 'context_dim': 768} + + supported_models = [SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl] + + for unet_config in supported_models: + matches = True + for k in match: + if match[k] != unet_config[k]: + matches = False + break + if matches: + return model_config_from_unet_config(unet_config) + return None diff --git a/comfy/sd.py b/comfy/sd.py index 61b59383..7f7c06bc 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -765,6 +765,51 @@ class ControlNet: def load_controlnet(ckpt_path, model=None): controlnet_data = utils.load_torch_file(ckpt_path, safe_load=True) + + controlnet_config = None + if "controlnet_cond_embedding.conv_in.weight" in controlnet_data: #diffusers format + use_fp16 = model_management.should_use_fp16() + controlnet_config = model_detection.model_config_from_diffusers_unet(controlnet_data, use_fp16).unet_config + diffusers_keys = utils.unet_to_diffusers(controlnet_config) + diffusers_keys["controlnet_mid_block.weight"] = "middle_block_out.0.weight" + diffusers_keys["controlnet_mid_block.bias"] = "middle_block_out.0.bias" + + count = 0 + loop = True + while loop: + suffix = [".weight", ".bias"] + for s in suffix: + k_in = "controlnet_down_blocks.{}{}".format(count, s) + k_out = "zero_convs.{}.0{}".format(count, s) + if k_in not in controlnet_data: + loop = False + break + diffusers_keys[k_in] = k_out + count += 1 + + count = 0 + loop = True + while loop: + suffix = [".weight", ".bias"] + for s in suffix: + if count == 0: + k_in = "controlnet_cond_embedding.conv_in{}".format(s) + else: + k_in = "controlnet_cond_embedding.blocks.{}{}".format(count - 1, s) + k_out = "input_hint_block.{}{}".format(count * 2, s) + if k_in not in controlnet_data: + k_in = "controlnet_cond_embedding.conv_out{}".format(s) + loop = False + diffusers_keys[k_in] = k_out + count += 1 + + new_sd = {} + for k in diffusers_keys: + if k in controlnet_data: + new_sd[diffusers_keys[k]] = controlnet_data.pop(k) + + controlnet_data = new_sd + pth_key = 'control_model.zero_convs.0.0.weight' pth = False key = 'zero_convs.0.0.weight' @@ -780,9 +825,9 @@ def load_controlnet(ckpt_path, model=None): print("error checkpoint does not contain controlnet or t2i adapter data", ckpt_path) return net - use_fp16 = model_management.should_use_fp16() - - controlnet_config = model_detection.model_config_from_unet(controlnet_data, prefix, use_fp16).unet_config + if controlnet_config is None: + use_fp16 = model_management.should_use_fp16() + controlnet_config = model_detection.model_config_from_unet(controlnet_data, prefix, use_fp16).unet_config controlnet_config.pop("out_channels") controlnet_config["hint_channels"] = 3 control_model = cldm.ControlNet(**controlnet_config) @@ -1140,69 +1185,24 @@ def load_unet(unet_path): #load unet in diffusers format parameters = calculate_parameters(sd, "") fp16 = model_management.should_use_fp16(model_params=parameters) - match = {} - match["context_dim"] = sd["down_blocks.1.attentions.1.transformer_blocks.0.attn2.to_k.weight"].shape[1] - match["model_channels"] = sd["conv_in.weight"].shape[0] - match["in_channels"] = sd["conv_in.weight"].shape[1] - match["adm_in_channels"] = None - if "class_embedding.linear_1.weight" in sd: - match["adm_in_channels"] = sd["class_embedding.linear_1.weight"].shape[1] - elif "add_embedding.linear_1.weight" in sd: - match["adm_in_channels"] = sd["add_embedding.linear_1.weight"].shape[1] + model_config = model_detection.model_config_from_diffusers_unet(sd, fp16) + if model_config is None: + print("ERROR UNSUPPORTED UNET", unet_path) + return None - SDXL = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, - 'num_classes': 'sequential', 'adm_in_channels': 2816, 'use_fp16': fp16, 'in_channels': 4, 'model_channels': 320, - 'num_res_blocks': 2, 'attention_resolutions': [2, 4], 'transformer_depth': [0, 2, 10], 'channel_mult': [1, 2, 4], - 'transformer_depth_middle': 10, 'use_linear_in_transformer': True, 'context_dim': 2048} + diffusers_keys = utils.unet_to_diffusers(model_config.unet_config) - SDXL_refiner = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, - 'num_classes': 'sequential', 'adm_in_channels': 2560, 'use_fp16': fp16, 'in_channels': 4, 'model_channels': 384, - 'num_res_blocks': 2, 'attention_resolutions': [2, 4], 'transformer_depth': [0, 4, 4, 0], 'channel_mult': [1, 2, 4, 4], - 'transformer_depth_middle': 4, 'use_linear_in_transformer': True, 'context_dim': 1280} - - SD21 = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, - 'adm_in_channels': None, 'use_fp16': fp16, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': 2, - 'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4], - 'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024} - - SD21_uncliph = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, - 'num_classes': 'sequential', 'adm_in_channels': 2048, 'use_fp16': True, 'in_channels': 4, 'model_channels': 320, - 'num_res_blocks': 2, 'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4], - 'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024} - - SD21_unclipl = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, - 'num_classes': 'sequential', 'adm_in_channels': 1536, 'use_fp16': True, 'in_channels': 4, 'model_channels': 320, - 'num_res_blocks': 2, 'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4], - 'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024} - - SD15 = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, - 'adm_in_channels': None, 'use_fp16': True, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': 2, - 'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4], - 'transformer_depth_middle': 1, 'use_linear_in_transformer': False, 'context_dim': 768} - - supported_models = [SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl] - print("match", match) - for unet_config in supported_models: - matches = True - for k in match: - if match[k] != unet_config[k]: - matches = False - break - if matches: - diffusers_keys = utils.unet_to_diffusers(unet_config) - new_sd = {} - for k in diffusers_keys: - if k in sd: - new_sd[diffusers_keys[k]] = sd.pop(k) - else: - print(diffusers_keys[k], k) - offload_device = model_management.unet_offload_device() - model_config = model_detection.model_config_from_unet_config(unet_config) - model = model_config.get_model(new_sd, "") - model = model.to(offload_device) - model.load_model_weights(new_sd, "") - return ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device) - print("ERROR UNSUPPORTED UNET", unet_path) + new_sd = {} + for k in diffusers_keys: + if k in sd: + new_sd[diffusers_keys[k]] = sd.pop(k) + else: + print(diffusers_keys[k], k) + offload_device = model_management.unet_offload_device() + model = model_config.get_model(new_sd, "") + model = model.to(offload_device) + model.load_model_weights(new_sd, "") + return ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device) def save_checkpoint(output_path, model, clip, vae, metadata=None): try: