mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-02-28 22:51:45 +00:00
Fix cosmos VAE failing with videos longer than 121 frames.
This commit is contained in:
parent
619b8cde74
commit
cca96a85ae
@ -18,6 +18,7 @@ import logging
|
|||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
import math
|
||||||
|
|
||||||
from .cosmos_tokenizer.layers3d import (
|
from .cosmos_tokenizer.layers3d import (
|
||||||
EncoderFactorized,
|
EncoderFactorized,
|
||||||
@ -105,17 +106,23 @@ class CausalContinuousVideoTokenizer(nn.Module):
|
|||||||
z, posteriors = self.distribution(moments)
|
z, posteriors = self.distribution(moments)
|
||||||
latent_ch = z.shape[1]
|
latent_ch = z.shape[1]
|
||||||
latent_t = z.shape[2]
|
latent_t = z.shape[2]
|
||||||
dtype = z.dtype
|
in_dtype = z.dtype
|
||||||
mean = self.latent_mean.view(latent_ch, -1)[:, : latent_t].reshape([1, latent_ch, -1, 1, 1]).to(dtype=dtype, device=z.device)
|
mean = self.latent_mean.view(latent_ch, -1)
|
||||||
std = self.latent_std.view(latent_ch, -1)[:, : latent_t].reshape([1, latent_ch, -1, 1, 1]).to(dtype=dtype, device=z.device)
|
std = self.latent_std.view(latent_ch, -1)
|
||||||
|
|
||||||
|
mean = mean.repeat(1, math.ceil(latent_t / mean.shape[-1]))[:, : latent_t].reshape([1, latent_ch, -1, 1, 1]).to(dtype=in_dtype, device=z.device)
|
||||||
|
std = std.repeat(1, math.ceil(latent_t / std.shape[-1]))[:, : latent_t].reshape([1, latent_ch, -1, 1, 1]).to(dtype=in_dtype, device=z.device)
|
||||||
return ((z - mean) / std) * self.sigma_data
|
return ((z - mean) / std) * self.sigma_data
|
||||||
|
|
||||||
def decode(self, z):
|
def decode(self, z):
|
||||||
in_dtype = z.dtype
|
in_dtype = z.dtype
|
||||||
latent_ch = z.shape[1]
|
latent_ch = z.shape[1]
|
||||||
latent_t = z.shape[2]
|
latent_t = z.shape[2]
|
||||||
mean = self.latent_mean.view(latent_ch, -1)[:, : latent_t].reshape([1, latent_ch, -1, 1, 1]).to(dtype=in_dtype, device=z.device)
|
mean = self.latent_mean.view(latent_ch, -1)
|
||||||
std = self.latent_std.view(latent_ch, -1)[:, : latent_t].reshape([1, latent_ch, -1, 1, 1]).to(dtype=in_dtype, device=z.device)
|
std = self.latent_std.view(latent_ch, -1)
|
||||||
|
|
||||||
|
mean = mean.repeat(1, math.ceil(latent_t / mean.shape[-1]))[:, : latent_t].reshape([1, latent_ch, -1, 1, 1]).to(dtype=in_dtype, device=z.device)
|
||||||
|
std = std.repeat(1, math.ceil(latent_t / std.shape[-1]))[:, : latent_t].reshape([1, latent_ch, -1, 1, 1]).to(dtype=in_dtype, device=z.device)
|
||||||
|
|
||||||
z = z / self.sigma_data
|
z = z / self.sigma_data
|
||||||
z = z * std + mean
|
z = z * std + mean
|
||||||
|
Loading…
Reference in New Issue
Block a user