mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-04-18 18:33:30 +00:00
Implement Cosmos Image/Video to World (Video) diffusion models.
Use CosmosImageToVideoLatent to set the input image/video.
This commit is contained in:
parent
1f1c7b7b56
commit
3aaabb12d4
@ -189,9 +189,10 @@ class BaseModel(torch.nn.Module):
|
||||
|
||||
if denoise_mask is not None:
|
||||
if len(denoise_mask.shape) == len(noise.shape):
|
||||
denoise_mask = denoise_mask[:,:1]
|
||||
denoise_mask = denoise_mask[:, :1]
|
||||
|
||||
denoise_mask = denoise_mask.reshape((-1, 1, denoise_mask.shape[-2], denoise_mask.shape[-1]))
|
||||
num_dim = noise.ndim - 2
|
||||
denoise_mask = denoise_mask.reshape((-1, 1) + tuple(denoise_mask.shape[-num_dim:]))
|
||||
if denoise_mask.shape[-2:] != noise.shape[-2:]:
|
||||
denoise_mask = utils.common_upscale(denoise_mask, noise.shape[-1], noise.shape[-2], "bilinear", "center")
|
||||
denoise_mask = utils.resize_to_batch_size(denoise_mask.round(), noise.shape[0])
|
||||
@ -201,12 +202,16 @@ class BaseModel(torch.nn.Module):
|
||||
if ck == "mask":
|
||||
cond_concat.append(denoise_mask.to(device))
|
||||
elif ck == "masked_image":
|
||||
cond_concat.append(concat_latent_image.to(device)) #NOTE: the latent_image should be masked by the mask in pixel space
|
||||
cond_concat.append(concat_latent_image.to(device)) # NOTE: the latent_image should be masked by the mask in pixel space
|
||||
elif ck == "mask_inverted":
|
||||
cond_concat.append(1.0 - denoise_mask.to(device))
|
||||
else:
|
||||
if ck == "mask":
|
||||
cond_concat.append(torch.ones_like(noise)[:,:1])
|
||||
cond_concat.append(torch.ones_like(noise)[:, :1])
|
||||
elif ck == "masked_image":
|
||||
cond_concat.append(self.blank_inpaint_image_like(noise))
|
||||
elif ck == "mask_inverted":
|
||||
cond_concat.append(torch.zeros_like(noise)[:, :1])
|
||||
data = torch.cat(cond_concat, dim=1)
|
||||
return data
|
||||
return None
|
||||
@ -294,6 +299,9 @@ class BaseModel(torch.nn.Module):
|
||||
return blank_image
|
||||
self.blank_inpaint_image_like = blank_inpaint_image_like
|
||||
|
||||
def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs):
|
||||
return self.model_sampling.noise_scaling(sigma.reshape([sigma.shape[0]] + [1] * (len(noise.shape) - 1)), noise, latent_image)
|
||||
|
||||
def memory_required(self, input_shape):
|
||||
if comfy.model_management.xformers_enabled() or comfy.model_management.pytorch_attention_flash_attention():
|
||||
dtype = self.get_dtype()
|
||||
@ -859,8 +867,11 @@ class HunyuanVideo(BaseModel):
|
||||
return out
|
||||
|
||||
class CosmosVideo(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.EDM, device=None):
|
||||
def __init__(self, model_config, model_type=ModelType.EDM, image_to_video=False, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.cosmos.model.GeneralDIT)
|
||||
self.image_to_video = image_to_video
|
||||
if self.image_to_video:
|
||||
self.concat_keys = ("mask_inverted",)
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = super().extra_conds(**kwargs)
|
||||
@ -873,3 +884,11 @@ class CosmosVideo(BaseModel):
|
||||
|
||||
out['fps'] = comfy.conds.CONDConstant(kwargs.get("frame_rate", None))
|
||||
return out
|
||||
|
||||
def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs):
|
||||
sigma = sigma.reshape([sigma.shape[0]] + [1] * (len(noise.shape) - 1))
|
||||
sigma_noise_augmentation = 0 #TODO
|
||||
if sigma_noise_augmentation != 0:
|
||||
latent_image = latent_image + noise
|
||||
latent_image = self.model_sampling.calculate_input(torch.tensor([sigma_noise_augmentation], device=latent_image.device, dtype=latent_image.dtype), latent_image)
|
||||
return latent_image * ((sigma ** 2 + self.model_sampling.sigma_data ** 2) ** 0.5)
|
||||
|
@ -245,13 +245,14 @@ def detect_unet_config(state_dict, key_prefix):
|
||||
dit_config["max_img_h"] = 240
|
||||
dit_config["max_img_w"] = 240
|
||||
dit_config["max_frames"] = 128
|
||||
dit_config["in_channels"] = 16
|
||||
concat_padding_mask = True
|
||||
dit_config["in_channels"] = (state_dict['{}x_embedder.proj.1.weight'.format(key_prefix)].shape[1] // 4) - int(concat_padding_mask)
|
||||
dit_config["out_channels"] = 16
|
||||
dit_config["patch_spatial"] = 2
|
||||
dit_config["patch_temporal"] = 1
|
||||
dit_config["model_channels"] = state_dict['{}blocks.block0.blocks.0.block.attn.to_q.0.weight'.format(key_prefix)].shape[0]
|
||||
dit_config["block_config"] = "FA-CA-MLP"
|
||||
dit_config["concat_padding_mask"] = True
|
||||
dit_config["concat_padding_mask"] = concat_padding_mask
|
||||
dit_config["pos_emb_cls"] = "rope3d"
|
||||
dit_config["pos_emb_learnable"] = False
|
||||
dit_config["pos_emb_interpolation"] = "crop"
|
||||
|
@ -376,7 +376,7 @@ class KSamplerX0Inpaint:
|
||||
if "denoise_mask_function" in model_options:
|
||||
denoise_mask = model_options["denoise_mask_function"](sigma, denoise_mask, extra_options={"model": self.inner_model, "sigmas": self.sigmas})
|
||||
latent_mask = 1. - denoise_mask
|
||||
x = x * denoise_mask + self.inner_model.inner_model.model_sampling.noise_scaling(sigma.reshape([sigma.shape[0]] + [1] * (len(self.noise.shape) - 1)), self.noise, self.latent_image) * latent_mask
|
||||
x = x * denoise_mask + self.inner_model.inner_model.scale_latent_inpaint(x=x, sigma=sigma, noise=self.noise, latent_image=self.latent_image) * latent_mask
|
||||
out = self.inner_model(x, sigma, model_options=model_options, seed=seed)
|
||||
if denoise_mask is not None:
|
||||
out = out * denoise_mask + self.latent_image * latent_mask
|
||||
|
@ -534,7 +534,7 @@ class VAE:
|
||||
def encode(self, pixel_samples):
|
||||
pixel_samples = self.vae_encode_crop_pixels(pixel_samples)
|
||||
pixel_samples = pixel_samples.movedim(-1, 1)
|
||||
if self.latent_dim == 3:
|
||||
if self.latent_dim == 3 and pixel_samples.ndim < 5:
|
||||
pixel_samples = pixel_samples.movedim(1, 0).unsqueeze(0)
|
||||
try:
|
||||
memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype)
|
||||
|
@ -824,9 +824,10 @@ class HunyuanVideo(supported_models_base.BASE):
|
||||
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}llama.transformer.".format(pref))
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.hunyuan_video.HunyuanVideoTokenizer, comfy.text_encoders.hunyuan_video.hunyuan_video_clip(**hunyuan_detect))
|
||||
|
||||
class Cosmos(supported_models_base.BASE):
|
||||
class CosmosT2V(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
"image_model": "cosmos",
|
||||
"in_channels": 16,
|
||||
}
|
||||
|
||||
sampling_settings = {
|
||||
@ -854,7 +855,16 @@ class Cosmos(supported_models_base.BASE):
|
||||
t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.cosmos.CosmosT5Tokenizer, comfy.text_encoders.cosmos.te(**t5_detect))
|
||||
|
||||
class CosmosI2V(CosmosT2V):
|
||||
unet_config = {
|
||||
"image_model": "cosmos",
|
||||
"in_channels": 17,
|
||||
}
|
||||
|
||||
models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideo, Cosmos]
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.CosmosVideo(self, image_to_video=True, device=device)
|
||||
return out
|
||||
|
||||
models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideo, CosmosT2V, CosmosI2V]
|
||||
|
||||
models += [SVD_img2vid]
|
||||
|
@ -1,6 +1,8 @@
|
||||
import nodes
|
||||
import torch
|
||||
import comfy.model_management
|
||||
import comfy.utils
|
||||
|
||||
|
||||
class EmptyCosmosLatentVideo:
|
||||
@classmethod
|
||||
@ -16,8 +18,48 @@ class EmptyCosmosLatentVideo:
|
||||
|
||||
def generate(self, width, height, length, batch_size=1):
|
||||
latent = torch.zeros([batch_size, 16, ((length - 1) // 8) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
|
||||
return ({"samples":latent}, )
|
||||
return ({"samples": latent}, )
|
||||
|
||||
|
||||
class CosmosImageToVideoLatent:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"vae": ("VAE", ),
|
||||
"image": ("IMAGE", ),
|
||||
"width": ("INT", {"default": 1280, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
|
||||
"height": ("INT", {"default": 704, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
|
||||
"length": ("INT", {"default": 121, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 8}),
|
||||
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
|
||||
}}
|
||||
|
||||
RETURN_TYPES = ("LATENT",)
|
||||
FUNCTION = "encode"
|
||||
|
||||
CATEGORY = "conditioning/inpaint"
|
||||
|
||||
def encode(self, vae, image, width, height, length, batch_size):
|
||||
pixels = comfy.utils.common_upscale(image[..., :3].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
||||
pixel_len = min(pixels.shape[0], length)
|
||||
padded_length = min(length, (((pixel_len - 1) // 8) + 2) * 8 - 7)
|
||||
padded_pixels = torch.ones((padded_length, height, width, 3)) * 0.5
|
||||
padded_pixels[:pixel_len] = pixels[:pixel_len]
|
||||
|
||||
latent_temp = vae.encode(padded_pixels)
|
||||
|
||||
latent = torch.zeros([1, latent_temp.shape[1], ((length - 1) // 8) + 1, latent_temp.shape[-2], latent_temp.shape[-1]], device=comfy.model_management.intermediate_device())
|
||||
latent_len = ((pixel_len - 1) // 8) + 1
|
||||
latent[:, :, :latent_len] = latent_temp[:, :, :latent_len]
|
||||
|
||||
mask = torch.ones([latent.shape[0], 1, ((length - 1) // 8) + 1, latent.shape[-2], latent.shape[-1]], device=comfy.model_management.intermediate_device())
|
||||
mask[:, :, :latent_len] *= 0.0
|
||||
|
||||
out_latent = {}
|
||||
out_latent["samples"] = latent
|
||||
out_latent["noise_mask"] = mask
|
||||
return (out_latent,)
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"EmptyCosmosLatentVideo": EmptyCosmosLatentVideo,
|
||||
"CosmosImageToVideoLatent": CosmosImageToVideoLatent,
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user