diff --git a/comfy/model_base.py b/comfy/model_base.py index 1885049e..66cd0ded 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -957,8 +957,9 @@ class WAN21(BaseModel): mask = utils.common_upscale(mask.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center") if mask.shape[-3] < noise.shape[-3]: mask = torch.nn.functional.pad(mask, (0, 0, 0, 0, 0, noise.shape[-3] - mask.shape[-3]), mode='constant', value=0) - mask = mask.view(mask.shape[0], -1, 4, mask.shape[-2], mask.shape[-1]).transpose(1, 2) + mask = mask.repeat(1, 4, 1, 1, 1) mask = utils.resize_to_batch_size(mask, noise.shape[0]) + return torch.cat((mask, image), dim=1) def extra_conds(self, **kwargs): diff --git a/comfy_extras/nodes_wan.py b/comfy_extras/nodes_wan.py index 52851d1f..f3a8ec66 100644 --- a/comfy_extras/nodes_wan.py +++ b/comfy_extras/nodes_wan.py @@ -5,13 +5,6 @@ import comfy.model_management import comfy.utils -def masked_images(num_images): - rem = 4 - (num_images % 4) - if rem == 4: - return num_images - return rem + num_images - - class WanImageToVideo: @classmethod def INPUT_TYPES(s): @@ -41,8 +34,8 @@ class WanImageToVideo: image[:start_image.shape[0]] = start_image concat_latent_image = vae.encode(image[:, :, :, :3]) - mask = torch.ones((1, 1, latent.shape[2] * 4, concat_latent_image.shape[-2], concat_latent_image.shape[-1]), device=start_image.device, dtype=start_image.dtype) - mask[:, :, :masked_images(start_image.shape[0])] = 0.0 + mask = torch.ones((1, 1, latent.shape[2], concat_latent_image.shape[-2], concat_latent_image.shape[-1]), device=start_image.device, dtype=start_image.dtype) + mask[:, :, :((start_image.shape[0] - 1) // 4) + 1] = 0.0 positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent_image, "concat_mask": mask}) negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent_image, "concat_mask": mask})