diff --git a/comfy/ldm/cosmos/vae.py b/comfy/ldm/cosmos/vae.py index c8db6861..d64f292d 100644 --- a/comfy/ldm/cosmos/vae.py +++ b/comfy/ldm/cosmos/vae.py @@ -18,6 +18,7 @@ import logging import torch from torch import nn from enum import Enum +import math from .cosmos_tokenizer.layers3d import ( EncoderFactorized, @@ -105,17 +106,23 @@ class CausalContinuousVideoTokenizer(nn.Module): z, posteriors = self.distribution(moments) latent_ch = z.shape[1] latent_t = z.shape[2] - dtype = z.dtype - mean = self.latent_mean.view(latent_ch, -1)[:, : latent_t].reshape([1, latent_ch, -1, 1, 1]).to(dtype=dtype, device=z.device) - std = self.latent_std.view(latent_ch, -1)[:, : latent_t].reshape([1, latent_ch, -1, 1, 1]).to(dtype=dtype, device=z.device) + in_dtype = z.dtype + mean = self.latent_mean.view(latent_ch, -1) + std = self.latent_std.view(latent_ch, -1) + + mean = mean.repeat(1, math.ceil(latent_t / mean.shape[-1]))[:, : latent_t].reshape([1, latent_ch, -1, 1, 1]).to(dtype=in_dtype, device=z.device) + std = std.repeat(1, math.ceil(latent_t / std.shape[-1]))[:, : latent_t].reshape([1, latent_ch, -1, 1, 1]).to(dtype=in_dtype, device=z.device) return ((z - mean) / std) * self.sigma_data def decode(self, z): in_dtype = z.dtype latent_ch = z.shape[1] latent_t = z.shape[2] - mean = self.latent_mean.view(latent_ch, -1)[:, : latent_t].reshape([1, latent_ch, -1, 1, 1]).to(dtype=in_dtype, device=z.device) - std = self.latent_std.view(latent_ch, -1)[:, : latent_t].reshape([1, latent_ch, -1, 1, 1]).to(dtype=in_dtype, device=z.device) + mean = self.latent_mean.view(latent_ch, -1) + std = self.latent_std.view(latent_ch, -1) + + mean = mean.repeat(1, math.ceil(latent_t / mean.shape[-1]))[:, : latent_t].reshape([1, latent_ch, -1, 1, 1]).to(dtype=in_dtype, device=z.device) + std = std.repeat(1, math.ceil(latent_t / std.shape[-1]))[:, : latent_t].reshape([1, latent_ch, -1, 1, 1]).to(dtype=in_dtype, device=z.device) z = z / self.sigma_data z = z * std + mean