mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-25 15:55:18 +00:00
Move some functions to utils.py
This commit is contained in:
parent
766c7b3815
commit
a74c5dbf37
@ -68,7 +68,7 @@ class SD20(supported_models_base.BASE):
|
||||
def process_clip_state_dict_for_saving(self, state_dict):
|
||||
replace_prefix = {}
|
||||
replace_prefix[""] = "cond_stage_model.model."
|
||||
state_dict = supported_models_base.state_dict_prefix_replace(state_dict, replace_prefix)
|
||||
state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix)
|
||||
state_dict = diffusers_convert.convert_text_enc_state_dict_v20(state_dict)
|
||||
return state_dict
|
||||
|
||||
@ -120,7 +120,7 @@ class SDXLRefiner(supported_models_base.BASE):
|
||||
keys_to_replace["conditioner.embedders.0.model.text_projection"] = "cond_stage_model.clip_g.text_projection"
|
||||
keys_to_replace["conditioner.embedders.0.model.logit_scale"] = "cond_stage_model.clip_g.logit_scale"
|
||||
|
||||
state_dict = supported_models_base.state_dict_key_replace(state_dict, keys_to_replace)
|
||||
state_dict = utils.state_dict_key_replace(state_dict, keys_to_replace)
|
||||
return state_dict
|
||||
|
||||
def process_clip_state_dict_for_saving(self, state_dict):
|
||||
@ -129,7 +129,7 @@ class SDXLRefiner(supported_models_base.BASE):
|
||||
if "clip_g.transformer.text_model.embeddings.position_ids" in state_dict_g:
|
||||
state_dict_g.pop("clip_g.transformer.text_model.embeddings.position_ids")
|
||||
replace_prefix["clip_g"] = "conditioner.embedders.0.model"
|
||||
state_dict_g = supported_models_base.state_dict_prefix_replace(state_dict_g, replace_prefix)
|
||||
state_dict_g = utils.state_dict_prefix_replace(state_dict_g, replace_prefix)
|
||||
return state_dict_g
|
||||
|
||||
def clip_target(self):
|
||||
@ -167,8 +167,8 @@ class SDXL(supported_models_base.BASE):
|
||||
keys_to_replace["conditioner.embedders.1.model.text_projection"] = "cond_stage_model.clip_g.text_projection"
|
||||
keys_to_replace["conditioner.embedders.1.model.logit_scale"] = "cond_stage_model.clip_g.logit_scale"
|
||||
|
||||
state_dict = supported_models_base.state_dict_prefix_replace(state_dict, replace_prefix)
|
||||
state_dict = supported_models_base.state_dict_key_replace(state_dict, keys_to_replace)
|
||||
state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix)
|
||||
state_dict = utils.state_dict_key_replace(state_dict, keys_to_replace)
|
||||
return state_dict
|
||||
|
||||
def process_clip_state_dict_for_saving(self, state_dict):
|
||||
@ -183,7 +183,7 @@ class SDXL(supported_models_base.BASE):
|
||||
|
||||
replace_prefix["clip_g"] = "conditioner.embedders.1.model"
|
||||
replace_prefix["clip_l"] = "conditioner.embedders.0"
|
||||
state_dict_g = supported_models_base.state_dict_prefix_replace(state_dict_g, replace_prefix)
|
||||
state_dict_g = utils.state_dict_prefix_replace(state_dict_g, replace_prefix)
|
||||
return state_dict_g
|
||||
|
||||
def clip_target(self):
|
||||
|
@ -3,21 +3,6 @@ from . import model_base
|
||||
from . import utils
|
||||
from . import latent_formats
|
||||
|
||||
|
||||
def state_dict_key_replace(state_dict, keys_to_replace):
|
||||
for x in keys_to_replace:
|
||||
if x in state_dict:
|
||||
state_dict[keys_to_replace[x]] = state_dict.pop(x)
|
||||
return state_dict
|
||||
|
||||
def state_dict_prefix_replace(state_dict, replace_prefix):
|
||||
for rp in replace_prefix:
|
||||
replace = list(map(lambda a: (a, "{}{}".format(replace_prefix[rp], a[len(rp):])), filter(lambda a: a.startswith(rp), state_dict.keys())))
|
||||
for x in replace:
|
||||
state_dict[x[1]] = state_dict.pop(x[0])
|
||||
return state_dict
|
||||
|
||||
|
||||
class ClipTarget:
|
||||
def __init__(self, tokenizer, clip):
|
||||
self.clip = clip
|
||||
@ -70,13 +55,13 @@ class BASE:
|
||||
|
||||
def process_clip_state_dict_for_saving(self, state_dict):
|
||||
replace_prefix = {"": "cond_stage_model."}
|
||||
return state_dict_prefix_replace(state_dict, replace_prefix)
|
||||
return utils.state_dict_prefix_replace(state_dict, replace_prefix)
|
||||
|
||||
def process_unet_state_dict_for_saving(self, state_dict):
|
||||
replace_prefix = {"": "model.diffusion_model."}
|
||||
return state_dict_prefix_replace(state_dict, replace_prefix)
|
||||
return utils.state_dict_prefix_replace(state_dict, replace_prefix)
|
||||
|
||||
def process_vae_state_dict_for_saving(self, state_dict):
|
||||
replace_prefix = {"": "first_stage_model."}
|
||||
return state_dict_prefix_replace(state_dict, replace_prefix)
|
||||
return utils.state_dict_prefix_replace(state_dict, replace_prefix)
|
||||
|
||||
|
@ -39,6 +39,20 @@ def calculate_parameters(sd, prefix=""):
|
||||
params += sd[k].nelement()
|
||||
return params
|
||||
|
||||
def state_dict_key_replace(state_dict, keys_to_replace):
|
||||
for x in keys_to_replace:
|
||||
if x in state_dict:
|
||||
state_dict[keys_to_replace[x]] = state_dict.pop(x)
|
||||
return state_dict
|
||||
|
||||
def state_dict_prefix_replace(state_dict, replace_prefix):
|
||||
for rp in replace_prefix:
|
||||
replace = list(map(lambda a: (a, "{}{}".format(replace_prefix[rp], a[len(rp):])), filter(lambda a: a.startswith(rp), state_dict.keys())))
|
||||
for x in replace:
|
||||
state_dict[x[1]] = state_dict.pop(x[0])
|
||||
return state_dict
|
||||
|
||||
|
||||
def transformers_convert(sd, prefix_from, prefix_to, number):
|
||||
keys_to_replace = {
|
||||
"{}positional_embedding": "{}embeddings.position_embedding.weight",
|
||||
|
Loading…
Reference in New Issue
Block a user