diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 0b678480..783a03eb 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -105,6 +105,9 @@ def detect_unet_config(state_dict, key_prefix): unet_config["audio_model"] = "dit1.0" return unet_config + if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys: + return None + unet_config = { "use_checkpoint": False, "image_size": 32, @@ -239,6 +242,8 @@ def model_config_from_unet_config(unet_config, state_dict=None): def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=False): unet_config = detect_unet_config(state_dict, unet_key_prefix) + if unet_config is None: + return None model_config = model_config_from_unet_config(unet_config, state_dict) if model_config is None and use_base_if_no_match: return comfy.supported_models_base.BASE(unet_config) diff --git a/comfy/sd.py b/comfy/sd.py index d0832ad3..66033b63 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -546,21 +546,17 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o def load_unet_state_dict(sd): #load unet in diffusers or regular format #Allow loading unets from checkpoint files - checkpoint = False diffusion_model_prefix = model_detection.unet_prefix_from_state_dict(sd) temp_sd = comfy.utils.state_dict_prefix_replace(sd, {diffusion_model_prefix: ""}, filter_keys=True) if len(temp_sd) > 0: sd = temp_sd - checkpoint = True parameters = comfy.utils.calculate_parameters(sd) unet_dtype = model_management.unet_dtype(model_params=parameters) load_device = model_management.get_torch_device() + model_config = model_detection.model_config_from_unet(sd, "") - if checkpoint or "input_blocks.0.0.weight" in sd or 'clf.1.weight' in sd: #ldm or stable cascade - model_config = model_detection.model_config_from_unet(sd, "") - if model_config is None: - return None + if model_config is not None: new_sd = sd elif 'transformer_blocks.0.attn.add_q_proj.weight' in sd: #MMDIT SD3 new_sd = model_detection.convert_diffusers_mmdit(sd, "")