diff --git a/comfy/model_base.py b/comfy/model_base.py index bf4ebefa..976702b6 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -973,11 +973,11 @@ class WAN21(BaseModel): self.image_to_video = image_to_video def concat_cond(self, **kwargs): - if not self.image_to_video: + noise = kwargs.get("noise", None) + if self.diffusion_model.patch_embedding.weight.shape[1] == noise.shape[1]: return None image = kwargs.get("concat_latent_image", None) - noise = kwargs.get("noise", None) device = kwargs["device"] if image is None: @@ -987,6 +987,9 @@ class WAN21(BaseModel): image = self.process_latent_in(image) image = utils.resize_to_batch_size(image, noise.shape[0]) + if not self.image_to_video: + return image + mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None)) if mask is None: mask = torch.zeros_like(noise)[:, :4]