From 575acb69e46a4c24ddba4e0bdb895d6e85dc9354 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sun, 31 Mar 2024 01:25:16 -0400 Subject: [PATCH] IP2P model loading support. This is the code to load the model and inference it with only a text prompt. This commit does not contain the nodes to properly use it with an image input. This supports both the original SD1 instructpix2pix model and the diffusers SDXL one. --- comfy/model_base.py | 34 ++++++++++++++++++++++++++++++++++ comfy/model_detection.py | 14 ++++++++++---- comfy/supported_models.py | 34 +++++++++++++++++++++++++++++++++- comfy/supported_models_base.py | 8 +++++++- 4 files changed, 84 insertions(+), 6 deletions(-) diff --git a/comfy/model_base.py b/comfy/model_base.py index 6f530d2f..898adb66 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -473,6 +473,40 @@ class SD_X4Upscaler(BaseModel): out['y'] = comfy.conds.CONDRegular(noise_level) return out +class IP2P: + def extra_conds(self, **kwargs): + out = {} + + image = kwargs.get("concat_latent_image", None) + noise = kwargs.get("noise", None) + device = kwargs["device"] + + if image is None: + image = torch.zeros_like(noise) + + if image.shape[1:] != noise.shape[1:]: + image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center") + + image = utils.resize_to_batch_size(image, noise.shape[0]) + + out['c_concat'] = comfy.conds.CONDNoiseShape(self.process_ip2p_image_in(image)) + adm = self.encode_adm(**kwargs) + if adm is not None: + out['y'] = comfy.conds.CONDRegular(adm) + return out + +class SD15_instructpix2pix(IP2P, BaseModel): + def __init__(self, model_config, model_type=ModelType.EPS, device=None): + super().__init__(model_config, model_type, device=device) + self.process_ip2p_image_in = lambda image: image + +class SDXL_instructpix2pix(IP2P, SDXL): + def __init__(self, model_config, model_type=ModelType.EPS, device=None): + super().__init__(model_config, model_type, device=device) + # self.process_ip2p_image_in = lambda image: comfy.latent_formats.SDXL().process_in(image) + self.process_ip2p_image_in = lambda image: image + + class StableCascade_C(BaseModel): def __init__(self, model_config, model_type=ModelType.STABLE_CASCADE, device=None): super().__init__(model_config, model_type, device=device, unet_model=StageC) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index bddbe2a4..795af828 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -182,9 +182,9 @@ def detect_unet_config(state_dict, key_prefix): return unet_config -def model_config_from_unet_config(unet_config): +def model_config_from_unet_config(unet_config, state_dict=None): for model_config in comfy.supported_models.models: - if model_config.matches(unet_config): + if model_config.matches(unet_config, state_dict): return model_config(unet_config) logging.error("no match {}".format(unet_config)) @@ -192,7 +192,7 @@ def model_config_from_unet_config(unet_config): def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=False): unet_config = detect_unet_config(state_dict, unet_key_prefix) - model_config = model_config_from_unet_config(unet_config) + model_config = model_config_from_unet_config(unet_config, state_dict) if model_config is None and use_base_if_no_match: return comfy.supported_models_base.BASE(unet_config) else: @@ -321,6 +321,12 @@ def unet_config_from_diffusers_unet(state_dict, dtype=None): 'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64, 'transformer_depth_output': [0, 0, 0, 2, 2, 2, 10, 10, 10], 'use_temporal_attention': False, 'use_temporal_resblock': False} + SDXL_diffusers_ip2p = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, + 'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 8, 'model_channels': 320, + 'num_res_blocks': [2, 2, 2], 'transformer_depth': [0, 0, 2, 2, 10, 10], 'channel_mult': [1, 2, 4], 'transformer_depth_middle': 10, + 'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64, 'transformer_depth_output': [0, 0, 0, 2, 2, 2, 10, 10, 10], + 'use_temporal_attention': False, 'use_temporal_resblock': False} + SSD_1B = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, 'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': [2, 2, 2], 'transformer_depth': [0, 0, 2, 2, 4, 4], 'transformer_depth_output': [0, 0, 0, 1, 1, 2, 10, 4, 4], @@ -351,7 +357,7 @@ def unet_config_from_diffusers_unet(state_dict, dtype=None): 'context_dim': 1024, 'num_head_channels': 64, 'transformer_depth_output': [1, 1, 1, 1, 1, 1], 'use_temporal_attention': False, 'use_temporal_resblock': False, 'disable_self_attentions': [True, False, False]} - supported_models = [SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl, SDXL_mid_cnet, SDXL_small_cnet, SDXL_diffusers_inpaint, SSD_1B, Segmind_Vega, KOALA_700M, KOALA_1B, SD09_XS] + supported_models = [SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl, SDXL_mid_cnet, SDXL_small_cnet, SDXL_diffusers_inpaint, SSD_1B, Segmind_Vega, KOALA_700M, KOALA_1B, SD09_XS, SDXL_diffusers_ip2p] for unet_config in supported_models: matches = True diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 5b2eb73f..9bfe3ea1 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -334,6 +334,11 @@ class Stable_Zero123(supported_models_base.BASE): "num_head_channels": -1, } + required_keys = { + "cc_projection.weight": None, + "cc_projection.bias": None, + } + clip_vision_prefix = "cond_stage_model.model.visual." latent_format = latent_formats.SD15 @@ -439,6 +444,33 @@ class Stable_Cascade_B(Stable_Cascade_C): out = model_base.StableCascade_B(self, device=device) return out +class SD15_instructpix2pix(SD15): + unet_config = { + "context_dim": 768, + "model_channels": 320, + "use_linear_in_transformer": False, + "adm_in_channels": None, + "use_temporal_attention": False, + "in_channels": 8, + } + + def get_model(self, state_dict, prefix="", device=None): + return model_base.SD15_instructpix2pix(self, device=device) + +class SDXL_instructpix2pix(SDXL): + unet_config = { + "model_channels": 320, + "use_linear_in_transformer": True, + "transformer_depth": [0, 0, 2, 2, 10, 10], + "context_dim": 2048, + "adm_in_channels": 2816, + "use_temporal_attention": False, + "in_channels": 8, + } + + def get_model(self, state_dict, prefix="", device=None): + return model_base.SDXL_instructpix2pix(self, device=device) + +models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p] -models = [Stable_Zero123, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p] models += [SVD_img2vid] diff --git a/comfy/supported_models_base.py b/comfy/supported_models_base.py index 4d7e2593..6196daab 100644 --- a/comfy/supported_models_base.py +++ b/comfy/supported_models_base.py @@ -16,6 +16,8 @@ class BASE: "num_head_channels": 64, } + required_keys = {} + clip_prefix = [] clip_vision_prefix = None noise_aug_config = None @@ -28,10 +30,14 @@ class BASE: manual_cast_dtype = None @classmethod - def matches(s, unet_config): + def matches(s, unet_config, state_dict=None): for k in s.unet_config: if k not in unet_config or s.unet_config[k] != unet_config[k]: return False + if state_dict is not None: + for k in s.required_keys: + if k not in state_dict: + return False return True def model_type(self, state_dict, prefix=""):