mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-06-07 12:02:10 +08:00
Support for WAN VACE preview model. (#7711)
* Support for WAN VACE preview model. * Remove print.
This commit is contained in:
parent
b6fd3ffd10
commit
ce22f687cc
@ -220,6 +220,34 @@ class WanAttentionBlock(nn.Module):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class VaceWanAttentionBlock(WanAttentionBlock):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
cross_attn_type,
|
||||||
|
dim,
|
||||||
|
ffn_dim,
|
||||||
|
num_heads,
|
||||||
|
window_size=(-1, -1),
|
||||||
|
qk_norm=True,
|
||||||
|
cross_attn_norm=False,
|
||||||
|
eps=1e-6,
|
||||||
|
block_id=0,
|
||||||
|
operation_settings={}
|
||||||
|
):
|
||||||
|
super().__init__(cross_attn_type, dim, ffn_dim, num_heads, window_size, qk_norm, cross_attn_norm, eps, operation_settings=operation_settings)
|
||||||
|
self.block_id = block_id
|
||||||
|
if block_id == 0:
|
||||||
|
self.before_proj = operation_settings.get("operations").Linear(self.dim, self.dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||||
|
self.after_proj = operation_settings.get("operations").Linear(self.dim, self.dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||||
|
|
||||||
|
def forward(self, c, x, **kwargs):
|
||||||
|
if self.block_id == 0:
|
||||||
|
c = self.before_proj(c) + x
|
||||||
|
c = super().forward(c, **kwargs)
|
||||||
|
c_skip = self.after_proj(c)
|
||||||
|
return c_skip, c
|
||||||
|
|
||||||
|
|
||||||
class Head(nn.Module):
|
class Head(nn.Module):
|
||||||
|
|
||||||
def __init__(self, dim, out_dim, patch_size, eps=1e-6, operation_settings={}):
|
def __init__(self, dim, out_dim, patch_size, eps=1e-6, operation_settings={}):
|
||||||
@ -395,6 +423,7 @@ class WanModel(torch.nn.Module):
|
|||||||
clip_fea=None,
|
clip_fea=None,
|
||||||
freqs=None,
|
freqs=None,
|
||||||
transformer_options={},
|
transformer_options={},
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
Forward pass through the diffusion model
|
Forward pass through the diffusion model
|
||||||
@ -457,7 +486,7 @@ class WanModel(torch.nn.Module):
|
|||||||
x = self.unpatchify(x, grid_sizes)
|
x = self.unpatchify(x, grid_sizes)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def forward(self, x, timestep, context, clip_fea=None, transformer_options={},**kwargs):
|
def forward(self, x, timestep, context, clip_fea=None, transformer_options={}, **kwargs):
|
||||||
bs, c, t, h, w = x.shape
|
bs, c, t, h, w = x.shape
|
||||||
x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size)
|
x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size)
|
||||||
patch_size = self.patch_size
|
patch_size = self.patch_size
|
||||||
@ -471,7 +500,7 @@ class WanModel(torch.nn.Module):
|
|||||||
img_ids = repeat(img_ids, "t h w c -> b (t h w) c", b=bs)
|
img_ids = repeat(img_ids, "t h w c -> b (t h w) c", b=bs)
|
||||||
|
|
||||||
freqs = self.rope_embedder(img_ids).movedim(1, 2)
|
freqs = self.rope_embedder(img_ids).movedim(1, 2)
|
||||||
return self.forward_orig(x, timestep, context, clip_fea=clip_fea, freqs=freqs, transformer_options=transformer_options)[:, :, :t, :h, :w]
|
return self.forward_orig(x, timestep, context, clip_fea=clip_fea, freqs=freqs, transformer_options=transformer_options, **kwargs)[:, :, :t, :h, :w]
|
||||||
|
|
||||||
def unpatchify(self, x, grid_sizes):
|
def unpatchify(self, x, grid_sizes):
|
||||||
r"""
|
r"""
|
||||||
@ -496,3 +525,114 @@ class WanModel(torch.nn.Module):
|
|||||||
u = torch.einsum('bfhwpqrc->bcfphqwr', u)
|
u = torch.einsum('bfhwpqrc->bcfphqwr', u)
|
||||||
u = u.reshape(b, c, *[i * j for i, j in zip(grid_sizes, self.patch_size)])
|
u = u.reshape(b, c, *[i * j for i, j in zip(grid_sizes, self.patch_size)])
|
||||||
return u
|
return u
|
||||||
|
|
||||||
|
|
||||||
|
class VaceWanModel(WanModel):
|
||||||
|
r"""
|
||||||
|
Wan diffusion backbone supporting both text-to-video and image-to-video.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
model_type='vace',
|
||||||
|
patch_size=(1, 2, 2),
|
||||||
|
text_len=512,
|
||||||
|
in_dim=16,
|
||||||
|
dim=2048,
|
||||||
|
ffn_dim=8192,
|
||||||
|
freq_dim=256,
|
||||||
|
text_dim=4096,
|
||||||
|
out_dim=16,
|
||||||
|
num_heads=16,
|
||||||
|
num_layers=32,
|
||||||
|
window_size=(-1, -1),
|
||||||
|
qk_norm=True,
|
||||||
|
cross_attn_norm=True,
|
||||||
|
eps=1e-6,
|
||||||
|
flf_pos_embed_token_number=None,
|
||||||
|
image_model=None,
|
||||||
|
vace_layers=None,
|
||||||
|
vace_in_dim=None,
|
||||||
|
device=None,
|
||||||
|
dtype=None,
|
||||||
|
operations=None,
|
||||||
|
):
|
||||||
|
|
||||||
|
super().__init__(model_type='t2v', patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, flf_pos_embed_token_number=flf_pos_embed_token_number, image_model=image_model, device=device, dtype=dtype, operations=operations)
|
||||||
|
operation_settings = {"operations": operations, "device": device, "dtype": dtype}
|
||||||
|
|
||||||
|
# Vace
|
||||||
|
if vace_layers is not None:
|
||||||
|
self.vace_layers = vace_layers
|
||||||
|
self.vace_in_dim = vace_in_dim
|
||||||
|
# vace blocks
|
||||||
|
self.vace_blocks = nn.ModuleList([
|
||||||
|
VaceWanAttentionBlock('t2v_cross_attn', self.dim, self.ffn_dim, self.num_heads, self.window_size, self.qk_norm, self.cross_attn_norm, self.eps, block_id=i, operation_settings=operation_settings)
|
||||||
|
for i in range(self.vace_layers)
|
||||||
|
])
|
||||||
|
|
||||||
|
self.vace_layers_mapping = {i: n for n, i in enumerate(range(0, self.num_layers, self.num_layers // self.vace_layers))}
|
||||||
|
# vace patch embeddings
|
||||||
|
self.vace_patch_embedding = operations.Conv3d(
|
||||||
|
self.vace_in_dim, self.dim, kernel_size=self.patch_size, stride=self.patch_size, device=device, dtype=torch.float32
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward_orig(
|
||||||
|
self,
|
||||||
|
x,
|
||||||
|
t,
|
||||||
|
context,
|
||||||
|
vace_context,
|
||||||
|
clip_fea=None,
|
||||||
|
freqs=None,
|
||||||
|
transformer_options={},
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
# embeddings
|
||||||
|
x = self.patch_embedding(x.float()).to(x.dtype)
|
||||||
|
grid_sizes = x.shape[2:]
|
||||||
|
x = x.flatten(2).transpose(1, 2)
|
||||||
|
|
||||||
|
# time embeddings
|
||||||
|
e = self.time_embedding(
|
||||||
|
sinusoidal_embedding_1d(self.freq_dim, t).to(dtype=x[0].dtype))
|
||||||
|
e0 = self.time_projection(e).unflatten(1, (6, self.dim))
|
||||||
|
|
||||||
|
# context
|
||||||
|
context = self.text_embedding(context)
|
||||||
|
|
||||||
|
context_img_len = None
|
||||||
|
if clip_fea is not None:
|
||||||
|
if self.img_emb is not None:
|
||||||
|
context_clip = self.img_emb(clip_fea) # bs x 257 x dim
|
||||||
|
context = torch.concat([context_clip, context], dim=1)
|
||||||
|
context_img_len = clip_fea.shape[-2]
|
||||||
|
|
||||||
|
c = self.vace_patch_embedding(vace_context.float()).to(vace_context.dtype)
|
||||||
|
c = c.flatten(2).transpose(1, 2)
|
||||||
|
|
||||||
|
# arguments
|
||||||
|
x_orig = x
|
||||||
|
|
||||||
|
patches_replace = transformer_options.get("patches_replace", {})
|
||||||
|
blocks_replace = patches_replace.get("dit", {})
|
||||||
|
for i, block in enumerate(self.blocks):
|
||||||
|
if ("double_block", i) in blocks_replace:
|
||||||
|
def block_wrap(args):
|
||||||
|
out = {}
|
||||||
|
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len)
|
||||||
|
return out
|
||||||
|
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs}, {"original_block": block_wrap})
|
||||||
|
x = out["img"]
|
||||||
|
else:
|
||||||
|
x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len)
|
||||||
|
|
||||||
|
ii = self.vace_layers_mapping.get(i, None)
|
||||||
|
if ii is not None:
|
||||||
|
c_skip, c = self.vace_blocks[ii](c, x=x_orig, e=e0, freqs=freqs, context=context, context_img_len=context_img_len)
|
||||||
|
x += c_skip
|
||||||
|
# head
|
||||||
|
x = self.head(x, e)
|
||||||
|
|
||||||
|
# unpatchify
|
||||||
|
x = self.unpatchify(x, grid_sizes)
|
||||||
|
return x
|
||||||
|
@ -1043,6 +1043,34 @@ class WAN21(BaseModel):
|
|||||||
out['clip_fea'] = comfy.conds.CONDRegular(clip_vision_output.penultimate_hidden_states)
|
out['clip_fea'] = comfy.conds.CONDRegular(clip_vision_output.penultimate_hidden_states)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class WAN21_Vace(WAN21):
|
||||||
|
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
|
||||||
|
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.VaceWanModel)
|
||||||
|
self.image_to_video = image_to_video
|
||||||
|
|
||||||
|
def extra_conds(self, **kwargs):
|
||||||
|
out = super().extra_conds(**kwargs)
|
||||||
|
noise = kwargs.get("noise", None)
|
||||||
|
noise_shape = list(noise.shape)
|
||||||
|
vace_frames = kwargs.get("vace_frames", None)
|
||||||
|
if vace_frames is None:
|
||||||
|
noise_shape[1] = 32
|
||||||
|
vace_frames = torch.zeros(noise_shape, device=noise.device, dtype=noise.dtype)
|
||||||
|
|
||||||
|
for i in range(0, vace_frames.shape[1], 16):
|
||||||
|
vace_frames = vace_frames.clone()
|
||||||
|
vace_frames[:, i:i + 16] = self.process_latent_in(vace_frames[:, i:i + 16])
|
||||||
|
|
||||||
|
mask = kwargs.get("vace_mask", None)
|
||||||
|
if mask is None:
|
||||||
|
noise_shape[1] = 64
|
||||||
|
mask = torch.ones(noise_shape, device=noise.device, dtype=noise.dtype)
|
||||||
|
|
||||||
|
out['vace_context'] = comfy.conds.CONDRegular(torch.cat([vace_frames.to(noise), mask.to(noise)], dim=1))
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
class Hunyuan3Dv2(BaseModel):
|
class Hunyuan3Dv2(BaseModel):
|
||||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan3d.model.Hunyuan3Dv2)
|
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan3d.model.Hunyuan3Dv2)
|
||||||
|
@ -317,10 +317,15 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
|||||||
dit_config["cross_attn_norm"] = True
|
dit_config["cross_attn_norm"] = True
|
||||||
dit_config["eps"] = 1e-6
|
dit_config["eps"] = 1e-6
|
||||||
dit_config["in_dim"] = state_dict['{}patch_embedding.weight'.format(key_prefix)].shape[1]
|
dit_config["in_dim"] = state_dict['{}patch_embedding.weight'.format(key_prefix)].shape[1]
|
||||||
if '{}img_emb.proj.0.bias'.format(key_prefix) in state_dict_keys:
|
if '{}vace_patch_embedding.weight'.format(key_prefix) in state_dict_keys:
|
||||||
dit_config["model_type"] = "i2v"
|
dit_config["model_type"] = "vace"
|
||||||
|
dit_config["vace_in_dim"] = state_dict['{}vace_patch_embedding.weight'.format(key_prefix)].shape[1]
|
||||||
|
dit_config["vace_layers"] = count_blocks(state_dict_keys, '{}vace_blocks.'.format(key_prefix) + '{}.')
|
||||||
else:
|
else:
|
||||||
dit_config["model_type"] = "t2v"
|
if '{}img_emb.proj.0.bias'.format(key_prefix) in state_dict_keys:
|
||||||
|
dit_config["model_type"] = "i2v"
|
||||||
|
else:
|
||||||
|
dit_config["model_type"] = "t2v"
|
||||||
flf_weight = state_dict.get('{}img_emb.emb_pos'.format(key_prefix))
|
flf_weight = state_dict.get('{}img_emb.emb_pos'.format(key_prefix))
|
||||||
if flf_weight is not None:
|
if flf_weight is not None:
|
||||||
dit_config["flf_pos_embed_token_number"] = flf_weight.shape[1]
|
dit_config["flf_pos_embed_token_number"] = flf_weight.shape[1]
|
||||||
|
@ -987,6 +987,16 @@ class WAN21_FunControl2V(WAN21_T2V):
|
|||||||
out = model_base.WAN21(self, image_to_video=False, device=device)
|
out = model_base.WAN21(self, image_to_video=False, device=device)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
class WAN21_Vace(WAN21_T2V):
|
||||||
|
unet_config = {
|
||||||
|
"image_model": "wan2.1",
|
||||||
|
"model_type": "vace",
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_model(self, state_dict, prefix="", device=None):
|
||||||
|
out = model_base.WAN21_Vace(self, image_to_video=False, device=device)
|
||||||
|
return out
|
||||||
|
|
||||||
class Hunyuan3Dv2(supported_models_base.BASE):
|
class Hunyuan3Dv2(supported_models_base.BASE):
|
||||||
unet_config = {
|
unet_config = {
|
||||||
"image_model": "hunyuan3d2",
|
"image_model": "hunyuan3d2",
|
||||||
@ -1055,6 +1065,6 @@ class HiDream(supported_models_base.BASE):
|
|||||||
return None # TODO
|
return None # TODO
|
||||||
|
|
||||||
|
|
||||||
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, Lumina2, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream]
|
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, Lumina2, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream]
|
||||||
|
|
||||||
models += [SVD_img2vid]
|
models += [SVD_img2vid]
|
||||||
|
@ -193,9 +193,115 @@ class WanFunInpaintToVideo:
|
|||||||
return flfv.encode(positive, negative, vae, width, height, length, batch_size, start_image=start_image, end_image=end_image, clip_vision_start_image=clip_vision_output)
|
return flfv.encode(positive, negative, vae, width, height, length, batch_size, start_image=start_image, end_image=end_image, clip_vision_start_image=clip_vision_output)
|
||||||
|
|
||||||
|
|
||||||
|
class WanVaceToVideo:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": {"positive": ("CONDITIONING", ),
|
||||||
|
"negative": ("CONDITIONING", ),
|
||||||
|
"vae": ("VAE", ),
|
||||||
|
"width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
|
||||||
|
"height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
|
||||||
|
"length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}),
|
||||||
|
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
|
||||||
|
},
|
||||||
|
"optional": {"control_video": ("IMAGE", ),
|
||||||
|
"control_masks": ("MASK", ),
|
||||||
|
"reference_image": ("IMAGE", ),
|
||||||
|
}}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT", "INT")
|
||||||
|
RETURN_NAMES = ("positive", "negative", "latent", "trim_latent")
|
||||||
|
FUNCTION = "encode"
|
||||||
|
|
||||||
|
CATEGORY = "conditioning/video_models"
|
||||||
|
|
||||||
|
EXPERIMENTAL = True
|
||||||
|
|
||||||
|
def encode(self, positive, negative, vae, width, height, length, batch_size, control_video=None, control_masks=None, reference_image=None):
|
||||||
|
latent_length = ((length - 1) // 4) + 1
|
||||||
|
if control_video is not None:
|
||||||
|
control_video = comfy.utils.common_upscale(control_video[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
||||||
|
if control_video.shape[0] < length:
|
||||||
|
control_video = torch.nn.functional.pad(control_video, (0, 0, 0, 0, 0, 0, 0, length - control_video.shape[0]), value=0.5)
|
||||||
|
else:
|
||||||
|
control_video = torch.ones((length, height, width, 3)) * 0.5
|
||||||
|
|
||||||
|
if reference_image is not None:
|
||||||
|
reference_image = comfy.utils.common_upscale(reference_image[:1].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
||||||
|
reference_image = vae.encode(reference_image[:, :, :, :3])
|
||||||
|
reference_image = torch.cat([reference_image, comfy.latent_formats.Wan21().process_out(torch.zeros_like(reference_image))], dim=1)
|
||||||
|
|
||||||
|
if control_masks is None:
|
||||||
|
mask = torch.ones((length, height, width, 1))
|
||||||
|
else:
|
||||||
|
mask = control_masks
|
||||||
|
if mask.ndim == 3:
|
||||||
|
mask = mask.unsqueeze(1)
|
||||||
|
mask = comfy.utils.common_upscale(mask[:length], width, height, "bilinear", "center").movedim(1, -1)
|
||||||
|
if mask.shape[0] < length:
|
||||||
|
mask = torch.nn.functional.pad(mask, (0, 0, 0, 0, 0, 0, 0, length - mask.shape[0]), value=1.0)
|
||||||
|
|
||||||
|
control_video = control_video - 0.5
|
||||||
|
inactive = (control_video * (1 - mask)) + 0.5
|
||||||
|
reactive = (control_video * mask) + 0.5
|
||||||
|
|
||||||
|
inactive = vae.encode(inactive[:, :, :, :3])
|
||||||
|
reactive = vae.encode(reactive[:, :, :, :3])
|
||||||
|
control_video_latent = torch.cat((inactive, reactive), dim=1)
|
||||||
|
if reference_image is not None:
|
||||||
|
control_video_latent = torch.cat((reference_image, control_video_latent), dim=2)
|
||||||
|
|
||||||
|
vae_stride = 8
|
||||||
|
height_mask = height // vae_stride
|
||||||
|
width_mask = width // vae_stride
|
||||||
|
mask = mask.view(length, height_mask, vae_stride, width_mask, vae_stride)
|
||||||
|
mask = mask.permute(2, 4, 0, 1, 3)
|
||||||
|
mask = mask.reshape(vae_stride * vae_stride, length, height_mask, width_mask)
|
||||||
|
mask = torch.nn.functional.interpolate(mask.unsqueeze(0), size=(latent_length, height_mask, width_mask), mode='nearest-exact').squeeze(0)
|
||||||
|
|
||||||
|
trim_latent = 0
|
||||||
|
if reference_image is not None:
|
||||||
|
mask_pad = torch.zeros_like(mask[:, :reference_image.shape[2], :, :])
|
||||||
|
mask = torch.cat((mask_pad, mask), dim=1)
|
||||||
|
latent_length += reference_image.shape[2]
|
||||||
|
trim_latent = reference_image.shape[2]
|
||||||
|
|
||||||
|
mask = mask.unsqueeze(0)
|
||||||
|
positive = node_helpers.conditioning_set_values(positive, {"vace_frames": control_video_latent, "vace_mask": mask})
|
||||||
|
negative = node_helpers.conditioning_set_values(negative, {"vace_frames": control_video_latent, "vace_mask": mask})
|
||||||
|
|
||||||
|
latent = torch.zeros([batch_size, 16, latent_length, height // 8, width // 8], device=comfy.model_management.intermediate_device())
|
||||||
|
out_latent = {}
|
||||||
|
out_latent["samples"] = latent
|
||||||
|
return (positive, negative, out_latent, trim_latent)
|
||||||
|
|
||||||
|
class TrimVideoLatent:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": { "samples": ("LATENT",),
|
||||||
|
"trim_amount": ("INT", {"default": 0, "min": 0, "max": 99999}),
|
||||||
|
}}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("LATENT",)
|
||||||
|
FUNCTION = "op"
|
||||||
|
|
||||||
|
CATEGORY = "latent/video"
|
||||||
|
|
||||||
|
EXPERIMENTAL = True
|
||||||
|
|
||||||
|
def op(self, samples, trim_amount):
|
||||||
|
samples_out = samples.copy()
|
||||||
|
|
||||||
|
s1 = samples["samples"]
|
||||||
|
samples_out["samples"] = s1[:, :, trim_amount:]
|
||||||
|
return (samples_out,)
|
||||||
|
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
NODE_CLASS_MAPPINGS = {
|
||||||
"WanImageToVideo": WanImageToVideo,
|
"WanImageToVideo": WanImageToVideo,
|
||||||
"WanFunControlToVideo": WanFunControlToVideo,
|
"WanFunControlToVideo": WanFunControlToVideo,
|
||||||
"WanFunInpaintToVideo": WanFunInpaintToVideo,
|
"WanFunInpaintToVideo": WanFunInpaintToVideo,
|
||||||
"WanFirstLastFrameToVideo": WanFirstLastFrameToVideo,
|
"WanFirstLastFrameToVideo": WanFirstLastFrameToVideo,
|
||||||
|
"WanVaceToVideo": WanVaceToVideo,
|
||||||
|
"TrimVideoLatent": TrimVideoLatent,
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user