mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-02-28 22:51:45 +00:00
Slightly better wan i2v mask implementation.
This commit is contained in:
parent
4ced06b879
commit
0844998db3
@ -957,8 +957,9 @@ class WAN21(BaseModel):
|
|||||||
mask = utils.common_upscale(mask.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
|
mask = utils.common_upscale(mask.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
|
||||||
if mask.shape[-3] < noise.shape[-3]:
|
if mask.shape[-3] < noise.shape[-3]:
|
||||||
mask = torch.nn.functional.pad(mask, (0, 0, 0, 0, 0, noise.shape[-3] - mask.shape[-3]), mode='constant', value=0)
|
mask = torch.nn.functional.pad(mask, (0, 0, 0, 0, 0, noise.shape[-3] - mask.shape[-3]), mode='constant', value=0)
|
||||||
mask = mask.view(mask.shape[0], -1, 4, mask.shape[-2], mask.shape[-1]).transpose(1, 2)
|
mask = mask.repeat(1, 4, 1, 1, 1)
|
||||||
mask = utils.resize_to_batch_size(mask, noise.shape[0])
|
mask = utils.resize_to_batch_size(mask, noise.shape[0])
|
||||||
|
|
||||||
return torch.cat((mask, image), dim=1)
|
return torch.cat((mask, image), dim=1)
|
||||||
|
|
||||||
def extra_conds(self, **kwargs):
|
def extra_conds(self, **kwargs):
|
||||||
|
@ -5,13 +5,6 @@ import comfy.model_management
|
|||||||
import comfy.utils
|
import comfy.utils
|
||||||
|
|
||||||
|
|
||||||
def masked_images(num_images):
|
|
||||||
rem = 4 - (num_images % 4)
|
|
||||||
if rem == 4:
|
|
||||||
return num_images
|
|
||||||
return rem + num_images
|
|
||||||
|
|
||||||
|
|
||||||
class WanImageToVideo:
|
class WanImageToVideo:
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
@ -41,8 +34,8 @@ class WanImageToVideo:
|
|||||||
image[:start_image.shape[0]] = start_image
|
image[:start_image.shape[0]] = start_image
|
||||||
|
|
||||||
concat_latent_image = vae.encode(image[:, :, :, :3])
|
concat_latent_image = vae.encode(image[:, :, :, :3])
|
||||||
mask = torch.ones((1, 1, latent.shape[2] * 4, concat_latent_image.shape[-2], concat_latent_image.shape[-1]), device=start_image.device, dtype=start_image.dtype)
|
mask = torch.ones((1, 1, latent.shape[2], concat_latent_image.shape[-2], concat_latent_image.shape[-1]), device=start_image.device, dtype=start_image.dtype)
|
||||||
mask[:, :, :masked_images(start_image.shape[0])] = 0.0
|
mask[:, :, :((start_image.shape[0] - 1) // 4) + 1] = 0.0
|
||||||
|
|
||||||
positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent_image, "concat_mask": mask})
|
positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent_image, "concat_mask": mask})
|
||||||
negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent_image, "concat_mask": mask})
|
negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent_image, "concat_mask": mask})
|
||||||
|
Loading…
Reference in New Issue
Block a user