mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-25 15:55:18 +00:00
575acb69e4
This is the code to load the model and inference it with only a text prompt. This commit does not contain the nodes to properly use it with an image input. This supports both the original SD1 instructpix2pix model and the diffusers SDXL one.
95 lines
3.3 KiB
Python
95 lines
3.3 KiB
Python
import torch
|
|
from . import model_base
|
|
from . import utils
|
|
from . import latent_formats
|
|
|
|
class ClipTarget:
|
|
def __init__(self, tokenizer, clip):
|
|
self.clip = clip
|
|
self.tokenizer = tokenizer
|
|
self.params = {}
|
|
|
|
class BASE:
|
|
unet_config = {}
|
|
unet_extra_config = {
|
|
"num_heads": -1,
|
|
"num_head_channels": 64,
|
|
}
|
|
|
|
required_keys = {}
|
|
|
|
clip_prefix = []
|
|
clip_vision_prefix = None
|
|
noise_aug_config = None
|
|
sampling_settings = {}
|
|
latent_format = latent_formats.LatentFormat
|
|
vae_key_prefix = ["first_stage_model."]
|
|
text_encoder_key_prefix = ["cond_stage_model."]
|
|
supported_inference_dtypes = [torch.float16, torch.bfloat16, torch.float32]
|
|
|
|
manual_cast_dtype = None
|
|
|
|
@classmethod
|
|
def matches(s, unet_config, state_dict=None):
|
|
for k in s.unet_config:
|
|
if k not in unet_config or s.unet_config[k] != unet_config[k]:
|
|
return False
|
|
if state_dict is not None:
|
|
for k in s.required_keys:
|
|
if k not in state_dict:
|
|
return False
|
|
return True
|
|
|
|
def model_type(self, state_dict, prefix=""):
|
|
return model_base.ModelType.EPS
|
|
|
|
def inpaint_model(self):
|
|
return self.unet_config["in_channels"] > 4
|
|
|
|
def __init__(self, unet_config):
|
|
self.unet_config = unet_config
|
|
self.latent_format = self.latent_format()
|
|
for x in self.unet_extra_config:
|
|
self.unet_config[x] = self.unet_extra_config[x]
|
|
|
|
def get_model(self, state_dict, prefix="", device=None):
|
|
if self.noise_aug_config is not None:
|
|
out = model_base.SD21UNCLIP(self, self.noise_aug_config, model_type=self.model_type(state_dict, prefix), device=device)
|
|
else:
|
|
out = model_base.BaseModel(self, model_type=self.model_type(state_dict, prefix), device=device)
|
|
if self.inpaint_model():
|
|
out.set_inpaint()
|
|
return out
|
|
|
|
def process_clip_state_dict(self, state_dict):
|
|
state_dict = utils.state_dict_prefix_replace(state_dict, {k: "" for k in self.text_encoder_key_prefix}, filter_keys=True)
|
|
return state_dict
|
|
|
|
def process_unet_state_dict(self, state_dict):
|
|
return state_dict
|
|
|
|
def process_vae_state_dict(self, state_dict):
|
|
return state_dict
|
|
|
|
def process_clip_state_dict_for_saving(self, state_dict):
|
|
replace_prefix = {"": self.text_encoder_key_prefix[0]}
|
|
return utils.state_dict_prefix_replace(state_dict, replace_prefix)
|
|
|
|
def process_clip_vision_state_dict_for_saving(self, state_dict):
|
|
replace_prefix = {}
|
|
if self.clip_vision_prefix is not None:
|
|
replace_prefix[""] = self.clip_vision_prefix
|
|
return utils.state_dict_prefix_replace(state_dict, replace_prefix)
|
|
|
|
def process_unet_state_dict_for_saving(self, state_dict):
|
|
replace_prefix = {"": "model.diffusion_model."}
|
|
return utils.state_dict_prefix_replace(state_dict, replace_prefix)
|
|
|
|
def process_vae_state_dict_for_saving(self, state_dict):
|
|
replace_prefix = {"": self.vae_key_prefix[0]}
|
|
return utils.state_dict_prefix_replace(state_dict, replace_prefix)
|
|
|
|
def set_inference_dtype(self, dtype, manual_cast_dtype):
|
|
self.unet_config['dtype'] = dtype
|
|
self.manual_cast_dtype = manual_cast_dtype
|