From 3f75419e2e656575309f391769c66813a9ec078d Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 7 Mar 2024 01:48:31 -0500 Subject: [PATCH] Add a node to use the super resolution controlnet. --- comfy_extras/nodes_stable_cascade.py | 31 ++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/comfy_extras/nodes_stable_cascade.py b/comfy_extras/nodes_stable_cascade.py index b795d008..27596fc6 100644 --- a/comfy_extras/nodes_stable_cascade.py +++ b/comfy_extras/nodes_stable_cascade.py @@ -102,8 +102,39 @@ class StableCascade_StageB_Conditioning: c.append(n) return (c, ) +class StableCascade_SuperResolutionControlnet: + def __init__(self, device="cpu"): + self.device = device + + @classmethod + def INPUT_TYPES(s): + return {"required": { + "image": ("IMAGE",), + "vae": ("VAE", ), + }} + RETURN_TYPES = ("IMAGE", "LATENT", "LATENT") + RETURN_NAMES = ("controlnet_input", "stage_c", "stage_b") + FUNCTION = "generate" + + CATEGORY = "_for_testing/stable_cascade" + + def generate(self, image, vae): + width = image.shape[-2] + height = image.shape[-3] + batch_size = image.shape[0] + controlnet_input = vae.encode(image[:,:,:,:3]).movedim(1, -1) + + c_latent = torch.zeros([batch_size, 16, height // 16, width // 16]) + b_latent = torch.zeros([batch_size, 4, height // 2, width // 2]) + return (controlnet_input, { + "samples": c_latent, + }, { + "samples": b_latent, + }) + NODE_CLASS_MAPPINGS = { "StableCascade_EmptyLatentImage": StableCascade_EmptyLatentImage, "StableCascade_StageB_Conditioning": StableCascade_StageB_Conditioning, "StableCascade_StageC_VAEEncode": StableCascade_StageC_VAEEncode, + "StableCascade_SuperResolutionControlnet": StableCascade_SuperResolutionControlnet, }