mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-04-20 11:23:29 +00:00
Fix SVD lowvram mode.
This commit is contained in:
parent
a252963f95
commit
d0165d819a
@ -51,9 +51,9 @@ class AlphaBlender(nn.Module):
|
|||||||
if self.merge_strategy == "fixed":
|
if self.merge_strategy == "fixed":
|
||||||
# make shape compatible
|
# make shape compatible
|
||||||
# alpha = repeat(self.mix_factor, '1 -> b () t () ()', t=t, b=bs)
|
# alpha = repeat(self.mix_factor, '1 -> b () t () ()', t=t, b=bs)
|
||||||
alpha = self.mix_factor
|
alpha = self.mix_factor.to(image_only_indicator.device)
|
||||||
elif self.merge_strategy == "learned":
|
elif self.merge_strategy == "learned":
|
||||||
alpha = torch.sigmoid(self.mix_factor)
|
alpha = torch.sigmoid(self.mix_factor.to(image_only_indicator.device))
|
||||||
# make shape compatible
|
# make shape compatible
|
||||||
# alpha = repeat(alpha, '1 -> s () ()', s = t * bs)
|
# alpha = repeat(alpha, '1 -> s () ()', s = t * bs)
|
||||||
elif self.merge_strategy == "learned_with_images":
|
elif self.merge_strategy == "learned_with_images":
|
||||||
@ -61,7 +61,7 @@ class AlphaBlender(nn.Module):
|
|||||||
alpha = torch.where(
|
alpha = torch.where(
|
||||||
image_only_indicator.bool(),
|
image_only_indicator.bool(),
|
||||||
torch.ones(1, 1, device=image_only_indicator.device),
|
torch.ones(1, 1, device=image_only_indicator.device),
|
||||||
rearrange(torch.sigmoid(self.mix_factor), "... -> ... 1"),
|
rearrange(torch.sigmoid(self.mix_factor.to(image_only_indicator.device)), "... -> ... 1"),
|
||||||
)
|
)
|
||||||
alpha = rearrange(alpha, self.rearrange_pattern)
|
alpha = rearrange(alpha, self.rearrange_pattern)
|
||||||
# make shape compatible
|
# make shape compatible
|
||||||
|
@ -82,14 +82,14 @@ class VideoResBlock(ResnetBlock):
|
|||||||
|
|
||||||
x = self.time_stack(x, temb)
|
x = self.time_stack(x, temb)
|
||||||
|
|
||||||
alpha = self.get_alpha(bs=b // timesteps)
|
alpha = self.get_alpha(bs=b // timesteps).to(x.device)
|
||||||
x = alpha * x + (1.0 - alpha) * x_mix
|
x = alpha * x + (1.0 - alpha) * x_mix
|
||||||
|
|
||||||
x = rearrange(x, "b c t h w -> (b t) c h w")
|
x = rearrange(x, "b c t h w -> (b t) c h w")
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
class AE3DConv(torch.nn.Conv2d):
|
class AE3DConv(ops.Conv2d):
|
||||||
def __init__(self, in_channels, out_channels, video_kernel_size=3, *args, **kwargs):
|
def __init__(self, in_channels, out_channels, video_kernel_size=3, *args, **kwargs):
|
||||||
super().__init__(in_channels, out_channels, *args, **kwargs)
|
super().__init__(in_channels, out_channels, *args, **kwargs)
|
||||||
if isinstance(video_kernel_size, Iterable):
|
if isinstance(video_kernel_size, Iterable):
|
||||||
@ -97,7 +97,7 @@ class AE3DConv(torch.nn.Conv2d):
|
|||||||
else:
|
else:
|
||||||
padding = int(video_kernel_size // 2)
|
padding = int(video_kernel_size // 2)
|
||||||
|
|
||||||
self.time_mix_conv = torch.nn.Conv3d(
|
self.time_mix_conv = ops.Conv3d(
|
||||||
in_channels=out_channels,
|
in_channels=out_channels,
|
||||||
out_channels=out_channels,
|
out_channels=out_channels,
|
||||||
kernel_size=video_kernel_size,
|
kernel_size=video_kernel_size,
|
||||||
@ -167,7 +167,7 @@ class AttnVideoBlock(AttnBlock):
|
|||||||
emb = emb[:, None, :]
|
emb = emb[:, None, :]
|
||||||
x_mix = x_mix + emb
|
x_mix = x_mix + emb
|
||||||
|
|
||||||
alpha = self.get_alpha()
|
alpha = self.get_alpha().to(x.device)
|
||||||
x_mix = self.time_mix_block(x_mix, timesteps=timesteps)
|
x_mix = self.time_mix_block(x_mix, timesteps=timesteps)
|
||||||
x = alpha * x + (1.0 - alpha) * x_mix # alpha merge
|
x = alpha * x + (1.0 - alpha) * x_mix # alpha merge
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user