mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-02-28 22:51:45 +00:00
Lower cosmos VAE memory usage by a bit.
This commit is contained in:
parent
008761166f
commit
4758fb64b9
@ -864,18 +864,16 @@ class EncoderFactorized(nn.Module):
|
|||||||
x = self.patcher3d(x)
|
x = self.patcher3d(x)
|
||||||
|
|
||||||
# downsampling
|
# downsampling
|
||||||
hs = [self.conv_in(x)]
|
h = self.conv_in(x)
|
||||||
for i_level in range(self.num_resolutions):
|
for i_level in range(self.num_resolutions):
|
||||||
for i_block in range(self.num_res_blocks):
|
for i_block in range(self.num_res_blocks):
|
||||||
h = self.down[i_level].block[i_block](hs[-1])
|
h = self.down[i_level].block[i_block](h)
|
||||||
if len(self.down[i_level].attn) > 0:
|
if len(self.down[i_level].attn) > 0:
|
||||||
h = self.down[i_level].attn[i_block](h)
|
h = self.down[i_level].attn[i_block](h)
|
||||||
hs.append(h)
|
|
||||||
if i_level != self.num_resolutions - 1:
|
if i_level != self.num_resolutions - 1:
|
||||||
hs.append(self.down[i_level].downsample(hs[-1]))
|
h = self.down[i_level].downsample(h)
|
||||||
|
|
||||||
# middle
|
# middle
|
||||||
h = hs[-1]
|
|
||||||
h = self.mid.block_1(h)
|
h = self.mid.block_1(h)
|
||||||
h = self.mid.attn_1(h)
|
h = self.mid.attn_1(h)
|
||||||
h = self.mid.block_2(h)
|
h = self.mid.block_2(h)
|
||||||
|
@ -281,54 +281,76 @@ class UnPatcher3D(UnPatcher):
|
|||||||
hh = hh.to(dtype=dtype)
|
hh = hh.to(dtype=dtype)
|
||||||
|
|
||||||
xlll, xllh, xlhl, xlhh, xhll, xhlh, xhhl, xhhh = torch.chunk(x, 8, dim=1)
|
xlll, xllh, xlhl, xlhh, xhll, xhlh, xhhl, xhhh = torch.chunk(x, 8, dim=1)
|
||||||
|
del x
|
||||||
|
|
||||||
# Height height transposed convolutions.
|
# Height height transposed convolutions.
|
||||||
xll = F.conv_transpose3d(
|
xll = F.conv_transpose3d(
|
||||||
xlll, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)
|
xlll, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)
|
||||||
)
|
)
|
||||||
|
del xlll
|
||||||
|
|
||||||
xll += F.conv_transpose3d(
|
xll += F.conv_transpose3d(
|
||||||
xllh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)
|
xllh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)
|
||||||
)
|
)
|
||||||
|
del xllh
|
||||||
|
|
||||||
xlh = F.conv_transpose3d(
|
xlh = F.conv_transpose3d(
|
||||||
xlhl, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)
|
xlhl, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)
|
||||||
)
|
)
|
||||||
|
del xlhl
|
||||||
|
|
||||||
xlh += F.conv_transpose3d(
|
xlh += F.conv_transpose3d(
|
||||||
xlhh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)
|
xlhh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)
|
||||||
)
|
)
|
||||||
|
del xlhh
|
||||||
|
|
||||||
xhl = F.conv_transpose3d(
|
xhl = F.conv_transpose3d(
|
||||||
xhll, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)
|
xhll, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)
|
||||||
)
|
)
|
||||||
|
del xhll
|
||||||
|
|
||||||
xhl += F.conv_transpose3d(
|
xhl += F.conv_transpose3d(
|
||||||
xhlh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)
|
xhlh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)
|
||||||
)
|
)
|
||||||
|
del xhlh
|
||||||
|
|
||||||
xhh = F.conv_transpose3d(
|
xhh = F.conv_transpose3d(
|
||||||
xhhl, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)
|
xhhl, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)
|
||||||
)
|
)
|
||||||
|
del xhhl
|
||||||
|
|
||||||
xhh += F.conv_transpose3d(
|
xhh += F.conv_transpose3d(
|
||||||
xhhh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)
|
xhhh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)
|
||||||
)
|
)
|
||||||
|
del xhhh
|
||||||
|
|
||||||
# Handles width transposed convolutions.
|
# Handles width transposed convolutions.
|
||||||
xl = F.conv_transpose3d(
|
xl = F.conv_transpose3d(
|
||||||
xll, hl.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1)
|
xll, hl.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1)
|
||||||
)
|
)
|
||||||
|
del xll
|
||||||
|
|
||||||
xl += F.conv_transpose3d(
|
xl += F.conv_transpose3d(
|
||||||
xlh, hh.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1)
|
xlh, hh.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1)
|
||||||
)
|
)
|
||||||
|
del xlh
|
||||||
|
|
||||||
xh = F.conv_transpose3d(
|
xh = F.conv_transpose3d(
|
||||||
xhl, hl.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1)
|
xhl, hl.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1)
|
||||||
)
|
)
|
||||||
|
del xhl
|
||||||
|
|
||||||
xh += F.conv_transpose3d(
|
xh += F.conv_transpose3d(
|
||||||
xhh, hh.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1)
|
xhh, hh.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1)
|
||||||
)
|
)
|
||||||
|
del xhh
|
||||||
|
|
||||||
# Handles time axis transposed convolutions.
|
# Handles time axis transposed convolutions.
|
||||||
x = F.conv_transpose3d(
|
x = F.conv_transpose3d(
|
||||||
xl, hl.unsqueeze(3).unsqueeze(4), groups=g, stride=(2, 1, 1)
|
xl, hl.unsqueeze(3).unsqueeze(4), groups=g, stride=(2, 1, 1)
|
||||||
)
|
)
|
||||||
|
del xl
|
||||||
|
|
||||||
x += F.conv_transpose3d(
|
x += F.conv_transpose3d(
|
||||||
xh, hh.unsqueeze(3).unsqueeze(4), groups=g, stride=(2, 1, 1)
|
xh, hh.unsqueeze(3).unsqueeze(4), groups=g, stride=(2, 1, 1)
|
||||||
)
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user