From d91f45ef280a5acbdc22f3cc757f8fdbb254261b Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 19 Feb 2024 10:29:18 -0500 Subject: [PATCH] Some cleanups to how the text encoders are loaded. --- comfy/sd.py | 23 +++++++++++++---------- comfy/supported_models.py | 31 +++++++++++++++---------------- comfy/supported_models_base.py | 6 ++++-- 3 files changed, 32 insertions(+), 28 deletions(-) diff --git a/comfy/sd.py b/comfy/sd.py index 00633e10..7a77bb17 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -138,8 +138,11 @@ class CLIP: tokens = self.tokenize(text) return self.encode_from_tokens(tokens) - def load_sd(self, sd): - return self.cond_stage_model.load_sd(sd) + def load_sd(self, sd, full_model=False): + if full_model: + return self.cond_stage_model.load_state_dict(sd, strict=False) + else: + return self.cond_stage_model.load_sd(sd) def get_sd(self): return self.cond_stage_model.state_dict() @@ -494,9 +497,6 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o parameters = comfy.utils.calculate_parameters(sd, "model.diffusion_model.") load_device = model_management.get_torch_device() - class WeightsLoader(torch.nn.Module): - pass - model_config = model_detection.model_config_from_unet(sd, "model.diffusion_model.") unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=model_config.supported_inference_dtypes) manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes) @@ -521,14 +521,17 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o vae = VAE(sd=vae_sd) if output_clip: - w = WeightsLoader() clip_target = model_config.clip_target() if clip_target is not None: - sd = model_config.process_clip_state_dict(sd) - if any(k.startswith('cond_stage_model.') for k in sd): + clip_sd = model_config.process_clip_state_dict(sd) + if len(clip_sd) > 0: clip = CLIP(clip_target, embedding_directory=embedding_directory) - w.cond_stage_model = clip.cond_stage_model - load_model_weights(w, sd) + m, u = clip.load_sd(clip_sd, full_model=True) + if len(m) > 0: + print("clip missing:", m) + + if len(u) > 0: + print("clip unexpected:", u) else: print("no CLIP/text encoder weights in checkpoint, the text encoder model will not be loaded.") diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 1a673646..f29f7f3d 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -40,8 +40,8 @@ class SD15(supported_models_base.BASE): state_dict['cond_stage_model.transformer.text_model.embeddings.position_ids'] = ids.round() replace_prefix = {} - replace_prefix["cond_stage_model."] = "cond_stage_model.clip_l." - state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix) + replace_prefix["cond_stage_model."] = "clip_l." + state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=True) return state_dict def process_clip_state_dict_for_saving(self, state_dict): @@ -72,10 +72,10 @@ class SD20(supported_models_base.BASE): def process_clip_state_dict(self, state_dict): replace_prefix = {} - replace_prefix["conditioner.embedders.0.model."] = "cond_stage_model.model." #SD2 in sgm format - state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix) - - state_dict = utils.transformers_convert(state_dict, "cond_stage_model.model.", "cond_stage_model.clip_h.transformer.text_model.", 24) + replace_prefix["conditioner.embedders.0.model."] = "clip_h." #SD2 in sgm format + replace_prefix["cond_stage_model.model."] = "clip_h." + state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=True) + state_dict = utils.transformers_convert(state_dict, "clip_h.", "clip_h.transformer.text_model.", 24) return state_dict def process_clip_state_dict_for_saving(self, state_dict): @@ -131,11 +131,10 @@ class SDXLRefiner(supported_models_base.BASE): def process_clip_state_dict(self, state_dict): keys_to_replace = {} replace_prefix = {} + replace_prefix["conditioner.embedders.0.model."] = "clip_g." + state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=True) - state_dict = utils.transformers_convert(state_dict, "conditioner.embedders.0.model.", "cond_stage_model.clip_g.transformer.text_model.", 32) - keys_to_replace["conditioner.embedders.0.model.text_projection"] = "cond_stage_model.clip_g.text_projection" - keys_to_replace["conditioner.embedders.0.model.logit_scale"] = "cond_stage_model.clip_g.logit_scale" - + state_dict = utils.transformers_convert(state_dict, "clip_g.", "clip_g.transformer.text_model.", 32) state_dict = utils.state_dict_key_replace(state_dict, keys_to_replace) return state_dict @@ -179,13 +178,13 @@ class SDXL(supported_models_base.BASE): keys_to_replace = {} replace_prefix = {} - replace_prefix["conditioner.embedders.0.transformer.text_model"] = "cond_stage_model.clip_l.transformer.text_model" - state_dict = utils.transformers_convert(state_dict, "conditioner.embedders.1.model.", "cond_stage_model.clip_g.transformer.text_model.", 32) - keys_to_replace["conditioner.embedders.1.model.text_projection"] = "cond_stage_model.clip_g.text_projection" - keys_to_replace["conditioner.embedders.1.model.text_projection.weight"] = "cond_stage_model.clip_g.text_projection" - keys_to_replace["conditioner.embedders.1.model.logit_scale"] = "cond_stage_model.clip_g.logit_scale" + replace_prefix["conditioner.embedders.0.transformer.text_model"] = "clip_l.transformer.text_model" + replace_prefix["conditioner.embedders.1.model."] = "clip_g." + state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=True) + + state_dict = utils.transformers_convert(state_dict, "clip_g.", "clip_g.transformer.text_model.", 32) + keys_to_replace["clip_g.text_projection.weight"] = "clip_g.text_projection" - state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix) state_dict = utils.state_dict_key_replace(state_dict, keys_to_replace) return state_dict diff --git a/comfy/supported_models_base.py b/comfy/supported_models_base.py index 3bd4f9c6..4d7e2593 100644 --- a/comfy/supported_models_base.py +++ b/comfy/supported_models_base.py @@ -22,6 +22,7 @@ class BASE: sampling_settings = {} latent_format = latent_formats.LatentFormat vae_key_prefix = ["first_stage_model."] + text_encoder_key_prefix = ["cond_stage_model."] supported_inference_dtypes = [torch.float16, torch.bfloat16, torch.float32] manual_cast_dtype = None @@ -55,6 +56,7 @@ class BASE: return out def process_clip_state_dict(self, state_dict): + state_dict = utils.state_dict_prefix_replace(state_dict, {k: "" for k in self.text_encoder_key_prefix}, filter_keys=True) return state_dict def process_unet_state_dict(self, state_dict): @@ -64,7 +66,7 @@ class BASE: return state_dict def process_clip_state_dict_for_saving(self, state_dict): - replace_prefix = {"": "cond_stage_model."} + replace_prefix = {"": self.text_encoder_key_prefix[0]} return utils.state_dict_prefix_replace(state_dict, replace_prefix) def process_clip_vision_state_dict_for_saving(self, state_dict): @@ -78,7 +80,7 @@ class BASE: return utils.state_dict_prefix_replace(state_dict, replace_prefix) def process_vae_state_dict_for_saving(self, state_dict): - replace_prefix = {"": "first_stage_model."} + replace_prefix = {"": self.vae_key_prefix[0]} return utils.state_dict_prefix_replace(state_dict, replace_prefix) def set_inference_dtype(self, dtype, manual_cast_dtype):