diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 0b3e4bcb..bb8ae214 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -68,7 +68,7 @@ class SD20(supported_models_base.BASE): def process_clip_state_dict_for_saving(self, state_dict): replace_prefix = {} replace_prefix[""] = "cond_stage_model.model." - state_dict = supported_models_base.state_dict_prefix_replace(state_dict, replace_prefix) + state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix) state_dict = diffusers_convert.convert_text_enc_state_dict_v20(state_dict) return state_dict @@ -120,7 +120,7 @@ class SDXLRefiner(supported_models_base.BASE): keys_to_replace["conditioner.embedders.0.model.text_projection"] = "cond_stage_model.clip_g.text_projection" keys_to_replace["conditioner.embedders.0.model.logit_scale"] = "cond_stage_model.clip_g.logit_scale" - state_dict = supported_models_base.state_dict_key_replace(state_dict, keys_to_replace) + state_dict = utils.state_dict_key_replace(state_dict, keys_to_replace) return state_dict def process_clip_state_dict_for_saving(self, state_dict): @@ -129,7 +129,7 @@ class SDXLRefiner(supported_models_base.BASE): if "clip_g.transformer.text_model.embeddings.position_ids" in state_dict_g: state_dict_g.pop("clip_g.transformer.text_model.embeddings.position_ids") replace_prefix["clip_g"] = "conditioner.embedders.0.model" - state_dict_g = supported_models_base.state_dict_prefix_replace(state_dict_g, replace_prefix) + state_dict_g = utils.state_dict_prefix_replace(state_dict_g, replace_prefix) return state_dict_g def clip_target(self): @@ -167,8 +167,8 @@ class SDXL(supported_models_base.BASE): keys_to_replace["conditioner.embedders.1.model.text_projection"] = "cond_stage_model.clip_g.text_projection" keys_to_replace["conditioner.embedders.1.model.logit_scale"] = "cond_stage_model.clip_g.logit_scale" - state_dict = supported_models_base.state_dict_prefix_replace(state_dict, replace_prefix) - state_dict = supported_models_base.state_dict_key_replace(state_dict, keys_to_replace) + state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix) + state_dict = utils.state_dict_key_replace(state_dict, keys_to_replace) return state_dict def process_clip_state_dict_for_saving(self, state_dict): @@ -183,7 +183,7 @@ class SDXL(supported_models_base.BASE): replace_prefix["clip_g"] = "conditioner.embedders.1.model" replace_prefix["clip_l"] = "conditioner.embedders.0" - state_dict_g = supported_models_base.state_dict_prefix_replace(state_dict_g, replace_prefix) + state_dict_g = utils.state_dict_prefix_replace(state_dict_g, replace_prefix) return state_dict_g def clip_target(self): diff --git a/comfy/supported_models_base.py b/comfy/supported_models_base.py index 395a90ab..88a1d7fd 100644 --- a/comfy/supported_models_base.py +++ b/comfy/supported_models_base.py @@ -3,21 +3,6 @@ from . import model_base from . import utils from . import latent_formats - -def state_dict_key_replace(state_dict, keys_to_replace): - for x in keys_to_replace: - if x in state_dict: - state_dict[keys_to_replace[x]] = state_dict.pop(x) - return state_dict - -def state_dict_prefix_replace(state_dict, replace_prefix): - for rp in replace_prefix: - replace = list(map(lambda a: (a, "{}{}".format(replace_prefix[rp], a[len(rp):])), filter(lambda a: a.startswith(rp), state_dict.keys()))) - for x in replace: - state_dict[x[1]] = state_dict.pop(x[0]) - return state_dict - - class ClipTarget: def __init__(self, tokenizer, clip): self.clip = clip @@ -70,13 +55,13 @@ class BASE: def process_clip_state_dict_for_saving(self, state_dict): replace_prefix = {"": "cond_stage_model."} - return state_dict_prefix_replace(state_dict, replace_prefix) + return utils.state_dict_prefix_replace(state_dict, replace_prefix) def process_unet_state_dict_for_saving(self, state_dict): replace_prefix = {"": "model.diffusion_model."} - return state_dict_prefix_replace(state_dict, replace_prefix) + return utils.state_dict_prefix_replace(state_dict, replace_prefix) def process_vae_state_dict_for_saving(self, state_dict): replace_prefix = {"": "first_stage_model."} - return state_dict_prefix_replace(state_dict, replace_prefix) + return utils.state_dict_prefix_replace(state_dict, replace_prefix) diff --git a/comfy/utils.py b/comfy/utils.py index 47f4b970..3ed32e37 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -39,6 +39,20 @@ def calculate_parameters(sd, prefix=""): params += sd[k].nelement() return params +def state_dict_key_replace(state_dict, keys_to_replace): + for x in keys_to_replace: + if x in state_dict: + state_dict[keys_to_replace[x]] = state_dict.pop(x) + return state_dict + +def state_dict_prefix_replace(state_dict, replace_prefix): + for rp in replace_prefix: + replace = list(map(lambda a: (a, "{}{}".format(replace_prefix[rp], a[len(rp):])), filter(lambda a: a.startswith(rp), state_dict.keys()))) + for x in replace: + state_dict[x[1]] = state_dict.pop(x[0]) + return state_dict + + def transformers_convert(sd, prefix_from, prefix_to, number): keys_to_replace = { "{}positional_embedding": "{}embeddings.position_embedding.weight",