diff --git a/comfy_extras/nodes_stable_cascade.py b/comfy_extras/nodes_stable_cascade.py index efe3586d..b795d008 100644 --- a/comfy_extras/nodes_stable_cascade.py +++ b/comfy_extras/nodes_stable_cascade.py @@ -18,6 +18,7 @@ import torch import nodes +import comfy.utils class StableCascade_EmptyLatentImage: @@ -47,6 +48,39 @@ class StableCascade_EmptyLatentImage: "samples": b_latent, }) +class StableCascade_StageC_VAEEncode: + def __init__(self, device="cpu"): + self.device = device + + @classmethod + def INPUT_TYPES(s): + return {"required": { + "image": ("IMAGE",), + "vae": ("VAE", ), + "compression": ("INT", {"default": 42, "min": 4, "max": 128, "step": 1}), + }} + RETURN_TYPES = ("LATENT", "LATENT") + RETURN_NAMES = ("stage_c", "stage_b") + FUNCTION = "generate" + + CATEGORY = "_for_testing/stable_cascade" + + def generate(self, image, vae, compression): + width = image.shape[-2] + height = image.shape[-3] + out_width = (width // compression) * vae.downscale_ratio + out_height = (height // compression) * vae.downscale_ratio + + s = comfy.utils.common_upscale(image.movedim(-1,1), out_width, out_height, "bicubic", "center").movedim(1,-1) + + c_latent = vae.encode(s[:,:,:,:3]) + b_latent = torch.zeros([c_latent.shape[0], 4, height // 4, width // 4]) + return ({ + "samples": c_latent, + }, { + "samples": b_latent, + }) + class StableCascade_StageB_Conditioning: @classmethod def INPUT_TYPES(s): @@ -71,4 +105,5 @@ class StableCascade_StageB_Conditioning: NODE_CLASS_MAPPINGS = { "StableCascade_EmptyLatentImage": StableCascade_EmptyLatentImage, "StableCascade_StageB_Conditioning": StableCascade_StageB_Conditioning, + "StableCascade_StageC_VAEEncode": StableCascade_StageC_VAEEncode, }