From d6a2137fc30d1a52c8b404b23ece0c5b6ef64867 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Sat, 14 Jun 2025 18:37:07 -0700 Subject: [PATCH] Support Cosmos predict2 image to video models. (#8535) Use the CosmosPredict2ImageToVideoLatent node. --- comfy/model_base.py | 21 ++++++++++++++++ comfy/model_detection.py | 15 ++++++++---- comfy_extras/nodes_cosmos.py | 46 ++++++++++++++++++++++++++++++++++++ 3 files changed, 77 insertions(+), 5 deletions(-) diff --git a/comfy/model_base.py b/comfy/model_base.py index be72ddd1..cb7689e8 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -1014,9 +1014,30 @@ class CosmosPredict2(BaseModel): if cross_attn is not None: out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) + denoise_mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None)) + if denoise_mask is not None: + out["denoise_mask"] = comfy.conds.CONDRegular(denoise_mask) + out['fps'] = comfy.conds.CONDConstant(kwargs.get("frame_rate", None)) return out + def process_timestep(self, timestep, x, denoise_mask=None, **kwargs): + if denoise_mask is None: + return timestep + condition_video_mask_B_1_T_1_1 = denoise_mask.mean(dim=[1, 3, 4], keepdim=True) + c_noise_B_1_T_1_1 = 0.0 * (1.0 - condition_video_mask_B_1_T_1_1) + timestep.reshape(timestep.shape[0], 1, 1, 1, 1) * condition_video_mask_B_1_T_1_1 + out = c_noise_B_1_T_1_1.squeeze(dim=[1, 3, 4]) + return out + + def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs): + sigma = sigma.reshape([sigma.shape[0]] + [1] * (len(noise.shape) - 1)) + sigma_noise_augmentation = 0 #TODO + if sigma_noise_augmentation != 0: + latent_image = latent_image + noise + latent_image = self.model_sampling.calculate_input(torch.tensor([sigma_noise_augmentation], device=latent_image.device, dtype=latent_image.dtype), latent_image) + sigma = (sigma / (sigma + 1)) + return latent_image / (1.0 - sigma) + class Lumina2(BaseModel): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.lumina.model.NextDiT) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index b8fef17a..4aa90d3b 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -441,11 +441,16 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["rope_h_extrapolation_ratio"] = 4.0 dit_config["rope_w_extrapolation_ratio"] = 4.0 dit_config["rope_t_extrapolation_ratio"] = 1.0 - elif dit_config["in_channels"] == 17: - dit_config["extra_per_block_abs_pos_emb"] = False - dit_config["rope_h_extrapolation_ratio"] = 3.0 - dit_config["rope_w_extrapolation_ratio"] = 3.0 - dit_config["rope_t_extrapolation_ratio"] = 1.0 + elif dit_config["in_channels"] == 17: # img to video + if dit_config["model_channels"] == 2048: + dit_config["extra_per_block_abs_pos_emb"] = False + dit_config["rope_h_extrapolation_ratio"] = 3.0 + dit_config["rope_w_extrapolation_ratio"] = 3.0 + dit_config["rope_t_extrapolation_ratio"] = 1.0 + elif dit_config["model_channels"] == 5120: + dit_config["rope_h_extrapolation_ratio"] = 2.0 + dit_config["rope_w_extrapolation_ratio"] = 2.0 + dit_config["rope_t_extrapolation_ratio"] = 0.8333333333333334 dit_config["extra_h_extrapolation_ratio"] = 1.0 dit_config["extra_w_extrapolation_ratio"] = 1.0 diff --git a/comfy_extras/nodes_cosmos.py b/comfy_extras/nodes_cosmos.py index bd35ddb0..4f496055 100644 --- a/comfy_extras/nodes_cosmos.py +++ b/comfy_extras/nodes_cosmos.py @@ -2,6 +2,7 @@ import nodes import torch import comfy.model_management import comfy.utils +import comfy.latent_formats class EmptyCosmosLatentVideo: @@ -75,8 +76,53 @@ class CosmosImageToVideoLatent: out_latent["noise_mask"] = mask.repeat((batch_size, ) + (1,) * (mask.ndim - 1)) return (out_latent,) +class CosmosPredict2ImageToVideoLatent: + @classmethod + def INPUT_TYPES(s): + return {"required": {"vae": ("VAE", ), + "width": ("INT", {"default": 848, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), + "height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), + "length": ("INT", {"default": 93, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}), + "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), + }, + "optional": {"start_image": ("IMAGE", ), + "end_image": ("IMAGE", ), + }} + + + RETURN_TYPES = ("LATENT",) + FUNCTION = "encode" + + CATEGORY = "conditioning/inpaint" + + def encode(self, vae, width, height, length, batch_size, start_image=None, end_image=None): + latent = torch.zeros([1, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) + if start_image is None and end_image is None: + out_latent = {} + out_latent["samples"] = latent + return (out_latent,) + + mask = torch.ones([latent.shape[0], 1, ((length - 1) // 4) + 1, latent.shape[-2], latent.shape[-1]], device=comfy.model_management.intermediate_device()) + + if start_image is not None: + latent_temp = vae_encode_with_padding(vae, start_image, width, height, length, padding=1) + latent[:, :, :latent_temp.shape[-3]] = latent_temp + mask[:, :, :latent_temp.shape[-3]] *= 0.0 + + if end_image is not None: + latent_temp = vae_encode_with_padding(vae, end_image, width, height, length, padding=0) + latent[:, :, -latent_temp.shape[-3]:] = latent_temp + mask[:, :, -latent_temp.shape[-3]:] *= 0.0 + + out_latent = {} + latent_format = comfy.latent_formats.Wan21() + latent = latent_format.process_out(latent) * mask + latent * (1.0 - mask) + out_latent["samples"] = latent.repeat((batch_size, ) + (1,) * (latent.ndim - 1)) + out_latent["noise_mask"] = mask.repeat((batch_size, ) + (1,) * (mask.ndim - 1)) + return (out_latent,) NODE_CLASS_MAPPINGS = { "EmptyCosmosLatentVideo": EmptyCosmosLatentVideo, "CosmosImageToVideoLatent": CosmosImageToVideoLatent, + "CosmosPredict2ImageToVideoLatent": CosmosPredict2ImageToVideoLatent, }