Allow image_only_indicator to be None.

This commit is contained in:
comfyanonymous 2024-02-29 13:09:43 -05:00
parent b3e97fc714
commit cb7c3a2921
3 changed files with 13 additions and 13 deletions

View File

@ -484,7 +484,6 @@ class UNetModel(nn.Module):
self.predict_codebook_ids = n_embed is not None
self.default_num_video_frames = None
self.default_image_only_indicator = None
time_embed_dim = model_channels * 4
self.time_embed = nn.Sequential(
@ -830,7 +829,7 @@ class UNetModel(nn.Module):
transformer_patches = transformer_options.get("patches", {})
num_video_frames = kwargs.get("num_video_frames", self.default_num_video_frames)
image_only_indicator = kwargs.get("image_only_indicator", self.default_image_only_indicator)
image_only_indicator = kwargs.get("image_only_indicator", None)
time_context = kwargs.get("time_context", None)
assert (y is not None) == (

View File

@ -46,18 +46,20 @@ class AlphaBlender(nn.Module):
else:
raise ValueError(f"unknown merge strategy {self.merge_strategy}")
def get_alpha(self, image_only_indicator: torch.Tensor) -> torch.Tensor:
def get_alpha(self, image_only_indicator: torch.Tensor, device) -> torch.Tensor:
# skip_time_mix = rearrange(repeat(skip_time_mix, 'b -> (b t) () () ()', t=t), '(b t) 1 ... -> b 1 t ...', t=t)
if self.merge_strategy == "fixed":
# make shape compatible
# alpha = repeat(self.mix_factor, '1 -> b () t () ()', t=t, b=bs)
alpha = self.mix_factor.to(image_only_indicator.device)
alpha = self.mix_factor.to(device)
elif self.merge_strategy == "learned":
alpha = torch.sigmoid(self.mix_factor.to(image_only_indicator.device))
alpha = torch.sigmoid(self.mix_factor.to(device))
# make shape compatible
# alpha = repeat(alpha, '1 -> s () ()', s = t * bs)
elif self.merge_strategy == "learned_with_images":
assert image_only_indicator is not None, "need image_only_indicator ..."
if image_only_indicator is None:
alpha = rearrange(torch.sigmoid(self.mix_factor.to(device)), "... -> ... 1")
else:
alpha = torch.where(
image_only_indicator.bool(),
torch.ones(1, 1, device=image_only_indicator.device),
@ -76,7 +78,7 @@ class AlphaBlender(nn.Module):
x_temporal,
image_only_indicator=None,
) -> torch.Tensor:
alpha = self.get_alpha(image_only_indicator)
alpha = self.get_alpha(image_only_indicator, x_spatial.device)
x = (
alpha.to(x_spatial.dtype) * x_spatial
+ (1.0 - alpha).to(x_spatial.dtype) * x_temporal

View File

@ -372,7 +372,6 @@ class SVD_img2vid(BaseModel):
if "time_conditioning" in kwargs:
out["time_context"] = comfy.conds.CONDCrossAttn(kwargs["time_conditioning"])
out['image_only_indicator'] = comfy.conds.CONDConstant(torch.zeros((1,), device=device))
out['num_video_frames'] = comfy.conds.CONDConstant(noise.shape[0])
return out