mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-11 02:15:17 +00:00
IP2P model loading support.
This is the code to load the model and inference it with only a text prompt. This commit does not contain the nodes to properly use it with an image input. This supports both the original SD1 instructpix2pix model and the diffusers SDXL one.
This commit is contained in:
parent
96b4c757cf
commit
575acb69e4
@ -473,6 +473,40 @@ class SD_X4Upscaler(BaseModel):
|
|||||||
out['y'] = comfy.conds.CONDRegular(noise_level)
|
out['y'] = comfy.conds.CONDRegular(noise_level)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
class IP2P:
|
||||||
|
def extra_conds(self, **kwargs):
|
||||||
|
out = {}
|
||||||
|
|
||||||
|
image = kwargs.get("concat_latent_image", None)
|
||||||
|
noise = kwargs.get("noise", None)
|
||||||
|
device = kwargs["device"]
|
||||||
|
|
||||||
|
if image is None:
|
||||||
|
image = torch.zeros_like(noise)
|
||||||
|
|
||||||
|
if image.shape[1:] != noise.shape[1:]:
|
||||||
|
image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
|
||||||
|
|
||||||
|
image = utils.resize_to_batch_size(image, noise.shape[0])
|
||||||
|
|
||||||
|
out['c_concat'] = comfy.conds.CONDNoiseShape(self.process_ip2p_image_in(image))
|
||||||
|
adm = self.encode_adm(**kwargs)
|
||||||
|
if adm is not None:
|
||||||
|
out['y'] = comfy.conds.CONDRegular(adm)
|
||||||
|
return out
|
||||||
|
|
||||||
|
class SD15_instructpix2pix(IP2P, BaseModel):
|
||||||
|
def __init__(self, model_config, model_type=ModelType.EPS, device=None):
|
||||||
|
super().__init__(model_config, model_type, device=device)
|
||||||
|
self.process_ip2p_image_in = lambda image: image
|
||||||
|
|
||||||
|
class SDXL_instructpix2pix(IP2P, SDXL):
|
||||||
|
def __init__(self, model_config, model_type=ModelType.EPS, device=None):
|
||||||
|
super().__init__(model_config, model_type, device=device)
|
||||||
|
# self.process_ip2p_image_in = lambda image: comfy.latent_formats.SDXL().process_in(image)
|
||||||
|
self.process_ip2p_image_in = lambda image: image
|
||||||
|
|
||||||
|
|
||||||
class StableCascade_C(BaseModel):
|
class StableCascade_C(BaseModel):
|
||||||
def __init__(self, model_config, model_type=ModelType.STABLE_CASCADE, device=None):
|
def __init__(self, model_config, model_type=ModelType.STABLE_CASCADE, device=None):
|
||||||
super().__init__(model_config, model_type, device=device, unet_model=StageC)
|
super().__init__(model_config, model_type, device=device, unet_model=StageC)
|
||||||
|
@ -182,9 +182,9 @@ def detect_unet_config(state_dict, key_prefix):
|
|||||||
|
|
||||||
return unet_config
|
return unet_config
|
||||||
|
|
||||||
def model_config_from_unet_config(unet_config):
|
def model_config_from_unet_config(unet_config, state_dict=None):
|
||||||
for model_config in comfy.supported_models.models:
|
for model_config in comfy.supported_models.models:
|
||||||
if model_config.matches(unet_config):
|
if model_config.matches(unet_config, state_dict):
|
||||||
return model_config(unet_config)
|
return model_config(unet_config)
|
||||||
|
|
||||||
logging.error("no match {}".format(unet_config))
|
logging.error("no match {}".format(unet_config))
|
||||||
@ -192,7 +192,7 @@ def model_config_from_unet_config(unet_config):
|
|||||||
|
|
||||||
def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=False):
|
def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=False):
|
||||||
unet_config = detect_unet_config(state_dict, unet_key_prefix)
|
unet_config = detect_unet_config(state_dict, unet_key_prefix)
|
||||||
model_config = model_config_from_unet_config(unet_config)
|
model_config = model_config_from_unet_config(unet_config, state_dict)
|
||||||
if model_config is None and use_base_if_no_match:
|
if model_config is None and use_base_if_no_match:
|
||||||
return comfy.supported_models_base.BASE(unet_config)
|
return comfy.supported_models_base.BASE(unet_config)
|
||||||
else:
|
else:
|
||||||
@ -321,6 +321,12 @@ def unet_config_from_diffusers_unet(state_dict, dtype=None):
|
|||||||
'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64, 'transformer_depth_output': [0, 0, 0, 2, 2, 2, 10, 10, 10],
|
'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64, 'transformer_depth_output': [0, 0, 0, 2, 2, 2, 10, 10, 10],
|
||||||
'use_temporal_attention': False, 'use_temporal_resblock': False}
|
'use_temporal_attention': False, 'use_temporal_resblock': False}
|
||||||
|
|
||||||
|
SDXL_diffusers_ip2p = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
||||||
|
'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 8, 'model_channels': 320,
|
||||||
|
'num_res_blocks': [2, 2, 2], 'transformer_depth': [0, 0, 2, 2, 10, 10], 'channel_mult': [1, 2, 4], 'transformer_depth_middle': 10,
|
||||||
|
'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64, 'transformer_depth_output': [0, 0, 0, 2, 2, 2, 10, 10, 10],
|
||||||
|
'use_temporal_attention': False, 'use_temporal_resblock': False}
|
||||||
|
|
||||||
SSD_1B = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
SSD_1B = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
||||||
'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320,
|
'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320,
|
||||||
'num_res_blocks': [2, 2, 2], 'transformer_depth': [0, 0, 2, 2, 4, 4], 'transformer_depth_output': [0, 0, 0, 1, 1, 2, 10, 4, 4],
|
'num_res_blocks': [2, 2, 2], 'transformer_depth': [0, 0, 2, 2, 4, 4], 'transformer_depth_output': [0, 0, 0, 1, 1, 2, 10, 4, 4],
|
||||||
@ -351,7 +357,7 @@ def unet_config_from_diffusers_unet(state_dict, dtype=None):
|
|||||||
'context_dim': 1024, 'num_head_channels': 64, 'transformer_depth_output': [1, 1, 1, 1, 1, 1],
|
'context_dim': 1024, 'num_head_channels': 64, 'transformer_depth_output': [1, 1, 1, 1, 1, 1],
|
||||||
'use_temporal_attention': False, 'use_temporal_resblock': False, 'disable_self_attentions': [True, False, False]}
|
'use_temporal_attention': False, 'use_temporal_resblock': False, 'disable_self_attentions': [True, False, False]}
|
||||||
|
|
||||||
supported_models = [SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl, SDXL_mid_cnet, SDXL_small_cnet, SDXL_diffusers_inpaint, SSD_1B, Segmind_Vega, KOALA_700M, KOALA_1B, SD09_XS]
|
supported_models = [SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl, SDXL_mid_cnet, SDXL_small_cnet, SDXL_diffusers_inpaint, SSD_1B, Segmind_Vega, KOALA_700M, KOALA_1B, SD09_XS, SDXL_diffusers_ip2p]
|
||||||
|
|
||||||
for unet_config in supported_models:
|
for unet_config in supported_models:
|
||||||
matches = True
|
matches = True
|
||||||
|
@ -334,6 +334,11 @@ class Stable_Zero123(supported_models_base.BASE):
|
|||||||
"num_head_channels": -1,
|
"num_head_channels": -1,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
required_keys = {
|
||||||
|
"cc_projection.weight": None,
|
||||||
|
"cc_projection.bias": None,
|
||||||
|
}
|
||||||
|
|
||||||
clip_vision_prefix = "cond_stage_model.model.visual."
|
clip_vision_prefix = "cond_stage_model.model.visual."
|
||||||
|
|
||||||
latent_format = latent_formats.SD15
|
latent_format = latent_formats.SD15
|
||||||
@ -439,6 +444,33 @@ class Stable_Cascade_B(Stable_Cascade_C):
|
|||||||
out = model_base.StableCascade_B(self, device=device)
|
out = model_base.StableCascade_B(self, device=device)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
class SD15_instructpix2pix(SD15):
|
||||||
|
unet_config = {
|
||||||
|
"context_dim": 768,
|
||||||
|
"model_channels": 320,
|
||||||
|
"use_linear_in_transformer": False,
|
||||||
|
"adm_in_channels": None,
|
||||||
|
"use_temporal_attention": False,
|
||||||
|
"in_channels": 8,
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_model(self, state_dict, prefix="", device=None):
|
||||||
|
return model_base.SD15_instructpix2pix(self, device=device)
|
||||||
|
|
||||||
|
class SDXL_instructpix2pix(SDXL):
|
||||||
|
unet_config = {
|
||||||
|
"model_channels": 320,
|
||||||
|
"use_linear_in_transformer": True,
|
||||||
|
"transformer_depth": [0, 0, 2, 2, 10, 10],
|
||||||
|
"context_dim": 2048,
|
||||||
|
"adm_in_channels": 2816,
|
||||||
|
"use_temporal_attention": False,
|
||||||
|
"in_channels": 8,
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_model(self, state_dict, prefix="", device=None):
|
||||||
|
return model_base.SDXL_instructpix2pix(self, device=device)
|
||||||
|
|
||||||
|
models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p]
|
||||||
|
|
||||||
models = [Stable_Zero123, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p]
|
|
||||||
models += [SVD_img2vid]
|
models += [SVD_img2vid]
|
||||||
|
@ -16,6 +16,8 @@ class BASE:
|
|||||||
"num_head_channels": 64,
|
"num_head_channels": 64,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
required_keys = {}
|
||||||
|
|
||||||
clip_prefix = []
|
clip_prefix = []
|
||||||
clip_vision_prefix = None
|
clip_vision_prefix = None
|
||||||
noise_aug_config = None
|
noise_aug_config = None
|
||||||
@ -28,10 +30,14 @@ class BASE:
|
|||||||
manual_cast_dtype = None
|
manual_cast_dtype = None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def matches(s, unet_config):
|
def matches(s, unet_config, state_dict=None):
|
||||||
for k in s.unet_config:
|
for k in s.unet_config:
|
||||||
if k not in unet_config or s.unet_config[k] != unet_config[k]:
|
if k not in unet_config or s.unet_config[k] != unet_config[k]:
|
||||||
return False
|
return False
|
||||||
|
if state_dict is not None:
|
||||||
|
for k in s.required_keys:
|
||||||
|
if k not in state_dict:
|
||||||
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def model_type(self, state_dict, prefix=""):
|
def model_type(self, state_dict, prefix=""):
|
||||||
|
Loading…
Reference in New Issue
Block a user