mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-11 02:15:17 +00:00
Cleaner support for loading different diffusion model types.
This commit is contained in:
parent
ffe0bb0a33
commit
5e1fced639
@ -105,6 +105,9 @@ def detect_unet_config(state_dict, key_prefix):
|
|||||||
unet_config["audio_model"] = "dit1.0"
|
unet_config["audio_model"] = "dit1.0"
|
||||||
return unet_config
|
return unet_config
|
||||||
|
|
||||||
|
if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys:
|
||||||
|
return None
|
||||||
|
|
||||||
unet_config = {
|
unet_config = {
|
||||||
"use_checkpoint": False,
|
"use_checkpoint": False,
|
||||||
"image_size": 32,
|
"image_size": 32,
|
||||||
@ -239,6 +242,8 @@ def model_config_from_unet_config(unet_config, state_dict=None):
|
|||||||
|
|
||||||
def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=False):
|
def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=False):
|
||||||
unet_config = detect_unet_config(state_dict, unet_key_prefix)
|
unet_config = detect_unet_config(state_dict, unet_key_prefix)
|
||||||
|
if unet_config is None:
|
||||||
|
return None
|
||||||
model_config = model_config_from_unet_config(unet_config, state_dict)
|
model_config = model_config_from_unet_config(unet_config, state_dict)
|
||||||
if model_config is None and use_base_if_no_match:
|
if model_config is None and use_base_if_no_match:
|
||||||
return comfy.supported_models_base.BASE(unet_config)
|
return comfy.supported_models_base.BASE(unet_config)
|
||||||
|
@ -546,21 +546,17 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
|
|||||||
def load_unet_state_dict(sd): #load unet in diffusers or regular format
|
def load_unet_state_dict(sd): #load unet in diffusers or regular format
|
||||||
|
|
||||||
#Allow loading unets from checkpoint files
|
#Allow loading unets from checkpoint files
|
||||||
checkpoint = False
|
|
||||||
diffusion_model_prefix = model_detection.unet_prefix_from_state_dict(sd)
|
diffusion_model_prefix = model_detection.unet_prefix_from_state_dict(sd)
|
||||||
temp_sd = comfy.utils.state_dict_prefix_replace(sd, {diffusion_model_prefix: ""}, filter_keys=True)
|
temp_sd = comfy.utils.state_dict_prefix_replace(sd, {diffusion_model_prefix: ""}, filter_keys=True)
|
||||||
if len(temp_sd) > 0:
|
if len(temp_sd) > 0:
|
||||||
sd = temp_sd
|
sd = temp_sd
|
||||||
checkpoint = True
|
|
||||||
|
|
||||||
parameters = comfy.utils.calculate_parameters(sd)
|
parameters = comfy.utils.calculate_parameters(sd)
|
||||||
unet_dtype = model_management.unet_dtype(model_params=parameters)
|
unet_dtype = model_management.unet_dtype(model_params=parameters)
|
||||||
load_device = model_management.get_torch_device()
|
load_device = model_management.get_torch_device()
|
||||||
|
model_config = model_detection.model_config_from_unet(sd, "")
|
||||||
|
|
||||||
if checkpoint or "input_blocks.0.0.weight" in sd or 'clf.1.weight' in sd: #ldm or stable cascade
|
if model_config is not None:
|
||||||
model_config = model_detection.model_config_from_unet(sd, "")
|
|
||||||
if model_config is None:
|
|
||||||
return None
|
|
||||||
new_sd = sd
|
new_sd = sd
|
||||||
elif 'transformer_blocks.0.attn.add_q_proj.weight' in sd: #MMDIT SD3
|
elif 'transformer_blocks.0.attn.add_q_proj.weight' in sd: #MMDIT SD3
|
||||||
new_sd = model_detection.convert_diffusers_mmdit(sd, "")
|
new_sd = model_detection.convert_diffusers_mmdit(sd, "")
|
||||||
|
Loading…
Reference in New Issue
Block a user