Fix SVD lowvram mode.

This commit is contained in:
comfyanonymous 2023-12-24 07:06:59 -05:00
parent a252963f95
commit d0165d819a
2 changed files with 7 additions and 7 deletions

View File

@ -51,9 +51,9 @@ class AlphaBlender(nn.Module):
if self.merge_strategy == "fixed": if self.merge_strategy == "fixed":
# make shape compatible # make shape compatible
# alpha = repeat(self.mix_factor, '1 -> b () t () ()', t=t, b=bs) # alpha = repeat(self.mix_factor, '1 -> b () t () ()', t=t, b=bs)
alpha = self.mix_factor alpha = self.mix_factor.to(image_only_indicator.device)
elif self.merge_strategy == "learned": elif self.merge_strategy == "learned":
alpha = torch.sigmoid(self.mix_factor) alpha = torch.sigmoid(self.mix_factor.to(image_only_indicator.device))
# make shape compatible # make shape compatible
# alpha = repeat(alpha, '1 -> s () ()', s = t * bs) # alpha = repeat(alpha, '1 -> s () ()', s = t * bs)
elif self.merge_strategy == "learned_with_images": elif self.merge_strategy == "learned_with_images":
@ -61,7 +61,7 @@ class AlphaBlender(nn.Module):
alpha = torch.where( alpha = torch.where(
image_only_indicator.bool(), image_only_indicator.bool(),
torch.ones(1, 1, device=image_only_indicator.device), torch.ones(1, 1, device=image_only_indicator.device),
rearrange(torch.sigmoid(self.mix_factor), "... -> ... 1"), rearrange(torch.sigmoid(self.mix_factor.to(image_only_indicator.device)), "... -> ... 1"),
) )
alpha = rearrange(alpha, self.rearrange_pattern) alpha = rearrange(alpha, self.rearrange_pattern)
# make shape compatible # make shape compatible

View File

@ -82,14 +82,14 @@ class VideoResBlock(ResnetBlock):
x = self.time_stack(x, temb) x = self.time_stack(x, temb)
alpha = self.get_alpha(bs=b // timesteps) alpha = self.get_alpha(bs=b // timesteps).to(x.device)
x = alpha * x + (1.0 - alpha) * x_mix x = alpha * x + (1.0 - alpha) * x_mix
x = rearrange(x, "b c t h w -> (b t) c h w") x = rearrange(x, "b c t h w -> (b t) c h w")
return x return x
class AE3DConv(torch.nn.Conv2d): class AE3DConv(ops.Conv2d):
def __init__(self, in_channels, out_channels, video_kernel_size=3, *args, **kwargs): def __init__(self, in_channels, out_channels, video_kernel_size=3, *args, **kwargs):
super().__init__(in_channels, out_channels, *args, **kwargs) super().__init__(in_channels, out_channels, *args, **kwargs)
if isinstance(video_kernel_size, Iterable): if isinstance(video_kernel_size, Iterable):
@ -97,7 +97,7 @@ class AE3DConv(torch.nn.Conv2d):
else: else:
padding = int(video_kernel_size // 2) padding = int(video_kernel_size // 2)
self.time_mix_conv = torch.nn.Conv3d( self.time_mix_conv = ops.Conv3d(
in_channels=out_channels, in_channels=out_channels,
out_channels=out_channels, out_channels=out_channels,
kernel_size=video_kernel_size, kernel_size=video_kernel_size,
@ -167,7 +167,7 @@ class AttnVideoBlock(AttnBlock):
emb = emb[:, None, :] emb = emb[:, None, :]
x_mix = x_mix + emb x_mix = x_mix + emb
alpha = self.get_alpha() alpha = self.get_alpha().to(x.device)
x_mix = self.time_mix_block(x_mix, timesteps=timesteps) x_mix = self.time_mix_block(x_mix, timesteps=timesteps)
x = alpha * x + (1.0 - alpha) * x_mix # alpha merge x = alpha * x + (1.0 - alpha) * x_mix # alpha merge