diff --git a/comfy/ldm/cosmos/cosmos_tokenizer/layers3d.py b/comfy/ldm/cosmos/cosmos_tokenizer/layers3d.py index 7d864a75..9a3ebed6 100644 --- a/comfy/ldm/cosmos/cosmos_tokenizer/layers3d.py +++ b/comfy/ldm/cosmos/cosmos_tokenizer/layers3d.py @@ -864,18 +864,16 @@ class EncoderFactorized(nn.Module): x = self.patcher3d(x) # downsampling - hs = [self.conv_in(x)] + h = self.conv_in(x) for i_level in range(self.num_resolutions): for i_block in range(self.num_res_blocks): - h = self.down[i_level].block[i_block](hs[-1]) + h = self.down[i_level].block[i_block](h) if len(self.down[i_level].attn) > 0: h = self.down[i_level].attn[i_block](h) - hs.append(h) if i_level != self.num_resolutions - 1: - hs.append(self.down[i_level].downsample(hs[-1])) + h = self.down[i_level].downsample(h) # middle - h = hs[-1] h = self.mid.block_1(h) h = self.mid.attn_1(h) h = self.mid.block_2(h) diff --git a/comfy/ldm/cosmos/cosmos_tokenizer/patching.py b/comfy/ldm/cosmos/cosmos_tokenizer/patching.py index 793f0da8..87a53a1d 100644 --- a/comfy/ldm/cosmos/cosmos_tokenizer/patching.py +++ b/comfy/ldm/cosmos/cosmos_tokenizer/patching.py @@ -281,54 +281,76 @@ class UnPatcher3D(UnPatcher): hh = hh.to(dtype=dtype) xlll, xllh, xlhl, xlhh, xhll, xhlh, xhhl, xhhh = torch.chunk(x, 8, dim=1) + del x # Height height transposed convolutions. xll = F.conv_transpose3d( xlll, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2) ) + del xlll + xll += F.conv_transpose3d( xllh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2) ) + del xllh xlh = F.conv_transpose3d( xlhl, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2) ) + del xlhl + xlh += F.conv_transpose3d( xlhh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2) ) + del xlhh xhl = F.conv_transpose3d( xhll, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2) ) + del xhll + xhl += F.conv_transpose3d( xhlh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2) ) + del xhlh xhh = F.conv_transpose3d( xhhl, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2) ) + del xhhl + xhh += F.conv_transpose3d( xhhh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2) ) + del xhhh # Handles width transposed convolutions. xl = F.conv_transpose3d( xll, hl.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1) ) + del xll + xl += F.conv_transpose3d( xlh, hh.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1) ) + del xlh + xh = F.conv_transpose3d( xhl, hl.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1) ) + del xhl + xh += F.conv_transpose3d( xhh, hh.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1) ) + del xhh # Handles time axis transposed convolutions. x = F.conv_transpose3d( xl, hl.unsqueeze(3).unsqueeze(4), groups=g, stride=(2, 1, 1) ) + del xl + x += F.conv_transpose3d( xh, hh.unsqueeze(3).unsqueeze(4), groups=g, stride=(2, 1, 1) )