From a7874d1a8b88f9e5cc3d37fdba9b763004b6357d Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 3 Jan 2024 03:30:39 -0500 Subject: [PATCH] Add support for the stable diffusion x4 upscaling model. This is an old model. Load the checkpoint like a regular one and use the new SD_4XUpscale_Conditioning node. --- comfy/latent_formats.py | 4 +++ comfy/model_base.py | 21 +++++++++++++++ comfy/sd.py | 5 ++++ comfy/supported_models.py | 28 +++++++++++++++++++- comfy_extras/nodes_sdupscale.py | 45 +++++++++++++++++++++++++++++++++ nodes.py | 1 + 6 files changed, 103 insertions(+), 1 deletion(-) create mode 100644 comfy_extras/nodes_sdupscale.py diff --git a/comfy/latent_formats.py b/comfy/latent_formats.py index c209087e..2252a075 100644 --- a/comfy/latent_formats.py +++ b/comfy/latent_formats.py @@ -33,3 +33,7 @@ class SDXL(LatentFormat): [-0.3112, -0.2359, -0.2076] ] self.taesd_decoder_name = "taesdxl_decoder" + +class SD_X4(LatentFormat): + def __init__(self): + self.scale_factor = 0.08333 diff --git a/comfy/model_base.py b/comfy/model_base.py index b3a1fcd5..64a380ff 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -364,3 +364,24 @@ class Stable_Zero123(BaseModel): cross_attn = self.cc_projection(cross_attn) out['c_crossattn'] = comfy.conds.CONDCrossAttn(cross_attn) return out + +class SD_X4Upscaler(BaseModel): + def __init__(self, model_config, model_type=ModelType.V_PREDICTION, device=None): + super().__init__(model_config, model_type, device=device) + + def extra_conds(self, **kwargs): + out = {} + + image = kwargs.get("concat_image", None) + noise = kwargs.get("noise", None) + + if image is None: + image = torch.zeros_like(noise)[:,:3] + + if image.shape[1:] != noise.shape[1:]: + image = utils.common_upscale(image, noise.shape[-1], noise.shape[-2], "bilinear", "center") + + image = utils.resize_to_batch_size(image, noise.shape[0]) + + out['c_concat'] = comfy.conds.CONDNoiseShape(image) + return out diff --git a/comfy/sd.py b/comfy/sd.py index 10a6715a..1ff25bec 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -174,6 +174,11 @@ class VAE: else: #default SD1.x/SD2.x VAE parameters ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0} + + if 'encoder.down.2.downsample.conv.weight' not in sd: #Stable diffusion x4 upscaler VAE + ddconfig['ch_mult'] = [1, 2, 4] + self.downscale_ratio = 4 + self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=4) else: self.first_stage_model = AutoencoderKL(**(config['params'])) diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 251bf6ac..e7a6cc17 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -278,6 +278,32 @@ class Stable_Zero123(supported_models_base.BASE): def clip_target(self): return None +class SD_X4Upscaler(SD20): + unet_config = { + "context_dim": 1024, + "model_channels": 256, + 'in_channels': 7, + "use_linear_in_transformer": True, + "adm_in_channels": None, + "use_temporal_attention": False, + } -models = [Stable_Zero123, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXLRefiner, SDXL, SSD1B, Segmind_Vega] + unet_extra_config = { + "disable_self_attentions": [True, True, True, False], + "num_heads": 8, + "num_head_channels": -1, + } + + latent_format = latent_formats.SD_X4 + + sampling_settings = { + "linear_start": 0.0001, + "linear_end": 0.02, + } + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.SD_X4Upscaler(self, device=device) + return out + +models = [Stable_Zero123, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXLRefiner, SDXL, SSD1B, Segmind_Vega, SD_X4Upscaler] models += [SVD_img2vid] diff --git a/comfy_extras/nodes_sdupscale.py b/comfy_extras/nodes_sdupscale.py new file mode 100644 index 00000000..38a027e0 --- /dev/null +++ b/comfy_extras/nodes_sdupscale.py @@ -0,0 +1,45 @@ +import torch +import nodes +import comfy.utils + +class SD_4XUpscale_Conditioning: + @classmethod + def INPUT_TYPES(s): + return {"required": { "images": ("IMAGE",), + "positive": ("CONDITIONING",), + "negative": ("CONDITIONING",), + "scale_ratio": ("FLOAT", {"default": 4.0, "min": 0.0, "max": 10.0, "step": 0.01}), + # "noise_augmentation": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 10.0, "step": 0.01}), #TODO + }} + RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") + RETURN_NAMES = ("positive", "negative", "latent") + + FUNCTION = "encode" + + CATEGORY = "conditioning/upscale_diffusion" + + def encode(self, images, positive, negative, scale_ratio): + width = max(1, round(images.shape[-2] * scale_ratio)) + height = max(1, round(images.shape[-3] * scale_ratio)) + + pixels = comfy.utils.common_upscale((images.movedim(-1,1) * 2.0) - 1.0, width // 4, height // 4, "bilinear", "center") + + out_cp = [] + out_cn = [] + + for t in positive: + n = [t[0], t[1].copy()] + n[1]['concat_image'] = pixels + out_cp.append(n) + + for t in negative: + n = [t[0], t[1].copy()] + n[1]['concat_image'] = pixels + out_cn.append(n) + + latent = torch.zeros([images.shape[0], 4, height // 4, width // 4]) + return (out_cp, out_cn, {"samples":latent}) + +NODE_CLASS_MAPPINGS = { + "SD_4XUpscale_Conditioning": SD_4XUpscale_Conditioning, +} diff --git a/nodes.py b/nodes.py index 8e3ec947..82244cf7 100644 --- a/nodes.py +++ b/nodes.py @@ -1880,6 +1880,7 @@ def init_custom_nodes(): "nodes_sag.py", "nodes_perpneg.py", "nodes_stable3d.py", + "nodes_sdupscale.py", ] for node_file in extras_files: