From cb7c3a2921cfc0805be0229b4634e1143d60e6fe Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 29 Feb 2024 13:09:43 -0500 Subject: [PATCH] Allow image_only_indicator to be None. --- .../modules/diffusionmodules/openaimodel.py | 3 +-- comfy/ldm/modules/diffusionmodules/util.py | 22 ++++++++++--------- comfy/model_base.py | 1 - 3 files changed, 13 insertions(+), 13 deletions(-) diff --git a/comfy/ldm/modules/diffusionmodules/openaimodel.py b/comfy/ldm/modules/diffusionmodules/openaimodel.py index c5477025..cf89ae01 100644 --- a/comfy/ldm/modules/diffusionmodules/openaimodel.py +++ b/comfy/ldm/modules/diffusionmodules/openaimodel.py @@ -484,7 +484,6 @@ class UNetModel(nn.Module): self.predict_codebook_ids = n_embed is not None self.default_num_video_frames = None - self.default_image_only_indicator = None time_embed_dim = model_channels * 4 self.time_embed = nn.Sequential( @@ -830,7 +829,7 @@ class UNetModel(nn.Module): transformer_patches = transformer_options.get("patches", {}) num_video_frames = kwargs.get("num_video_frames", self.default_num_video_frames) - image_only_indicator = kwargs.get("image_only_indicator", self.default_image_only_indicator) + image_only_indicator = kwargs.get("image_only_indicator", None) time_context = kwargs.get("time_context", None) assert (y is not None) == ( diff --git a/comfy/ldm/modules/diffusionmodules/util.py b/comfy/ldm/modules/diffusionmodules/util.py index 5a6aa7d7..ce14ad5e 100644 --- a/comfy/ldm/modules/diffusionmodules/util.py +++ b/comfy/ldm/modules/diffusionmodules/util.py @@ -46,23 +46,25 @@ class AlphaBlender(nn.Module): else: raise ValueError(f"unknown merge strategy {self.merge_strategy}") - def get_alpha(self, image_only_indicator: torch.Tensor) -> torch.Tensor: + def get_alpha(self, image_only_indicator: torch.Tensor, device) -> torch.Tensor: # skip_time_mix = rearrange(repeat(skip_time_mix, 'b -> (b t) () () ()', t=t), '(b t) 1 ... -> b 1 t ...', t=t) if self.merge_strategy == "fixed": # make shape compatible # alpha = repeat(self.mix_factor, '1 -> b () t () ()', t=t, b=bs) - alpha = self.mix_factor.to(image_only_indicator.device) + alpha = self.mix_factor.to(device) elif self.merge_strategy == "learned": - alpha = torch.sigmoid(self.mix_factor.to(image_only_indicator.device)) + alpha = torch.sigmoid(self.mix_factor.to(device)) # make shape compatible # alpha = repeat(alpha, '1 -> s () ()', s = t * bs) elif self.merge_strategy == "learned_with_images": - assert image_only_indicator is not None, "need image_only_indicator ..." - alpha = torch.where( - image_only_indicator.bool(), - torch.ones(1, 1, device=image_only_indicator.device), - rearrange(torch.sigmoid(self.mix_factor.to(image_only_indicator.device)), "... -> ... 1"), - ) + if image_only_indicator is None: + alpha = rearrange(torch.sigmoid(self.mix_factor.to(device)), "... -> ... 1") + else: + alpha = torch.where( + image_only_indicator.bool(), + torch.ones(1, 1, device=image_only_indicator.device), + rearrange(torch.sigmoid(self.mix_factor.to(image_only_indicator.device)), "... -> ... 1"), + ) alpha = rearrange(alpha, self.rearrange_pattern) # make shape compatible # alpha = repeat(alpha, '1 -> s () ()', s = t * bs) @@ -76,7 +78,7 @@ class AlphaBlender(nn.Module): x_temporal, image_only_indicator=None, ) -> torch.Tensor: - alpha = self.get_alpha(image_only_indicator) + alpha = self.get_alpha(image_only_indicator, x_spatial.device) x = ( alpha.to(x_spatial.dtype) * x_spatial + (1.0 - alpha).to(x_spatial.dtype) * x_temporal diff --git a/comfy/model_base.py b/comfy/model_base.py index 170b1fd4..69117dd0 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -372,7 +372,6 @@ class SVD_img2vid(BaseModel): if "time_conditioning" in kwargs: out["time_context"] = comfy.conds.CONDCrossAttn(kwargs["time_conditioning"]) - out['image_only_indicator'] = comfy.conds.CONDConstant(torch.zeros((1,), device=device)) out['num_video_frames'] = comfy.conds.CONDConstant(noise.shape[0]) return out