mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-02-28 14:40:27 +00:00
WIP support for Wan I2V model.
This commit is contained in:
parent
cb06e9669b
commit
4ced06b879
@ -10,6 +10,7 @@ from comfy.ldm.modules.attention import optimized_attention
|
|||||||
from comfy.ldm.flux.layers import EmbedND
|
from comfy.ldm.flux.layers import EmbedND
|
||||||
from comfy.ldm.flux.math import apply_rope
|
from comfy.ldm.flux.math import apply_rope
|
||||||
import comfy.ldm.common_dit
|
import comfy.ldm.common_dit
|
||||||
|
import comfy.model_management
|
||||||
|
|
||||||
def sinusoidal_embedding_1d(dim, position):
|
def sinusoidal_embedding_1d(dim, position):
|
||||||
# preprocess
|
# preprocess
|
||||||
@ -37,7 +38,7 @@ class WanRMSNorm(nn.Module):
|
|||||||
Args:
|
Args:
|
||||||
x(Tensor): Shape [B, L, C]
|
x(Tensor): Shape [B, L, C]
|
||||||
"""
|
"""
|
||||||
return self._norm(x.float()).type_as(x) * self.weight
|
return self._norm(x.float()).type_as(x) * comfy.model_management.cast_to(self.weight, dtype=x.dtype, device=x.device)
|
||||||
|
|
||||||
def _norm(self, x):
|
def _norm(self, x):
|
||||||
return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
|
return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
|
||||||
@ -125,7 +126,7 @@ class WanI2VCrossAttention(WanSelfAttention):
|
|||||||
window_size=(-1, -1),
|
window_size=(-1, -1),
|
||||||
qk_norm=True,
|
qk_norm=True,
|
||||||
eps=1e-6, operation_settings={}):
|
eps=1e-6, operation_settings={}):
|
||||||
super().__init__(dim, num_heads, window_size, qk_norm, eps)
|
super().__init__(dim, num_heads, window_size, qk_norm, eps, operation_settings=operation_settings)
|
||||||
|
|
||||||
self.k_img = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
self.k_img = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||||
self.v_img = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
self.v_img = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||||
@ -218,7 +219,7 @@ class WanAttentionBlock(nn.Module):
|
|||||||
"""
|
"""
|
||||||
# assert e.dtype == torch.float32
|
# assert e.dtype == torch.float32
|
||||||
|
|
||||||
e = (self.modulation + e).chunk(6, dim=1)
|
e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device) + e).chunk(6, dim=1)
|
||||||
# assert e[0].dtype == torch.float32
|
# assert e[0].dtype == torch.float32
|
||||||
|
|
||||||
# self-attention
|
# self-attention
|
||||||
@ -263,7 +264,7 @@ class Head(nn.Module):
|
|||||||
e(Tensor): Shape [B, C]
|
e(Tensor): Shape [B, C]
|
||||||
"""
|
"""
|
||||||
# assert e.dtype == torch.float32
|
# assert e.dtype == torch.float32
|
||||||
e = (self.modulation + e.unsqueeze(1)).chunk(2, dim=1)
|
e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device) + e.unsqueeze(1)).chunk(2, dim=1)
|
||||||
x = (self.head(self.norm(x) * (1 + e[1]) + e[0]))
|
x = (self.head(self.norm(x) * (1 + e[1]) + e[0]))
|
||||||
return x
|
return x
|
||||||
|
|
||||||
@ -401,7 +402,6 @@ class WanModel(torch.nn.Module):
|
|||||||
t,
|
t,
|
||||||
context,
|
context,
|
||||||
clip_fea=None,
|
clip_fea=None,
|
||||||
y=None,
|
|
||||||
freqs=None,
|
freqs=None,
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
@ -425,12 +425,6 @@ class WanModel(torch.nn.Module):
|
|||||||
List[Tensor]:
|
List[Tensor]:
|
||||||
List of denoised video tensors with original input shapes [C_out, F, H / 8, W / 8]
|
List of denoised video tensors with original input shapes [C_out, F, H / 8, W / 8]
|
||||||
"""
|
"""
|
||||||
if self.model_type == 'i2v':
|
|
||||||
assert clip_fea is not None and y is not None
|
|
||||||
|
|
||||||
if y is not None:
|
|
||||||
x = torch.cat([x, y], dim=0)
|
|
||||||
|
|
||||||
# embeddings
|
# embeddings
|
||||||
x = self.patch_embedding(x)
|
x = self.patch_embedding(x)
|
||||||
grid_sizes = x.shape[2:]
|
grid_sizes = x.shape[2:]
|
||||||
@ -465,7 +459,7 @@ class WanModel(torch.nn.Module):
|
|||||||
return x
|
return x
|
||||||
# return [u.float() for u in x]
|
# return [u.float() for u in x]
|
||||||
|
|
||||||
def forward(self, x, timestep, context, y=None, image=None, **kwargs):
|
def forward(self, x, timestep, context, clip_fea=None, **kwargs):
|
||||||
bs, c, t, h, w = x.shape
|
bs, c, t, h, w = x.shape
|
||||||
x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size)
|
x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size)
|
||||||
patch_size = self.patch_size
|
patch_size = self.patch_size
|
||||||
@ -479,7 +473,7 @@ class WanModel(torch.nn.Module):
|
|||||||
img_ids = repeat(img_ids, "t h w c -> b (t h w) c", b=bs)
|
img_ids = repeat(img_ids, "t h w c -> b (t h w) c", b=bs)
|
||||||
|
|
||||||
freqs = self.rope_embedder(img_ids).movedim(1, 2)
|
freqs = self.rope_embedder(img_ids).movedim(1, 2)
|
||||||
return self.forward_orig(x, timestep, context, clip_fea=y, y=image, freqs=freqs)[:, :, :t, :h, :w]
|
return self.forward_orig(x, timestep, context, clip_fea=clip_fea, freqs=freqs)[:, :, :t, :h, :w]
|
||||||
|
|
||||||
def unpatchify(self, x, grid_sizes):
|
def unpatchify(self, x, grid_sizes):
|
||||||
r"""
|
r"""
|
||||||
|
@ -929,13 +929,45 @@ class Lumina2(BaseModel):
|
|||||||
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
class WAN21_T2V(BaseModel):
|
class WAN21(BaseModel):
|
||||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
|
||||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel)
|
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel)
|
||||||
|
self.image_to_video = image_to_video
|
||||||
|
|
||||||
|
def concat_cond(self, **kwargs):
|
||||||
|
if not self.image_to_video:
|
||||||
|
return None
|
||||||
|
|
||||||
|
image = kwargs.get("concat_latent_image", None)
|
||||||
|
noise = kwargs.get("noise", None)
|
||||||
|
device = kwargs["device"]
|
||||||
|
|
||||||
|
if image is None:
|
||||||
|
image = torch.zeros_like(noise)
|
||||||
|
|
||||||
|
image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
|
||||||
|
image = self.process_latent_in(image)
|
||||||
|
image = utils.resize_to_batch_size(image, noise.shape[0])
|
||||||
|
|
||||||
|
mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None))
|
||||||
|
if mask is None:
|
||||||
|
mask = torch.zeros_like(noise)[:, :4]
|
||||||
|
else:
|
||||||
|
mask = 1.0 - torch.mean(mask, dim=1, keepdim=True)
|
||||||
|
mask = utils.common_upscale(mask.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
|
||||||
|
if mask.shape[-3] < noise.shape[-3]:
|
||||||
|
mask = torch.nn.functional.pad(mask, (0, 0, 0, 0, 0, noise.shape[-3] - mask.shape[-3]), mode='constant', value=0)
|
||||||
|
mask = mask.view(mask.shape[0], -1, 4, mask.shape[-2], mask.shape[-1]).transpose(1, 2)
|
||||||
|
mask = utils.resize_to_batch_size(mask, noise.shape[0])
|
||||||
|
return torch.cat((mask, image), dim=1)
|
||||||
|
|
||||||
def extra_conds(self, **kwargs):
|
def extra_conds(self, **kwargs):
|
||||||
out = super().extra_conds(**kwargs)
|
out = super().extra_conds(**kwargs)
|
||||||
cross_attn = kwargs.get("cross_attn", None)
|
cross_attn = kwargs.get("cross_attn", None)
|
||||||
if cross_attn is not None:
|
if cross_attn is not None:
|
||||||
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||||
|
|
||||||
|
clip_vision_output = kwargs.get("clip_vision_output", None)
|
||||||
|
if clip_vision_output is not None:
|
||||||
|
out['clip_fea'] = comfy.conds.CONDRegular(clip_vision_output.penultimate_hidden_states)
|
||||||
return out
|
return out
|
||||||
|
@ -313,6 +313,7 @@ def detect_unet_config(state_dict, key_prefix):
|
|||||||
dit_config["qk_norm"] = True
|
dit_config["qk_norm"] = True
|
||||||
dit_config["cross_attn_norm"] = True
|
dit_config["cross_attn_norm"] = True
|
||||||
dit_config["eps"] = 1e-6
|
dit_config["eps"] = 1e-6
|
||||||
|
dit_config["in_dim"] = state_dict['{}patch_embedding.weight'.format(key_prefix)].shape[1]
|
||||||
if '{}img_emb.proj.0.bias'.format(key_prefix) in state_dict_keys:
|
if '{}img_emb.proj.0.bias'.format(key_prefix) in state_dict_keys:
|
||||||
dit_config["model_type"] = "i2v"
|
dit_config["model_type"] = "i2v"
|
||||||
else:
|
else:
|
||||||
|
@ -917,7 +917,7 @@ class WAN21_T2V(supported_models_base.BASE):
|
|||||||
text_encoder_key_prefix = ["text_encoders."]
|
text_encoder_key_prefix = ["text_encoders."]
|
||||||
|
|
||||||
def get_model(self, state_dict, prefix="", device=None):
|
def get_model(self, state_dict, prefix="", device=None):
|
||||||
out = model_base.WAN21_T2V(self, device=device)
|
out = model_base.WAN21(self, device=device)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def clip_target(self, state_dict={}):
|
def clip_target(self, state_dict={}):
|
||||||
@ -925,6 +925,16 @@ class WAN21_T2V(supported_models_base.BASE):
|
|||||||
t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}umt5xxl.transformer.".format(pref))
|
t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}umt5xxl.transformer.".format(pref))
|
||||||
return supported_models_base.ClipTarget(comfy.text_encoders.wan.WanT5Tokenizer, comfy.text_encoders.wan.te(**t5_detect))
|
return supported_models_base.ClipTarget(comfy.text_encoders.wan.WanT5Tokenizer, comfy.text_encoders.wan.te(**t5_detect))
|
||||||
|
|
||||||
models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideo, CosmosT2V, CosmosI2V, Lumina2, WAN21_T2V]
|
class WAN21_I2V(WAN21_T2V):
|
||||||
|
unet_config = {
|
||||||
|
"image_model": "wan2.1",
|
||||||
|
"model_type": "i2v",
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_model(self, state_dict, prefix="", device=None):
|
||||||
|
out = model_base.WAN21(self, image_to_video=True, device=device)
|
||||||
|
return out
|
||||||
|
|
||||||
|
models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideo, CosmosT2V, CosmosI2V, Lumina2, WAN21_T2V, WAN21_I2V]
|
||||||
|
|
||||||
models += [SVD_img2vid]
|
models += [SVD_img2vid]
|
||||||
|
61
comfy_extras/nodes_wan.py
Normal file
61
comfy_extras/nodes_wan.py
Normal file
@ -0,0 +1,61 @@
|
|||||||
|
import nodes
|
||||||
|
import node_helpers
|
||||||
|
import torch
|
||||||
|
import comfy.model_management
|
||||||
|
import comfy.utils
|
||||||
|
|
||||||
|
|
||||||
|
def masked_images(num_images):
|
||||||
|
rem = 4 - (num_images % 4)
|
||||||
|
if rem == 4:
|
||||||
|
return num_images
|
||||||
|
return rem + num_images
|
||||||
|
|
||||||
|
|
||||||
|
class WanImageToVideo:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": {"positive": ("CONDITIONING", ),
|
||||||
|
"negative": ("CONDITIONING", ),
|
||||||
|
"vae": ("VAE", ),
|
||||||
|
"width": ("INT", {"default": 1280, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
|
||||||
|
"height": ("INT", {"default": 720, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
|
||||||
|
"length": ("INT", {"default": 121, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}),
|
||||||
|
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
|
||||||
|
},
|
||||||
|
"optional": {"clip_vision_output": ("CLIP_VISION_OUTPUT", ),
|
||||||
|
"start_image": ("IMAGE", ),
|
||||||
|
}}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
|
||||||
|
RETURN_NAMES = ("positive", "negative", "latent")
|
||||||
|
FUNCTION = "encode"
|
||||||
|
|
||||||
|
CATEGORY = "conditioning/video_models"
|
||||||
|
|
||||||
|
def encode(self, positive, negative, vae, width, height, length, batch_size, start_image=None, clip_vision_output=None):
|
||||||
|
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
|
||||||
|
if start_image is not None:
|
||||||
|
start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
||||||
|
image = torch.ones((length, height, width, start_image.shape[-1]), device=start_image.device, dtype=start_image.dtype) * 0.5
|
||||||
|
image[:start_image.shape[0]] = start_image
|
||||||
|
|
||||||
|
concat_latent_image = vae.encode(image[:, :, :, :3])
|
||||||
|
mask = torch.ones((1, 1, latent.shape[2] * 4, concat_latent_image.shape[-2], concat_latent_image.shape[-1]), device=start_image.device, dtype=start_image.dtype)
|
||||||
|
mask[:, :, :masked_images(start_image.shape[0])] = 0.0
|
||||||
|
|
||||||
|
positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent_image, "concat_mask": mask})
|
||||||
|
negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent_image, "concat_mask": mask})
|
||||||
|
|
||||||
|
if clip_vision_output is not None:
|
||||||
|
positive = node_helpers.conditioning_set_values(positive, {"clip_vision_output": clip_vision_output})
|
||||||
|
negative = node_helpers.conditioning_set_values(negative, {"clip_vision_output": clip_vision_output})
|
||||||
|
|
||||||
|
out_latent = {}
|
||||||
|
out_latent["samples"] = latent
|
||||||
|
return (positive, negative, out_latent)
|
||||||
|
|
||||||
|
|
||||||
|
NODE_CLASS_MAPPINGS = {
|
||||||
|
"WanImageToVideo": WanImageToVideo,
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user