mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-03-14 13:17:32 +00:00
Allow image_only_indicator to be None.
This commit is contained in:
parent
b3e97fc714
commit
cb7c3a2921
@ -484,7 +484,6 @@ class UNetModel(nn.Module):
|
|||||||
self.predict_codebook_ids = n_embed is not None
|
self.predict_codebook_ids = n_embed is not None
|
||||||
|
|
||||||
self.default_num_video_frames = None
|
self.default_num_video_frames = None
|
||||||
self.default_image_only_indicator = None
|
|
||||||
|
|
||||||
time_embed_dim = model_channels * 4
|
time_embed_dim = model_channels * 4
|
||||||
self.time_embed = nn.Sequential(
|
self.time_embed = nn.Sequential(
|
||||||
@ -830,7 +829,7 @@ class UNetModel(nn.Module):
|
|||||||
transformer_patches = transformer_options.get("patches", {})
|
transformer_patches = transformer_options.get("patches", {})
|
||||||
|
|
||||||
num_video_frames = kwargs.get("num_video_frames", self.default_num_video_frames)
|
num_video_frames = kwargs.get("num_video_frames", self.default_num_video_frames)
|
||||||
image_only_indicator = kwargs.get("image_only_indicator", self.default_image_only_indicator)
|
image_only_indicator = kwargs.get("image_only_indicator", None)
|
||||||
time_context = kwargs.get("time_context", None)
|
time_context = kwargs.get("time_context", None)
|
||||||
|
|
||||||
assert (y is not None) == (
|
assert (y is not None) == (
|
||||||
|
@ -46,18 +46,20 @@ class AlphaBlender(nn.Module):
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f"unknown merge strategy {self.merge_strategy}")
|
raise ValueError(f"unknown merge strategy {self.merge_strategy}")
|
||||||
|
|
||||||
def get_alpha(self, image_only_indicator: torch.Tensor) -> torch.Tensor:
|
def get_alpha(self, image_only_indicator: torch.Tensor, device) -> torch.Tensor:
|
||||||
# skip_time_mix = rearrange(repeat(skip_time_mix, 'b -> (b t) () () ()', t=t), '(b t) 1 ... -> b 1 t ...', t=t)
|
# skip_time_mix = rearrange(repeat(skip_time_mix, 'b -> (b t) () () ()', t=t), '(b t) 1 ... -> b 1 t ...', t=t)
|
||||||
if self.merge_strategy == "fixed":
|
if self.merge_strategy == "fixed":
|
||||||
# make shape compatible
|
# make shape compatible
|
||||||
# alpha = repeat(self.mix_factor, '1 -> b () t () ()', t=t, b=bs)
|
# alpha = repeat(self.mix_factor, '1 -> b () t () ()', t=t, b=bs)
|
||||||
alpha = self.mix_factor.to(image_only_indicator.device)
|
alpha = self.mix_factor.to(device)
|
||||||
elif self.merge_strategy == "learned":
|
elif self.merge_strategy == "learned":
|
||||||
alpha = torch.sigmoid(self.mix_factor.to(image_only_indicator.device))
|
alpha = torch.sigmoid(self.mix_factor.to(device))
|
||||||
# make shape compatible
|
# make shape compatible
|
||||||
# alpha = repeat(alpha, '1 -> s () ()', s = t * bs)
|
# alpha = repeat(alpha, '1 -> s () ()', s = t * bs)
|
||||||
elif self.merge_strategy == "learned_with_images":
|
elif self.merge_strategy == "learned_with_images":
|
||||||
assert image_only_indicator is not None, "need image_only_indicator ..."
|
if image_only_indicator is None:
|
||||||
|
alpha = rearrange(torch.sigmoid(self.mix_factor.to(device)), "... -> ... 1")
|
||||||
|
else:
|
||||||
alpha = torch.where(
|
alpha = torch.where(
|
||||||
image_only_indicator.bool(),
|
image_only_indicator.bool(),
|
||||||
torch.ones(1, 1, device=image_only_indicator.device),
|
torch.ones(1, 1, device=image_only_indicator.device),
|
||||||
@ -76,7 +78,7 @@ class AlphaBlender(nn.Module):
|
|||||||
x_temporal,
|
x_temporal,
|
||||||
image_only_indicator=None,
|
image_only_indicator=None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
alpha = self.get_alpha(image_only_indicator)
|
alpha = self.get_alpha(image_only_indicator, x_spatial.device)
|
||||||
x = (
|
x = (
|
||||||
alpha.to(x_spatial.dtype) * x_spatial
|
alpha.to(x_spatial.dtype) * x_spatial
|
||||||
+ (1.0 - alpha).to(x_spatial.dtype) * x_temporal
|
+ (1.0 - alpha).to(x_spatial.dtype) * x_temporal
|
||||||
|
@ -372,7 +372,6 @@ class SVD_img2vid(BaseModel):
|
|||||||
if "time_conditioning" in kwargs:
|
if "time_conditioning" in kwargs:
|
||||||
out["time_context"] = comfy.conds.CONDCrossAttn(kwargs["time_conditioning"])
|
out["time_context"] = comfy.conds.CONDCrossAttn(kwargs["time_conditioning"])
|
||||||
|
|
||||||
out['image_only_indicator'] = comfy.conds.CONDConstant(torch.zeros((1,), device=device))
|
|
||||||
out['num_video_frames'] = comfy.conds.CONDConstant(noise.shape[0])
|
out['num_video_frames'] = comfy.conds.CONDConstant(noise.shape[0])
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user