mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-25 15:55:18 +00:00
Support SDXL inpaint models.
This commit is contained in:
parent
c335fdf200
commit
7931ff0fd9
@ -111,6 +111,9 @@ class BaseModel(torch.nn.Module):
|
|||||||
|
|
||||||
return {**unet_state_dict, **vae_state_dict, **clip_state_dict}
|
return {**unet_state_dict, **vae_state_dict, **clip_state_dict}
|
||||||
|
|
||||||
|
def set_inpaint(self):
|
||||||
|
self.concat_keys = ("mask", "masked_image")
|
||||||
|
|
||||||
def unclip_adm(unclip_conditioning, device, noise_augmentor, noise_augment_merge=0.0):
|
def unclip_adm(unclip_conditioning, device, noise_augmentor, noise_augment_merge=0.0):
|
||||||
adm_inputs = []
|
adm_inputs = []
|
||||||
weights = []
|
weights = []
|
||||||
@ -148,12 +151,6 @@ class SD21UNCLIP(BaseModel):
|
|||||||
else:
|
else:
|
||||||
return unclip_adm(unclip_conditioning, device, self.noise_augmentor, kwargs.get("unclip_noise_augment_merge", 0.05))
|
return unclip_adm(unclip_conditioning, device, self.noise_augmentor, kwargs.get("unclip_noise_augment_merge", 0.05))
|
||||||
|
|
||||||
|
|
||||||
class SDInpaint(BaseModel):
|
|
||||||
def __init__(self, model_config, model_type=ModelType.EPS, device=None):
|
|
||||||
super().__init__(model_config, model_type, device=device)
|
|
||||||
self.concat_keys = ("mask", "masked_image")
|
|
||||||
|
|
||||||
def sdxl_pooled(args, noise_augmentor):
|
def sdxl_pooled(args, noise_augmentor):
|
||||||
if "unclip_conditioning" in args:
|
if "unclip_conditioning" in args:
|
||||||
return unclip_adm(args.get("unclip_conditioning", None), args["device"], noise_augmentor)[:,:1280]
|
return unclip_adm(args.get("unclip_conditioning", None), args["device"], noise_augmentor)[:,:1280]
|
||||||
|
@ -183,8 +183,12 @@ def unet_config_from_diffusers_unet(state_dict, use_fp16):
|
|||||||
'num_res_blocks': 2, 'attention_resolutions': [], 'transformer_depth': [0, 0, 0], 'channel_mult': [1, 2, 4],
|
'num_res_blocks': 2, 'attention_resolutions': [], 'transformer_depth': [0, 0, 0], 'channel_mult': [1, 2, 4],
|
||||||
'transformer_depth_middle': 0, 'use_linear_in_transformer': True, "num_head_channels": 64, 'context_dim': 1}
|
'transformer_depth_middle': 0, 'use_linear_in_transformer': True, "num_head_channels": 64, 'context_dim': 1}
|
||||||
|
|
||||||
|
SDXL_diffusers_inpaint = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
||||||
|
'num_classes': 'sequential', 'adm_in_channels': 2816, 'use_fp16': use_fp16, 'in_channels': 9, 'model_channels': 320,
|
||||||
|
'num_res_blocks': 2, 'attention_resolutions': [2, 4], 'transformer_depth': [0, 2, 10], 'channel_mult': [1, 2, 4],
|
||||||
|
'transformer_depth_middle': 10, 'use_linear_in_transformer': True, 'context_dim': 2048, "num_head_channels": 64}
|
||||||
|
|
||||||
supported_models = [SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl, SDXL_mid_cnet, SDXL_small_cnet]
|
supported_models = [SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl, SDXL_mid_cnet, SDXL_small_cnet, SDXL_diffusers_inpaint]
|
||||||
|
|
||||||
for unet_config in supported_models:
|
for unet_config in supported_models:
|
||||||
matches = True
|
matches = True
|
||||||
|
@ -355,13 +355,14 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl
|
|||||||
model_config.latent_format = latent_formats.SD15(scale_factor=scale_factor)
|
model_config.latent_format = latent_formats.SD15(scale_factor=scale_factor)
|
||||||
model_config.unet_config = unet_config
|
model_config.unet_config = unet_config
|
||||||
|
|
||||||
if config['model']["target"].endswith("LatentInpaintDiffusion"):
|
if config['model']["target"].endswith("ImageEmbeddingConditionedLatentDiffusion"):
|
||||||
model = model_base.SDInpaint(model_config, model_type=model_type)
|
|
||||||
elif config['model']["target"].endswith("ImageEmbeddingConditionedLatentDiffusion"):
|
|
||||||
model = model_base.SD21UNCLIP(model_config, noise_aug_config["params"], model_type=model_type)
|
model = model_base.SD21UNCLIP(model_config, noise_aug_config["params"], model_type=model_type)
|
||||||
else:
|
else:
|
||||||
model = model_base.BaseModel(model_config, model_type=model_type)
|
model = model_base.BaseModel(model_config, model_type=model_type)
|
||||||
|
|
||||||
|
if config['model']["target"].endswith("LatentInpaintDiffusion"):
|
||||||
|
model.set_inpaint()
|
||||||
|
|
||||||
if fp16:
|
if fp16:
|
||||||
model = model.half()
|
model = model.half()
|
||||||
|
|
||||||
|
@ -153,7 +153,10 @@ class SDXL(supported_models_base.BASE):
|
|||||||
return model_base.ModelType.EPS
|
return model_base.ModelType.EPS
|
||||||
|
|
||||||
def get_model(self, state_dict, prefix="", device=None):
|
def get_model(self, state_dict, prefix="", device=None):
|
||||||
return model_base.SDXL(self, model_type=self.model_type(state_dict, prefix), device=device)
|
out = model_base.SDXL(self, model_type=self.model_type(state_dict, prefix), device=device)
|
||||||
|
if self.inpaint_model():
|
||||||
|
out.set_inpaint()
|
||||||
|
return out
|
||||||
|
|
||||||
def process_clip_state_dict(self, state_dict):
|
def process_clip_state_dict(self, state_dict):
|
||||||
keys_to_replace = {}
|
keys_to_replace = {}
|
||||||
|
@ -57,12 +57,13 @@ class BASE:
|
|||||||
self.unet_config[x] = self.unet_extra_config[x]
|
self.unet_config[x] = self.unet_extra_config[x]
|
||||||
|
|
||||||
def get_model(self, state_dict, prefix="", device=None):
|
def get_model(self, state_dict, prefix="", device=None):
|
||||||
if self.inpaint_model():
|
if self.noise_aug_config is not None:
|
||||||
return model_base.SDInpaint(self, model_type=self.model_type(state_dict, prefix), device=device)
|
out = model_base.SD21UNCLIP(self, self.noise_aug_config, model_type=self.model_type(state_dict, prefix), device=device)
|
||||||
elif self.noise_aug_config is not None:
|
|
||||||
return model_base.SD21UNCLIP(self, self.noise_aug_config, model_type=self.model_type(state_dict, prefix), device=device)
|
|
||||||
else:
|
else:
|
||||||
return model_base.BaseModel(self, model_type=self.model_type(state_dict, prefix), device=device)
|
out = model_base.BaseModel(self, model_type=self.model_type(state_dict, prefix), device=device)
|
||||||
|
if self.inpaint_model():
|
||||||
|
out.set_inpaint()
|
||||||
|
return out
|
||||||
|
|
||||||
def process_clip_state_dict(self, state_dict):
|
def process_clip_state_dict(self, state_dict):
|
||||||
return state_dict
|
return state_dict
|
||||||
|
Loading…
Reference in New Issue
Block a user