diff --git a/comfy/ldm/wan/model.py b/comfy/ldm/wan/model.py index 4f2315ac..6533039f 100644 --- a/comfy/ldm/wan/model.py +++ b/comfy/ldm/wan/model.py @@ -10,6 +10,7 @@ from comfy.ldm.modules.attention import optimized_attention from comfy.ldm.flux.layers import EmbedND from comfy.ldm.flux.math import apply_rope import comfy.ldm.common_dit +import comfy.model_management def sinusoidal_embedding_1d(dim, position): # preprocess @@ -37,7 +38,7 @@ class WanRMSNorm(nn.Module): Args: x(Tensor): Shape [B, L, C] """ - return self._norm(x.float()).type_as(x) * self.weight + return self._norm(x.float()).type_as(x) * comfy.model_management.cast_to(self.weight, dtype=x.dtype, device=x.device) def _norm(self, x): return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps) @@ -125,7 +126,7 @@ class WanI2VCrossAttention(WanSelfAttention): window_size=(-1, -1), qk_norm=True, eps=1e-6, operation_settings={}): - super().__init__(dim, num_heads, window_size, qk_norm, eps) + super().__init__(dim, num_heads, window_size, qk_norm, eps, operation_settings=operation_settings) self.k_img = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) self.v_img = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) @@ -218,7 +219,7 @@ class WanAttentionBlock(nn.Module): """ # assert e.dtype == torch.float32 - e = (self.modulation + e).chunk(6, dim=1) + e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device) + e).chunk(6, dim=1) # assert e[0].dtype == torch.float32 # self-attention @@ -263,7 +264,7 @@ class Head(nn.Module): e(Tensor): Shape [B, C] """ # assert e.dtype == torch.float32 - e = (self.modulation + e.unsqueeze(1)).chunk(2, dim=1) + e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device) + e.unsqueeze(1)).chunk(2, dim=1) x = (self.head(self.norm(x) * (1 + e[1]) + e[0])) return x @@ -401,7 +402,6 @@ class WanModel(torch.nn.Module): t, context, clip_fea=None, - y=None, freqs=None, ): r""" @@ -425,12 +425,6 @@ class WanModel(torch.nn.Module): List[Tensor]: List of denoised video tensors with original input shapes [C_out, F, H / 8, W / 8] """ - if self.model_type == 'i2v': - assert clip_fea is not None and y is not None - - if y is not None: - x = torch.cat([x, y], dim=0) - # embeddings x = self.patch_embedding(x) grid_sizes = x.shape[2:] @@ -465,7 +459,7 @@ class WanModel(torch.nn.Module): return x # return [u.float() for u in x] - def forward(self, x, timestep, context, y=None, image=None, **kwargs): + def forward(self, x, timestep, context, clip_fea=None, **kwargs): bs, c, t, h, w = x.shape x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size) patch_size = self.patch_size @@ -479,7 +473,7 @@ class WanModel(torch.nn.Module): img_ids = repeat(img_ids, "t h w c -> b (t h w) c", b=bs) freqs = self.rope_embedder(img_ids).movedim(1, 2) - return self.forward_orig(x, timestep, context, clip_fea=y, y=image, freqs=freqs)[:, :, :t, :h, :w] + return self.forward_orig(x, timestep, context, clip_fea=clip_fea, freqs=freqs)[:, :, :t, :h, :w] def unpatchify(self, x, grid_sizes): r""" diff --git a/comfy/model_base.py b/comfy/model_base.py index 76fd52c7..1885049e 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -929,13 +929,45 @@ class Lumina2(BaseModel): out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) return out -class WAN21_T2V(BaseModel): - def __init__(self, model_config, model_type=ModelType.FLOW, device=None): +class WAN21(BaseModel): + def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None): super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel) + self.image_to_video = image_to_video + + def concat_cond(self, **kwargs): + if not self.image_to_video: + return None + + image = kwargs.get("concat_latent_image", None) + noise = kwargs.get("noise", None) + device = kwargs["device"] + + if image is None: + image = torch.zeros_like(noise) + + image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center") + image = self.process_latent_in(image) + image = utils.resize_to_batch_size(image, noise.shape[0]) + + mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None)) + if mask is None: + mask = torch.zeros_like(noise)[:, :4] + else: + mask = 1.0 - torch.mean(mask, dim=1, keepdim=True) + mask = utils.common_upscale(mask.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center") + if mask.shape[-3] < noise.shape[-3]: + mask = torch.nn.functional.pad(mask, (0, 0, 0, 0, 0, noise.shape[-3] - mask.shape[-3]), mode='constant', value=0) + mask = mask.view(mask.shape[0], -1, 4, mask.shape[-2], mask.shape[-1]).transpose(1, 2) + mask = utils.resize_to_batch_size(mask, noise.shape[0]) + return torch.cat((mask, image), dim=1) def extra_conds(self, **kwargs): out = super().extra_conds(**kwargs) cross_attn = kwargs.get("cross_attn", None) if cross_attn is not None: out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) + + clip_vision_output = kwargs.get("clip_vision_output", None) + if clip_vision_output is not None: + out['clip_fea'] = comfy.conds.CONDRegular(clip_vision_output.penultimate_hidden_states) return out diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 3dbd9dad..f149a4bf 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -313,6 +313,7 @@ def detect_unet_config(state_dict, key_prefix): dit_config["qk_norm"] = True dit_config["cross_attn_norm"] = True dit_config["eps"] = 1e-6 + dit_config["in_dim"] = state_dict['{}patch_embedding.weight'.format(key_prefix)].shape[1] if '{}img_emb.proj.0.bias'.format(key_prefix) in state_dict_keys: dit_config["model_type"] = "i2v" else: diff --git a/comfy/supported_models.py b/comfy/supported_models.py index d39a03f2..e28bd138 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -917,7 +917,7 @@ class WAN21_T2V(supported_models_base.BASE): text_encoder_key_prefix = ["text_encoders."] def get_model(self, state_dict, prefix="", device=None): - out = model_base.WAN21_T2V(self, device=device) + out = model_base.WAN21(self, device=device) return out def clip_target(self, state_dict={}): @@ -925,6 +925,16 @@ class WAN21_T2V(supported_models_base.BASE): t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}umt5xxl.transformer.".format(pref)) return supported_models_base.ClipTarget(comfy.text_encoders.wan.WanT5Tokenizer, comfy.text_encoders.wan.te(**t5_detect)) -models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideo, CosmosT2V, CosmosI2V, Lumina2, WAN21_T2V] +class WAN21_I2V(WAN21_T2V): + unet_config = { + "image_model": "wan2.1", + "model_type": "i2v", + } + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.WAN21(self, image_to_video=True, device=device) + return out + +models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideo, CosmosT2V, CosmosI2V, Lumina2, WAN21_T2V, WAN21_I2V] models += [SVD_img2vid] diff --git a/comfy_extras/nodes_wan.py b/comfy_extras/nodes_wan.py new file mode 100644 index 00000000..52851d1f --- /dev/null +++ b/comfy_extras/nodes_wan.py @@ -0,0 +1,61 @@ +import nodes +import node_helpers +import torch +import comfy.model_management +import comfy.utils + + +def masked_images(num_images): + rem = 4 - (num_images % 4) + if rem == 4: + return num_images + return rem + num_images + + +class WanImageToVideo: + @classmethod + def INPUT_TYPES(s): + return {"required": {"positive": ("CONDITIONING", ), + "negative": ("CONDITIONING", ), + "vae": ("VAE", ), + "width": ("INT", {"default": 1280, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), + "height": ("INT", {"default": 720, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), + "length": ("INT", {"default": 121, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}), + "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), + }, + "optional": {"clip_vision_output": ("CLIP_VISION_OUTPUT", ), + "start_image": ("IMAGE", ), + }} + + RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") + RETURN_NAMES = ("positive", "negative", "latent") + FUNCTION = "encode" + + CATEGORY = "conditioning/video_models" + + def encode(self, positive, negative, vae, width, height, length, batch_size, start_image=None, clip_vision_output=None): + latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) + if start_image is not None: + start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) + image = torch.ones((length, height, width, start_image.shape[-1]), device=start_image.device, dtype=start_image.dtype) * 0.5 + image[:start_image.shape[0]] = start_image + + concat_latent_image = vae.encode(image[:, :, :, :3]) + mask = torch.ones((1, 1, latent.shape[2] * 4, concat_latent_image.shape[-2], concat_latent_image.shape[-1]), device=start_image.device, dtype=start_image.dtype) + mask[:, :, :masked_images(start_image.shape[0])] = 0.0 + + positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent_image, "concat_mask": mask}) + negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent_image, "concat_mask": mask}) + + if clip_vision_output is not None: + positive = node_helpers.conditioning_set_values(positive, {"clip_vision_output": clip_vision_output}) + negative = node_helpers.conditioning_set_values(negative, {"clip_vision_output": clip_vision_output}) + + out_latent = {} + out_latent["samples"] = latent + return (positive, negative, out_latent) + + +NODE_CLASS_MAPPINGS = { + "WanImageToVideo": WanImageToVideo, +} diff --git a/nodes.py b/nodes.py index a9325410..f7f6cb15 100644 --- a/nodes.py +++ b/nodes.py @@ -2269,6 +2269,7 @@ def init_builtin_extra_nodes(): "nodes_cosmos.py", "nodes_video.py", "nodes_lumina2.py", + "nodes_wan.py", ] import_failed = []