Lower cosmos VAE memory usage by a bit.

This commit is contained in:
comfyanonymous 2025-01-15 22:57:52 -05:00
parent 008761166f
commit 4758fb64b9
2 changed files with 25 additions and 5 deletions

View File

@ -864,18 +864,16 @@ class EncoderFactorized(nn.Module):
x = self.patcher3d(x) x = self.patcher3d(x)
# downsampling # downsampling
hs = [self.conv_in(x)] h = self.conv_in(x)
for i_level in range(self.num_resolutions): for i_level in range(self.num_resolutions):
for i_block in range(self.num_res_blocks): for i_block in range(self.num_res_blocks):
h = self.down[i_level].block[i_block](hs[-1]) h = self.down[i_level].block[i_block](h)
if len(self.down[i_level].attn) > 0: if len(self.down[i_level].attn) > 0:
h = self.down[i_level].attn[i_block](h) h = self.down[i_level].attn[i_block](h)
hs.append(h)
if i_level != self.num_resolutions - 1: if i_level != self.num_resolutions - 1:
hs.append(self.down[i_level].downsample(hs[-1])) h = self.down[i_level].downsample(h)
# middle # middle
h = hs[-1]
h = self.mid.block_1(h) h = self.mid.block_1(h)
h = self.mid.attn_1(h) h = self.mid.attn_1(h)
h = self.mid.block_2(h) h = self.mid.block_2(h)

View File

@ -281,54 +281,76 @@ class UnPatcher3D(UnPatcher):
hh = hh.to(dtype=dtype) hh = hh.to(dtype=dtype)
xlll, xllh, xlhl, xlhh, xhll, xhlh, xhhl, xhhh = torch.chunk(x, 8, dim=1) xlll, xllh, xlhl, xlhh, xhll, xhlh, xhhl, xhhh = torch.chunk(x, 8, dim=1)
del x
# Height height transposed convolutions. # Height height transposed convolutions.
xll = F.conv_transpose3d( xll = F.conv_transpose3d(
xlll, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2) xlll, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)
) )
del xlll
xll += F.conv_transpose3d( xll += F.conv_transpose3d(
xllh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2) xllh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)
) )
del xllh
xlh = F.conv_transpose3d( xlh = F.conv_transpose3d(
xlhl, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2) xlhl, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)
) )
del xlhl
xlh += F.conv_transpose3d( xlh += F.conv_transpose3d(
xlhh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2) xlhh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)
) )
del xlhh
xhl = F.conv_transpose3d( xhl = F.conv_transpose3d(
xhll, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2) xhll, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)
) )
del xhll
xhl += F.conv_transpose3d( xhl += F.conv_transpose3d(
xhlh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2) xhlh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)
) )
del xhlh
xhh = F.conv_transpose3d( xhh = F.conv_transpose3d(
xhhl, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2) xhhl, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)
) )
del xhhl
xhh += F.conv_transpose3d( xhh += F.conv_transpose3d(
xhhh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2) xhhh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)
) )
del xhhh
# Handles width transposed convolutions. # Handles width transposed convolutions.
xl = F.conv_transpose3d( xl = F.conv_transpose3d(
xll, hl.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1) xll, hl.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1)
) )
del xll
xl += F.conv_transpose3d( xl += F.conv_transpose3d(
xlh, hh.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1) xlh, hh.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1)
) )
del xlh
xh = F.conv_transpose3d( xh = F.conv_transpose3d(
xhl, hl.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1) xhl, hl.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1)
) )
del xhl
xh += F.conv_transpose3d( xh += F.conv_transpose3d(
xhh, hh.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1) xhh, hh.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1)
) )
del xhh
# Handles time axis transposed convolutions. # Handles time axis transposed convolutions.
x = F.conv_transpose3d( x = F.conv_transpose3d(
xl, hl.unsqueeze(3).unsqueeze(4), groups=g, stride=(2, 1, 1) xl, hl.unsqueeze(3).unsqueeze(4), groups=g, stride=(2, 1, 1)
) )
del xl
x += F.conv_transpose3d( x += F.conv_transpose3d(
xh, hh.unsqueeze(3).unsqueeze(4), groups=g, stride=(2, 1, 1) xh, hh.unsqueeze(3).unsqueeze(4), groups=g, stride=(2, 1, 1)
) )