Slightly better wan i2v mask implementation.

This commit is contained in:
comfyanonymous 2025-02-26 03:49:50 -05:00
parent 4ced06b879
commit 0844998db3
2 changed files with 4 additions and 10 deletions

View File

@ -957,8 +957,9 @@ class WAN21(BaseModel):
mask = utils.common_upscale(mask.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center") mask = utils.common_upscale(mask.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
if mask.shape[-3] < noise.shape[-3]: if mask.shape[-3] < noise.shape[-3]:
mask = torch.nn.functional.pad(mask, (0, 0, 0, 0, 0, noise.shape[-3] - mask.shape[-3]), mode='constant', value=0) mask = torch.nn.functional.pad(mask, (0, 0, 0, 0, 0, noise.shape[-3] - mask.shape[-3]), mode='constant', value=0)
mask = mask.view(mask.shape[0], -1, 4, mask.shape[-2], mask.shape[-1]).transpose(1, 2) mask = mask.repeat(1, 4, 1, 1, 1)
mask = utils.resize_to_batch_size(mask, noise.shape[0]) mask = utils.resize_to_batch_size(mask, noise.shape[0])
return torch.cat((mask, image), dim=1) return torch.cat((mask, image), dim=1)
def extra_conds(self, **kwargs): def extra_conds(self, **kwargs):

View File

@ -5,13 +5,6 @@ import comfy.model_management
import comfy.utils import comfy.utils
def masked_images(num_images):
rem = 4 - (num_images % 4)
if rem == 4:
return num_images
return rem + num_images
class WanImageToVideo: class WanImageToVideo:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
@ -41,8 +34,8 @@ class WanImageToVideo:
image[:start_image.shape[0]] = start_image image[:start_image.shape[0]] = start_image
concat_latent_image = vae.encode(image[:, :, :, :3]) concat_latent_image = vae.encode(image[:, :, :, :3])
mask = torch.ones((1, 1, latent.shape[2] * 4, concat_latent_image.shape[-2], concat_latent_image.shape[-1]), device=start_image.device, dtype=start_image.dtype) mask = torch.ones((1, 1, latent.shape[2], concat_latent_image.shape[-2], concat_latent_image.shape[-1]), device=start_image.device, dtype=start_image.dtype)
mask[:, :, :masked_images(start_image.shape[0])] = 0.0 mask[:, :, :((start_image.shape[0] - 1) // 4) + 1] = 0.0
positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent_image, "concat_mask": mask}) positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent_image, "concat_mask": mask})
negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent_image, "concat_mask": mask}) negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent_image, "concat_mask": mask})