Make the load checkpoint with config function call the regular one.

I was going to completely remove this function because it is unmaintainable
but I think this is the best compromise.

The clip skip and v_prediction parts of the configs should still work but
not the fp16 vs fp32.
This commit is contained in:
comfyanonymous 2024-05-06 20:04:39 -04:00
parent 3787b4f246
commit c61eadf69a

View File

@ -418,6 +418,8 @@ def load_gligen(ckpt_path):
return comfy.model_patcher.ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=model_management.unet_offload_device())
def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_clip=True, embedding_directory=None, state_dict=None, config=None):
logging.warning("Warning: The load checkpoint with config function is deprecated and will eventually be removed, please use the other one.")
model, clip, vae, _ = load_checkpoint_guess_config(ckpt_path, output_vae=output_vae, output_clip=output_clip, output_clipvision=False, embedding_directory=embedding_directory, output_model=True)
#TODO: this function is a mess and should be removed eventually
if config is None:
with open(config_path, 'r') as stream:
@ -425,81 +427,20 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl
model_config_params = config['model']['params']
clip_config = model_config_params['cond_stage_config']
scale_factor = model_config_params['scale_factor']
vae_config = model_config_params['first_stage_config']
fp16 = False
if "unet_config" in model_config_params:
if "params" in model_config_params["unet_config"]:
unet_config = model_config_params["unet_config"]["params"]
if "use_fp16" in unet_config:
fp16 = unet_config.pop("use_fp16")
if fp16:
unet_config["dtype"] = torch.float16
noise_aug_config = None
if "noise_aug_config" in model_config_params:
noise_aug_config = model_config_params["noise_aug_config"]
model_type = model_base.ModelType.EPS
if "parameterization" in model_config_params:
if model_config_params["parameterization"] == "v":
model_type = model_base.ModelType.V_PREDICTION
m = model.clone()
class ModelSamplingAdvanced(comfy.model_sampling.ModelSamplingDiscrete, comfy.model_sampling.V_PREDICTION):
pass
m.add_object_patch("model_sampling", ModelSamplingAdvanced(model.model.model_config))
model = m
clip = None
vae = None
layer_idx = clip_config.get("params", {}).get("layer_idx", None)
if layer_idx is not None:
clip.clip_layer(layer_idx)
class WeightsLoader(torch.nn.Module):
pass
if state_dict is None:
state_dict = comfy.utils.load_torch_file(ckpt_path)
class EmptyClass:
pass
model_config = comfy.supported_models_base.BASE({})
from . import latent_formats
model_config.latent_format = latent_formats.SD15(scale_factor=scale_factor)
model_config.unet_config = model_detection.convert_config(unet_config)
if config['model']["target"].endswith("ImageEmbeddingConditionedLatentDiffusion"):
model = model_base.SD21UNCLIP(model_config, noise_aug_config["params"], model_type=model_type)
else:
model = model_base.BaseModel(model_config, model_type=model_type)
if config['model']["target"].endswith("LatentInpaintDiffusion"):
model.set_inpaint()
if fp16:
model = model.half()
offload_device = model_management.unet_offload_device()
model = model.to(offload_device)
model.load_model_weights(state_dict, "model.diffusion_model.")
if output_vae:
vae_sd = comfy.utils.state_dict_prefix_replace(state_dict, {"first_stage_model.": ""}, filter_keys=True)
vae = VAE(sd=vae_sd, config=vae_config)
if output_clip:
w = WeightsLoader()
clip_target = EmptyClass()
clip_target.params = clip_config.get("params", {})
if clip_config["target"].endswith("FrozenOpenCLIPEmbedder"):
clip_target.clip = sd2_clip.SD2ClipModel
clip_target.tokenizer = sd2_clip.SD2Tokenizer
clip = CLIP(clip_target, embedding_directory=embedding_directory)
w.cond_stage_model = clip.cond_stage_model.clip_h
elif clip_config["target"].endswith("FrozenCLIPEmbedder"):
clip_target.clip = sd1_clip.SD1ClipModel
clip_target.tokenizer = sd1_clip.SD1Tokenizer
clip = CLIP(clip_target, embedding_directory=embedding_directory)
w.cond_stage_model = clip.cond_stage_model.clip_l
load_clip_weights(w, state_dict)
return (comfy.model_patcher.ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device), clip, vae)
return (model, clip, vae)
def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True):
sd = comfy.utils.load_torch_file(ckpt_path)