mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-11 02:15:17 +00:00
Some cleanups to how the text encoders are loaded.
This commit is contained in:
parent
dbe0979b3f
commit
d91f45ef28
21
comfy/sd.py
21
comfy/sd.py
@ -138,7 +138,10 @@ class CLIP:
|
|||||||
tokens = self.tokenize(text)
|
tokens = self.tokenize(text)
|
||||||
return self.encode_from_tokens(tokens)
|
return self.encode_from_tokens(tokens)
|
||||||
|
|
||||||
def load_sd(self, sd):
|
def load_sd(self, sd, full_model=False):
|
||||||
|
if full_model:
|
||||||
|
return self.cond_stage_model.load_state_dict(sd, strict=False)
|
||||||
|
else:
|
||||||
return self.cond_stage_model.load_sd(sd)
|
return self.cond_stage_model.load_sd(sd)
|
||||||
|
|
||||||
def get_sd(self):
|
def get_sd(self):
|
||||||
@ -494,9 +497,6 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
|
|||||||
parameters = comfy.utils.calculate_parameters(sd, "model.diffusion_model.")
|
parameters = comfy.utils.calculate_parameters(sd, "model.diffusion_model.")
|
||||||
load_device = model_management.get_torch_device()
|
load_device = model_management.get_torch_device()
|
||||||
|
|
||||||
class WeightsLoader(torch.nn.Module):
|
|
||||||
pass
|
|
||||||
|
|
||||||
model_config = model_detection.model_config_from_unet(sd, "model.diffusion_model.")
|
model_config = model_detection.model_config_from_unet(sd, "model.diffusion_model.")
|
||||||
unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=model_config.supported_inference_dtypes)
|
unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=model_config.supported_inference_dtypes)
|
||||||
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
|
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
|
||||||
@ -521,14 +521,17 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
|
|||||||
vae = VAE(sd=vae_sd)
|
vae = VAE(sd=vae_sd)
|
||||||
|
|
||||||
if output_clip:
|
if output_clip:
|
||||||
w = WeightsLoader()
|
|
||||||
clip_target = model_config.clip_target()
|
clip_target = model_config.clip_target()
|
||||||
if clip_target is not None:
|
if clip_target is not None:
|
||||||
sd = model_config.process_clip_state_dict(sd)
|
clip_sd = model_config.process_clip_state_dict(sd)
|
||||||
if any(k.startswith('cond_stage_model.') for k in sd):
|
if len(clip_sd) > 0:
|
||||||
clip = CLIP(clip_target, embedding_directory=embedding_directory)
|
clip = CLIP(clip_target, embedding_directory=embedding_directory)
|
||||||
w.cond_stage_model = clip.cond_stage_model
|
m, u = clip.load_sd(clip_sd, full_model=True)
|
||||||
load_model_weights(w, sd)
|
if len(m) > 0:
|
||||||
|
print("clip missing:", m)
|
||||||
|
|
||||||
|
if len(u) > 0:
|
||||||
|
print("clip unexpected:", u)
|
||||||
else:
|
else:
|
||||||
print("no CLIP/text encoder weights in checkpoint, the text encoder model will not be loaded.")
|
print("no CLIP/text encoder weights in checkpoint, the text encoder model will not be loaded.")
|
||||||
|
|
||||||
|
@ -40,8 +40,8 @@ class SD15(supported_models_base.BASE):
|
|||||||
state_dict['cond_stage_model.transformer.text_model.embeddings.position_ids'] = ids.round()
|
state_dict['cond_stage_model.transformer.text_model.embeddings.position_ids'] = ids.round()
|
||||||
|
|
||||||
replace_prefix = {}
|
replace_prefix = {}
|
||||||
replace_prefix["cond_stage_model."] = "cond_stage_model.clip_l."
|
replace_prefix["cond_stage_model."] = "clip_l."
|
||||||
state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix)
|
state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=True)
|
||||||
return state_dict
|
return state_dict
|
||||||
|
|
||||||
def process_clip_state_dict_for_saving(self, state_dict):
|
def process_clip_state_dict_for_saving(self, state_dict):
|
||||||
@ -72,10 +72,10 @@ class SD20(supported_models_base.BASE):
|
|||||||
|
|
||||||
def process_clip_state_dict(self, state_dict):
|
def process_clip_state_dict(self, state_dict):
|
||||||
replace_prefix = {}
|
replace_prefix = {}
|
||||||
replace_prefix["conditioner.embedders.0.model."] = "cond_stage_model.model." #SD2 in sgm format
|
replace_prefix["conditioner.embedders.0.model."] = "clip_h." #SD2 in sgm format
|
||||||
state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix)
|
replace_prefix["cond_stage_model.model."] = "clip_h."
|
||||||
|
state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=True)
|
||||||
state_dict = utils.transformers_convert(state_dict, "cond_stage_model.model.", "cond_stage_model.clip_h.transformer.text_model.", 24)
|
state_dict = utils.transformers_convert(state_dict, "clip_h.", "clip_h.transformer.text_model.", 24)
|
||||||
return state_dict
|
return state_dict
|
||||||
|
|
||||||
def process_clip_state_dict_for_saving(self, state_dict):
|
def process_clip_state_dict_for_saving(self, state_dict):
|
||||||
@ -131,11 +131,10 @@ class SDXLRefiner(supported_models_base.BASE):
|
|||||||
def process_clip_state_dict(self, state_dict):
|
def process_clip_state_dict(self, state_dict):
|
||||||
keys_to_replace = {}
|
keys_to_replace = {}
|
||||||
replace_prefix = {}
|
replace_prefix = {}
|
||||||
|
replace_prefix["conditioner.embedders.0.model."] = "clip_g."
|
||||||
|
state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=True)
|
||||||
|
|
||||||
state_dict = utils.transformers_convert(state_dict, "conditioner.embedders.0.model.", "cond_stage_model.clip_g.transformer.text_model.", 32)
|
state_dict = utils.transformers_convert(state_dict, "clip_g.", "clip_g.transformer.text_model.", 32)
|
||||||
keys_to_replace["conditioner.embedders.0.model.text_projection"] = "cond_stage_model.clip_g.text_projection"
|
|
||||||
keys_to_replace["conditioner.embedders.0.model.logit_scale"] = "cond_stage_model.clip_g.logit_scale"
|
|
||||||
|
|
||||||
state_dict = utils.state_dict_key_replace(state_dict, keys_to_replace)
|
state_dict = utils.state_dict_key_replace(state_dict, keys_to_replace)
|
||||||
return state_dict
|
return state_dict
|
||||||
|
|
||||||
@ -179,13 +178,13 @@ class SDXL(supported_models_base.BASE):
|
|||||||
keys_to_replace = {}
|
keys_to_replace = {}
|
||||||
replace_prefix = {}
|
replace_prefix = {}
|
||||||
|
|
||||||
replace_prefix["conditioner.embedders.0.transformer.text_model"] = "cond_stage_model.clip_l.transformer.text_model"
|
replace_prefix["conditioner.embedders.0.transformer.text_model"] = "clip_l.transformer.text_model"
|
||||||
state_dict = utils.transformers_convert(state_dict, "conditioner.embedders.1.model.", "cond_stage_model.clip_g.transformer.text_model.", 32)
|
replace_prefix["conditioner.embedders.1.model."] = "clip_g."
|
||||||
keys_to_replace["conditioner.embedders.1.model.text_projection"] = "cond_stage_model.clip_g.text_projection"
|
state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=True)
|
||||||
keys_to_replace["conditioner.embedders.1.model.text_projection.weight"] = "cond_stage_model.clip_g.text_projection"
|
|
||||||
keys_to_replace["conditioner.embedders.1.model.logit_scale"] = "cond_stage_model.clip_g.logit_scale"
|
state_dict = utils.transformers_convert(state_dict, "clip_g.", "clip_g.transformer.text_model.", 32)
|
||||||
|
keys_to_replace["clip_g.text_projection.weight"] = "clip_g.text_projection"
|
||||||
|
|
||||||
state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix)
|
|
||||||
state_dict = utils.state_dict_key_replace(state_dict, keys_to_replace)
|
state_dict = utils.state_dict_key_replace(state_dict, keys_to_replace)
|
||||||
return state_dict
|
return state_dict
|
||||||
|
|
||||||
|
@ -22,6 +22,7 @@ class BASE:
|
|||||||
sampling_settings = {}
|
sampling_settings = {}
|
||||||
latent_format = latent_formats.LatentFormat
|
latent_format = latent_formats.LatentFormat
|
||||||
vae_key_prefix = ["first_stage_model."]
|
vae_key_prefix = ["first_stage_model."]
|
||||||
|
text_encoder_key_prefix = ["cond_stage_model."]
|
||||||
supported_inference_dtypes = [torch.float16, torch.bfloat16, torch.float32]
|
supported_inference_dtypes = [torch.float16, torch.bfloat16, torch.float32]
|
||||||
|
|
||||||
manual_cast_dtype = None
|
manual_cast_dtype = None
|
||||||
@ -55,6 +56,7 @@ class BASE:
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
def process_clip_state_dict(self, state_dict):
|
def process_clip_state_dict(self, state_dict):
|
||||||
|
state_dict = utils.state_dict_prefix_replace(state_dict, {k: "" for k in self.text_encoder_key_prefix}, filter_keys=True)
|
||||||
return state_dict
|
return state_dict
|
||||||
|
|
||||||
def process_unet_state_dict(self, state_dict):
|
def process_unet_state_dict(self, state_dict):
|
||||||
@ -64,7 +66,7 @@ class BASE:
|
|||||||
return state_dict
|
return state_dict
|
||||||
|
|
||||||
def process_clip_state_dict_for_saving(self, state_dict):
|
def process_clip_state_dict_for_saving(self, state_dict):
|
||||||
replace_prefix = {"": "cond_stage_model."}
|
replace_prefix = {"": self.text_encoder_key_prefix[0]}
|
||||||
return utils.state_dict_prefix_replace(state_dict, replace_prefix)
|
return utils.state_dict_prefix_replace(state_dict, replace_prefix)
|
||||||
|
|
||||||
def process_clip_vision_state_dict_for_saving(self, state_dict):
|
def process_clip_vision_state_dict_for_saving(self, state_dict):
|
||||||
@ -78,7 +80,7 @@ class BASE:
|
|||||||
return utils.state_dict_prefix_replace(state_dict, replace_prefix)
|
return utils.state_dict_prefix_replace(state_dict, replace_prefix)
|
||||||
|
|
||||||
def process_vae_state_dict_for_saving(self, state_dict):
|
def process_vae_state_dict_for_saving(self, state_dict):
|
||||||
replace_prefix = {"": "first_stage_model."}
|
replace_prefix = {"": self.vae_key_prefix[0]}
|
||||||
return utils.state_dict_prefix_replace(state_dict, replace_prefix)
|
return utils.state_dict_prefix_replace(state_dict, replace_prefix)
|
||||||
|
|
||||||
def set_inference_dtype(self, dtype, manual_cast_dtype):
|
def set_inference_dtype(self, dtype, manual_cast_dtype):
|
||||||
|
Loading…
Reference in New Issue
Block a user