mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-11 10:25:16 +00:00
9b93b920be
The created checkpoints contain workflow metadata that can be loaded by dragging them on top of the UI or loading them with the "Load" button. Checkpoints will be saved in fp16 or fp32 depending on the format ComfyUI is using for inference on your hardware. To force fp32 use: --force-fp32 Anything that patches the model weights like merging or loras will be saved. The output directory is currently set to: output/checkpoints but that might change in the future.
179 lines
6.6 KiB
Python
179 lines
6.6 KiB
Python
import torch
|
|
from . import model_base
|
|
from . import utils
|
|
|
|
from . import sd1_clip
|
|
from . import sd2_clip
|
|
from . import sdxl_clip
|
|
|
|
from . import supported_models_base
|
|
from . import latent_formats
|
|
|
|
from . import diffusers_convert
|
|
|
|
class SD15(supported_models_base.BASE):
|
|
unet_config = {
|
|
"context_dim": 768,
|
|
"model_channels": 320,
|
|
"use_linear_in_transformer": False,
|
|
"adm_in_channels": None,
|
|
}
|
|
|
|
unet_extra_config = {
|
|
"num_heads": 8,
|
|
"num_head_channels": -1,
|
|
}
|
|
|
|
latent_format = latent_formats.SD15
|
|
|
|
def process_clip_state_dict(self, state_dict):
|
|
k = list(state_dict.keys())
|
|
for x in k:
|
|
if x.startswith("cond_stage_model.transformer.") and not x.startswith("cond_stage_model.transformer.text_model."):
|
|
y = x.replace("cond_stage_model.transformer.", "cond_stage_model.transformer.text_model.")
|
|
state_dict[y] = state_dict.pop(x)
|
|
|
|
if 'cond_stage_model.transformer.text_model.embeddings.position_ids' in state_dict:
|
|
ids = state_dict['cond_stage_model.transformer.text_model.embeddings.position_ids']
|
|
if ids.dtype == torch.float32:
|
|
state_dict['cond_stage_model.transformer.text_model.embeddings.position_ids'] = ids.round()
|
|
|
|
return state_dict
|
|
|
|
def clip_target(self):
|
|
return supported_models_base.ClipTarget(sd1_clip.SD1Tokenizer, sd1_clip.SD1ClipModel)
|
|
|
|
class SD20(supported_models_base.BASE):
|
|
unet_config = {
|
|
"context_dim": 1024,
|
|
"model_channels": 320,
|
|
"use_linear_in_transformer": True,
|
|
"adm_in_channels": None,
|
|
}
|
|
|
|
latent_format = latent_formats.SD15
|
|
|
|
def v_prediction(self, state_dict):
|
|
if self.unet_config["in_channels"] == 4: #SD2.0 inpainting models are not v prediction
|
|
k = "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.norm1.bias"
|
|
out = state_dict[k]
|
|
if torch.std(out, unbiased=False) > 0.09: # not sure how well this will actually work. I guess we will find out.
|
|
return True
|
|
return False
|
|
|
|
def process_clip_state_dict(self, state_dict):
|
|
state_dict = utils.transformers_convert(state_dict, "cond_stage_model.model.", "cond_stage_model.transformer.text_model.", 24)
|
|
return state_dict
|
|
|
|
def process_clip_state_dict_for_saving(self, state_dict):
|
|
replace_prefix = {}
|
|
replace_prefix[""] = "cond_stage_model.model."
|
|
state_dict = supported_models_base.state_dict_prefix_replace(state_dict, replace_prefix)
|
|
state_dict = diffusers_convert.convert_text_enc_state_dict_v20(state_dict)
|
|
return state_dict
|
|
|
|
def clip_target(self):
|
|
return supported_models_base.ClipTarget(sd2_clip.SD2Tokenizer, sd2_clip.SD2ClipModel)
|
|
|
|
class SD21UnclipL(SD20):
|
|
unet_config = {
|
|
"context_dim": 1024,
|
|
"model_channels": 320,
|
|
"use_linear_in_transformer": True,
|
|
"adm_in_channels": 1536,
|
|
}
|
|
|
|
clip_vision_prefix = "embedder.model.visual."
|
|
noise_aug_config = {"noise_schedule_config": {"timesteps": 1000, "beta_schedule": "squaredcos_cap_v2"}, "timestep_dim": 768}
|
|
|
|
|
|
class SD21UnclipH(SD20):
|
|
unet_config = {
|
|
"context_dim": 1024,
|
|
"model_channels": 320,
|
|
"use_linear_in_transformer": True,
|
|
"adm_in_channels": 2048,
|
|
}
|
|
|
|
clip_vision_prefix = "embedder.model.visual."
|
|
noise_aug_config = {"noise_schedule_config": {"timesteps": 1000, "beta_schedule": "squaredcos_cap_v2"}, "timestep_dim": 1024}
|
|
|
|
class SDXLRefiner(supported_models_base.BASE):
|
|
unet_config = {
|
|
"model_channels": 384,
|
|
"use_linear_in_transformer": True,
|
|
"context_dim": 1280,
|
|
"adm_in_channels": 2560,
|
|
"transformer_depth": [0, 4, 4, 0],
|
|
}
|
|
|
|
latent_format = latent_formats.SDXL
|
|
|
|
def get_model(self, state_dict):
|
|
return model_base.SDXLRefiner(self)
|
|
|
|
def process_clip_state_dict(self, state_dict):
|
|
keys_to_replace = {}
|
|
replace_prefix = {}
|
|
|
|
state_dict = utils.transformers_convert(state_dict, "conditioner.embedders.0.model.", "cond_stage_model.clip_g.transformer.text_model.", 32)
|
|
keys_to_replace["conditioner.embedders.0.model.text_projection"] = "cond_stage_model.clip_g.text_projection"
|
|
|
|
state_dict = supported_models_base.state_dict_key_replace(state_dict, keys_to_replace)
|
|
return state_dict
|
|
|
|
def process_clip_state_dict_for_saving(self, state_dict):
|
|
replace_prefix = {}
|
|
state_dict_g = diffusers_convert.convert_text_enc_state_dict_v20(state_dict, "clip_g")
|
|
replace_prefix["clip_g"] = "conditioner.embedders.0.model"
|
|
state_dict_g = supported_models_base.state_dict_prefix_replace(state_dict_g, replace_prefix)
|
|
return state_dict_g
|
|
|
|
def clip_target(self):
|
|
return supported_models_base.ClipTarget(sdxl_clip.SDXLTokenizer, sdxl_clip.SDXLRefinerClipModel)
|
|
|
|
class SDXL(supported_models_base.BASE):
|
|
unet_config = {
|
|
"model_channels": 320,
|
|
"use_linear_in_transformer": True,
|
|
"transformer_depth": [0, 2, 10],
|
|
"context_dim": 2048,
|
|
"adm_in_channels": 2816
|
|
}
|
|
|
|
latent_format = latent_formats.SDXL
|
|
|
|
def get_model(self, state_dict):
|
|
return model_base.SDXL(self)
|
|
|
|
def process_clip_state_dict(self, state_dict):
|
|
keys_to_replace = {}
|
|
replace_prefix = {}
|
|
|
|
replace_prefix["conditioner.embedders.0.transformer.text_model"] = "cond_stage_model.clip_l.transformer.text_model"
|
|
state_dict = utils.transformers_convert(state_dict, "conditioner.embedders.1.model.", "cond_stage_model.clip_g.transformer.text_model.", 32)
|
|
keys_to_replace["conditioner.embedders.1.model.text_projection"] = "cond_stage_model.clip_g.text_projection"
|
|
|
|
state_dict = supported_models_base.state_dict_prefix_replace(state_dict, replace_prefix)
|
|
state_dict = supported_models_base.state_dict_key_replace(state_dict, keys_to_replace)
|
|
return state_dict
|
|
|
|
def process_clip_state_dict_for_saving(self, state_dict):
|
|
replace_prefix = {}
|
|
keys_to_replace = {}
|
|
state_dict_g = diffusers_convert.convert_text_enc_state_dict_v20(state_dict, "clip_g")
|
|
for k in state_dict:
|
|
if k.startswith("clip_l"):
|
|
state_dict_g[k] = state_dict[k]
|
|
|
|
replace_prefix["clip_g"] = "conditioner.embedders.1.model"
|
|
replace_prefix["clip_l"] = "conditioner.embedders.0"
|
|
state_dict_g = supported_models_base.state_dict_prefix_replace(state_dict_g, replace_prefix)
|
|
return state_dict_g
|
|
|
|
def clip_target(self):
|
|
return supported_models_base.ClipTarget(sdxl_clip.SDXLTokenizer, sdxl_clip.SDXLClipModel)
|
|
|
|
|
|
models = [SD15, SD20, SD21UnclipL, SD21UnclipH, SDXLRefiner, SDXL]
|